Android端实现手写数字识别

  对DeepLearning最初的印象是,大量的训练样本+机器学习,也就是说原来传统的机器学习会遇到的问题,不能解决的问题,换成DeepLearning同样解决不了。比如目标识别中因为光照变化,目标被遮挡,目标的几何变化造成的识别率大幅下降,在DeepLearning中同样也不能很好解决。但是不是说DeepLearning就一无事处,最近几年这么热也决不是因为名字取得好。DeepLearning比较明显的优势就是在特征选择上,想想之前做生物特征识别时,各种找特征,还得考虑什么光照不变,旋转不变,抗尺寸变换,抗遮挡,那叫一个累呀。现在可好啦,一个Convolution Layer,再配Fully Connected Layer,最后来个Softmax,丢一堆带标签的样本进去自动给你找出特征。当然这个只是一个接触DeepLearning不到一个月的小白的肤浅认识,大家听听就好。

  本文算是最近1个月学习DeepLearning的入门小作业,选择的例子也是DeepLearning最流行的HelloWorld程序MNIST手写数字识别,采用Caffe2进行训练,并在Android端实现一个Demo样例。

DeepLearning基础

  这里推荐下台大李宏毅老师的DeepLearning课程,讲解风趣幽默,生动详细,力荐。B站链接
深度学习框架:caffe2

模型训练

整个模型训练过程主要包括数据准备、模型建立、模型训练,参考caffe2官网的tutorial

1. 数据准备

数据准备在MachineLearning类的应用中起到致关重要的作用,相当于煮饭的时候用到的米。再利害的算法,如果没有足够的数据,那也是巧妇难为无米之炊,难道马云会称现在是DT时代。另外DeepLearning作为MachineLearning的一个分支,目前了解到的大部分的DeepLearning算法更多的还是属于SuperviseLearning。SuperviseLearning一个明显的特征是非常依赖数据,而且是人工标注的数据,这也难怪一些DeepLearning大大们说在AI应用中,有多少人工就有多少智能。本文用到的数据链接地址:MNIST手写数字数据集

2. 模型建立

这个例子采用的是DeepLearning中的经典网络LeNet,关于LeNet可以参考这篇文章

3. 模型训练

这里模型训练采用caffe2框架,

模型在Android端的部署

参考caffe2官网给的AICamera例子,建立Android Studio工程(github工程地址:https://github.com/lyapple2008/MNIST_CNN_APP ),其中最主要的代码如下所示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
void loadToNetDef(AAssetManager *mgr, caffe2::NetDef *net, const char *filename) {
AAsset *asset = AAssetManager_open(mgr, filename, AASSET_MODE_BUFFER);
assert(asset != nullptr);
const void *data = AAsset_getBuffer(asset);
assert(data != nullptr);
off_t len = AAsset_getLength(asset);
assert(len != 0);
if (!net->ParseFromArray(data, len)) {
alog("Couldn't parse net from data.\n");
}
AAsset_close(asset);
}

extern "C"
void Java_com_example_beyoung_handwrittendigit_MainActivity_initCaffe2(
JNIEnv *env,
jobject,
jobject assetManager) {
AAssetManager *mgr = AAssetManager_fromJava(env, assetManager);
alog("Attempting to load protobuf netdefs...");
loadToNetDef(mgr, &_initNet, "mnist/init_net.pb");
loadToNetDef(mgr, &_predictNet, "mnist/predict_net.pb");
alog("done.");
alog("Instantiating predictor...");
_predictor = new caffe2::Predictor(_initNet, _predictNet);
if (_predictor) {
alog("done...");
} else {
alog("fail to instantiat predictor...");
}
}

extern "C"
JNIEXPORT jstring JNICALL
Java_com_example_beyoung_handwrittendigit_MainActivity_recognitionFromCaffe2(
JNIEnv *env,
jobject,
jint h, jint w, jintArray data) {
if (!_predictor) {
return env->NewStringUTF("Loading...");
}

jsize len = env->GetArrayLength(data);
jint *img_data = env->GetIntArrayElements(data, 0);
jint img_size = h * w;
assert(img_size <= INPUT_DATA_SIZE);

// convert rgb image to grey image and normalize to 0~1
for (auto i = 0; i < h; ++i) {
std::ostringstream stringStream;
for (auto j = 0; j < w; ++j) {
int color = img_data[i * w + j];
//int red = ((color & 0x00FF0000) >> 16);
//int green = ((color & 0x0000FF00) >> 8);
//int blue = color & 0x000000FF;
//float grey = red * 0.3 + green * 0.59 + blue * 0.11;
float grey = 0.0;
if (color != 0) {
grey = 1.0;
}
input_data[i * w + j] = grey;
//alog("%f", grey);
//alog("%d", color);
if (color != 0) {
color = 1;
}
stringStream << color << " ";
}
//alog("\n");
alog("%s", stringStream.str().c_str());
}

caffe2::TensorCPU input;
input.Resize(std::vector<int>({1, IMG_C, IMG_H, IMG_W}));
memcpy(input.mutable_data<float>(), input_data, INPUT_DATA_SIZE * sizeof(float));
caffe2::Predictor::TensorVector input_vec{&input};
caffe2::Predictor::TensorVector output_vec;
_predictor->run(input_vec, &output_vec);

constexpr int k = 3;
float max[k] = {0};
int max_index[k] = {0};
// Find the top-k result manually
if (output_vec.capacity() > 0) {
for (auto output : output_vec) {
for (auto i = 0; i < output->size(); ++i) {
for (auto j = 0; j < k; ++j) {
if (output->template data<float>()[i] > max[j]) {
for (auto _j = k - 1; _j > j; --_j) {
max[_j - 1] = max[_j];
max_index[_j - 1] = max_index[_j];
}
max[j] = output->template data<float>()[i];
max_index[j] = i;
goto skip;
}
}
skip:;
}
}
}

std::ostringstream stringStream;
for (auto j = 0; j < k; ++j) {
stringStream << max_index[j] << ": " << max[j]*100 << "%\n";
}

// if (output_vec.capacity() > 0) {
// for (auto output: output_vec) {
// for (auto i = 0;i<output->size();++i) {
// stringStream << output->template data<float>()[i] << "\n";
// }
// }
// }

return env->NewStringUTF(stringStream.str().c_str());
}

总结

  通过上面的Demo可以看出,通过MNIST数据训练出来的模型在实际运行的准确率还是很堪忧的。所以一个算法从实验室数据到实际应用还有很长的路要走,虽然最近应用于各个领域的深度学习模型层出不穷,测试数据也很好看,但是在实际应用过程中还有很多路要走。虽然DeepLearning已经表现出很强大的黑魔法属性,在实际应用过程中还是有很多工作要做,不然只能停留在Demo阶段。以本文的手写数字识别为例,实际过程的准确率与测试集上的准确率相差甚远,这时候就需要进行大量的优化工作。由于学习深度学习没多久,暂时只能根据以往在机器学习上的经验来进行优化,目前能想到的优化方向有:训练集与实际运行环境要一致、准备更多的训练集、深度另外的模型方法。