GAN:敵対的生成ネットワーク

深層生成モデル2

Deep
Sampling
Author

司馬博文

Published

2/11/2024

Modified

2/13/2024

概要
数学者のために,深層生成モデルの先駆けである GAN を概観する.

1 GAN (Goodfellow et al., 2014)

Samples from a GAN Taken from (Goodfellow et al., 2014, p. 7) Figure 2

1.1 導入

GAN 以前の深層生成モデルは,学習の難しさから,データ生成分布にパラメトリックな仮定をおき,その中で 最尤推定 を行うことが一般的であった.深層 Boltzmann マシン (Salakhutdinov and Hinton, 2009) もその例である.

複雑なモデルで尤度を解析的に計算することは困難である.そのために,MCMC によるサンプリングによりこれを回避することを考え,その Markov 連鎖の遷移核を学習するという生成確率的ネットワーク (GSN: Generative Stochastic Network) などのアプローチ (Bengio et al., 2014) も提案されていた.

GAN (Generative Adversarial Network) は,このような中で (Goodfellow et al., 2014) によって提案された深層生成モデルである.GAN も尤度の評価を必要としないが,MCMC などのサンプリング手法も用いず,ただ誤差逆伝播法のみによって学習が可能である.

同時の深層学習は,ImageNet コンペティションにおいて大成功を収めた AlexNet (Krizhevsky et al., 2012) など,主に識別のタスクにおいて大きな成功を収めていたが,生成モデルにおいては芳しくなかった.

主な障壁は

  1. 分布の近似が難しいこと
  2. 区分的線型な活性化関数を用いても勾配を通じた学習が難しいこと

の2点であったが,GAN はこの2つの問題を回避すべく提案された.

生成モデル \(G\) に対して,判別モデル \(D\) を対置し,加えて \((G,D)\) をセットで誤差逆伝播法とドロップアウト法 (Hinton et al., 2012)(当時深層識別モデルを最も成功させていた学習法)により学習可能にしたのである.

1.2 枠組み

データの空間を \(x\in\mathcal{X}\) とし,潜在変数の値域 \(\mathcal{Z}\) とその上の確率測度 \(P_z\in\mathcal{P}(\mathcal{Z})\),そして深層ニューラルネットワークのパラメータ空間 \(\Theta_g\) を用意して,生成モデルを写像 \(G:\mathcal{Z}\times\Theta_g\to\mathcal{X}\) とする.

生成モデル \(G\) は押し出しによりモデル \(\{G(-,\theta_g)_*P_z\}_{\theta_g\in\Theta_g}\) を定める.

このモデルの密度(尤度)の評価を回避するために,これに判別モデル \(D\) を対置する.これは,パラメータ \(\theta_d\in\Theta_d\) を通じて学習される写像 \(D:\mathcal{X}\times\Theta_d\to[0,1]\) とし,あるデータ \(x\in\mathcal{X}\) を観測した際に,これが \(G\) から生成されたものではなく,実際の訓練データである確率を \(D(x)\) によって近似することを目指す.

