HMC vs Pólya-Gamma Sampler

Julia によるベイズ・ロジスティック回帰での計算効率比較

Computation
MCMC
Author

司馬 博文

Published

12/03/2025

ベイズ・ロジスティック回帰において,MCMC により事後分布サンプリングを行う方法はいくつかある:

本稿ではこの2つのサンプラーの性能を比較する.

他にも,Zig-Zag Sampler with Control Variate (Bierkens et al., 2019), Zig-Zag Sampler with Importance Resampling (Sen et al., 2020) などの方法が提案されており,大標本 \(n\gg1\) に対するスケーラビリティが期待されるが,ここでは扱わない.

どちらかというと,中程度の次元 \(d\approx50\) にスパース性を仮定した設定の下で,サンプラーの性能を比較することを考える.

n, p, pₑ = 200, 50, 10

using Random, StatsFuns, Distributions
β_true = vcat(randn(pₑ), zeros(p - pₑ))
X = randn(n, p)

η_true = X * β_true
π_true = logistic.(η_true)

y = rand.(Bernoulli.(π_true))
y = collect(Float64, y)

以降,次のように進む

  1. Pólya-Gamma 拡大を解説し,PolyaGammaHybridSamplers.jl パッケージを利用して Gibbs サンプラーを実装する.

  2. Turing.jl を用いて HMC サンプラーを実装する方法を解説する.Turing は完全 Julia ベースの確率的プログラミング言語である.このフレームワーク上で比較することで,Julia 言語の JIT コンパイル,型チェック,ガベージコレクションなどの条件をなるべく揃えた比較を狙う.

  3. 最後に次の3点を指標にして,2つのサンプラーを比較検討する.

    1. 事後平均により真の回帰係数 \(\beta\) を推定したときの \(\ell^2\)-誤差
    2. 3000 サンプルあたりの有効サンプル数 (ESS: 独立サンプル何個分に当たるか)
    3. 単位実行時間あたりの有効サンプル数

これら3つの指標に対して,HMC / PG サンプラーはそれぞれ得意・不得意が違う.

一見 HMC の方がサンプルの質が良いが,実は計算時間で割ると PG サンプラーの方が効率が良いことを見る.

従って,我々の設定(中程度の次元 \(d\approx50\) にスパース性を仮定した設定で Gaussian sample of size \(n=200\))においては,PG サンプラーからより多くのサンプルを抽出して Monte Carlo 推定量を構成した方が,分散は小さくなると予想される.

1 Pólya-Gamma 拡大に基づく Gibbs サンプラー

Pólya-Gamma 拡大に基づく Gibbs サンプラーは,次のように定義される:

using PolyaGammaHybridSamplers, LinearAlgebra, MCMCChains, Dates, MCMCDiagnosticTools

