基于 keras-js 快速实现浏览器内的 CNN 手写数字识别

3,929 阅读12分钟

在这篇文章中,我会快速地介绍如何使用 keras 训练一个简单的识别 MNIST(一个手写数字数据集)的 CNN(卷积神经网络),并且把训练好的网络应用到 web 浏览器内。

DEMO 地址:starkwang.github.io/keras-js-de…


零、准备工作

首先需要给你的电脑安装 keras,具体安装的步骤请参考 keras 官方文档


一、快速入门

首先十分推荐阅读 tensorflow 官方文档中的 MNIST For ML Beginners,这里是极客学院的中文翻译

MNIST 是一个很流行的入门级机器学习/计算机视觉数据集,它包含 0 - 9 的各种手写数字图片:

每张图片的尺寸均为 28 * 28,用一个 28 * 28 的二维数组来表示,换句话说,每张图片都是由 784 个像素点组成,每个像素点的值在 0 - 255 之间。

比如下面就是一个 "3" 的数据:

000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 038 043 105 255 253 253 253 253 253 174 006 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 043 139 224 226 252 253 252 252 252 252 252 252 158 014 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 178 252 252 252 252 253 252 252 252 252 252 252 252 059 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 109 252 252 230 132 133 132 132 189 252 252 252 252 059 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 004 029 029 024 000 000 000 000 014 226 252 252 172 007 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 085 243 252 252 144 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 088 189 252 252 252 014 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 091 212 247 252 252 252 204 009 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 032 125 193 193 193 253 252 252 252 238 102 028 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 045 222 252 252 252 252 253 252 252 252 177 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 045 223 253 253 253 253 255 253 253 253 253 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 031 123 052 044 044 044 044 143 252 252 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 015 252 252 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 086 252 252 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 005 075 009 000 000 000 000 000 000 098 242 252 252 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 061 183 252 029 000 000 000 000 018 092 239 252 252 243 065 000 000 000 000 000 000 000 000 
000 000 000 000 000 208 252 252 147 134 134 134 134 203 253 252 252 188 083 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 208 252 252 252 252 252 252 252 252 253 230 153 008 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 049 157 252 252 252 252 252 217 207 146 045 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 007 103 235 252 172 103 024 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000

使用 keras,可以很方便地导入 MNIST 数据集:

from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

总体来说,我们的想要得到的网络模型,是有一个固定的输入输出的:

  • 输入为一个 28 * 28 的二维整数数组
  • 输出是一个长度为 10 的数组,依次表示 0-9 的可能性(例如如果有一张图片 80% 概率为 1, 20% 概率为 7的话,那么这个数组就是 [0, 0.8, 0, 0, 0, 0, 0, 0.2, 0, 0]

二、使用 keras 训练网络

我们想要训练的模型,由以下几层网络组成:

  1. 32 个 3x3 卷积核的卷积层
  2. 64 个 3x3 卷积核的卷积层
  3. 采样因子为 (2, 2) 的池化层
  4. Dropout 层
  5. Flatten 层
  6. ReLu 全连接层
  7. Dropout 层
  8. Softmax 全连接层

用 keras 训练一个识别 MNIST 的 CNN 网络非常方便,下面是一个官方给出的例子(源码在此):

from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K

batch_size = 128
num_classes = 10
epochs = 12

# input image dimensions
img_rows, img_cols = 28, 28

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

# Save model
model.save('myMnistCNN.h5')

如果已经安装好了 keras,直接运行即可:

python mnist_cnn.py

三、转换输出模型

获得训练好的 .h5 文件之后,模型还不能直接使用,因为我们需要对它进行转编码,keras-js 提供了一个 python 脚本来自动执行:

python ./python/encoder.py -q myMnistCNN.h5

这个脚本会把 .h5 文件转编码为 keras-js 可读的格式,里面包含了训练好的神经网络的所有模型和参数。

四、使用 keras-js 导入模型

首先需要引入 keras-js,可以通过 script 标签直接引入:

<script src="https://unpkg.com/keras-js"></script>

也可以通过 npm 安装后使用 webpack 构建引入,参考这里

接下来就可以直接创建一个 Model,keras-js 会自动加载对应的 bin 文件:

const model = new KerasJS.Model({
    filepath: '/path/to/mnist_cnn.bin',
    gpu: true,
    transferLayerOutputs: true
})

初始化完毕之后,就可以用于 MNIST 识别了,输入是一个长度为 784 的数组(包含 28*28 各个像素点的灰度值),输出是一个长度为 10 的数组(0-9的概率):

(可以使用上文中给的那个 "3" 的数据范例)

model
  .ready()
  .then(() => {
    // data 是一个长度为 784 的数组,每一项都介于 0 - 255 之间
    // 这里我们需要把数组转换为 Float32 类型
    const inputData = new Float32Array(data)
    // 识别
    return model.predict(inputData)
  })
  .then(outputData => {
    // 输出为 0-9 的概率,例如:
    // { output: [0, 0, 0, 0.8, 0, 0, 0.2, 0, 0, 0] }
  })
  .catch(err => {
    // ...
  })

五、Canvas 实现一个手写板

最后一步就是实现一个手写板,具体的代码就不放上来了,主要就是通过 mousedownmousemovemouseup 事件来绘制图形。

绘制完毕之后,调用 ctx.getImageData,就可以得到 canvas 内的像素数据,每个像素对应四个数值,依次是每个点的 rgba 值,处理之后就可以得到长度为 784 的灰度数组了。然后使用上文提到的 model.predict 即可。