Keras(Tensorflow)でMNIST

KerasでMNIST手書き文字分類問題を試した。

実行環境

Python 3.6.0
Keras 1.2.2
Tensorflow 1.0.1
GeForce 780Ti

コード

ライブラリのインポート
import numpy as np
np.random.seed(123)

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D
from keras.utils import np_utils
MNISTデータセットのロード
from keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()

データセットの画像を表示

import matplotlib.pyplot as plt
plt.imshow(X_train[1])

f:id:yusuke_ujitoko:20170319194818p:plain

データを事前処理
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)

X_train = X_train.astype('float32')
X_test = X_test.astype('float32')

X_train /= 255
X_test /= 255

ラベルのデータ構造変更

Y_train = np_utils.to_categorical(y_train, 10)
Y_test = np_utils.to_categorical(y_test, 10)
ネットワークを定義
model = Sequential()
model.add(Convolution2D(32, 3, 3, activation='relu', input_shape = (28, 28, 1)))
model.add(Convolution2D(32, 3, 3, activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))

誤差関数、勾配法を定義する。

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
訓練
history = model.fit(X_train, Y_train, batch_size=32, nb_epoch=10, verbose=1)
評価
score = model.evaluate(X_test, Y_test, verbose=0)
print(score)

Accuracyの可視化

fig = plt.figure()
plt.plot(history.history['acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.xlim(0, 10)
plt.savefig('acc.png')

f:id:yusuke_ujitoko:20170319214321p:plain

Lossの可視化

fig = plt.figure()
plt.plot(history.history['loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.xlim(0, 10)
plt.savefig('loss.png')

f:id:yusuke_ujitoko:20170319214313p:plain

モデルの可視化

可視化のための事前準備として、graphvizとpydot_ngをインストール

apt-get install graphviz
pip install pydot_ng
from keras.utils.visualize_util import plot
plot(model, to_file='model.png')

f:id:yusuke_ujitoko:20170319202758p:plain

参考