首页 理论教育MNIST手写数字识别的优化方法

MNIST手写数字识别的优化方法

【摘要】:MNIST数据集是由Yann LeCun所搜集的,MNIST数字文字识别数据集数据量不会太多,而且是单色的图像,比较简单,很适合深度学习的初学者用来练习模型、训练、预测。MNIST数据集共有训练数据60 000项、测试数据10 000项。图8-9手写数字识别数据预处理①导入所需模块。x_train=x_train_image.reshape.astypex_test=x_test_image.reshape.astype④将feature标准化,标准化可以提高模型预测的准确度。x_train_normalize=x_train/255x_test_normalize=x_test/255⑤label以one-hot endoding进行转换。y_trainOneHot=np_utils.to_categoricaly_testOneHot=np_utils.to_categorical建立模型①建立Sequential模型。model.add进行训练①定义训练方式。

MNIST数据集是由Yann LeCun所搜集的,MNIST数字文字识别数据集数据量不会太多,而且是单色的图像,比较简单,很适合深度学习的初学者用来练习模型、训练、预测。

MNIST数据集共有训练数据60 000项、测试数据10 000项。MNIST数据集的每一项数据都由图像(images)与真实的数字(labels)所组成,如图8-9所示。

图8-9 手写数字识别

(1)数据预处理

①导入所需模块。

from keras.utils import np_utils

from keras.datasets import mnist

import matplotlib.pyplot as plt

from keras.models import Sequential

from keras.layers import Dense

from keras.layers import Dropout

import pandas as pd

②读取MNIST数据。

(x_train_image,y_train_label),(x_test_image,y_test_label)=mnist.load_data()

③将数字图像特征值(feature)使用reshape转换,将原本的28×28的数字图像转换成784个float数字。

x_train=x_train_image.reshape(60000,784).astype('float32')

x_test=x_test_image.reshape(10000,784).astype('float32')

④将feature标准化,标准化可以提高模型预测的准确度。

x_train_normalize=x_train/255

x_test_normalize=x_test/255

⑤label以one-hot endoding进行转换。

y_trainOneHot=np_utils.to_categorical(y_train_label)

y_testOneHot=np_utils.to_categorical(y_test_label)

(2)建立模型

①建立Sequential模型。

model=Sequential()

②建立“输入层”和“隐藏层”。

model.add(Dense(units=256,input_dim=784,kernel_initializer='normal',activation='relu'))

③建立“输出层”。

model.add(Dense(units=10,kernel_initializer='normal',activation='softmax'))

(3)进行训练

①定义训练方式。

model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])

②开始训练。

train_history=model.fit(x=x_train_normalize,y=y_trainOneHot,validation_split=0.2,epochs=10,batch_size=200,verbose=2)

(4)以测试数据评估模型准确率

scores=model.evaluate(x_test_normalize,y_testOneHot)

print('accuracy=',scores[1])

(5)进行预测

prediction=model.predict_classes(x_test_normalize)

预测结果会放在prediction中,可以与y_test_label进行比对,查看预测效果。