Zig-Zag サンプラーのサブサンプリングによるスケーラビリティ

大規模モデル・大規模データに対する MCMC を目指して

MCMC
Computation
Julia
Sampling
Author

司馬博文

Published

7/18/2024

Modified

8/29/2024

概要
Zig-Zag サンプラーは,その非対称なダイナミクスにより,収束が速くなることが期待されている MCMC 手法である.それだけでなく,対数尤度の勾配に対する不偏推定量をサブサンプリングにより構成することで,ベイズ推論においてサンプルサイズに依らない一定のコストで効率的な事後分布からのサンプリングが可能である.

MCMC の計算複雑性のボトルネックは,尤度の評価にある.各ステップで全てのデータを用いて尤度を計算する必要がある点が,MCMC を深層学習などの大規模データの設定への応用を難しくしている (Murphy, 2023, p. 647)

サブサンプリングが可能であることと,複数の効率的なサブサンプリング法の提案により,Zig-Zag 過程は次世代のサンプラーとして圧倒的なスケーラビリティ(Super Efficient Bayesian Inference (Bierkens et al., 2019))を示すのではないかと期待されている.1

1 対数尤度の勾配を不変推定する

\(p(x)\) を事前分布,\(p(y|x)\) を観測のモデル(または尤度)とし,データ \(y_1,\cdots,y_n\) は互いに独立であるとする.

このとき,事後分布 \(\pi(x):=p(x|y)\) と Hamiltonian \(U\) は次のように表せる: \[ \pi(x)\,\propto\,\left(\prod_{k=1}^n p(y_k|x)\right)p(x) \] \[\begin{align*} U(x)&=-\sum_{k=1}^n\log p(y_k|x)-\log p(x)\\ &=\frac{1}{n}\sum_{k=1}^n\biggr(-n\log p(y_k|x)-\log p(x)\biggl)=:\frac{1}{n}\sum_{k=1}^nU^k(x). \end{align*}\]

このとき,\(U\) の導関数 \(\partial_i U(x)\) は,独立な観測 \(y_1,\cdots,y_n\) について項別微分をして平均をとったものに等しい: \[ \partial_iU(x)=\frac{1}{n}\sum_{k=1}^nE^k_i(x), \tag{1}\] \[ E^k_i(x):=\partial_iU^k(x)=\frac{\partial }{\partial x_i}\biggr(-n\log p(y_k|x)-\log p(x)\biggl). \]

よって,精度は劣るかもしれないが,一様に選んだ \(K\sim\mathrm{U}([n])\) から定まる \(E^K_i\) の値は \(\partial_i U(x)\) の不偏推定量となっている.この発想により,ZZ-SS という新たなアルゴリズムを構成できる.

2 サブサンプリングを取り入れた Zig-Zag サンプラー

この各 \(E^K_i\) が定める強度関数 \[ m^K_i(t):=\biggr(\theta E^K_i(x+\theta t)\biggl)_+=\biggr(\theta\partial_iU^K(x+\theta t)\biggl)_+ \] を用いた Zig-Zag サンプラーを (Bierkens et al., 2019) では ZZ-SS (Zig-Zag with Sub-Sampling) と呼んでいる.

\[ \max_{k\in[n]}m^k_i\le M_i \] を満たす連続関数 \(M_i\) を用いて次のようにシミュレーションすることができる:

(Bierkens et al., 2019, p. 1303 アルゴリズム3)
  1. 代理強度関数 \(M_1,\cdots,M_d\) を持つ互いに独立な \(\mathbb{R}_+\) 上の非一様 Poisson 点過程の到着時刻 \(T_1,\cdots,T_d\) をシミュレーションする.
  2. 最初に到着した座標番号 \(j:=\operatorname*{argmin}_{i\in[d]}T_i\) について,確率 \[ \frac{m^K_j(T_j)}{M_j(T_j)},\qquad K\sim\mathrm{U}([n]), \] で時刻 \(T_j\) に速度成分 \(\theta_j\) の符号を反転させる.
  3. 1に \(t=T_j\) として戻って,繰り返す.

