A Blog Entry on Bayesian Computation by an Applied Mathematician
$$
$$
1 はじめに
1.1 輸送問題としての生成モデリング
生成モデリングは畢竟,2つの分布 \(P_0,P_1\in\mathcal{P}(\mathbb{R}^d)\) を結ぶ写像 \[ \phi_*P_0=P_1 \] を,フロー \((\phi_t),\phi_1=\phi\) によって学習することに帰着する.
1.2 ベクトル場を学習する方法
このようなフロー \((\phi_t)\) を定めるベクトル場 \(F_t\) \[ \frac{\partial \phi_t(x)}{\partial t}=F_t(\phi_t(x)) \] は,\(P_0,P_1\) を補間する確率密度 \[ P_t:=(\phi_t)_*P_0 \] に対して,連続方程式 というPDE \[ \frac{\partial p_t}{\partial t}+\operatorname{div}(F_tp_t)=0. \] を満たす必要がある.
この PDE から適切なベクトル場 \(F_t\) を特定し,これを目的にベクトル場 \(F_\theta(x_t,t)\) を学習してサンプリングに活かすことが,FM (Flow Matching) により可能になる.
1.3 歴史:拡散モデル
決定論的なフロー \((\phi_t)\) による \(P_0\) の \(P_1\) への変換は,元々は 拡散モデル によって考えられていた.
この方法では \(P_0\) に収束するエルゴード性を持つ OU 過程を用いて \(P_1\) の \(P_0\) への変換を学習し,この時間反転を学習することでノイズ分布 \(P_0\) からのサンプルから,データ分布 \(P_1\) からのサンプルを生成をする.
DDPM (Ho et al., 2020) は確率モデルとしてこの枠組みを定式化し最尤推定を目指した.SGM (Y. Song et al., 2019) は直接データ分布のスコア \(\nabla_x\log p_t(x)\) をスコアマッチングにより学習することを提案した.
このモデルにはノイズスケジュールなどの不要なパラメータや調節可能なハイパーパラメータが多く,等価な分布変換を定める ODE が存在する (Y. Song et al., 2021) ことが自覚されると,ODE とベクトル場による方法が志向された.これで Flow Matching (Lipman et al., 2023) に至る.
1.4 Schrödinger 橋
分布 \(P_1\) から開始する OU 過程は,エルゴード性をもてど,有限時間内で \(P_0\) との誤差が消えるわけではない.実際,拡散モデルの時間極限 \(T>0\) はなるべく大きく取ることが推奨されている (Y. Song and Ermon, 2020).
一方で,正確に2つの分布を繋ぐダイナミクスの1つを,Schrödinger 橋が与え,これをスコアマッチングによって学習することが (Heng et al., 2022) によって考えられた.
1.5 サンプリングへの応用
本稿では,最初に提案された拡散模型である DDPM から始まり,生成モデリングと密度推定に使える手法を考える.
一方で 次稿 では,正規化定数が不明な分布 \[ \pi(x)=\frac{\gamma(x)}{Z},\qquad Z:=\int_\mathcal{X}\gamma(x)\,dx \] に対しても使える汎用サンプラー DDS (Denoising Diffusion Sampler) (Vargas et al., 2023) とその Schrödinger 橋による改良を扱う:
2 雑音除去拡散 (DDPM)
潜在空間 \(\mathcal{X}\) 上の事前分布 \(\mu\) と,尤度が確率核 \(\mathcal{X}\to\mathcal{Y}\) \[ x\mapsto g(y|x)\,dy \] の形で与えられているとする.
2.1 雑音除去拡散 (DD)
拡散模型 は,次で定まる OU 過程によってデータ分布を \(\mathrm{N}_d(0,I_d)\) にまで破壊しているとみなせる (Y. Song et al., 2021): \[ dX_t=-\frac{1}{2}X_t\,dt+dB_t,\qquad X_0\sim p(x|y). \]
ただし,この過程は指数エルゴード性を持つと言っても,完全に \(\mathrm{N}_d(0,I_d)\) に従うようになるのは \(t\to\infty\) の極限においてである.この極限においては,\(p(x|y)\) はもやは \(y\) に依らなくなる.
この \((X_t)\) の有限時区間 \([0,T]\) における時間反転は,\((X_t)\) の密度を \(p_t(x_t|y)\) で表すと, \[ dZ_t=\frac{1}{2}Z_t\,dt+\nabla_z\log p_{T-t}(Z_t|y)\,dt+dW_t,\qquad Z_0\sim p_T(x_T|y), \tag{1}\] の弱解になる (Anderson, 1982), (Haussmann and Pardoux, 1986).この \((Z_t)_{t\in[0,T]}\) を 雑音除去拡散 (Denoising Diffusion) という.
2.2 \((Z_t)\) からのサンプリング
すると残りの問題は,拡散過程 \((Z_t)_{t\in[0,T]}\) からのサンプリングになるが,これは \(\log p_{T-t}(Z_t|y)\) という項の評価と \(p_T(x_T|y)\) からのサンプリングが必要である.
\((Z_t)\) を \[ dZ_t=\frac{1}{2}Z_t\,dt+s_{T-t}^\theta(Z_t,y)\,dt+dW_t,\qquad Z_0\sim\mathrm{N}_d(0,I_d), \] で近似することが (Y. Song et al., 2021) の方法である.思い切って \(\mathrm{N}_d(0,I_d)\approx p_T(x_T|y)\) としてしまい,\(s_t^\theta(x_t,y)\) のモデリングに特化するのである.
この過程 \((Z_t)\) が定める測度を \(\mathbb{Q}_y^\theta\in\mathcal{P}(C([0,T];\mathcal{X}))\) と表すと,訓練目標は KL 乖離度の期待値 \[\begin{align*} \mathcal{L}(\theta)&:=2\operatorname{E}\biggl[\operatorname{KL}\biggr(\mathbb{P}_Y,\mathbb{Q}_Y^\theta\biggl)\biggr]\\ &=\int^T_0\operatorname{E}\biggl[\left\|s^\theta_t(X_t,Y)-\nabla_x\log p_{t|0}(X_t|X_0)\right\|^2\biggr]\,dt+\mathrm{const.} \end{align*}\] が考えられる.ただし,\(\mathbb{P}_Y\) は \((X_t)\) の分布,\(p_{t|0}\) は \((X_t)\) の遷移密度を表す.この損失は DSM (Vincent, 2011) で与えられたものに等しい.
\[ (X_0,Y)\sim p(x,y)=g(y|x)\mu(x) \] からのシミュレーションが可能であるならば,この目的関数は確率的最適化アルゴリズムによって最適化できる.
こうして,雑音除去拡散サンプラー (DDPS: Denoising Diffusion Posterior Sampler) を得る.
2.3 近似ベイズ計算への応用
事前分布と尤度 \(g(y|x)\) からのサンプリングが可能な状況は,生成モデリングの他に Simulation-based Inference などの近似推論でもあり得る.
実際,この DDPS は従来の ABC (Approximate Bayesian Computation) 法の代替になり得る.
さらに,拡散模型の加速法 (Progressive Distillation (Salimans and Ho, 2022) など)が DDPS にも応用可能である.
2.4 逆問題への応用
サンプルが画像だとしても,画像修復 (inpainting) や高解像度化 (super-resolution) などの逆問題応用が豊富に存在する.
このような,単一の \(Y=y\) を固定した状況で潜在変数 \(X_T\) からサンプリングをしたい場合では,\(\log p_t(x_t|y)\) を一緒くたに \(s^\theta_{t}(x_t,y)\) に取り替えてしまうのではなく,次の事前分布と尤度への分解に基づいて扱うこともできる:
\[ \nabla_x\log p_t(x_t|y)=\nabla_x\log\mu_t(x_t)+\nabla_x\log g_t(y|x_t), \] \[ \mu_t(x_t):=\int_\mathcal{X}\mu(x_0)p_{t|0}(x_t|x_0)\,dx_0,\qquad g_t(y|x_t):=\int_\mathcal{X}g(y|x_0)p_{0|t}(x_0,x_t)\,dx_0. \]
この第一項は \(s_t^\theta(x_t)\) により統一的にモデリングでき,同様に \(X_0\sim\mu(x)\) から始まる雑音化過程 \((X_t)\) の分布を \(\mathbb{P}\) として \(\operatorname{KL}(\mathbb{P},\mathbb{Q}^\theta)\) 最小化問題として処理できる.
\(g_t(y|x_t)\) の項も近似可能である.(Chung et al., 2023) では条件付き誘導が,(J. Song et al., 2023) では Monte Carlo 法が用いられている.
3 Schrödinger 橋による事後分布サンプリング (DSB-PS)
3.1 はじめに
\(p_T(x_T|y)\approx\mathrm{N}_d(0,I_d)\) の近似を成り立たせるために \(T\) を十分大きく取る必要がある問題は,OU 過程の代わりに Schrödinger 橋を用いることで解決できることが (Shi et al., 2022) で提案された.
Schrödinger 橋自体は,(De Bortoli et al., 2021) などから拡散模型への応用は議論されていた.
3.2 定義
Schrödinger 橋 (SB) とは, \[ \Pi^*:=\operatorname*{argmin}_{\Pi\in\mathcal{P}_0}\operatorname{KL}(\Pi,\mathbb{P}), \] \[ \mathcal{P}_0:=\biggl\{\Pi\in\mathcal{P}(C([0,T];\mathcal{X}\times\mathcal{Y}))\,\bigg|\,\Pi_0(x_0,y_0)=p(x_0,y_0),\Pi_T(x_T,y_T)=\mathrm{N}_d(0,I_d)p(y_T)\biggr\}, \] によって定まる確率分布に従う確率過程をいう.ただし,\(\mathbb{P}:=\mathbb{P}_{y_0}\otimes\delta_{p(y)}\) とした.\(\delta_{p(y)}\) は次で定まる確率分布である: \[ dY_t=0,\qquad Y_0\sim p(y). \]
これは表示 \[ \Pi^*=\mathbb{P}^*_{y_0}\otimes\delta_{p(y)} \] を持つから,\(Z_0\sim\mathrm{N}_d(0,I_d)\) に従う過程 \((Z_t)\) をシミュレーションすることで, \[ Z_T\sim\Pi^*_0(x|y)=p(x_0|y)\qquad p(y)\text{-a.s.} \] が成り立つ.
3.3 SB のシミュレーション
SB 問題の解 \(\Pi\) は 逐次的比例フィッティング (IPF: Iterative Proportional Fitting) により得られる.
3.3.1 IPF とは
IPF アルゴリズムは離散的な形で (Deming and Stephan, 1940) が分割表データ解析の研究で提案している.その手続きを (Ireland and Kullback, 1968) が距離の最小化として特徴付け,(Kullback, 1968) が確率密度に対しても一般化した.ただし,この確率密度に対するアルゴリズムは (Fortet, 1940) が Schrödinger 方程式の研究ですでに提案しているものである.
IPF は元々,指定した2つの確率ベクトル \(r\in(0,\infty)^{d_r},c\in(0,\infty)^{d_c}\) を周辺分布に持つ結合分布(カップリング)のうち,指定の行列 \(W\in M_{d_rd_c}(\mathbb{R}_+)\) に最も近い KL 乖離度を持つカップリングを見つけるための逐次アルゴリズムである (Kurras, 2015).
種々の分野で再発見され,複数の名前を持っているようである.例:Sheleikhovskii 法,Kruithof アルゴリズム,Furness 法,Sinkhorn-Knopp アルゴリズム,RAS 法など (Kurras, 2015).1
\(W\) の成分が正である場合は,(Sinkhorn, 1967) がアルゴリズムの収束と解の一意性を示している.2
しかし,\(W\) の成分が零を含む場合,零成分の位置に依存してアルゴリズムは収束しないことがあり得ることを,(Sinkhorn and Knopp, 1967) が \(d_r=d_c=1\) の場合について示している.
3.3.2 アルゴリズム
IPF アルゴリズムは,観念的には,2つの周辺分布のうち片方を制約に課しながら,KL 距離を最小にする射影を返していく:
\[ \Pi^{2n+1}:=\operatorname*{argmin}_{\Pi\in\mathcal{P}(C([0,T];\mathcal{X}\times\mathcal{Y}))}\biggl\{\operatorname{KL}(\Pi,\Pi^{2n})\,\bigg|\,\Pi_T=\mathrm{N}_d(0,I_d)\otimes p(y_T)dy_T\biggr\}, \] \[ \Pi^{2n+2}:=\operatorname*{argmin}_{\Pi\in\mathcal{P}(C([0,T];\mathcal{X}\times\mathcal{Y}))}\biggl\{\operatorname{KL}(\Pi,\Pi^{2n+1})\,\bigg|\,\Pi_0(x_0,y_0)=p(x_0,y_0)\biggr\}. \]
今回の場合, \[ \Pi^{2n+1}=\mathbb{P}_{y_T}^{2n+1}\otimes\delta_{p(y)},\qquad\Pi^{2n+2}=\mathbb{P}^{2n+2}_{y_0}\otimes\delta_{p(y)}, \] と分解される.ただし,\(\mathbb{P}_{y_T}^{2n+1}\) は次で定まる \((Z_t)\) の時間反転 \[ dZ_t=f_{T-t}^{2n+1}(Z_t,y_T)\,dt+dW_t,\qquad Z_0\sim\mathrm{N}_d(0,I_d),f_t^{2n+1}(x_t,y):=-f_t^{2n}(x_t,y)+\nabla_{x}\log\Pi^{2n}_t(x_t|y), \] \(\mathbb{P}_{y_0}^{2n+2}\) は次で定まる \((X_t)\) の経路測度となる: \[ dX_t=f^{2n+2}_t(X_t,y_0)\,dt+dB_t,\qquad X_0\sim p(x|y_0)\,dx,f_t^{2n+2}(x_t,y):=-f_t^{2n+1}(x_t,y)+\nabla_x\log\Pi_t^{2n+1}(x_t|y). \] ただし,\(f^0_t(x_t)=-x_t/2\).
3.3.3 DDPS との関係
最初のイテレーション \(n=0\) における \(\mathbb{P}^1_y\) が雑音除去拡散 (1) に対応する.
しかし,IPF アルゴリズムのイテレーションを繰り返していくごとに,\(T>0\) が十分に大きくない場合でも正確に \(\mathrm{N}_d(0,I_d)\) にデータ分布を還元する SB が得られるようになっていく.
3.3.4 DSB-PS
この際,スコア \(\nabla_z\log p_{T-t}(Z_t|y)\) から始まり,\(f^{n}_t\;(n\ge2)\) の推定も逐次的に行なわなければならない点については,mean-matching (De Bortoli et al., 2021), (Shi et al., 2022) という方法が考えられている.
この方法を用いて,IPF アルゴリズムが収束するまで実行して最終的に得るサンプラーを Schrödinger 橋サンプラー (DSB-PS: Diffusion Schrödinger Bridge Posterior Sampling) という.
4 文献紹介
I2SB (Liu et al., 2023) のプロジェクトページは こちら.Text-to-Speech にも応用されている (Chen et al., 2023).
SB はエントロピー正則化最適輸送の解になるが,このエントロピー/ラグランジアンの言葉で帰納バイアスを入れることができる (Koshizuka and Sato, 2023), (Isobe et al., 2024).
鈴木大慈氏のスライド では変分法的な見方が徹底されている.
- Langevin 動力学は,平衡分布との KL 乖離度を最小化する Wasserstein 勾配流になっている.
- その 時間反転過程 は OU 過程からの KL 乖離度とエネルギーの和を,\(\overline{Y}_T\sim p_0\) の境界条件の下最小化する過程になっている.
- さらに境界条件を加えたものが SB である.
References
Footnotes
また,行列スケーリングを通じた最小情報コピュラとの関連を (Bedford et al., 2016), (清智也, 2021) が指摘している.↩︎
ただし,(Deming and Stephan, 1940) にも (Fortet, 1940) にも言及しておらず,Markov 連鎖の遷移確率の推定という文脈で研究している.↩︎