この組 \((G,D)\) に対して, \[ V(D,G):=\operatorname{E}[\log D(X)]+\operatorname{E}[\log(1-D(G(Z))] \] \[ X\sim P_{\text{data}},\quad Z\sim P_z \] を目的関数とし, \[ \min_{G\in\mathrm{Hom}_\mathrm{Mark}(\mathcal{Z}\times\mathcal{G}_g,\mathcal{X})}\max_{D\in\mathcal{L}(\mathcal{X};[0,1])}V(D,G) \tag{1}\] を解く,ミニマックスゲームを考える.1

1.3 理論

\(G\)\(D\) が表現するモデルが十分に大きいとき,すなわち \(\Theta_g,\Theta_d\) が十分に大きく,殆どノンパラメトリックモデルであるとみなせる場合には,学習基準 Equation 1 は真の生成分布 \(P_{\text{data}}\) に収束するアルゴリズムを与える.

このことを示すには,\(P_{\text{data}}\) が,Equation 1 の大域的最適解であることを示せば良い.

定義 (Jensen-Shannon divergence)

確率測度 \(P,Q\in\mathcal{P}(\mathcal{X})\) に対して,

  1. \[ \operatorname{KL}(P,Q):=\begin{cases} \int_\mathcal{X}\log\left(\frac{d P}{d Q}\right)\,dP&P\ll Q,\\ \infty&\mathrm{otherwise}. \end{cases} \]Kullback-Leibler 乖離度 という.
  2. \[ \operatorname{JS}(P,Q):=\operatorname{KL}\left(P,\frac{P+Q}{2}\right)+\operatorname{KL}\left(Q,\frac{P+Q}{2}\right) \]Jensen-Shannon 乖離度 という.

このとき,\(\sqrt{\operatorname{JS}}\) は,任意の \(\sigma\)-有限測度 \(\mu\in\mathcal{M}(\mathcal{X})\) に関して, \[ \mathcal{P}_\mu(\mathcal{X}):=\left\{P\in\mathcal{P}(\mathcal{X})\mid P\ll\mu\right\} \] 上に距離を定める.

KL 乖離度は \(P\ne Q\Rightarrow\operatorname{KL}(P,Q)>0\) を満たすが,対称性も三角不等式も満たさない.そもそも,\(\mathbb{R}_+\)-値とは限らず,\(\infty\) を取り得る.

JS 乖離度は, \[ P\ll\frac{P+Q}{2} \] であるから,\(\mathcal{P}(\mathcal{X})^2\) 上で常に \(\mathbb{R}_+\)-値であることに注意.

以降,\(\sqrt{\operatorname{JS}}\) が距離であることを示す.

  1. \(P=Q\) のとき \(\operatorname{JS}(P,Q)=0\) であり,\(P\ne Q\) のとき, \[ P\ne\frac{P+Q}{2} \] であるから,\(\operatorname{JS}(P,Q)>0\) である.
  2. 対称性も直ちに従う.
  3. あとは三角不等式を示せば良いが,任意の \(P,Q\in\mathcal{P}_\mu(\mathcal{X})\) に関して,密度を \[ p:=\frac{d P}{d \mu},\quad q:=\frac{d Q}{d \mu} \] と表すと, \[ \sqrt{\operatorname{JS}(P,Q)}=\left\|\sqrt{L(p,q)}\right\|_{L^2(\mu)} \] であることより,次の補題と \(\|-\|_{L^2(\mu)}\) の三角不等式より従う.

非負実数 \(p,q\in\mathbb{R}_+\) について, \[ L(p,q):=p\log\frac{2p}{p+q}+q\log\frac{2q}{p+q} \] で定まる関数 \(L:\mathbb{R}_+^2\to\mathbb{R}_+\) は,任意の \(r\in\mathbb{R}_+\) について, \[ \sqrt{L(p,q)}\le\sqrt{L(p,r)}+\sqrt{L(r,q)} \] を満たす.

右辺を \[ f(p,q,r):=\sqrt{L(p,r)}+\sqrt{L(r,q)} \] とおいて,\(r\) に関する偏導関数の符号変化を調べる. \[ \begin{align*} \frac{\partial f}{\partial r}&=\frac{1}{2\sqrt{L(p,r)}}\frac{\partial L}{\partial r}(p,r)+\frac{1}{2\sqrt{L(r,q)}}\frac{\partial L}{\partial r}(r,q)\\ &=\frac{\log\frac{2r}{p+r}}{2\sqrt{L(p,r)}}+\frac{\log\frac{2r}{r+q}}{2\sqrt{L(r,q)}}\\ &=\frac{1}{\sqrt{r}}\left(\frac{\log\frac{2}{x+1}}{2\sqrt{L(x,1)}}+\frac{\log\frac{2}{\beta x+1}}{2\sqrt{L(\beta x,1)}}\right). \end{align*} \tag{2}\] ここで,\(x:=\frac{p}{r},\beta x=\frac{q}{r}\) とおいた. \[ p<q\quad\Leftrightarrow\quad\beta>1 \] と仮定しても一般性は失われない.

そこで,\(x\in(-1,\infty)\setminus\{1\}\) の関数 \[ \begin{align*} g(x)&:=\frac{\log\frac{2}{x+1}}{\sqrt{L(x,1)}}\\ &=\frac{\log\frac{2}{x+1}}{\sqrt{x\log\frac{2x}{x+1}+\log\frac{2}{x+1}}} \end{align*} \] の性質を調べる.

実は \(g'>0\;\mathrm{on}\;\mathbb{R}_+\setminus\{1\}\) であり, \[ \lim_{x\to0+}g(x)=\sqrt{\log 2}>0 \] \[ \lim_{x\to\infty}g(x)=0 \] \[ \lim_{x\to1\mp}g(x)=\pm1 \] と併せると,\(g((0,1))\subset(0,1)\)\(g((1,\infty))\subset(-1,0)\) である.特に \(\lvert g\rvert<1\)

これより, Equation 2\(x=1,\beta\) と,その間で1回の計3回符号変化し,\(x\to\infty\) の極限では負である.

よって,\(f\)\(r\) の関数として,\(r=p\) で極小値,\(r\in(p,q)\) のどこかで極大値を取り,\(r=q\) で再び極小値を取る. \[ f(p,q,p)=f(p,q,q)=\sqrt{L(p,q)} \] であるから,結論を得る.

命題

\(P_0,P_1\in\mathcal{P}(\mathcal{X})\) を確率測度で,それぞれ密度 \(p_0,p_1\) を持つとする.\(X_0\sim P_0,X_1\sim P_1\) とする.このとき,

  1. 最大化問題 \[ L:=\sup_{D\in\mathcal{L}(\mathcal{X};[0,1])}\biggr(\operatorname{E}[\log D(X_0)]+\operatorname{E}[\log(1-D(X_1))]\biggl) \] はただ一つの解 \[ D^*(x)=\frac{p_0(x)}{p_0(x)+p_1(x)} \] を持つ.

  2. \(\operatorname{JS}(P_0,P_1)\)\(L\) と定数の差を除いて一致する.

  1. 目的関数は \[ \begin{align*} &\operatorname{E}[\log D(X_0)]+\operatorname{E}[\log(1-D(X_1))]\\ =&\int_\mathcal{X}\log D\cdot p_0\,d\mu+\int_\mathcal{X}\log(1-D)\cdot p_1\,d\mu\\ =&\int_\mathcal{X}\biggr(p_0\log D+p_1\log(1-D)\biggl)\,d\mu \end{align*} \] と変形できる.いま,任意の \(a,b\in(0,1]\) に関して, \[ f(t):=a\log t+b\log(1-t)\quad(t\in(0,1)) \]\(t=\frac{a}{a+b}\) 上で最大値を取る.\(a,b\) のどちらか一方のみが \(0\) である場合も含めてこの主張は成り立つ.よって, \[ D(x)=\frac{p_0(x)}{p_0(x)+p_1(x)} \] が目的関数を最大化することが判る.
  2. 1より,\(L\) の上限 \(\sup\) は達成されることがわかった: \[ \begin{align*} L&=\operatorname{E}[\log D^*(X_0)]+\operatorname{E}[\log(1-D^*(X_1))]\\ &=\int_\mathcal{X}\left(p_0\log\frac{p_0}{p_0+p_1}+p_1\log\frac{p_1}{p_0+p_1}\right)\,d\mu\\ &=\int_\mathcal{X}\biggr(p_0\log\frac{2p_0}{p_0+p_1}+p_1\log\frac{2p_1}{p_0+p_1}-p_0\log 2-p_1\log 2\biggl)\,d\mu\\ &=-2\log2+\operatorname{JS}(P_0,P_1). \end{align*} \]

これより,訓練基準 Equation 1 はただ一つの大域的な最適解を持ち,これは \(P_{\text{data}}=G_*P_z\) かつ \(D^*=\frac{1}{2}\) のときに最小値 \(-2\log2\) を取るということが判る.

1.4 アルゴリズムとその収束

\((G,D)\) を勾配降下法により同時に学習するには,

  1. 判別器 \(D\) の最大化ステップ
    1. ミニバッチ \(\{z^i\}_{i=1}^m\)\(\{x^i\}_{i=1}^m\) をそれぞれ \(P_z\)\(P_{\text{data}}\) からサンプリングする.
    2. 確率的勾配 \[ D_{\theta_d}\frac{1}{m}\sum_{i=1}^m\left(\log D(x^i)+\log(1-D(G(z^i)))\right) \] の増加方向にパラメータ \(\theta_d\) を更新する.
  2. 生成モデル \(G\) の最小化ステップ
    1. ミニパッチ \(\{z^i\}_{i=1}^m\)\(P_z\) からサンプリングする.
    2. 確率的勾配 \[ D_{\theta_g}\sum_{i=1}^m\log\biggr(1-D(G(z^i))\biggl) \] の減少方向にパラメータ \(\theta_g\) を更新する.

というアルゴリズムを実行すれば良い.(Goodfellow et al., 2014 p.) の数値実験ではモーメンタム法 (Rumelhart et al., 1987, p. 330) が用いられている.

このアルゴリズムは,次の3条件が成り立つならば,\(G_*P_z\)\(P_{\text{data}}\) に収束する:

  1. モデル \(G,D\) の表現力が十分大きい.
  2. 判別器 \(D\) の最大化ステップにおいて,必ず \(\max_{D\in\mathcal{L}(\mathcal{X};[0,1])}V(D,G)\) が達成される.
  3. 生成モデル \(G\) の最大化ステップにおいても,必ず \(V(D,G)\) が改善される.

実際は,\(G\) はパラメトリックモデル \(\{G_*P_z(\theta,-)\}_{\theta\in\Theta_g}\) であるから,その分の誤差は残ることになる.

また,\(D\) が最適化されていない状況で \(G\) が学習されすぎると,多くの \(z\in\mathcal{Z}\) の値を \(D\) が不得意な判別点 \(x\in\mathcal{X}\) に対応させすぎてしまうことがあり得る.

\(P_{\text{data}}\) が強い多峰性を持つ場合でも効率よく学習することができる.これは同じ確率分布からのサンプリング手法として,MCMC にはない美点になり得る (Goodfellow et al., 2014, p. 6)

1.5 補遺:Jensen-Shannon 乖離度のその他の性質

1.5.1 情報理論からの導入

乖離度としての Jensen-Shannon 乖離度は (Lin, 1991) で最初に導入されたようである.

が,その以前から, \[ \operatorname{JS}(P,Q)=2H\left(\frac{P+Q}{2}\right)-H(P)-H(Q) \] という関係を通じて,(Rao, 1982, p. 25) などは右辺を Jensen 差分 (difference) と呼んでいたようである.(Rao, 1987, p. 222) は,\(H\) が Shannon のエントロピーではなくとも,有用な性質を持つことを情報幾何学の立場から議論している.

1.5.2 JS 乖離度が定める距離

\[ \biggr(\operatorname{JS}(P,Q)\biggl)^\alpha \]\(\alpha=\frac{1}{2}\) において距離をなすことを示したが,実は一般の \(\alpha\in(0,1/2]\) に関して距離をなす (Osán et al., 2018)

1.5.3 変分問題としての特徴付け

任意の \(P,Q\in\mathcal{P}_\mu(\mathcal{X})\) について,

\[ \operatorname{JS}(P,Q)=\min_{R\in\mathcal{P}_\mu(\mathcal{X})}\left\{\operatorname{KL}(P,R)+\operatorname{KL}(Q,R)\right\} \]

1.5.4 有界な距離である

\(\operatorname{JS}:\mathcal{P}_\mu(\mathcal{X})^2\to\mathbb{R}_+\) は最大値 \(\sqrt{2\log 2}\) を持つ.

1.5.5 \(\chi^2\)-距離に漸近する (Endres and Schindelin, 2003, p. 1859)

1.5.6 \(f\)-乖離度の例である

\(f\)-乖離度の考え方は (Rényi, 1961, p. 561) で導入された.他,(Csiszár, 1963), (Morimoto, 1963), (Ali and Silvey, 1966) なども独立に導入している.

定義 (\(f\)-divergence)

\(P\ll Q\) とする.凸関数 \(f:\mathbb{R}_+\to\mathbb{R}\) に対して,

\[ D_f(P,Q):=\int_\mathcal{X}f\left(\frac{d P}{d Q}\right)\,dQ \]\(f\)-乖離度 という.

KL-乖離度は \[ f(x)=x\log x \] について,JS-乖離度は \[ f(x)=x\log\frac{2x}{x+1}+\log\frac{2}{x+1} \] についての \(f\)-乖離度である.

全変動ノルムも \[ f(x)=\lvert x-1\rvert \] に関する \(f\)-乖離度である.

さらには,\(\alpha\)-乖離度\(f\)-乖離度の例である.

2 GAN の改良

(Nowozin et al., 2016) による \(f\)-GAN,(Arjovsky et al., 2017) による Wasserstein GAN など,GAN の改良が続いている.

2.1 \(f\)-GAN

JS-乖離度に限らず一般の \(f\)-乖離度 Section 1.5.6 に関して,GAN が構成できる (Nowozin et al., 2016)

この一般化により,GAN の枠組みの本質は凸解析に基づくものであることが明らかになる.

2.2 GAN の学習の問題点

  • やはり多峰性に弱く,モードのうちいくつかが再現されないことがある (Mode collapse).
  • 収束判定が困難である.これは学習基準が最小化ではなく均衡点を求めることにあることにも起因する.
  • 勾配消失が起こる.

2.3 Wasserstein GAN

最後の勾配消失の問題は,JS-乖離度の性質にあるとして,これを Wasserstein 距離に取り替える形で提案されたのが Wasserstein GAN である (Arjovsky et al., 2017)

References

Ali, S. M., and Silvey, S. D. (1966). A general class of coefficients of divergence of one distribution from another. Journal of the Royal Statistical Society. Series B (Methodological), 28(1), 131–142.
Arjovsky, M., Chintala, S., and Bottou, L. (2017). Wasserstein generative adversarial networks. In Proceedings of the 34th international conference on machine learning,Vol. 70, pages 214–223.
Bengio, Y., Laufer, E., Alain, G., and Yosinski, J. (2014). Deep generative stochastic networks trainable by backprop. In Proceedings of the 31st international conference on machine learning,Vol. 32, pages 226–234.
Csiszár, I. (1963). Eine informationstheoretische ungleichung und ihre anwendung auf beweis der ergodizitaet von markoffschen ketten. Magyár Tudomá Akadémia Mahematikai Kutató Intézetének Köezleményei, 6, 85–108.
Endres, D. M., and Schindelin, J. E. (2003). A new metric for probability distributions. IEEE Transactions on Information Theory, 49(7), 1858–1860.
Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., … Bengio, Y. (2014). Generative adversarial nets. In Advances in neural information processing systems,Vol. 27, pages 1–9.
Hinton, G. E., Srivastava, N., Krizhevsky, A., Sutskever, I., and Salakhutdinov, R. R. (2012). Improving neural networks by preventing co-adaptation of feature detectors.
Krizhevsky, A., Sutskever, I., and Hinton, G. E. (2012). ImageNet classification with deep convolutional neural networks. In Advances in neural information processing systems,Vol. 25.
Lin, J. (1991). Divergence measures based on the shannon entropy. IEEE Transactions on Information Theory, 37(1), 145–151.
Morimoto, T. (1963). Markov processes and the \(H\)-theorem. Journal of the Physical Society of Japan, 18(3), 328–331.
Nielsen, F. (2021). On a variational definition for the jensen-shannon symmetrization of distances based on the information radius. Entropy, 23(4), 464.
Nowozin, S., Cseke, B., and Tomioka, R. (2016). F-GAN: Training generative neural samplers using variational divergence minimization. In Advances in neural information processing systems,Vol. 29.
Osán, T. M., Bussandri, D. G., and Lamberti, P. W. (2018). Monoparametric family of metrics derived from classical jensen-shannon divergence. Physica A: Statistical Mechanics and Its Applications, 495, 336–344.
Rao, C. R. (1982). Diversity and dissimilarity coefficients: A unified approach. Theoretical Population Biology, 21(1), 24–43.
Rao, C. R. (1987). Differential metrics in probability spaces. IMS Lecture Notes Monograph Series, 10, 217–240.
Rényi, A. (1961). On measures of entropy and information. In Proceedings of the fourth berkeley symposium on mathematical statistics and probability,Vol. 1, pages 547–561.
Rumelhart, D. E., Hinton, G. E., and Williams, R. J. (1987). Parallel distributed processing: Explorations in the microstructure of cognition: foundations. In D. E. Rumelhart and J. L. McClelland, editors, pages 318–362. MIT Press.
Salakhutdinov, R., and Hinton, G. (2009). Deep boltzmann machines. In Proceedings of the twelth international conference on artificial intelligence and statistics,Vol. 5, pages 448–455.

Footnotes

  1. この基準にしたがって学習すると,\(G\) が外れすぎている際,\(\log(1-D(G(z)))\) が殆ど \(0\) になり得る.そのような場合は,\(\log D(G(z))\) の最大化を代わりに考えることで,学習が進むことがある (Goodfellow et al., 2014, p. 3)↩︎