離散空間上のフローベース模型

位相構造を取り入れた次世代の構造生成へ

Deep
Sampling
Nature
Author

司馬博文

Published

8/09/2024

Modified

8/10/2024

概要
画像と動画に関してだけでなく,化学分子の構造生成の分野でも拡散模型が state of the art となっている.これは,連続空間上だけでなく,グラフなどの離散空間上でも拡散模型が拡張されたことが大きい.本稿では,離散データを連続潜在空間に埋め込むことなく,直接離散空間上に拡散模型をデザインする方法をまとめる.

1 離散雑音除去拡散模型 (D3PM) (Austin et al., 2021)

Minimal Implementation of a D3PM by Simo Ryu (Ryu, 2024) (Tap to image to visit his repository)

Minimal Implementation of a D3PM by Simo Ryu (Ryu, 2024) (Tap to image to visit his repository)

1.1 はじめに

離散データ上のフローベースのサンプリング法として,Argmax Flows と Multinomial Diffusion が (Hoogeboom et al., 2021) により提案された.

D3PM (Austin et al., 2021) はこの拡張として提案されたものである.

その結果,D3PM は BERT (Lewis et al., 2020) などのマスク付き言語モデルと等価になる.

1.2 ノイズ過程

1.2.1 設計意図

効率的な訓練のために,

  1. \(q(x_t|x_0)\) からシミュレーション可能
  2. \(q(x_{t-1}|x_t,x_0)\) が評価可能

であるとする.これにより, \[ L_{t-1}(x_0):=\int_\mathcal{X}\operatorname{KL}\biggr(q(x_{t-1}|x_t,x_0),p_\theta(x_{t-1}|x_t)\biggl)\,q(x_t|x_0)\,dx_t \]
の Monte Carlo 近似が可能になる.

\(p(x_T)=q(x_T|x_0)\) を一様分布など,簡単にシミュレーション可能な分布とする.

1.2.2 実装

\(x_0\in\mathcal{X}\) は,\([K]\)-値の離散ベクトル \(x_0^{(i)}\)\(D\) 個集まったものとする.ただし,\(x_0^{(i)}\) は one-hot encoding による横ベクトルとする.

すると,ある確率行列 \(Q_t\) に関して, \[ Q(-|x_{t-1})=x_{t-1}Q_t=\cdots=x_0Q_1\cdots Q_t \] と表せる.右辺の第 \(i\) 行は,次 \(k\in[K]\) の状態に至る確率を表す確率ベクトルとなっている.

するとこの逆は,ベイズの定理より \[ q(x_{t-1}|x_t,x_0)=\frac{q(x_t|x_{t-1},x_0)q(x_{t-1}|x_0)}{q(x_t|x_0)} \] \[ Q(-|x_t,x_0)= \]

1.2.3\(Q\) の取り方

\[ Q_t:=(1-\beta_t)I_K+\frac{\beta_t}{K} \] と取った場合を一様核という.

または,\(Q_t\) として 脱落核 を取ることもできる.これは1つの点 \(m\in[K]\) を吸収点とする方法である: \[ (Q_t)_{ij}:=\begin{cases}1&i=j=m,\\ 1-\beta_t&i=j\in[K]\setminus\{m\}\\ \beta_t&\mathrm{otherwise} \end{cases} \]

1.3 除去過程

\(p_\theta(x_{t-1}|x_t)\) をモデリングするのではなく,\(\widetilde{p}_\theta(x_0|x_t)\) をモデリングし, \[ p_\theta(x_{t-1}|x_t)\,\propto\,\sum_{\widetilde{x}_0\in[K]}q(x_{t-1}|x_t,\widetilde{x}_0)\widetilde{p}_\theta(\widetilde{x}_0|x_t) \] は間接的にモデリングする.

これにより,ステップ数を小さく取った場合でも,\(k\) ステップをまとめて \(p_\theta(x_{t-k}|x_t)\) をいきなりサンプリングするということも十分に可能になるためである.

1.4 BERT (Devlin et al., 2019) との対応

\(Q_t\) として,一様核と脱落核を重ね合わせたとする.

すなわち,各トークンを各ステップで \(\alpha=10\%\) でマスクし,\(\beta=5\%\) で一様にリサンプリングし,これを元に戻す逆過程を学習する.

これは BERT (Devlin et al., 2019) と全く同じ目的関数を定める.

MaskGIT (Masked Generative Image Transformer) (Chang et al., 2022) も,画像をベクトル量子化した後に,全く同様の要領でマスク・リサンプリングをし,これを回復しようとする.これはトランスフォーマーなどの自己回帰的モデルを用いて逐次的に生成するより,サンプリングがはるかに速くなるという.

2 参考文献

(Ryu, 2024) に素晴らしい教育的リポジトリがある.D3PM の 425 行での PyTorch での実装を提供している.

(Campbell et al., 2024) は最新の論文の一つである.

References

Austin, J., Johnson, D. D., Ho, J., Tarlow, D., and Berg, R. van den. (2021). Structured denoising diffusion models in discrete state-spaces. In M. Ranzato, A. Beygelzimer, Y. Dauphin, P. S. Liang, and J. W. Vaughan, editors, Advances in neural information processing systems,Vol. 34, pages 17981–17993. Curran Associates, Inc.
Campbell, A., Yim, J., Barzilay, R., Rainforth, T., and Jaakkola, T. (2024). Generative flows on discrete state-spaces: Enabling multimodal flows with applications to protein co-design.
Chang, H., Zhang, H., Jiang, L., Liu, C., and Freeman, W. T. (2022). MaskGIT: Masked generative image transformer. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR), pages 11315–11325.
Devlin, J., Chang, M.-W., Lee, K., and Toutanova, K. (2019). BERT: Pre-training of deep bidirectional transformers for language understanding. In Proceedings of the 2019 conference of the north american chapter of the association for computational linguistics: Human language technologies,Vol. 1, pages 4171–4186.
Hoogeboom, E., Nielsen, D., Jaini, P., Forré, P., and Welling, M. (2021). Argmax flows and multinomial diffusion: Learning categorical distributions. In A. Beygelzimer, Y. Dauphin, P. Liang, and J. W. Vaughan, editors, Advances in neural information processing systems.
Lewis, M., Liu, Y., Goyal, N., Ghazvininejad, M., Mohamed, A., Levy, O., … Zettlemoyer, L. (2020). BART: Denoising sequence-to-sequence pre-training for natural language generation, translation, and comprehension. In D. Jurafsky, J. Chai, N. Schluter, and J. Tetreault, editors, Proceedings of the 58th annual meeting of the association for computational linguistics, pages 7871–7880. Online: Association for Computational Linguistics.
Ryu, S. (2024). Minimal implementation of a D3PM (structured denoising diffusion models in discrete state-spaces), in pytorch.