ZZ-SS によってシミュレートされる過程は,レート関数 \[ \lambda_i(x,\theta)=\operatorname{E}[(\theta E^K_i(x))_+]=\frac{1}{n}\sum_{k=1}^n(\theta E^k_i(x))_+ \] を持った Zig-Zag 過程に等しい

これは,元々のレート関数に対して, \[ \gamma_i(x,\theta):=\frac{1}{n}\sum_{k=1}^n(\theta_iE^k_i(x))_+-\left(\theta_i\frac{1}{n}\sum_{k=1}^nE^k_i(x)\right)_+ \] という項を加えて得る Zig-Zag サンプラーともみなすことができる.非負性は関数 \((x)_+:=x\lor0\) の凸性から従う.最後に \(\gamma_i(x,\theta)=\gamma_i(x,F_i(\theta))\) を確認すれば良い.

これは \[\begin{align*} &\qquad\frac{1}{n}\sum_{k=1}^n\biggr(\theta_iE_i^k(x)\biggl)_+-\frac{1}{n}\sum_{k=1}^n\biggr(-\theta_iE_i^k(x)\biggl)_+\\ &=\frac{1}{n}\sum_{k=1}^n\left((\theta_iE_i^k(x))_+-(-\theta_iE_i^k(x))_+\right)=\frac{1}{n}\sum_{k=1}^n\theta_iE_i^k(x) \end{align*}\] であることから従う.

こうして,サブサンプリングの実行による精度の劣化が,(Andrieu and Livingstone, 2021) の枠組みで捉えられる,ということでもある(レート関数が増加したので,スイッチングイベントが増え,diffusive な動きが増加する).

例えば \(p(y_k|x)\) が Cauchy 密度であるなど \(\partial_iU\) が有界であるとき,\(M_i:=\max_{x\in\mathbb{R}^d}\partial_iU(x)\) などと選ぶことができる.\(M_i\) をより \(\partial_iU\) に近く選ぶほど剪定の効率は上がるが,\(M_i\) を複雑にしすぎると今度は \(M_i\) を強度とする Poisson 点過程のシミュレーションが困難になる.

そのため,ZZ-SS では代理レート関数 \(M_i\) は大きく取る必要があり,尤度関数の評価の回数が増える.そのため,アルゴリズムの計算複雑性は上がっていることに注意 (Bierkens et al., 2019, p. 1302 第4節)

3 制御変数による分散低減

\(\partial_iU(x)\) が Lipschitz 連続であるとき,\(E_i^k\) をある参照点 \(\partial_iU(x_*)\) とそこからの乖離と取ることで \(n\to\infty\) の極限で分散が抑えられる.

こうすることで,\(M_i\) を1次関数としたまま,より小さく \(E_i^k\) にフィットするように取ることができる.

命題

任意の \(i\in[d]\) について,ある \(C_i>0\) が存在して, \[ \lvert\partial_iU(x)-\partial_iU(y)\rvert\le C_i\lvert x-y\rvert,\quad(x,y)\in\mathbb{R}^{2d}, \] が成り立つとする.このとき, \[ M_i(t):=a_i+b_it \] \[ a_i:=(\theta_i\partial_iU(x_*))_++C_i\|x-x_*\|_p,\quad b_i:=C_id^{1/p} \] と定めれば, \[ m_i^k\le M_i \] が成り立つ.ただし, \[ \partial_iU(x)=\frac{1}{n}\sum_{k=1}^nE^k_i(x),\tag{1} \] \[ E^k_i(x):=\partial_iU(x_*)+\partial_iU^k(x)-\partial_iU^k(x_*), \] \[ m^k_i(t):=\biggr(\theta E_i^k(x+\theta t)\biggl)_+, \] とした.

この仮定は例えば \(\partial_iU\) が有界な導関数を持つならば成り立つ.\(p(y_k|x)\) が Gauss 密度であるやさらに裾が重いときは成り立つ.

