9

Keras 教学 – Keras MNIST 手写辨识 x 深度学习的 HelloWorld

 3 years ago
source link: https://flashgene.com/archives/164373.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

Keras 实现 MNIST 手写辨识

每一个程式语言都有 HelloWorld,深度学习领域中最经典的 Demo 就是 MNIST 手写辨识,MNIST 资料即是由 28×28 灰阶图片,分别有 0~9 分布 60,000 张训练资料与 10,000 张测试资料 。目前各大演算法对于 MNIST 的準确率已经提昇许多,我们也来用 Keras 来测试看看啰。

Python Keras MNIST 手写辨识

这是一个神经网路的範例,利用了 Python Keras 来训练一个手写辨识分类 Model。

我们要的问题是将手写数字的灰度图像(28×28 Pixel)分类为 10 类(0至9)。使用的数据集是 MNIST 经典数据集,它是由国家标準技术研究所(MNIST 的 NIST)在1980年代组装而成的,包含 60,000 张训练图像和 10,000 张测试图像。您可以将「解决」MNIST 视为深度学习的 “Hello World”。

由于 Keras 已经整理了一些经典的 Play Book Data,因此我们可以很快透过以下方式取得 MNIST 资料集。

from keras.datasets import mnist
 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

images 是用来训练与测试的资料,label 则为每一笔影像资料对应的正确答案,每一张手写图片都是 28 x 28 的灰阶 Bit Map,透过以下 Python 来看一下资料集的结构

train_images.shape
len(train_labels)
train_labels
test_images.shape
len(test_labels)
test_labels

输出

建立準备训练的神经网路

开始训练神经网路以前,需要先建构网路,然后才开始训练,如下:

from keras import models
from keras import layers
 
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dense(10, activation='softmax'))

上面这里是神经网路的核心组成方式,我们在全连接层建立了两层,由一个有 512 个神经元的网路架构连接到 10 个神经元的输出层。输出层採用 softmax 表示数字 0~9 的机率分配,这 10 个数字的总和将会是 1。以下将我们建立的网路进行 compile,这里详细的参数以后会介绍。

network.compile(optimizer='rmsprop',
                loss='categorical_crossentropy',
                metrics=['accuracy'])

上面的参数其实就直接对应到上一篇文章介绍的类神经网路基础架构,可以这样理解:

以下将资料正规划成为 0~1 的数值,变成 60000, 28×28 Shape 好送进上面定义的网路输入层。

fix_train_images = train_images.reshape((60000, 28 * 28)).astype('float32') / 255
fix_test_images = test_images.reshape((10000, 28 * 28)).astype('float32') / 255

由于我们使用的 categorical_crossentropy 损失函数,因此将标记资料进行格式转换。如下:

from keras.utils import to_categorical
 
fix_train_labels = to_categorical(train_labels)
fix_test_labels = to_categorical(test_labels)

开始训练 MNIST 类神经网路

进行训练模型,预计训练的正确率应该会在 0.989 左右

result = network.fit(
    fix_train_images,
    fix_train_labels,
    epochs=20,
    batch_size=128,
    validation_data=(fix_test_images, fix_test_labels))

对应神经网路模型概念如下:

执行结果如下:

将训练后的模型输入测试资料进行评比,一般说这样的正确率应该会在 0.977% 左右

test_loss, test_acc = network.evaluate(fix_test_images, fix_test_labels)
print('test_loss:', test_loss)
print('test_acc:', test_acc)

执行结果如下:

为什麽训练时的正确率会高于验证测试呢?在这样数据中,由于模型训练时对训练资料造成些微的过度拟合 (Over Fitting) 。一般来说这样的情况是正常的,未来我们可以透过参数的调整或其他方法提高正确性。

透过 Keras 图表协助分析训练过程

由于训练 Model 时会进行好几次的 Epoch,每一次 Epoch 都是对训练资料集进行一轮完整的训练,妥善观察每一次 Epoch 的数据是很重要地。我们可以透过 matplotlib 函式库绘製图表,帮我们进行分析。

以下方式可以绘製训练过程 Loss Function 对应的损失分数。Validation loss 不一定会跟随 Training loss 一起降低,当 Model Over Fitting Train Data 时,就会发生 Validation loss 上升的情况。

history_dict = result.history
 
loss_values = history_dict['loss']
val_loss_values = history_dict['val_loss']
epochs = range(1, len(loss_values) + 1)
 
import matplotlib.pyplot as plt
plt.plot(epochs, loss_values, 'bo', label='Training loss')
plt.plot(epochs, val_loss_values, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
 
plt.show()

执行后的图表如下:

可以看到差不多在第八次 Epochs 就得到最佳解。此外,我们也可以透过以下程式可以绘製训练过程的正确率变化。训练的过程中,当 Accuracy 后期并有没太大的变化,表示 Model 很快就在假设空间里进行不错的收敛。

plt.clf()
acc = history_dict['accuracy']
val_acc = history_dict['val_accuracy']
 
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
 
plt.show()

结论我们可以看到「算越久不一定越好」一旦 Over Fitting 之后就会得到反效果,市面上有很多 AI Cloud 都会在 Training 的过程把每一个 Epoch 的 Model 都保留储存,好让使用者可以随意取出小果最好的 Model,如果想要这样做,透过 Keras Check Point Callback 也是可以达到的。

今天介绍的 MNIST 训练与相关程式码,可以在 GitHub 取得 ipynb 档案 ,透过免费的 Google Colab 就可以执行测试啰,需要的请自行下载,记得按讚分享加关注喔。Keras Machine Learning 未完待续…….


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK