LeNetは1998のLeCunによって発表されたネットワーク。
MNISTに対して、このLeNetの類似ネットワークを適用した時の、パラメータを可視化してみるというのが本記事の主旨。
今回使ったネットワーク構成を示す。
畳み込み層2つに、全結合層が2つ連なったもの。
畳み込み層の直後にMaxプーリング層がある。
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_1 (Conv2D) (None, 28, 28, 20) 520 _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 14, 14, 20) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 10, 10, 50) 25050 _________________________________________________________________ max_pooling2d_2 (MaxPooling2 (None, 5, 5, 50) 0 _________________________________________________________________ flatten_1 (Flatten) (None, 1250) 0 _________________________________________________________________ dense_1 (Dense) (None, 500) 625500 _________________________________________________________________ dense_2 (Dense) (None, 10) 5010 ================================================================= Total params: 656,080.0 Trainable params: 656,080.0 Non-trainable params: 0.0 _________________________________________________________________ Test loss: 0.0360684088513 Test accuracy: 0.9908
今回は2つの畳込み層のパラメータをそれぞれ可視化する。
第一層の畳み込み層は、カーネルサイズは5×5で、カーネルは全部で20個。
第三層の畳み込み層は、カーネルサイズは5×5で、カーネルは全部で50個。
可視化
予め訓練したモデルとモデルのパラメータをロードしておく。
そして可視化する。
%matplotlib inline import matplotlib.pyplot as plt from sklearn.preprocessing import MinMaxScaler import numpy as np def visualize_filter(model, nb_layer, nb_visualize_kernels): W = model.layers[nb_layer].get_weights()[0] print(W.shape) nb_filters = W.shape[3] nb_kernel = W.shape[0] for i in range(nb_filters): if i >= nb_visualize_kernels: continue im = W[:, :, :, i] im = im[:,:,0] scaler = MinMaxScaler(feature_range=(0,255)) im = scaler.fit_transform(im) plt.subplot(4, 5, i+1) plt.axis('off') plt.imshow(im, cmap='gray') plt.show() visualize_filter(model, 0, 20) visualize_filter(model, 2, 20)
第一層目の重みパラメータは以下のように可視化された。
一方、第三層目の重みパラメータは以下のように可視化された。
第一層目の数に合わせて、先頭の20個のカーネルのみ抽出してある。
この2つを比較しても、あまり差が無いように見える。
若干第三層目の方が、複雑なパラメータ配置となっているか…?