拡散モデルによる事後分布サンプリング

Langevin 拡散の時間反転を用いたシミュレーションベースのサンプリング法

Process
Sampling
P(X)
Author

司馬博文

Published

8/03/2024

Modified

10/12/2024

概要

拡散モデルから始まるフロー学習手法は,画像と動画に関して 2024 年時点で最良の性能を誇る. これは統計的に言えば事後分布からの近似的サンプリングを実行していることに相当する. 近似的ではなく,正確に2つの分布を補間するような拡散過程を推定するためには Schrödinger 橋がある. Schrödinger 橋については 次稿 に譲るとし,本稿ではサンプラーとしての拡散モデルを復習する.

関連記事一覧

1 雑音除去拡散 (DD: Denoising Diffusion)

1.1 はじめに

潜在空間 \(\mathcal{X}\) 上の事前分布 \(p(x)dx\) と,尤度が確率核 \(P:\mathcal{X}\to\mathcal{Y}\) \[ P(x,dy)=p(y|x)\,dy \] の形で与えられているとする.

この際の事後分布 \[ p(x|y)=\frac{g(y|x)\mu(x)}{\int_\mathcal{X}g(y|x)\mu(x)\,dx} \] を償却的に学習しサンプリングするためには,\(p(x|y)\) から開始する Langevin 拡散の時間反転をシミュレーションする 拡散模型 (DDPM / SGM) の方法が使える.

Langevin 拡散の時間反転は 雑音除去拡散 (DD: Denoising Diffusion) とも呼ばれ,拡散模型は正しい DD を学習すること・高速な DD のシミュレーション法を開発することが主眼である.

償却的に学習するため,一度モデルが訓練されれば,種々の \(y\in\mathcal{Y}\) に対応できる.例えば \(x\) が画像,\(y\) がクラスラベルなどの場合では,\(p(x|y)\) からのサンプリングは 条件付き生成 とも呼ばれる.これについては次稿参照:

だが \(p(x|y)\) を一般の事後分布と見れば,この DD を用いた拡散模型的アプローチは条件付き生成のタスクに限らず,逆問題や一般のベイズ推論にも応用可能だろう.

さらにスコアマッチングを通じた学習法では,事前分布 \(p(x)\) と尤度 \(p(y|x)\,dy\) の両方からサンプリングができれば実行可能であるため,シミュレーションベースの推論 (SBI: Simulation-based Inference) への応用も期待できる.

そこで,この事後分布 \(p(x|y)\) からの シミュレーション・ベースのサンプラーとしての拡散模型 を,DDPM をもじって 雑音除去拡散サンプラー (DDPS: Denoising Diffusion Posterior Sampler) と呼ぼう.

(Benton et al., 2024) では DDPS を,離散空間や Riemann 多様体を含む一般の状態空間 \(\mathcal{X}\) 上に一般化した DMM (Denoising Markov Model) を提案している.

本稿では生成モデルとしてはお馴染みの拡散模型をサンプラー DDPS として捉えなおし,近似ベイズ推論や逆問題への応用をみる.

しかしこの方法は正確な方法ではない上に,正確性を記すために時間 \(T>0\) を大きく取る必要があり,サンプリングに大変な時間がかかる.

次稿では DD の代わりに Schrödinger 橋 (SB: Schrödinger Bridge) を用いて,より正確で高速なサンプリングの達成を目指す:

1.2 雑音除去拡散 (DD) の定義

ひとまず,観測 \(y\in\mathcal{Y}\) を固定して議論する.

拡散模型 では,次で定まる OU 過程によってデータ分布を \(\mathrm{N}_d(0,I_d)\) にまで破壊することを考える (Y. Song et al., 2021)\[ dX_t=-\frac{1}{2}X_t\,dt+dB_t,\qquad X_0\sim p(x|y). \]

OU 過程は指数エルゴード性を持ち,十分な時間が経てば \(\mathrm{N}_d(0,I_d)\) に収束することは確約されている.とはいっても,完全に \(\mathrm{N}_d(0,I_d)\) に従うようになるのは \(t\to\infty\) の極限においてである.

