CNTK目前只有x64版本,如果在电子表格中直接调用com组件,在进程内调用x64架构的dll会导致错误.
所以编写进程外com组件.配合vba实现在电子表格内调用,使用cntk进行ai训练.
经过调试,已经可以正常运作.
使用了官方示例中的训练数据(用excel生成),为了达到重复训练的效果,每次训练完成后存储trainer的checkpoint文件,下次训练加载t文件继续进行训练,也可以一次性分批次训练.
cntk可以保存trainer,然后下次训练加载
var trainer = Trainer.CreateTrainer(classifierOutput, loss, evalError, parameterLearners); // 之前已经训练的数据条数 int trainedNumb1 = 0; // 最新训练的数字条数 int trainedNumb2 = 0; // 如果存在trainer 就加载 if (File.Exists(trainer_filename.ToString())) { var trasiner = trainer.RestoreFromCheckpoint(trainer_filename.ToString()); trainedNumb1 = (int)trainer.TotalNumberOfSamplesSeen(); //trainedNumb1 = (int)trasiner.Size(); result = $"Trainer Restore From Checkpoint:{trainer_filename.ToString()}" + "\r\n" + result; } trainedNumb2 = trainedNumb1 + features.Count;也可以保存model,训练完成后直接加载model进行评估.
string modelname = model_filename + ".model"; if (File.Exists(modelname)) { Function Model = Function.Load(modelname, DeviceDescriptor.CPUDevice, ModelFormat.CNTKv2); var featureVariable = Model.Arguments[0]; var labelVariable = Model.Output; // 变换成为矩阵 var matrix_features = Double2Matrix(o2d(features)); // 输入的维度 int inputDim = matrix_features.ColumnCount; int numOutputClasses = labelVariable.Shape.TotalSize; var matrix = Double2Matrix(o2d(features)); float[] feature_input_array = Matrix2Array(matrix); var featureinput = Value.CreateBatch<float>(new int[] { inputDim }, feature_input_array, DeviceDescriptor.CPUDevice); var inputDataMap = new Dictionary<Variable, Value>() { { featureVariable, featureinput } }; var outputDataMap = new Dictionary<Variable, Value>() { { Model.Output, null } }; Model.Evaluate(inputDataMap, outputDataMap, DeviceDescriptor.CPUDevice); var outputValue = outputDataMap[Model.Output]; IList<IList<float>> actualLabelSoftMax = outputValue.GetDenseData<float>(Model.Output); var resultMat = Matrix<double>.Build.Dense(actualLabelSoftMax.Count, actualLabelSoftMax[0].Count); for (int ii = 0; ii < actualLabelSoftMax.Count; ii++) { for (int jj = 0; jj < actualLabelSoftMax[ii].Count; jj++) { resultMat[ii, jj] = actualLabelSoftMax[ii][jj]; } } return resultMat.ToArray(); } else { return "no model file found"; //classifierOutput = CreateModel(featureVariable, labelsClasses, DeviceDescriptor.CPUDevice); }
另外保存trainer的时候还会产生.ckp文件(checkpoint),这三个文件通过进程外dll保存在excel同文件夹的目录下面.
参考:
实例代码参考:
https://hrnjica.net/2017/11/26/how-to-save-cntk-model-to-file-in-c/