1 速習 PDMP
1.1 はじめに
PDMP (Piecewise Deterministic Markov Process) または 連続時間 MCMC とは,名前の通り MCMC のサンプリングを連続時間で行うアルゴリズムである:
詳しくは次の記事も参照:
A Blog Entry on Bayesian Computation by an Applied Mathematician
$$
$$
1.2 使い方
using PDMPFlux
2 pdmp_jax
パッケージ
本節では PDMP のシミュレーションに自動微分を応用した (Andral and Kamatani, 2024) のアルゴリズムを紹介する.
2.1 デモ
(Neal, 2003) が slice sampling のデモ用に定義した 漏斗分布 を考える: \[ p(y,x)=\phi(y;0,3)\prod_{i=1}^9\phi(x_i;0,e^{y/2}),\qquad y\in\mathbb{R},x\in\mathbb{R}^9. \]
import jax
from jax.scipy.stats import multivariate_normal
from jax.scipy.stats import norm
import jax.numpy as jnp
def funnel(d=10, sig=3, clip_y=11):
"""Funnel distribution for testing. Returns energy and sample functions."""
def unbatched(x):
= x[0]
y = - y**2 / 6
log_density_y
= jnp.exp(y/2)
variance_other
= - jnp.sum(x[1:]**2) / (2 * variance_other)
log_density_other
return - log_density_y - log_density_other
def sample_data(n_samples):
# sample from Nd funnel distribution
= (sig * jnp.array(np.random.randn(n_samples, 1))).clip(-clip_y, clip_y)
y = jnp.array(np.random.randn(n_samples, d - 1)) * jnp.exp(-y / 2)
x return jnp.concatenate((y, x), axis=1)
return unbatched, sample_data
import pdmp_jax as pdmp
= 10
dim = funnel(d=dim)
U, _ = jax.grad(U)
grad_U = 8
seed = jnp.ones((dim,)) # initial position
xinit = jnp.ones((dim,)) # initial velocity
vinit = 0
grid_size = 100000 # number of skeleton points
N_sk = 100000 # number of samples
N = pdmp.ZigZag(dim, grad_U, grid_size)
sampler # sample the skeleton of the process
= sampler.sample_skeleton(N_sk, xinit, vinit, seed, verbose = True) # takes only 3 seconds on my M1 Mac
out # sample from the skeleton
= sampler.sample_from_skeleton(N,out) sample
import seaborn as sns
= sample[:,0],y = sample[:,1])
sns.jointplot(x plt.show()
number of error bound : 46817
2.2 サンプリングループの構造
ar=lambda_t/state.lambda_bar
によって Poisson thinning を行う.ただし,
lambda_bar
とは近似的な上界であり,「最も近い直前の grid 上の点での値」でしかない.当然lambda_t
を超過し得る.そのような場合にerror_acceptance()
に入る.error_acceptance()
に入った場合,horizon
を縮めてより慎重に同じ区間を Poisson thinning しなおす.adaptive=true
の場合はこのタイミングでhorizon
を恒久的に縮める.最後
if_reject()
に入った場合,horizon
に到達したらone_step_while()
まで戻るが,そうでない場合はinner_while()
まで戻る実装がなされている.
2.3 適応的なステップサイズ
3 PDMPFlux.jl
パッケージ
3.1 デモ
using PDMPFlux
using Random, Distributions, Plots, LaTeXStrings, Zygote, LinearAlgebra
"""
Funnel distribution for testing. Returns energy and sample functions.
For reference, see Neal, R. M. (2003). Slice sampling. The Annals of Statistics, 31(3), 705–767.
"""
function funnel(d::Int=10, σ::Float64=3.0, clip_y::Int=11)
function neg_energy(x::Vector{Float64})
= x[1]
v = logpdf(Normal(0.0, 3.0), v)
log_density_v = exp(v)
variance_other = d - 1
other_dim = I * variance_other
cov_other = zeros(other_dim)
mean_other = logpdf(MvNormal(mean_other, cov_other), x[2:end])
log_density_other return - log_density_v - log_density_other
end
function sample_data(n_samples::Int)
# sample from Nd funnel distribution
= clamp.(σ * randn(n_samples, 1), -clip_y, clip_y)
y = randn(n_samples, d - 1) .* exp.(-y / 2)
x return hcat(y, x)
end
return neg_energy, sample_data
end
function plot_funnel(d::Int=10, n_samples::Int=10000)
= funnel(d)
_, sample_data = sample_data(n_samples)
data
# 最初の2次元を抽出(yとx1)
= data[:, 1]
y = data[:, 2]
x1
# 散布図をプロット
scatter(y, x1, alpha=0.5, markersize=1, xlabel=L"y", ylabel=L"x_1",
="Funnel Distribution (First Two Dimensions' Ground Truth)", grid=true, legend=false, color="#78C2AD")
title
# xlim と ylim を追加
xlims!(-8, 8) # x軸の範囲を -8 から 8 に設定
ylims!(-7, 7) # y軸の範囲を -7 から 7 に設定
end
plot_funnel()
function run_ZigZag_on_funnel(N_sk::Int=100_000, N::Int=100_000, d::Int=10, verbose::Bool=false)
= funnel(d)
U, _ grad_U(x::Vector{Float64}) = gradient(U, x)[1]
= ones(d)
xinit = ones(d)
vinit = 2024
seed = 0 # constant bounds
grid_size = ZigZag(d, grad_U, grid_size=grid_size)
sampler = sample_skeleton(sampler, N_sk, xinit, vinit, seed=seed, verbose = verbose)
out = sample_from_skeleton(sampler, N, out)
samples return out, samples
end
= run_ZigZag_on_funnel() # 4分かかる
output, samples
jointplot(samples)
このデモコードは Zygote.jl
による自動微分を用いると 5:29 かかっていたところが,ForwardDiff.jl
による自動微分を用いると 0:21 に短縮された.
3.2 Zygote.jl
と ForwardDiff.jl
による自動微分
Zygote.jl
は FluxML が開発する Julia の自動微分パッケージである.
using Zygote
@time Zygote.gradient(x -> 3x^2 + 2x + 1, 5)
0.830714 seconds (3.45 M allocations: 168.328 MiB, 3.95% gc time, 99.98% compilation time)
(32.0,)
f(x::Vector{Float64}) = 3x[1]^2 + 2x[2] + 1
g(x) = Zygote.gradient(f,x)
g([1.0,2.0])
([6.0, 2.0],)
大変柔軟な実装を持っており,広い Julia 関数を微分できる.
ForwardDiff.jl
(Revels et al., 2016) は Zygote.jl
よりも高速な自動微分を特徴としている.
using ForwardDiff
@time ForwardDiff.derivative(x -> 3x^2 + 2x + 1, 5)
0.068804 seconds (260.59 k allocations: 12.725 MiB, 16.89% gc time, 99.92% compilation time)
32
3.3 Brent の最適化
Optim.jl
は Julia の最適化パッケージであり,デフォルトで Brent の最適化アルゴリズムを提供する.
using Optim
f(x) = (x-1)^2
= optimize(f, 0.0, 1.0)
result result.minimizer
0.999999984947842
3.4 StatsPlots.jl
による可視化
StatsPlots
は現在 Plots.jl
に統合されている.
また PDMPFlux.jl
は marginalhist
を wrap した jointplot()
関数を提供する.
3.5 ProgressBars.jl
による進捗表示
ProgressBars.jl
は tqdm の Julia wrapper を提供する.PDMPFlux.jl
ではこちらを採用して,サンプリングの実行進捗を表示する.
なお ProgressMeter.jl
も同様の機能を提供しており,有名な別の PDMP パッケージである ZigZagBoomerang.jl
ではこちらを採用している.
4 終わりに
今後の確率的プログラミングの1つの焦点は自動微分かもしれない.
今回のパッケージ開発で,少なくとも v0.2.0
の時点では,プログラムに与える U_grad
は多くの場合(10 次元の多変量 Gauss,50 次元の Banana など) Zygote.jl
が少し速い(Funnel 分布では ForwardDiff.jl
が速い).
しかし上界を構成する際の func
の微分は ForwardDiff.jl
の方が圧倒的に速い.大変に不可思議である.
だから現在の実装は Zygote.jl
と ForwardDiff.jl
の両方を用いている.
5 ToDo
Zig-Zag 以外のサンプラーの実装
ZigZag(dim) は自動で知ってほしい
Try clause 内の else を用いているので Julia 1.8 以上が必要.
MCMCChains のような plot.jl を完成させる.
PDMPFlux.jl
のドキュメントを整備する.ZigZagBoomerang.jl
を見習って統合したり API をつけたり?Turing エコシステムと統合できたりしないか?
Rng
を指定できるようにする?pdmp-jax
では 37 秒前後かかる Banana density の例が,PDMPFlux.jl
では 2 分前後かかる.- しかし,Julia の方が数値誤差が少ないのか,banana potential の対称性がうまく結果に出る.尾が消えたりしない.
- →
ForwardDiff.jl
を採用したところ,02:05 から 10 分以上に変化した.ReverseDiff.jl
を採用したところ 4:44 になった.50 次元というのが微妙なところなのかもしれない.
Funnel 分布で試したところ,
PDMPFlux.jl
の棄却率が極めて高い.