次のようにして参照点 \(x_*\) を選ぶ事前処理を行うことで,データのサイズに依存しない計算複雑性で事後分布からの正確なサンプリングが可能になる.

preprocessing for ZZ-CV
  1. \(x_*:=\operatorname*{argmin}_{x\in\mathbb{R}^d}U(x)\) を探索する.
  2. \(\partial_iU(x_*),\partial_iU^k(x_*)\) を計算する.

この2つはいずれも \(O(n)\) の複雑性で実行できる.

4 ZZ-CV のスケーリング

このとき,\(x_*\) を定める事前処理が,\(\widehat{x}\) を最尤推定量として, \[ \|x_*-\widehat{x}\|_p=O(n^{-1/2})\quad(n\to\infty) \] 程度の正確性があれば,事後分布の最尤推定量周りの漸近展開 (Johnson, 1970) を通じて, \[ \|x-x_*\|_p=O_p(n^{-1/2})\quad(n\to\infty) \] \[ \partial_iU(x_*)=O_p(n^{1/2})\quad(n\to\infty) \] が成り立つ.

事後分布に対する Zig-Zag 過程は,\(\sqrt{n}\) だけ時間を加速したものが \(\mathrm{N}_d(0,i(x_0))\) を標的にする Zig-Zag 過程に収束するから,\(O(n^{-1/2})\) のタイムステップで区切ってサンプルとすることができる.

しかし \[ \max_{k\in[n]}\biggr(\theta_i\partial_iU^k(x+\theta t)\biggl)_+\le M_i \] を満たす \(M_i\)\(O(n^{\alpha})\;(\alpha\ge1/2)\) のスケールで増大していく.

各スイッチングイベントにおいて,全データにアクセスする \(O(n)\) の計算複雑性が必要であるから,総じて \(O(n^{\alpha+1/2})\) の計算複雑性となる.

ZZ-CV が平衡に至っている場合は \(x\) はほとんど \(x_*\) に集積するため, \[ \lvert E^k_i(x)\rvert=\biggl|\partial_iU(x_*)+\partial_iU^k(x)-\partial_iU^k(x_*)\biggr|=O(n^{1/2}) \] が成り立つ.よってこれを抑える \(M_i\)\(O(n^{1/2})\) で済み,必要以上に大きい代理レート関数を用意して剪定する必要がない.

全データにアクセスする \(O(n)\) のステップもないために,事前処理 3 と十分平衡に至っているとみなせるまでの burn-in を除いて,\(O(1)\) の計算複雑性でサンプリングが可能である.このことを (Bierkens et al., 2019) は super-efficiency と呼ぶ.

他に,事後分布の集中領域でうまくスイッチング回数が抑えられる \(\lambda_i\) が構成できたならば,低い計算複雑性を達成できるだろう.

ZZ-CV では,これに事後分布の Gauss 近似を用いたことになる.

また,\(U\) の2階微分が有界でない場合,この枠組みが使えない.実際,(Bierkens et al., 2019, pp. 1315 第6.5節) ではこの場合での数値実験の結果が示されており,事後分布が集積しないために super-efficiency は得られていない.

参照点 \(x_*\) を複数取る拡張なども (Bierkens et al., 2019, p. 第7節) で考えられている.

5 数値実験:MSE の比較

ある Gauss 分布に従うデータを生成する: \[ Y^j\overset{\text{iid}}{\sim}\mathrm{N}(x_0,\sigma^2),\qquad j\in[n], \] 分散 \(\sigma^2\) は既知として,位置母数 \(x\in\mathbb{R}\) を推定する問題を考える.

