Metropolis-Hastings サンプラー
Julia と Turing エコシステムを用いて
離散時間 MCMC から連続時間 MCMC へ
司馬博文
7/12/2024
7/18/2024
A Blog Entry on Bayesian Computation by an Applied Mathematician
$$
$$
応答変数 \(Y\in\mathcal{L}(\Omega;2)\) は,説明変数 \(X\in\mathcal{L}(\Omega;\mathbb{R}^p)\) の関数であるとして,係数 \(\xi\in\mathcal{L}(\Omega;\mathbb{R}^p)\) をロジスティック回帰モデル \[ \operatorname{P}[Y=1\,|\,X,\xi]=g^{-1}(X^\top\xi)=\frac{\exp(X^\top\xi)}{1+\exp(X^\top\xi)} \tag{1}\] を通じて定める.ただし, \[ g(x):=\log\frac{x}{1-x} \] は ロジット関数,\(g^{-1}\) は ロジスティック関数 という.1
このパラメータ \(\xi\) をベイス推定することを考える.即ち,データ \(\{(y^i,x^i)\}_{i=1}^n\subset2\times\mathbb{R}^p\) と事前分布 \(p_0(\xi)d\xi\in\mathcal{P}(\mathbb{R}^p)\) を通じて,事後分布 \[ \pi(\xi)\,\propto\,p_0(\xi)\prod_{i=1}^n\frac{\exp(y^i(x^i)^\top\xi)}{1+\exp((x^i)^\top\xi)}=e^{-U(\xi)} \] を計算することを考える.ただし, \[\begin{align*} U(\xi)&:=-\log p_0(\xi)-\sum_{i=1}^n\log\left(\frac{\exp(y^i(x^i)^\top\xi)}{1+\exp((x^i)^\top\xi)}\right)\\ &=:U_0(\xi)+U_1(\xi) \end{align*}\] と定めた.
ロジットリンクによる変換が複雑であるため,ロジスティック回帰は(完全な)ベイズ推定を実行することが難しいモデルとして知られてきた.
一方で,リンク関数 \(g\) を標準正規分布 \(\mathrm{N}(0,1)\) の分布関数の逆関数に取り替えた プロビットモデル は,Gaussian data augmentation (Albert and Chib, 1993) と呼ばれるデータ拡張に基づく Gibbs サンプラーが早くから提案されており,これにより効率的なベイズ推論が可能となっていた.
プロビットモデルはロジットモデルに似ており,実用上はただ裾の重さが違うのみであると言って良い (Gelman et al., 2014, p. 407).そのこともあり,プロビットモデルのベイズ推論は計量経済学や政治科学で広く使われている手法となったが,ロジットモデルのベイズ推論の応用は遅れた (Polson et al., 2013).
しかし実は,ロジットモデルの事後分布 \(\pi\) も正規分布の Pólya-Gamma 混合として表すことができ,データ拡張によって効率的な Gibbs サンプラーを構成することができる (Polson et al., 2013).現在ではこのデータ拡張 Gibbs サンプラーが,標準的な事後分布サンプラーとなっている.
データもモデルも大規模になっていく現代では,このようなデータ拡張に基づく Gibbs サンプラーは,特定の条件が揃うと極めて収束が遅くなる場面が少なくないことが明らかになってきている.
そのうちの1つのパターンが大規模な 不均衡データ (Johndrow et al., 2019),すなわち,特定のラベルが極めて稀少なカテゴリカルデータである.このようなデータに対しては,プロビットモデルやロジットモデルに限らず,ほとんど全てのデータ拡張に基づく Gibbs サンプラーが低速化することが報告されている:
We have found that this behavior occurs routinely, essentially regardless of the type and complexity of the statistical model, if the data are large and imbalanced. (Johndrow et al., 2019, p. 1395)
ここでは問題を簡単にし,カテゴリーが2つの場合,即ち2値のスパースデータ \(y^i\in2=\{0,1\}\) の場合を考える.
この下で, \[ \sum_{i=1}^n y^i\,\bigg|\,n\sim\mathrm{Bin}(n,g^{-1}(\theta)),\qquad\theta\sim\mathrm{N}(0,B), \] すなわち,モデル (1) において \[ p=1,\qquad X=1,\qquad p_0(\xi)d\xi=\mathrm{N}(a,B), \] とした,説明変数なしの切片項のみでの回帰分析の場合を考える.この場合,ポテンシャルは次のように表される: \[ -U(\xi)=\xi\sum_{i=1}^ny^i-n\log(1+e^{\xi})-\frac{(\xi-a)^2}{2B}-\frac{1}{2}\log2\pi B. \]
ここまで単純化した設定でも,前述の Gibbs サンプラーの収束鈍化が見られることを検証する.ここでは
そして \[ \sum_{i=1}^ny^i=1 \] を保ちながら \(n\to\infty\) として実験するが,(Johndrow et al., 2019) では, \[ \sum_{i=1}^ny^i\ll n \] である大規模不均衡データである限り,\((a,B)\) の値に依らず同様の結果が得られることが報告されている.
Metropolis-Hastings 法は,Turing Institute による Julia の AdvancedMH.jl
パッケージなどを通じて実装することができる:
using AdvancedMH
using Distributions
using MCMCChains
using ForwardDiff
using StructArrays
using LinearAlgebra
using LogDensityProblems
using LogDensityProblemsAD
# Define the components of a basic model.
struct LogTargetDensity_Logistic
a::Float64
B::Float64
n::Int64
end
LogDensityProblems.logdensity(p::LogTargetDensity_Logistic, ξ) = -log(2π * p.B) - (ξ[1] - p.a)^2/(2 * p.B) + ξ[1] - p.n * log(1 + exp(ξ[1]))
LogDensityProblems.dimension(p::LogTargetDensity_Logistic) = 1
LogDensityProblems.capabilities(::Type{LogTargetDensity_Logistic}) = LogDensityProblems.LogDensityOrder{0}()
function MHSampler(n::Int64; discard_initial=30000)
model_with_ad = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), LogTargetDensity_Logistic(a, B, n))
spl = RWMH(MvNormal(zeros(1), I))
chain = sample(model_with_ad, spl, 50000; chain_type=Chains, param_names=["ξ"])
return chain
end
# ξ_vector = MHSampler(10000)
# plot(ξ_vector, title="Plot of \$\\xi\$ values", xlabel="Index", ylabel="ξ", legend=false, color="#78C2AD")
MHSampler (generic function with 1 method)
using DataFrames
using Plots
n_list = [10, 100, 1000, 10000]
elapsed_time_Metropolis = @elapsed begin
chains = [MHSampler(n) for n in n_list]
end
autos = [DataFrame(autocor(chain, lags=1:100)) for chain in chains]
MHChain = chains
combined_df = vcat(autos..., source=:chain)
lag_columns = names(combined_df)[2:101]
lags = 1:100
p_Metropolis = plot(
title = "Metropolis",
xlabel = "Lag",
ylabel = "Autocorrelation",
legend = :topright,
#background_color = "#F0F1EB"
)
for (i, n) in zip(1:4, n_list)
plot!(
p_Metropolis,
lags,
Array(combined_df[i, lag_columns]),
label = "n = $n",
linewidth = 2
)
end
Sampling: 1%|▍ | ETA: 0:00:10Sampling: 100%|█████████████████████████████████████████| Time: 0:00:00
パッケージ PolyaGammaSamplers
は現在,過去のバージョンの依存関係を必要とするので,グローバルの環境とは分離しておくのが良い.
ここでは,Pólya-Gamma 分布のサンプラーの実装 PolyaGammaSamplers
を参考にして,直接次のように定義する.
using Random
using StatsFuns
struct PolyaGammaPSWSampler{T <: Real} <: Sampleable{Univariate, Continuous}
b::Int
z::T
end
struct JStarPSWSampler{T <: Real} <: Sampleable{Univariate, Continuous}
z::T
end
function Base.rand(rng::AbstractRNG, s::PolyaGammaPSWSampler)
out = 0.0
s_aux = JStarPSWSampler(s.z / 2)
for _ in 1:s.b
out += rand(rng, s_aux) / 4
end
return out
end
function Base.rand(rng::AbstractRNG, s::JStarPSWSampler)
z = abs(s.z) # modified to avoid negative z
t = 0.64
μ = 1 / z
k = π^2 / 8 + z^2 / 2
p = (π / 2 / k) * exp(- k * t)
q = 2 * exp( - z) * cdf(InverseGaussian(μ, 1.0), t)
while true
# Simulate a candidate x
u = rand(rng)
v = rand(rng)
if (u < p / (p + q))
# (Truncated Exponential)
e = randexp(rng)
x = t + e / k
else
# (Truncated Inverse Gaussian)
x = randtigauss(rng, z, t)
end
# Evaluate if the candidate should be accepted
s = a_xnt(x, 0, t)
y = v * s
n = 0
while true
n += 1
if (n % 2 == 1)
s += a_xnt(x, n, t)
y > s && break
else
s -= a_xnt(x, n, t)
y < s && return x
end
end
end
end
# Return ``a_n(x)`` for a given t, see [1], eqs. (12)-(13)
# Equations (12)-(13) in [1]
# Note:
# This is a literal transcription from the article's formula
# except for the letter case
function a_xnt(x::Real, n::Int, t::Real)
x ≤ t ? a_xnt_left(x, n, t) : a_xnt_right(x, n, t)
end
# Return ``a_n(x)^L`` for a given t
# Equation (12) in [1]
# Note:
# This is a literal transcription from the article's formula
# except for the letter case
function a_xnt_left(x::Real, n::Int, t::Real)
π * (n + 0.5) * (2 / π / x)^(3 / 2) * exp(- 2 * (n + 0.5)^2 / x)
end
# Return ``a_n(x)^R`` for a given t, see [1], eq. (13)
# Equation (13) in [1]
# Note:
# This is a literal transcription from the article's formula
# except for the letter case
function a_xnt_right(x::Real, n::Int, t::Real)
π * (n + 0.5) * exp(- (n + 0.5)^2 * π^2 * x / 2)
end
# Simulate from an IG(μ, 1) distribution
# Algorithms 2-3 in [1]'s supplementary material
# Note:
# This is a literal transcription from the article's pseudo code
# except for the letter case
function randtigauss(rng::AbstractRNG, z::Real, t::Real)
1 / z > t ? randtigauss_v1(rng, z, t) : randtigauss_v2(rng, z, t)
end
# Simulate from an IG(μ, 1) distribution, for μ := 1 / z > t;
# Algorithms 2 in [1]'s supplementary material
# Note:
# This is a literal transcription from the article's pseudo code
# except for the letter case and one little a detail: the
# original condition `x > R` must be replaced by `x > t`
function randtigauss_v1(rng::AbstractRNG, z::Real, t::Real)
x = t + one(t)
α = zero(t)
while rand(rng) > α
e = randexp(rng) # In [1]: E
é = randexp(rng) # In [1]: E'
while e^2 > (2 * é / t)
e = randexp(rng)
é = randexp(rng)
end
x = t / (1 + t * e)^2
α = exp(- z^2 * x / 2)
end
return x
end
# Simulate from an IG(μ, 1) distribution, for μ := 1 / z ≤ t
# Algorithms 3 in [1]'s supplementary material
# Note: This is a literal transcription from the article's pseudo code
function randtigauss_v2(rng::AbstractRNG, z::Real, t::Real)
x = t + one(t)
μ = 1 / z
while x > t
y = randn(rng)^2
x = μ + μ^2 * y / 2 - μ * √(4 * μ * y + (μ * y)^2) / 2
if rand(rng) > μ / (μ + x)
x = μ^2 / x
end
end
return x
end
randtigauss_v2 (generic function with 1 method)
# using PolyaGammaSamplers
function PGSampler(n::Int64; discard_initial=30000, iter_number=50000, initial_ξ=0.0, B=100)
λ = 1 - n/2
ξ_list = [initial_ξ]
ω_list = []
while length(ξ_list) < iter_number
ξ = ξ_list[end]
ω_sampler = PolyaGammaPSWSampler(n, ξ)
ω_new = rand(ω_sampler)
push!(ω_list, ω_new)
ξ_sampler = Normal((ω_new + B^(-1))^(-1) * λ, (ω_new + B^(-1))^(-1))
ξ_new = rand(ξ_sampler)
push!(ξ_list, ξ_new)
end
return Chains(ξ_list[discard_initial+1:end])
end
function Distributions.mean(s::PolyaGammaPSWSampler)
s.b * inv(2.0 * s.z) * tanh(s.z / 2.0)
end
function Distributions.var(s::PolyaGammaPSWSampler)
s.b * inv(4 * s.z^3) * (sinh(s.z) - s.z) * (sech(s.z / 2)^2)
end
elapsed_time_PolyaGamma = @elapsed begin
chains = [PGSampler(n) for n in n_list]
end
autos = [DataFrame(autocor(chain, lags=1:100)) for chain in chains]
PGChain = chains
combined_df = vcat(autos..., source=:chain)
lag_columns = names(combined_df)[2:101]
lags = 1:100
p_PolyaGamma = plot(
title = "Pólya-Gamma",
xlabel = "Lag",
ylabel = "Autocorrelation",
legend = (0.65, 0.35),
#background_color = "#F0F1EB"
)
for (i, n) in zip(1:4, n_list)
plot!(
p_PolyaGamma,
lags,
Array(combined_df[i, lag_columns]),
label = "n = $n",
linewidth = 2,
)
end
Elapsed time: 2.158250875 seconds v.s. 77.111793166 seconds
PG サンプラーは MH 法に比べ恐ろしいほどに時間がかかる.これは,Turing
のパッケージの最適化が優秀であるのか,Pólya-Gamma サンプラーの宿命であるのか,引き続き調べる必要がある.
前節の設定の下で,計算複雑性のオーダーが,Metropolis 法では最悪で \((\log n)^3\) であるのに対し,Gibbs サンプラーでは最高でも \(n^{3/2}(\log n)^{2.5}\) のオーダーになることが (Johndrow et al., 2019) で示されている.
その理由は,提案分布と対象分布のズレに由来することも (Johndrow et al., 2019) は明らかにしている.
\(\sum_{i=1}^ny^i\) の値を固定して \(n\to\infty\) の極限を取った場合,事後分布は次のように,負方向にスライドしながら,幅が狭まっていく.その幅の縮小レートは \(n^{-1/2}\) ではなく,約 \((\log n)^{-1}\) になる.
using StatsPlots
using LaTeXStrings
plot(
plot(MHChain[1], title=L"n=10", color="#78C2AD"),
plot(MHChain[2], title=L"n=100", color="#78C2AD"),
plot(MHChain[3], title=L"n=1000", color="#78C2AD"),
plot(MHChain[4], title=L"n=10000", color="#78C2AD"),
layout=(4,1),
size=(1000, 800),
#background_color = "#F0F1EB"
)
一方で,提案分布は \(\xi_t\) をモードとした場合,\(\xi_{t+1}\) もモードの周りに幅 \(\frac{(\log n)^{3/2}}{n^{1/2}}\) で集中してしまう.すなわち,提案のステップサイズが事後分布のスケールに比べて極めて小さくなってしまう.
plot(
plot(PGChain[1], title=L"n=10", color="#78C2AD"),
plot(PGChain[2], title=L"n=100", color="#78C2AD"),
plot(PGChain[3], title=L"n=1000", color="#78C2AD"),
plot(PGChain[4], title=L"n=10000", color="#78C2AD"),
layout=(4,1),
size=(1000, 800),
#background_color = "#F0F1EB"
)
MH 法でサンプルした事後分布に比べて,より鋭くなっていることがわかるだろう(\(y\) 軸のスケールに注目).
これにより,分布の十分な探索が阻害され,サンプル間の自己相関が高くなってしまうという問題が起こるようである.
ロジスティック回帰において,\(p\ll n\) が満たされない設定下では,最尤推定量(MAP 推定量)がバイアスを持つ (Sur and Candès, 2019).
後編に続く:
\(x\) が確率を表すとき \(\frac{x}{1-x}\) という量はオッズともいい,それゆえ \(p\) のロジットは対数オッズ比ともいう.↩︎