function pg_logistic_gibbs(
  X::Matrix{Float64},
  y::Vector{Float64};
  n_samples::Int = 5000,
  burnin::Int = 1000,
  σ_prior::Float64 = 10.0,
)
  n, p = size(X)

  # 事前: β ~ N(0, σ_prior^2 I)
  V0_inv = (1 / σ_prior^2) * LinearAlgebra.I  # precision of prior

  # 初期値
  β = zeros(p)
  κ = y .- 0.5  # κ_i = y_i - 1/2

  # サンプル保存用
  n_iter = n_samples + burnin
  β_samples = Matrix{Float64}(undef, n_samples, p)

  t_start = time()
  for it in 1:n_iter
    # 1. PG 補助変数 ω_i | β のサンプル
    η = X * β
    ω = similar(η)
    for i in 1:n
      pg = PolyaGammaHybridSampler(1.0, η[i])
      ω[i] = rand(pg)
    end

    # 2. β | ω, y のサンプル (多変量ガウス)
    Ω = Diagonal(ω)
    precision = X' * Ω * X + V0_inv          # posterior precision
    cov = inv(Matrix(precision))             # posterior covariance
    m = cov * (X' * κ)                       # posterior mean (μ0=0 のため)

    # β ~ N(m, cov)
    β = rand(MvNormal(m, Symmetric(cov)))

    # burn-in 後に保存
    if it > burnin
      β_samples[it - burnin, :] .= β
    end
  end
  t_stop = time()
  runtime_sec = t_stop - t_start

  names = Symbol.("β[$i]" for i in 1:p)
  values = reshape(β_samples, n_samples, p, 1)
  chain = Chains(values, names)
  chain = setinfo(chain, (
    start_time = [t_start],  # 1本チェインなら長さ1のベクトルでOK
    stop_time  = [t_stop],
  ))

  return chain, runtime_sec
end
σ_prior = 10.0
chain_pg, t_pg = pg_logistic_gibbs(X, y;
    n_samples = 3000,
    burnin = 3000,
    σ_prior = σ_prior,
)
summarize(chain_pg)
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e     Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯
        β[1]    5.0938    1.0469    0.1640    41.4443   170.0661    1.0050     ⋯
        β[2]    0.6911    0.6890    0.0412   279.7676   604.9084    1.0023     ⋯
        β[3]   -0.7582    0.6453    0.0313   432.3889   688.4741    1.0019     ⋯
        β[4]    9.6110    1.6016    0.2042    60.8553   248.3164    1.0094     ⋯
        β[5]    3.2738    0.7971    0.0690   130.5799   429.7454    1.0046     ⋯
        β[6]   -9.2844    1.6218    0.2528    41.1782   193.8003    1.0149     ⋯
        β[7]    6.3970    1.1865    0.1728    47.5656   245.1486    1.0057     ⋯
        β[8]   -1.0989    0.6198    0.0364   298.6133   526.0545    1.0017     ⋯
        β[9]    9.5024    1.5085    0.2210    47.1074   201.7003    1.0131     ⋯
       β[10]    0.9900    0.8524    0.0484   310.7054   715.2504    1.0000     ⋯
       β[11]   -0.2928    0.5996    0.0283   447.7859   819.8250    1.0005     ⋯
       β[12]   -0.1322    0.6403    0.0316   411.7946   878.4689    1.0014     ⋯
       β[13]    1.7252    0.5301    0.0394   184.6827   454.2496    1.0062     ⋯
       β[14]    0.8363    0.5599    0.0315   317.5118   818.9495    1.0030     ⋯
       β[15]    0.4966    0.6398    0.0457   196.0665   421.4479    1.0003     ⋯
       β[16]    0.3339    0.6015    0.0347   299.0341   856.0848    1.0000     ⋯
       β[17]   -1.1701    0.6122    0.0421   211.7786   515.0791    1.0104     ⋯
       β[18]   -4.9203    1.1148    0.1460    56.6453   245.6118    1.0082     ⋯
       β[19]   -0.6173    0.6462    0.0357   342.7770   523.7577    1.0016     ⋯
       β[20]    0.0537    0.6254    0.0417   226.5395   612.0113    1.0134     ⋯
       β[21]   -1.7961    0.8030    0.0550   212.0161   382.7236    1.0001     ⋯
           ⋮         ⋮         ⋮         ⋮          ⋮          ⋮         ⋮     ⋱
                                                    1 column and 29 rows omitted

2 HMC

using Turing, LinearAlgebra

@model function logreg_turing(x, y, σ_prior)
    n, p = size(x)
    
    # 事前分布
    β ~ MvNormal(zeros(p), (σ_prior^2) * I)
    
    # ベクトル化した尤度(高速化)
    η = x * β
    y ~ arraydist(Bernoulli.(logistic.(η)))
end

model = logreg_turing(X, y, σ_prior)
n_samples = 3000
n_adapt   = 3000

chain_hmc = sample(
    model,
    NUTS(n_adapt, 0.6),
    n_samples,
)
summarize(chain_hmc)
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯
        β[1]    4.8971    1.1114    0.0365    930.7941   1588.1634    0.9998   ⋯
        β[2]    0.6469    0.6489    0.0131   2445.7267   2123.6187    1.0008   ⋯
        β[3]   -0.6992    0.6723    0.0139   2355.6245   2262.1939    1.0013   ⋯
        β[4]    9.2684    1.8183    0.0640    822.0774   1250.5017    0.9997   ⋯
        β[5]    3.1613    0.8446    0.0259   1088.2547   1823.5098    1.0003   ⋯
        β[6]   -8.9051    1.7579    0.0635    780.6818   1268.6961    1.0027   ⋯
        β[7]    6.1628    1.3241    0.0459    846.6077   1284.8146    1.0001   ⋯
        β[8]   -1.0952    0.6145    0.0136   2070.1264   2358.9584    0.9997   ⋯
        β[9]    9.1594    1.6972    0.0622    762.3070   1316.5157    1.0000   ⋯
       β[10]    0.9604    0.8599    0.0169   2588.7571   2614.2082    1.0002   ⋯
       β[11]   -0.3022    0.6428    0.0123   2753.3895   2122.3385    1.0001   ⋯
       β[12]   -0.1215    0.6553    0.0123   2834.4546   2363.1820    1.0000   ⋯
       β[13]    1.6817    0.5392    0.0143   1457.7747   2042.3019    0.9999   ⋯
       β[14]    0.8117    0.6077    0.0132   2121.6542   2058.3458    1.0002   ⋯
       β[15]    0.4104    0.6389    0.0139   2137.1197   2438.8069    1.0001   ⋯
       β[16]    0.3791    0.5993    0.0124   2362.2376   2405.8028    0.9997   ⋯
       β[17]   -1.1051    0.6244    0.0158   1567.6081   2023.3167    0.9999   ⋯
       β[18]   -4.6668    1.1094    0.0395    818.3569   1339.5415    1.0016   ⋯
       β[19]   -0.5693    0.6325    0.0123   2687.2431   2121.0722    1.0008   ⋯
       β[20]    0.0208    0.6205    0.0138   2030.4991   2267.4789    1.0009   ⋯
       β[21]   -1.7055    0.7714    0.0204   1440.4471   1931.9390    1.0024   ⋯
           ⋮         ⋮         ⋮         ⋮           ⋮           ⋮         ⋮   ⋱
                                                    1 column and 29 rows omitted

3 結果の比較

using Statistics

# 真の β との誤差
mean_hmc = vec(mean(Array(chain_hmc), dims=1))  # ここは実際のパラメータ名に合わせて調整
mean_pg = vec(mean(Array(chain_pg), dims=1))

println("‖β̂_HMC - β_true‖₂ = ", norm(mean_hmc .- β_true))
println("‖β̂_PG  - β_true‖₂ = ", norm(mean_pg  .- β_true))

# ランタイムや ESS の比較も:
ess_hmc = ess_rhat(chain_hmc)
ess_pg  = ess_rhat(chain_pg)
println("ESS (HMC) = ", mean(ess_hmc[:,:ess]))
println("ESS (PG) = ", mean(ess_pg[:,:ess]))
println("ESS/s (HMC) = ", mean(ess_hmc[:,:ess_per_sec]))
println("ESS/s (PG) = ", mean(ess_pg[:,:ess_per_sec]))
‖β̂_HMC - β_true‖₂ = 17.631999619963004
‖β̂_PG  - β_true‖₂ = 18.435891995080777
ESS (HMC) = 1853.1003569246561
ESS (PG) = 238.70064632682056
ESS/s (HMC) = 56.478021301534724
ESS/s (PG) = 146.71213664832237

References

Bierkens, J., Fearnhead, P., and Roberts, G. (2019). The Zig-Zag Process and Super-Efficient Sampling for Bayesian Analysis of Big Data. The Annals of Statistics, 47(3), 1288–1320.
Duane, S., Kennedy, A. D., Pendleton, B. J., and Roweth, D. (1987). Hybrid monte carlo. Physics Letters B, 195(2), 216–222.
Polson, N. G., Scott, J. G., and Windle, J. (2013). Bayesian inference for logistic models using pólya–gamma latent variables. Journal of the American Statistical Association, 108(504), 1339–1349.
Sen, D., Sachs, M., Lu, J., and Dunson, D. B. (2020). Efficient posterior sampling for high-dimensional imbalanced logistic regression. Biometrika, 107(4), 1005–1012.