事前分布を \(\mathrm{N}(0,\rho^2)\) とすると,定数の違いを除いて \[\begin{align*} U(x)&=\frac{x^2}{2\rho^2}+\frac{1}{2\sigma^2}\sum_{j=1}^n(x-y^j)^2\\ &=\frac{1}{n}\sum_{j=1}^n\left(\frac{x^2}{2\rho^2}+\frac{n}{2\sigma^2}(x-y^j)^2\right)=:\frac{1}{n}\sum_{j=1}^nU^j(x) \end{align*}\] であるから, \[\begin{align*} U'(x)&=\frac{x}{\rho^2}+\frac{1}{\sigma^2}\sum_{j=1}^n(x-y^j)\\ &=\frac{x}{\rho^2}+\frac{n}{\sigma^2}(x-\overline{y}), \end{align*}\] \[ U''(x)=\frac{1}{\rho^2}+\frac{n}{\sigma^2}. \]

従って,Zig-Zag 過程のイベントの強度関数は \[\begin{align*} m(t)&=\biggr(\theta U'(x+\theta t)\biggl)_+\\ &=\left(\frac{\theta(x+\theta t)}{\rho^2}+\frac{\theta}{\sigma^2}\sum_{j=1}^n(x+\theta t-y^j)\right)_+\\ &=\left(\frac{\theta x}{\rho^2}+\frac{\theta}{\sigma^2}\sum_{j=1}^n(x-y^j)+t\left(\frac{1}{\rho^2}+\frac{n}{\sigma^2}\right)\right)_+ \end{align*}\] と表せ,これは1次関数 \((a+bt)_+\) の形であるから直接のシミュレーションが可能である.4

サブサンプリングなしの Zig-Zag 過程のシミュレーションをする関数 ZZ() を定義
using ZigZagBoomerang
using Distributions
using Random

λ(∇U, x, θ, F::ZigZag1d) = pos(θ*∇U(x)) # rate function on E
λ_bar(τ, a, b) = pos(a + b*τ)  # affine proxy

"""
`x`: current location, `θ`: current velocity, `t`: current time,
"""
function move_forward(τ, t, x, θ, ::ZigZag1d)
    τ + t, x + θ*τ , θ
end

"""
    `∇U`: gradient of the negative log-density
    `(x,θ)`: initial state
    `T`: Time Horizon    
    `a+bt`: computational bound for intensity m(t)

    `num`: ポアソン時刻に到着した回数
    `acc`: 受容回数.`acc/num` は acceptance rate
"""
function ZZ(∇U, x::Float64, θ::Float64, T::Float64, y, Flow::ZigZagBoomerang.ContinuousDynamics; rng=Random.GLOBAL_RNG, ab=ab_ZZ)
    t = zero(T)
    Ξ = [(t, x, θ)]
    num = acc = 0
    epoch_list = [num]
    a, b = ab(x, θ, Flow)
    t′ =  t + poisson_time(a, b, rand())  # イベントは a,b が定める affine proxy に従って生成する

    while t < T
        τ = t′ - t
        t, x, θ = move_forward(τ, t, x, θ, Flow)
        l, lb = λ(∇U, x, θ, Flow), λ_bar(τ, a, b)  # λ が真のレート, λ_bar が affine proxy
        num += 1
        if rand()*lb < l
            acc += 1
            if l > lb + 0.01
                println(l-lb)
            end
            θ = -θ
            push!(Ξ, (t, x, θ))
            push!(epoch_list, num)
        end
        a, b = ab(x, θ, Flow)
        t′ = t + poisson_time(a, b, rand())
    end

    return Ξ, epoch_list, acc/num
end
今回の設定に応じたレート関数 (a+bt)+ を用意
pos(x) = max(zero(x), x)  # positive part
a(x, θ, ρ, σ, y) = θ * x / ρ^2 +/σ^2) * sum(x .- y)
b(x, θ, ρ, σ, y) = ρ^(-2) + length(y)/σ^2

ρ, σ, x0, θ0 = 1.0, 1.0, 1.0, 1.0
n1, n2 = 100, 10^4
TrueDistribution = Normal(x0, σ)
y1 = rand(TrueDistribution, n1)
y2 = rand(TrueDistribution, n2)

# computational bounds for intensity m(t)
ab_ZZ_n1(x, θ, ::ZigZag1d) = (a(x, θ, ρ, σ, y1), b(x, θ, ρ, σ, y1))
ab_ZZ_n2(x, θ, ::ZigZag1d) = (a(x, θ, ρ, σ, y2), b(x, θ, ρ, σ, y2))

∇U1(x) = x/ρ^2 + (length(y1)/σ^2) * (x - mean(y1)) 
∇U2(x) = x/ρ^2 + (length(y2)/σ^2) * (x - mean(y2)) 

# T = 2500.0
# trace_ZZ1, epochs_ZZ1, acc_ZZ1 = ZZ(∇U1, x0, θ0, T, ZigZag1d(); ab=ab_ZZ_n1)
# trace_ZZ2, num_ZZ2, acc_ZZ2 = ZZ(∇U2, x0, θ0, T, ZigZag1d(); ab=ab_ZZ_n2)
# dt = 0.01
# traj_ZZ1 = discretize(trace_ZZ1, ZigZag1d(), dt)
# traj_ZZ2 = discretize(trace_ZZ2, ZigZag1d(), dt)
N 回 ZZ() を実行して,その事後平均の MSE を計算する関数 experiment() を定義
function SquaredError(sample::Vector{Float64}, y)
    True_Posterior_Mean = sum(y) / (length(y) + 1)
    return (mean(sample) - True_Posterior_Mean)^2
end

"""
    epoch_list: 注目するエポック数のリスト
    N: 実験回数
"""
function experiment(epoch_list, T, dt, N, ∇U, x0, θ0, y, Sampler; ab=ab_ZZ_n1)
    SE_sum = zero(epoch_list)
    acc_list = []
    for _ in 1:N
        trace_ZZ1, epochs_ZZ1, acc_ZZ1 = Sampler(∇U, x0, θ0, T, y, ZigZag1d(); ab=ab)
        push!(acc_list, acc_ZZ1)
        traj_ZZ1 = discretize(trace_ZZ1, ZigZag1d(), dt)
        SE_list = []
        for T in epoch_list
            epoch = findfirst(x -> x > T, epochs_ZZ1) - 1
            t = findfirst(x -> x > trace_ZZ1[epoch][1], traj_ZZ1.t) - 1
            SE = SquaredError(traj_ZZ1.x[1:t], y)
            push!(SE_list, SE)
        end
        SE_sum += SE_list
    end
    return SE_sum ./ N, mean(acc_list)
end
実験の実行
using Plots

T = 3000.0
epoch_list = [10.0, 100.0, 1000.0, 10000.0]
dt = 0.01
N = 11

MSE_ZZ1, acc = experiment(epoch_list, T, dt, N, ∇U1, x0, θ0, y1, ZZ; ab=ab_ZZ_n1)
p = plot(#epoch_list, MSE_ZZ1,
    xscale=:log10,
    yscale=:log10,
    xlabel="epochs",
    ylabel="MSE"
    # ,background_color = "#F0F1EB"
    )
scatter!(p, epoch_list, MSE_ZZ1,
    marker=:circle,
    markersize=5,
    markeralpha=0.6,
    color="#78C2AD",
    label=nothing
    )

using GLM, DataFrames
df = DataFrame(X = log10.(epoch_list), Y = log10.(MSE_ZZ1))
model = lm(@formula(Y ~ X), df)
X_pred = range(minimum(df.X), maximum(df.X), length=100)
Y_pred = predict(model, DataFrame(X = X_pred))
plot!(p, 10 .^ X_pred, 10 .^ Y_pred,
    line=:solid,
    linewidth=2,
    color="#78C2AD",
    label="ZZ"
    )

# display(p)

println("Average acceptance rate: $acc")
Average acceptance rate: 1.0

より,たしかに剪定なしの正確なシミュレーションができている.

一方で, \[ U^j(x):=\frac{x^2}{2\rho^2}+\frac{n}{2\sigma^2}(x-y^j)^2, \] \[ \lambda^j(x,\theta):=\biggr(\theta(U^j)'(x)\biggl)_+ \] としてサブサンプリングを取り入れることを考えるが,これを同じ \((a+bt)_+\) ではバウンド出来ない:

ZZ-SS (ZigZag with Subsampling) の定義
λj(j,x,θ,y) = pos* (x/ρ^2 + length(y)/σ^2 * (x - y[j])))

function ZZ_SS(∇U, x::Float64, θ::Float64, T::Float64, y, Flow::ZigZagBoomerang.ContinuousDynamics; rng=Random.GLOBAL_RNG, ab=ab_ZZ)
    t = zero(T)
    Ξ = [(t, x, θ)]
    num = acc = 0
    epoch_list = [num]
    a, b = ab(x, θ, Flow)
    t′ =  t + poisson_time(a, b, rand())  # イベントは a,b が定める affine proxy に従って生成する

    while t < T
        τ = t′ - t
        t, x, θ = move_forward(τ, t, x, θ, Flow)
        j = rand(1:length(y))
        l, lb = λj(j, x, θ, y), λ_bar(τ, a, b)  # λ が真のレート, λ_bar が affine proxy
        num += 1
        if rand()*lb < l
            if l > lb + 0.01
                # println(l-lb)
                acc += 1  #  overflow を数えるように変更済み!注意!
            end
            θ = -θ
            push!(Ξ, (t, x, θ))
            push!(epoch_list, num)
        end
        a, b = ab(x, θ, Flow)
        t′ = t + poisson_time(a, b, rand())
    end

    return Ξ, epoch_list, acc/num
end
実験の実行
using LaTeXStrings

MSE_ZZ_SS, acc = experiment(epoch_list, T, dt, N, ∇U1, x0, θ0, y1, ZZ_SS; ab=ab_ZZ_n1)
println(L"上界 $(a+bt)_+$ を超えてしまう平均的割合: ", "$acc")
上界 $(a+bt)_+$ を超えてしまう平均的割合: 0.49500662479838825

しかし,ZZ-CV アルゴリズムではこのようなことは起こらない.実際,次の等式が成り立つ: \[ U'(x_*)+(U^j)'(x)-(U^j)'(x_*)=U'(x),\qquad x,x_*\in\mathbb{R}. \]

このモデルにおける MAP 推定量は \[ \widehat{x}:=\frac{\overline{y}}{1+\frac{\sigma^2}{n\rho^2}} \] である.

ZZ-CV (ZigZag with Control Variates) の定義
x_star = mean(y1) / (1 + σ^2/(length(y1) * ρ^2))

C(ρ, σ, y) = ρ^(-2) + length(y)/σ^2
a(x, θ, ρ, σ, y) = pos(θ*∇U1(x_star)) + C(ρ, σ, y) * abs(x - x_star)
b(x, θ, ρ, σ, y) = C(ρ, σ, y)

# New Computational Bounds for ZZ-CV
ab_ZZ_CV(x, θ, ::ZigZag1d) = (a(x, θ, ρ, σ, y1), b(x, θ, ρ, σ, y1))

function ZZ_CV(∇U, x::Float64, θ::Float64, T::Float64, y, Flow::ZigZagBoomerang.ContinuousDynamics; rng=Random.GLOBAL_RNG, ab=ab_ZZ_CV)
    t = zero(T)
    Ξ = [(t, x, θ)]
    num = acc = 0
    epoch_list = [num]
    a, b = ab(x, θ, Flow)
    t′ =  t + poisson_time(a, b, rand())  # イベントは a,b が定める affine proxy に従って生成する

    while t < T
        τ = t′ - t
        t, x, θ = move_forward(τ, t, x, θ, Flow)
        # j = rand(1:length(y))  # 今回はたまたま要らない
        l, lb =λ(∇U, x, θ, Flow), λ_bar(τ, a, b)  # λ が真のレート, λ_bar が affine proxy
        num += 1
        if rand()*lb < l
            acc += 1
            if l > lb + 0.01
                println(l-lb)
            end
            θ = -θ
            push!(Ξ, (t, x, θ))
            push!(epoch_list, num)
        end
        a, b = ab(x, θ, Flow)
        t′ = t + poisson_time(a, b, rand())
    end

    return Ξ, epoch_list, acc/num
end
実験の実行
MSE_ZZ_CV, acc = experiment(epoch_list, T, dt, N, ∇U1, x0, θ0, y1, ZZ_CV; ab=ab_ZZ_CV)

q = scatter(p, epoch_list, MSE_ZZ_CV,
    marker=:circle,
    markersize=5,
    markeralpha=0.6,
    color="#E95420",
    label=nothing
    )

df = DataFrame(X = log10.(epoch_list), Y = log10.(MSE_ZZ_CV))
model = lm(@formula(Y ~ X), df)
X_pred = range(minimum(df.X), maximum(df.X), length=100)
Y_pred = predict(model, DataFrame(X = X_pred))
plot!(q, 10 .^ X_pred, 10 .^ Y_pred,
    line=:dash,
    linewidth=2,
    color="#E95420",
    label="ZZ-CV (without amendment)"
    )

display(q)

一見すると ZZ-CV が負けているように見える.しかし,点線で描いているのは,横軸が epoch であることを正しく考慮していない間違ったプロットであるためである.

(Bierkens et al., 2019, p. 1310) において epoch とは,計算量の1単位分としており,ZZ における1回の到着時刻のシミュレーションは,ZZ-CV の \(n\) 回分に当たる.これを考慮に入れてプロットし直すと次の通りになる:

実験の実行
T_SuperEfficient = 300000.0
epoch_list_SuperEfficient = [1000.0, 10000.0, 100000.0, 1000000.0]

@time MSE_ZZ_CV, acc = experiment(epoch_list_SuperEfficient, T_SuperEfficient, dt, N, ∇U1, x0, θ0, y1, ZZ_CV; ab=ab_ZZ_CV)

scatter!(p, epoch_list, MSE_ZZ_CV,
    marker=:circle,
    markersize=5,
    markeralpha=0.6,
    color="#E95420",
    label=nothing
    )

df = DataFrame(X = log10.(epoch_list), Y = log10.(MSE_ZZ_CV))
model = lm(@formula(Y ~ X), df)
X_pred = range(minimum(df.X), maximum(df.X), length=100)
Y_pred = predict(model, DataFrame(X = X_pred))
plot!(p, 10 .^ X_pred, 10 .^ Y_pred,
    line=:solid,
    linewidth=2,
    color="#E95420",
    label="ZZ-CV"
    )

display(p)
 48.490704 seconds (2.48 G allocations: 46.290 GiB, 7.66% gc time)

これは換言すれば横軸が「ズルい」ということでもあるが,同時に \(n\to\infty\) の極限では,圧倒的に ZZ-CV が効率的になるということでもある.5

6 MALA との比較

MALA のセットアップ
using AdvancedHMC, AdvancedMH, ForwardDiff
using LogDensityProblems
using LogDensityProblemsAD
using StructArrays
using LinearAlgebra

struct LogTargetDensity
    y::Vector{Float64}
end

function U(i, x, y)
    x[1] * x[1] / (2 * ρ * ρ) + length(y) * (x[1] - y[i]) * (x[1] - y[i]) / (2 * σ * σ)  # 自動微分のために x は長さ1のベクトルと扱う必要がある
end

function U(x, y)
    vec = [U(i, x, y) for i in 1:length(y)]
    return mean(vec)
end

LogDensityProblems.logdensity(p::LogTargetDensity, x) = U(x, p.y)
LogDensityProblems.dimension(p::LogTargetDensity) = 1
LogDensityProblems.capabilities(::Type{LogTargetDensity}) = LogDensityProblems.LogDensityOrder{0}()

model_with_ad = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), LogTargetDensity(y1))

# σ² = 0.1 # ほぼ横ばい
# σ² = 0.5 # 1回小さいエポック10で効率勝った+全く横ばいになった
σ² = 0.2  # すごく良い感じ
spl = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))
実験の実行
function experiment_MALA(epoch_list, N, y)
    SE_sum = zero(epoch_list)
    for _ in 1:N
        chain = sample(model_with_ad, spl, Int64(epoch_list[end]); initial_params=[x0], chain_type=StructArray, param_names=["x"], stats=true)
        traj_MALA = Vector{Float64}(chain.x)
        SE_list = []
        for T in epoch_list
            SE = SquaredError(traj_MALA[1:T], y)
            push!(SE_list, SE)
        end
        SE_sum += SE_list
    end
    return SE_sum ./ N
