CNTK目前只有x64版本,如果在电子表格中直接调用com组件,在进程内调用x64架构的dll会导致错误.
所以编写进程外com组件.配合vba实现在电子表格内调用,使用cntk进行ai训练.
经过调试,已经可以正常运作.
使用了官方示例中的训练数据(用excel生成),为了达到重复训练的效果,每次训练完成后存储trainer的checkpoint文件,下次训练加载t文件继续进行训练,也可以一次性分批次训练.
cntk可以保存trainer,然后下次训练加载
var trainer = Trainer.CreateTrainer(classifierOutput, loss, evalError, parameterLearners);
// 之前已经训练的数据条数
int trainedNumb1 = 0;
// 最新训练的数字条数
int trainedNumb2 = 0;
// 如果存在trainer 就加载
if (File.Exists(trainer_filename.ToString()))
{
var trasiner = trainer.RestoreFromCheckpoint(trainer_filename.ToString());
trainedNumb1 = (int)trainer.TotalNumberOfSamplesSeen();
//trainedNumb1 = (int)trasiner.Size();
result = $"Trainer Restore From Checkpoint:{trainer_filename.ToString()}" + "\r\n" + result;
}
trainedNumb2 = trainedNumb1 + features.Count;
也可以保存model,训练完成后直接加载model进行评估.
string modelname = model_filename + ".model";
if (File.Exists(modelname))
{
Function Model = Function.Load(modelname, DeviceDescriptor.CPUDevice, ModelFormat.CNTKv2);
var featureVariable = Model.Arguments[0];
var labelVariable = Model.Output;
// 变换成为矩阵
var matrix_features = Double2Matrix(o2d(features));
// 输入的维度
int inputDim = matrix_features.ColumnCount;
int numOutputClasses = labelVariable.Shape.TotalSize;
var matrix = Double2Matrix(o2d(features));
float[] feature_input_array = Matrix2Array(matrix);
var featureinput = Value.CreateBatch<float>(new int[] { inputDim }, feature_input_array, DeviceDescriptor.CPUDevice);
var inputDataMap = new Dictionary<Variable, Value>() { { featureVariable, featureinput } };
var outputDataMap = new Dictionary<Variable, Value>() { { Model.Output, null } };
Model.Evaluate(inputDataMap, outputDataMap, DeviceDescriptor.CPUDevice);
var outputValue = outputDataMap[Model.Output];
IList<IList<float>> actualLabelSoftMax = outputValue.GetDenseData<float>(Model.Output);
var resultMat = Matrix<double>.Build.Dense(actualLabelSoftMax.Count, actualLabelSoftMax[0].Count);
for (int ii = 0; ii < actualLabelSoftMax.Count; ii++)
{
for (int jj = 0; jj < actualLabelSoftMax[ii].Count; jj++)
{
resultMat[ii, jj] = actualLabelSoftMax[ii][jj];
}
}
return resultMat.ToArray();
}
else
{
return "no model file found";
//classifierOutput = CreateModel(featureVariable, labelsClasses, DeviceDescriptor.CPUDevice);
}
另外保存trainer的时候还会产生.ckp文件(checkpoint),这三个文件通过进程外dll保存在excel同文件夹的目录下面.
参考:
实例代码参考:
https://hrnjica.net/2017/11/26/how-to-save-cntk-model-to-file-in-c/