- Wasserstein GAN,略してWGANと呼ばれる
- Arxiv: https://arxiv.org/abs/1611.02163
- 著者によるコード: https://github.com/martinarjovsky/WassersteinGAN
このWGANを使って以下のようないらすとや画像を生成したというのが本記事の主旨。
本論文のcontribution
- 確率分布間の距離を計算するための,従来使われてきた距離指標とEarth Mover(EM) distanceとの比較を行った
- EM distanceの近似関数を最小化するWasserstein-GANを定義した
- WGANが既存のGANにおける問題を解決することを示した
- WGANではdiscriminatorとgeneratorの学習バランスを気にしなくて良い
- ネットワーク構造も気にしなくて良い
- mode dropping phenomenonも殆ど無い
- discriminatorを学習させることでEM distanceを連続的に推定できる
- learning curveが生成画像の質の指標とみなせる.
既存のGANにおける問題
既存のGANでは,訓練データの分布と生成データの分布のf-divergenceを最小化することを目的としている.
f-divergenceは様々提案されているが、Jensen–Shannon divergenceに近いものが多い.
f-divergenceは密度比の関数になっているが、
問題が起きる場面がある。
例えば問題が起きる場面としてでも2つの分布のsupportがoverlapしていない状況を考える.
overlapしていないところでは,この密度比は無限に大きくなるか,ゼロになってしまう.
また,supportがdisjointな場合,密度比が定数なので,f-divergenceは定数となる.
簡単な例を下に示す.
- 訓練データ群がの点にあるとする
このときである.
つまり訓練データはかつまでの線分上にある. - モデルはパラメータを持っており,点を生成する.
- このとき分布は完全に一致するか,まったくoverlapしないかのどちらかとなる.
指標として
- Total Variation (TV) distance
- Kullback-Leibler (KL) divergence
- Jensen-Shannon (JS) divergence
- Earth-Mover (EM) distance or Wasserstein-1
の4つが論文で挙げられている.
そのうちEM distance(左)とJS divergence(右)を下に示す.
右のJS divergenceは,多くの場所で平らになっている.
これは多くの場所で勾配が0となることを示している.
もしGANでDiscriminatorがJS divergenceを近似するほど学習していた場合には,Generatorは勾配情報をほとんど受け取れないことに(ほぼゼロ)となってしまう.
これはGANにおける勾配消失問題の一例.
そこで今度は左のEM distanceを見てみる.
EM distanceはf-divergenceと違い,密度比の関数ではない.
EM distanceは直観的には,
分布を分布に変化させるために,地点からへ移動させるべき質量を示している.
つまり最適輸送問題におけるコストに相当するもの.
分布のsupportがoverlapしない場合には,EM distanceはどのくらい離れているかを表現できる.
上の左図を見てみると,EM distanceでは最適なパラメータへの勾配を得ることができるのがわかる.
(に収束できる)
EM distanceの定義は以下のようになっている.
$$
W(P_{r}, P_{g}) = inf_{\gamma \in \Pi (P_{r}, P_{g})} \mathbb{E}_{(x,y) \sim \gamma} \mid \mid x - y \mid \mid
$$
(infはminimumと考える)
2つの分布における全ての二点間のセットを考え,そのセットにおける二点間の距離の平均を取って,一番小さいセットにおける平均距離がEM distanceとなる.
ただし、このEM distanceはintractableで直接計算できない.
そこで次のような双対問題があるので上の定義の代わりに利用する. $$K \cdot W(P_{r},P_{g}) = \sup_{||f||_{L} \le K} (\mathbb{E}_{x \sim P_{r}}[f(x)] - \mathbb{E}_{x \sim P_{g}}[f(x)])$$
これはK-リプシッツ写像である.
これは直観的には,訓練データの平均と生成データの平均間の最大のマージンを得る関数となっている.
WGANの効果
- 学習が安定する
- 色んなアーキテクチャで上手くいく.例えば多層パーセプトロンでもOK
- $$ \mathbb{E}_{x \sim P_{r}}[f(x)] - \mathbb{E}_{x \sim P_{g}}[f(x)] $$ のプロットが収束するので,いつGANの学習を終えればいいか判断できる.
- mode collapse問題を避けられる
Wasserstein GANの実装
論文に記載のアルゴリズムは以下のようなものになっている
上のアルゴリズムを見ると分かるように、オリジナルのGANと極めて似ている。
実装ではGANのコードに以下変更を加えるだけで済む。
- 誤差関数からlogをなくす。
- Discriminatorの出力は確率ではなくなるので、Discriminatorの出力に活性化関数を施さない。
- (K-リプシッツ連続性を保つためにDiscriminatorの重みをクリップする
- GeneratorよりDiscriminatorをよく学習させる
- OptimizerとしてADAMではなくRMSPropを使う
- 学習率は小さい値を使う。論文では0.00005を使っている。
いらすとや画像の生成
ネットワーク構造は割と何でもよいようなので、
DCGANを試した際の構造をそのまま使った。
その時はkerasで実装していたので、tensorflowに移植した上で修正を施した。
またデータセットについては、 DCGAN, UnrolledGANのときはいらすとやから取得できる全画像を使っていたが,今回は人間の顔画像のみをデータセットとして用いる. 具体的には64x64x3のサイズの画像1814点を用いた.
ミニバッチサイズは64.
GTX780Tiで7時間ほどかけて13万epoch分計算した。
0 epoch
500 epoch
1000 epoch
2000 epoch
13000 epoch
WGANによる生成画像アニメーション
0〜7000 epochまでの生成画像のgifアニメがこちら。
下のDCGANの生成画像と比較すると
- 振動が抑えられている
- ぼやけている
のがわかる。
(比較用)DCGANによる生成画像アニメーション