end

MSE_MALA = experiment_MALA(Vector{Int64}(epoch_list), N, y1)
結果のプロット
scatter!(p, epoch_list, MSE_MALA,
    marker=:circle,
    markersize=5,
    markeralpha=0.6,
    color="blue",
    label=nothing
    )

df = DataFrame(X = log10.(epoch_list), Y = log10.(MSE_MALA))
model = lm(@formula(Y ~ X), df)
X_pred = range(minimum(df.X), maximum(df.X), length=100)
Y_pred = predict(model, DataFrame(X = X_pred))
plot!(p, 10 .^ X_pred, 10 .^ Y_pred,
    line=:solid,
    linewidth=2,
    color="blue",
    label="MALA"
    )

display(p)

7 非一様な部分サンプリング

当然,必ずしも一様な分解 \[ U(x)=\frac{1}{n}\sum_{j=1}^nU^j(x) \] に基づいた一様なサブサンプリング \(K\sim\mathrm{U}([n])\) を行う必要はない.

剪定の手続きを棄却法とみると,重点サンプリングのアイデアを導入することで制御変数に依らない分散低減が狙える (Sen et al., 2020 importance subsampling strategy)

特に,比例的高次元極限や,不均衡データに対するロジスティック回帰では,事後分布が十分な集中性を持たないために制御変数の方法 3 が十分な効率改善を示さないが,この重点サブサンプリングによれば効率の改善が見込める.

