对DeepLearning最初的印象是,大量的训练样本+机器学习,也就是说原来传统的机器学习会遇到的问题,不能解决的问题,换成DeepLearning同样解决不了。比如目标识别中因为光照变化,目标被遮挡,目标的几何变化造成的识别率大幅下降,在DeepLearning中同样也不能很好解决。但是不是说DeepLearning就一无事处,最近几年这么热也决不是因为名字取得好。DeepLearning比较明显的优势就是在特征选择上,想想之前做生物特征识别时,各种找特征,还得考虑什么光照不变,旋转不变,抗尺寸变换,抗遮挡,那叫一个累呀。现在可好啦,一个Convolution Layer,再配Fully Connected Layer,最后来个Softmax,丢一堆带标签的样本进去自动给你找出特征。当然这个只是一个接触DeepLearning不到一个月的小白的肤浅认识,大家听听就好。
本文算是最近1个月学习DeepLearning的入门小作业,选择的例子也是DeepLearning最流行的HelloWorld程序MNIST手写数字识别,采用Caffe2进行训练,并在Android端实现一个Demo样例。
DeepLearning基础
这里推荐下台大李宏毅老师的DeepLearning课程,讲解风趣幽默,生动详细,力荐。B站链接
深度学习框架:caffe2
模型训练
整个模型训练过程主要包括数据准备、模型建立、模型训练,参考caffe2官网的tutorial
1. 数据准备
数据准备在MachineLearning类的应用中起到致关重要的作用,相当于煮饭的时候用到的米。再利害的算法,如果没有足够的数据,那也是巧妇难为无米之炊,难道马云会称现在是DT时代。另外DeepLearning作为MachineLearning的一个分支,目前了解到的大部分的DeepLearning算法更多的还是属于SuperviseLearning。SuperviseLearning一个明显的特征是非常依赖数据,而且是人工标注的数据,这也难怪一些DeepLearning大大们说在AI应用中,有多少人工就有多少智能。本文用到的数据链接地址:MNIST手写数字数据集
2. 模型建立
这个例子采用的是DeepLearning中的经典网络LeNet,关于LeNet可以参考这篇文章。
3. 模型训练
这里模型训练采用caffe2框架,
模型在Android端的部署
参考caffe2官网给的AICamera例子,建立Android Studio工程(github工程地址:https://github.com/lyapple2008/MNIST_CNN_APP ),其中最主要的代码如下所示
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
| void loadToNetDef(AAssetManager *mgr, caffe2::NetDef *net, const char *filename) { AAsset *asset = AAssetManager_open(mgr, filename, AASSET_MODE_BUFFER); assert(asset != nullptr); const void *data = AAsset_getBuffer(asset); assert(data != nullptr); off_t len = AAsset_getLength(asset); assert(len != 0); if (!net->ParseFromArray(data, len)) { alog("Couldn't parse net from data.\n"); } AAsset_close(asset); }
extern "C" void Java_com_example_beyoung_handwrittendigit_MainActivity_initCaffe2( JNIEnv *env, jobject, jobject assetManager) { AAssetManager *mgr = AAssetManager_fromJava(env, assetManager); alog("Attempting to load protobuf netdefs..."); loadToNetDef(mgr, &_initNet, "mnist/init_net.pb"); loadToNetDef(mgr, &_predictNet, "mnist/predict_net.pb"); alog("done."); alog("Instantiating predictor..."); _predictor = new caffe2::Predictor(_initNet, _predictNet); if (_predictor) { alog("done..."); } else { alog("fail to instantiat predictor..."); } }
extern "C" JNIEXPORT jstring JNICALL Java_com_example_beyoung_handwrittendigit_MainActivity_recognitionFromCaffe2( JNIEnv *env, jobject, jint h, jint w, jintArray data) { if (!_predictor) { return env->NewStringUTF("Loading..."); }
jsize len = env->GetArrayLength(data); jint *img_data = env->GetIntArrayElements(data, 0); jint img_size = h * w; assert(img_size <= INPUT_DATA_SIZE);
for (auto i = 0; i < h; ++i) { std::ostringstream stringStream; for (auto j = 0; j < w; ++j) { int color = img_data[i * w + j]; float grey = 0.0; if (color != 0) { grey = 1.0; } input_data[i * w + j] = grey; if (color != 0) { color = 1; } stringStream << color << " "; } alog("%s", stringStream.str().c_str()); }
caffe2::TensorCPU input; input.Resize(std::vector<int>({1, IMG_C, IMG_H, IMG_W})); memcpy(input.mutable_data<float>(), input_data, INPUT_DATA_SIZE * sizeof(float)); caffe2::Predictor::TensorVector input_vec{&input}; caffe2::Predictor::TensorVector output_vec; _predictor->run(input_vec, &output_vec);
constexpr int k = 3; float max[k] = {0}; int max_index[k] = {0}; if (output_vec.capacity() > 0) { for (auto output : output_vec) { for (auto i = 0; i < output->size(); ++i) { for (auto j = 0; j < k; ++j) { if (output->template data<float>()[i] > max[j]) { for (auto _j = k - 1; _j > j; --_j) { max[_j - 1] = max[_j]; max_index[_j - 1] = max_index[_j]; } max[j] = output->template data<float>()[i]; max_index[j] = i; goto skip; } } skip:; } } }
std::ostringstream stringStream; for (auto j = 0; j < k; ++j) { stringStream << max_index[j] << ": " << max[j]*100 << "%\n"; }
return env->NewStringUTF(stringStream.str().c_str()); }
|
总结
通过上面的Demo可以看出,通过MNIST数据训练出来的模型在实际运行的准确率还是很堪忧的。所以一个算法从实验室数据到实际应用还有很长的路要走,虽然最近应用于各个领域的深度学习模型层出不穷,测试数据也很好看,但是在实际应用过程中还有很多路要走。虽然DeepLearning已经表现出很强大的黑魔法属性,在实际应用过程中还是有很多工作要做,不然只能停留在Demo阶段。以本文的手写数字识别为例,实际过程的准确率与测试集上的准确率相差甚远,这时候就需要进行大量的优化工作。由于学习深度学习没多久,暂时只能根据以往在机器学习上的经验来进行优化,目前能想到的优化方向有:训练集与实际运行环境要一致、准备更多的训练集、深度另外的模型方法。