この \((X_t)\) の有限時区間 \([0,T]\) における時間反転は,\((X_t)\) の密度を \(p_t(x_t|y)\) で表すと, \[ dZ_t=\frac{1}{2}Z_t\,dt+\nabla_z\log p_{T-t}(Z_t|y)\,dt+dW_t,\qquad Z_0\sim p_T(x_T|y), \tag{1}\] の弱解になる (Anderson, 1982), (Haussmann and Pardoux, 1986)

この時間反転過程 \((Z_t)_{t\in[0,T]}\)雑音除去拡散 (Denoising Diffusion) という.

数学的な詳細は次の記事参照.

1.3 スコアマッチングによる DD からのサンプリング

すると残りの問題は,DD \((Z_t)_{t\in[0,T]}\) からのサンプリングになる.

そのためDD からのサンプリングのためにはドリフト項 \(\log p_{T-t}(Z_t|y)\) の評価と初期分布 \(p_T(x_T|y)\) からのサンプリングの2点が必要である.

\((Z_t)\) を推定されたスコア \(s^\theta\) を用いた Gauss 過程から開始する拡散 \[ dZ_t=\frac{1}{2}Z_t\,dt+s_{T-t}^\theta(Z_t,y)\,dt+dW_t,\qquad Z_0\sim\mathrm{N}_d(0,I_d), \] で近似することが SGM (Score-based Generative Model) (Y. Song et al., 2021) の方法である.

思い切って \(\mathrm{N}_d(0,I_d)\approx p_T(x_T|y)\) としてしまい,\(s_t^\theta(x_t,y)\) のモデリングに集中するのである.

具体的には,この過程 \((Z_t)\) が定める測度を \(\mathbb{Q}_y^\theta\in\mathcal{P}(C([0,T];\mathcal{X}))\) と表すと,訓練目標を KL 乖離度の期待値 \[\begin{align*} \mathcal{L}(\theta)&:=2\operatorname{E}\biggl[\operatorname{KL}\biggr(\mathbb{P}_Y,\mathbb{Q}_Y^\theta\biggl)\biggr]\\ &=\int^T_0\operatorname{E}\biggl[\left\|s^\theta_t(X_t,Y)-\nabla_x\log p_{t|0}(X_t|X_0)\right\|^2\biggr]\,dt+\mathrm{const.} \end{align*}\] とする (Th’m 1 Y. Song et al., 2021).ただし,\(\mathbb{P}_Y\)\((X_t)\) の分布,\(p_{t|0}\)\((X_t)\) の遷移密度を表す.

この損失は DSM (Denoising Score Matching) (Vincent, 2011) で与えられたものに等しく, \[ (X_0,Y)\sim p(x,y)=g(y|x)\mu(x) \] からのシミュレーションが可能であるならば,この目的関数は確率的最適化アルゴリズムによって最適化できる.

こうして 雑音除去拡散サンプラー (DDPS: Denoising Diffusion Posterior Sampler) を得る.

1.4 近似ベイズ計算への応用

DDPS を学習するためには,事前分布と尤度 \(g(y|x)\) からのサンプリングが可能であればよく,関数として評価することは一度もなかった.

従って DDPS は生成モデリングの他に Simulation-based Inference などの近似推論に用いることができる (Shi et al., 2022)

実際,この DDPS は従来の ABC (Approximate Bayesian Computation) 法の代替になり得る.(Benton et al., 2024), (Sharrock et al., 2024) が良いレビューである.

さらに 拡散模型の加速法 (Progressive Distillation (Salimans and Ho, 2022) など)が DDPS にも応用可能である.

1.5 逆問題への応用

サンプルが画像だとしても,画像修復 (inpainting) や高解像度化 (super-resolution) などの逆問題応用が豊富に存在する.

このような条件付き生成のタスクに対しては 誘導 (guidance) の技術が考えられている.

誘導の方法では,精度を向上させるために \[ \nabla_x\log p_t(x_t|y)=\nabla_x\log p_t(x_t)+\nabla_x\log p_t(y|x_t). \] の分解に注目する.

Classifier-Free Diffusion Guidance (Ho and Salimans, 2021) では,\(\nabla_x\log p_t(x_t|y)\) の形のスコアをモデリングする中で,\(y=\emptyset\) のものを2割ほど作り,これを \(\nabla_x\log p_t(x_t)\) の推定値としてサンプリングに用いる.