詳しくは,次稿参照:

Article Image

大規模な不均衡データに対するロジスティック回帰(後編)

離散時間 MCMC から連続時間 MCMC へ

References

Andrieu, C., and Livingstone, S. (2021). Peskun–Tierney ordering for Markovian Monte Carlo: Beyond the reversible scenario. The Annals of Statistics, 49(4), 1958–1981.
Bierkens, J., Fearnhead, P., and Roberts, G. (2019). The Zig-Zag Process and Super-Efficient Sampling for Bayesian Analysis of Big Data. The Annals of Statistics, 47(3), 1288–1320.
Fearnhead, P., Bierkens, J., Pollock, M., and Roberts, G. O. (2018). Piecewise deterministic markov processes for continuous-time monte carlo. Statistical Science, 33(3), 386–412.
Johnson, R. A. (1970). Asymptotic expansions associated with posterior distributions. The Annals of Mathematical Statistics, 41(3), 851–864.
Murphy, K. P. (2023). Probabilistic machine learning: Advanced topics. MIT Press.
Sen, D., Sachs, M., Lu, J., and Dunson, D. B. (2020). Efficient posterior sampling for high-dimensional imbalanced logistic regression. Biometrika, 107(4), 1005–1012.
Vasdekis, G. (2021). On zig-zag extensions and related ergodicity properties (PhD thesis). University of Warwick. Retrieved from http://webcat.warwick.ac.uk/record=b3714913

Footnotes

  1. この2点が両方肝心である.効率的なサブサンプリング推定量の開発が (Fearnhead et al., 2018) 以来議論の焦点になっている.↩︎

  2. (Vasdekis, 2021, p. 25)(Bierkens et al., 2019, pp. 1302 定理4.1) も参照.↩︎

  3. (Bierkens et al., 2019, pp. 1306–) 第5.1節参照.↩︎

  4. 実装は ZigZagBoomerang パッケージの zigzagboom1d.jl を参考にした.↩︎

  5. ただし,例えば今回も計算時間で言えば長くなっていることに注意.↩︎