n, p, pₑ = 200, 50, 10
using Random, StatsFuns, Distributions
β_true = vcat(randn(pₑ), zeros(p - pₑ))
X = randn(n, p)
η_true = X * β_true
π_true = logistic.(η_true)
y = rand.(Bernoulli.(π_true))
y = collect(Float64, y)A Blog Entry on Bayesian Computation by an Applied Mathematician
$$
$$
ベイズ・ロジスティック回帰において,MCMC により事後分布サンプリングを行う方法はいくつかある:
- Pólya-Gamma 拡大に基づく Gibbs サンプラー (Polson et al., 2013)
- HMC (Hamiltonian Monte Carlo) (Duane et al., 1987)
本稿ではこの2つのサンプラーの性能を比較する.
他にも,Zig-Zag Sampler with Control Variate (Bierkens et al., 2019), Zig-Zag Sampler with Importance Resampling (Sen et al., 2020) などの方法が提案されており,大標本 \(n\gg1\) に対するスケーラビリティが期待されるが,ここでは扱わない.
どちらかというと,中程度の次元 \(d\approx50\) にスパース性を仮定した設定の下で,サンプラーの性能を比較することを考える.
以降,次のように進む
Pólya-Gamma 拡大を解説し,
PolyaGammaHybridSamplers.jlパッケージを利用して Gibbs サンプラーを実装する.Turing.jlを用いて HMC サンプラーを実装する方法を解説する.Turing は完全 Julia ベースの確率的プログラミング言語である.このフレームワーク上で比較することで,Julia 言語の JIT コンパイル,型チェック,ガベージコレクションなどの条件をなるべく揃えた比較を狙う.最後に次の3点を指標にして,2つのサンプラーを比較検討する.
- 事後平均により真の回帰係数 \(\beta\) を推定したときの \(\ell^2\)-誤差
- 3000 サンプルあたりの有効サンプル数 (ESS: 独立サンプル何個分に当たるか)
- 単位実行時間あたりの有効サンプル数
これら3つの指標に対して,HMC / PG サンプラーはそれぞれ得意・不得意が違う.
一見 HMC の方がサンプルの質が良いが,実は計算時間で割ると PG サンプラーの方が効率が良いことを見る.
従って,我々の設定(中程度の次元 \(d\approx50\) にスパース性を仮定した設定で Gaussian sample of size \(n=200\))においては,PG サンプラーからより多くのサンプルを抽出して Monte Carlo 推定量を構成した方が,分散は小さくなると予想される.
1 Pólya-Gamma 拡大に基づく Gibbs サンプラー
Pólya-Gamma 拡大に基づく Gibbs サンプラーは,次のように定義される:
using PolyaGammaHybridSamplers, LinearAlgebra, MCMCChains, Dates, MCMCDiagnosticTools
function pg_logistic_gibbs(
X::Matrix{Float64},
y::Vector{Float64};
n_samples::Int = 5000,
burnin::Int = 1000,
σ_prior::Float64 = 10.0,
)
n, p = size(X)
# 事前: β ~ N(0, σ_prior^2 I)
V0_inv = (1 / σ_prior^2) * LinearAlgebra.I # precision of prior
# 初期値
β = zeros(p)
κ = y .- 0.5 # κ_i = y_i - 1/2
# サンプル保存用
n_iter = n_samples + burnin
β_samples = Matrix{Float64}(undef, n_samples, p)
t_start = time()
for it in 1:n_iter
# 1. PG 補助変数 ω_i | β のサンプル
η = X * β
ω = similar(η)
for i in 1:n
pg = PolyaGammaHybridSampler(1.0, η[i])
ω[i] = rand(pg)
end
# 2. β | ω, y のサンプル (多変量ガウス)
Ω = Diagonal(ω)
precision = X' * Ω * X + V0_inv # posterior precision
cov = inv(Matrix(precision)) # posterior covariance
m = cov * (X' * κ) # posterior mean (μ0=0 のため)
# β ~ N(m, cov)
β = rand(MvNormal(m, Symmetric(cov)))
# burn-in 後に保存
if it > burnin
β_samples[it - burnin, :] .= β
end
end
t_stop = time()
runtime_sec = t_stop - t_start
names = Symbol.("β[$i]" for i in 1:p)
values = reshape(β_samples, n_samples, p, 1)
chain = Chains(values, names)
chain = setinfo(chain, (
start_time = [t_start], # 1本チェインなら長さ1のベクトルでOK
stop_time = [t_stop],
))
return chain, runtime_sec
endσ_prior = 10.0
chain_pg, t_pg = pg_logistic_gibbs(X, y;
n_samples = 3000,
burnin = 3000,
σ_prior = σ_prior,
)
summarize(chain_pg)parameters mean std mcse ess_bulk ess_tail rhat e ⋯ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯ β[1] 5.0938 1.0469 0.1640 41.4443 170.0661 1.0050 ⋯ β[2] 0.6911 0.6890 0.0412 279.7676 604.9084 1.0023 ⋯ β[3] -0.7582 0.6453 0.0313 432.3889 688.4741 1.0019 ⋯ β[4] 9.6110 1.6016 0.2042 60.8553 248.3164 1.0094 ⋯ β[5] 3.2738 0.7971 0.0690 130.5799 429.7454 1.0046 ⋯ β[6] -9.2844 1.6218 0.2528 41.1782 193.8003 1.0149 ⋯ β[7] 6.3970 1.1865 0.1728 47.5656 245.1486 1.0057 ⋯ β[8] -1.0989 0.6198 0.0364 298.6133 526.0545 1.0017 ⋯ β[9] 9.5024 1.5085 0.2210 47.1074 201.7003 1.0131 ⋯ β[10] 0.9900 0.8524 0.0484 310.7054 715.2504 1.0000 ⋯ β[11] -0.2928 0.5996 0.0283 447.7859 819.8250 1.0005 ⋯ β[12] -0.1322 0.6403 0.0316 411.7946 878.4689 1.0014 ⋯ β[13] 1.7252 0.5301 0.0394 184.6827 454.2496 1.0062 ⋯ β[14] 0.8363 0.5599 0.0315 317.5118 818.9495 1.0030 ⋯ β[15] 0.4966 0.6398 0.0457 196.0665 421.4479 1.0003 ⋯ β[16] 0.3339 0.6015 0.0347 299.0341 856.0848 1.0000 ⋯ β[17] -1.1701 0.6122 0.0421 211.7786 515.0791 1.0104 ⋯ β[18] -4.9203 1.1148 0.1460 56.6453 245.6118 1.0082 ⋯ β[19] -0.6173 0.6462 0.0357 342.7770 523.7577 1.0016 ⋯ β[20] 0.0537 0.6254 0.0417 226.5395 612.0113 1.0134 ⋯ β[21] -1.7961 0.8030 0.0550 212.0161 382.7236 1.0001 ⋯ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱ 1 column and 29 rows omitted
2 HMC
using Turing, LinearAlgebra
@model function logreg_turing(x, y, σ_prior)
n, p = size(x)
# 事前分布
β ~ MvNormal(zeros(p), (σ_prior^2) * I)
# ベクトル化した尤度(高速化)
η = x * β
y ~ arraydist(Bernoulli.(logistic.(η)))
end
model = logreg_turing(X, y, σ_prior)n_samples = 3000
n_adapt = 3000
chain_hmc = sample(
model,
NUTS(n_adapt, 0.6),
n_samples,
)summarize(chain_hmc)parameters mean std mcse ess_bulk ess_tail rhat ⋯ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯ β[1] 4.8971 1.1114 0.0365 930.7941 1588.1634 0.9998 ⋯ β[2] 0.6469 0.6489 0.0131 2445.7267 2123.6187 1.0008 ⋯ β[3] -0.6992 0.6723 0.0139 2355.6245 2262.1939 1.0013 ⋯ β[4] 9.2684 1.8183 0.0640 822.0774 1250.5017 0.9997 ⋯ β[5] 3.1613 0.8446 0.0259 1088.2547 1823.5098 1.0003 ⋯ β[6] -8.9051 1.7579 0.0635 780.6818 1268.6961 1.0027 ⋯ β[7] 6.1628 1.3241 0.0459 846.6077 1284.8146 1.0001 ⋯ β[8] -1.0952 0.6145 0.0136 2070.1264 2358.9584 0.9997 ⋯ β[9] 9.1594 1.6972 0.0622 762.3070 1316.5157 1.0000 ⋯ β[10] 0.9604 0.8599 0.0169 2588.7571 2614.2082 1.0002 ⋯ β[11] -0.3022 0.6428 0.0123 2753.3895 2122.3385 1.0001 ⋯ β[12] -0.1215 0.6553 0.0123 2834.4546 2363.1820 1.0000 ⋯ β[13] 1.6817 0.5392 0.0143 1457.7747 2042.3019 0.9999 ⋯ β[14] 0.8117 0.6077 0.0132 2121.6542 2058.3458 1.0002 ⋯ β[15] 0.4104 0.6389 0.0139 2137.1197 2438.8069 1.0001 ⋯ β[16] 0.3791 0.5993 0.0124 2362.2376 2405.8028 0.9997 ⋯ β[17] -1.1051 0.6244 0.0158 1567.6081 2023.3167 0.9999 ⋯ β[18] -4.6668 1.1094 0.0395 818.3569 1339.5415 1.0016 ⋯ β[19] -0.5693 0.6325 0.0123 2687.2431 2121.0722 1.0008 ⋯ β[20] 0.0208 0.6205 0.0138 2030.4991 2267.4789 1.0009 ⋯ β[21] -1.7055 0.7714 0.0204 1440.4471 1931.9390 1.0024 ⋯ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱ 1 column and 29 rows omitted
3 結果の比較
using Statistics
# 真の β との誤差
mean_hmc = vec(mean(Array(chain_hmc), dims=1)) # ここは実際のパラメータ名に合わせて調整
mean_pg = vec(mean(Array(chain_pg), dims=1))
println("‖β̂_HMC - β_true‖₂ = ", norm(mean_hmc .- β_true))
println("‖β̂_PG - β_true‖₂ = ", norm(mean_pg .- β_true))
# ランタイムや ESS の比較も:
ess_hmc = ess_rhat(chain_hmc)
ess_pg = ess_rhat(chain_pg)
println("ESS (HMC) = ", mean(ess_hmc[:,:ess]))
println("ESS (PG) = ", mean(ess_pg[:,:ess]))
println("ESS/s (HMC) = ", mean(ess_hmc[:,:ess_per_sec]))
println("ESS/s (PG) = ", mean(ess_pg[:,:ess_per_sec]))‖β̂_HMC - β_true‖₂ = 17.631999619963004
‖β̂_PG - β_true‖₂ = 18.435891995080777
ESS (HMC) = 1853.1003569246561
ESS (PG) = 238.70064632682056
ESS/s (HMC) = 56.478021301534724
ESS/s (PG) = 146.71213664832237