LeNetのパラメータの可視化

LeNetは1998のLeCunによって発表されたネットワーク。
MNISTに対して、このLeNetの類似ネットワークを適用した時の、パラメータを可視化してみるというのが本記事の主旨。

今回使ったネットワーク構成を示す。
畳み込み層2つに、全結合層が2つ連なったもの。
畳み込み層の直後にMaxプーリング層がある。

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 28, 28, 20)        520       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 20)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 10, 10, 50)        25050     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 5, 5, 50)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 1250)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 500)               625500    
_________________________________________________________________
dense_2 (Dense)              (None, 10)                5010      
=================================================================
Total params: 656,080.0
Trainable params: 656,080.0
Non-trainable params: 0.0
_________________________________________________________________
Test loss: 0.0360684088513
Test accuracy: 0.9908

今回は2つの畳込み層のパラメータをそれぞれ可視化する。
第一層の畳み込み層は、カーネルサイズは5×5で、カーネルは全部で20個。
第三層の畳み込み層は、カーネルサイズは5×5で、カーネルは全部で50個。

可視化

予め訓練したモデルとモデルのパラメータをロードしておく。
そして可視化する。

%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import numpy as np

def visualize_filter(model, nb_layer, nb_visualize_kernels):
    W = model.layers[nb_layer].get_weights()[0]
    print(W.shape)
    
    nb_filters = W.shape[3]
    nb_kernel = W.shape[0]
    
    for i in range(nb_filters):
        if i >= nb_visualize_kernels:
            continue
        im = W[:, :, :, i]
        im = im[:,:,0]
        scaler = MinMaxScaler(feature_range=(0,255))
        im = scaler.fit_transform(im)
        plt.subplot(4, 5, i+1)
        plt.axis('off')
        plt.imshow(im, cmap='gray')
    plt.show()

visualize_filter(model, 0, 20)
visualize_filter(model, 2, 20)

第一層目の重みパラメータは以下のように可視化された。

f:id:yusuke_ujitoko:20170320180015p:plain

一方、第三層目の重みパラメータは以下のように可視化された。
第一層目の数に合わせて、先頭の20個のカーネルのみ抽出してある。

f:id:yusuke_ujitoko:20170320180025p:plain

この2つを比較しても、あまり差が無いように見える。
若干第三層目の方が、複雑なパラメータ配置となっているか…?