残りはベイズの公式を通じて \(\nabla_x\log p_t(y|x_t)\) の推定値とする.詳しくは次稿参照:

1.6 非償却的な推論

一方で単一の \(Y=y\) を想定した状況ではさらに精度を上げる方法が考えられる.

\(\log p_t(x_t|y)\) を一緒くたに \(s^\theta_{t}(x_t,y)\) に取り替えてしまうのではなく,まず第一項 \(\nabla_x\log p_t(x_t|y)\)\(s_t^\theta(x_t)\) により統一的にモデリングする.

そして \(\nabla_x\log p_t(y|x_t)\) の項は Tweedie の推定量 \[ \widehat{x}_0:=\operatorname{E}[x_0|x_t]=\frac{1}{\sqrt{\overline{\alpha}_t}}\biggr(x_t+(1-\overline{\alpha}_t)\nabla_{x_t}\log p_t(x_t)\biggl) \tag{2}\] を通じて \[ p(y|x_t)\approx p(y|\widehat{x}_0) \] によって近似する.式 (2) の \(\nabla_{x_t}\log p_t(x_t)\) に事前に訓練したスコアネットワーク \(s_t^\theta(x_t)\) を用いる.

(Chung et al., 2023) はこの方法を Computer Vision における非線型逆問題に適用している.

(J. Song et al., 2023) では Monte Carlo 法が用いられている.

2 雑音除去 Markov モデル

DDPS の議論は \(\mathcal{X}=\mathbb{R}^d\) 上の OU 過程の場合に限っていた.

全く同様の構成は,\(\mathcal{X}\) が離散空間や Riemann 多様体とし,その上の一般の Markov 過程に対しても実行可能である.

この事実を,一般化されたスコアマッチングの方法とともに提案したのが (Benton et al., 2024) である.

References

Anderson, B. D. O. (1982). Reverse-time diffusion equation models. Stochastic Processes and Their Applications, 12(3), 313–326.
Benton, J., Shi, Y., De Bortoli, V., Deligiannidis, G., and Doucet, A. (2024). From denoising diffusions to denoising Markov models. Journal of the Royal Statistical Society Series B: Statistical Methodology, 86(2), 286–301.
Chung, H., Kim, J., Mccann, M. T., Klasky, M. L., and Ye, J. C. (2023). Diffusion posterior sampling for general noisy inverse problems. In The eleventh international conference on learning representations.
Haussmann, U. G., and Pardoux, E. (1986). Time Reversal of Diffusions. The Annals of Probability, 14(4), 1188–1205.
Ho, J., and Salimans, T. (2021). Classifier-free diffusion guidance. In NeurIPS 2021 workshop on deep generative models and downstream applications.
Salimans, T., and Ho, J. (2022). Progressive distillation for fast sampling of diffusion models. In International conference on learning representations.
Sharrock, L., Simons, J., Liu, S., and Beaumont, M. (2024). Sequential neural score estimation: Likelihood-free inference with conditional score based diffusion models. In R. Salakhutdinov, Z. Kolter, K. Heller, A. Weller, N. Oliver, J. Scarlett, and F. Berkenkamp, editors, Proceedings of the 41st international conference on machine learning,Vol. 235, pages 44565–44602. PMLR.
Shi, Y., De Bortoli, V., Deligiannidis, G., and Doucet, A. (2022). Conditional simulation using diffusion Schrödinger bridges. In J. Cussens and K. Zhang, editors, Proceedings of the thirty-eighth conference on uncertainty in artificial intelligence,Vol. 180, pages 1792–1802. PMLR.
Song, J., Zhang, Q., Yin, H., Mardani, M., Liu, M.-Y., Kautz, J., … Vahdat, A. (2023). Loss-guided diffusion models for plug-and-play controllable generation. In A. Krause, E. Brunskill, K. Cho, B. Engelhardt, S. Sabato, and J. Scarlett, editors, Proceedings of the 40th international conference on machine learning,Vol. 202, pages 32483–32498. PMLR.
Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., and Poole, B. (2021). Score-Based Generative Modeling through Stochastic Differential Equations. In International conference on learning representations.
Vincent, P. (2011). A connection between score matching and denoising autoencoders. Neural Computation, 23(7), 1661–1674.