大規模な不均衡データに対するロジスティック回帰(前編)

離散時間 MCMC から連続時間 MCMC へ

Bayesian
Computation
Python
MCMC
Statistics
Author

司馬博文

Published

7/12/2024

Modified

7/18/2024

概要
ロジットモデルやプロビットモデルの事後分布からのサンプリングには,その混合構造を利用したデータ拡張による Gibbs サンプラーが考案されている.しかし,このような Gibbs サンプラーは不明な理由で極めて収束が遅くなることがよく見られ,そのうちの1つのパターンが 大規模な不均衡データ である.この記事では,この現象がなぜ起こるかに関する考察を与え,次稿で代替手法として Zig-Zag サンプラーがうまくいくことをみる.

0.1 ロジットモデルのベイズ推定

応答変数 \(Y\in\mathcal{L}(\Omega;2)\) は,説明変数 \(X\in\mathcal{L}(\Omega;\mathbb{R}^p)\) の関数であるとして,係数 \(\xi\in\mathcal{L}(\Omega;\mathbb{R}^p)\) をロジスティック回帰モデル \[ \operatorname{P}[Y=1\,|\,X,\xi]=g^{-1}(X^\top\xi)=\frac{\exp(X^\top\xi)}{1+\exp(X^\top\xi)} \tag{1}\] を通じて定める.ただし, \[ g(x):=\log\frac{x}{1-x} \]ロジット関数\(g^{-1}\)ロジスティック関数 という.1

このパラメータ \(\xi\) をベイス推定することを考える.即ち,データ \(\{(y^i,x^i)\}_{i=1}^n\subset2\times\mathbb{R}^p\) と事前分布 \(p_0(\xi)d\xi\in\mathcal{P}(\mathbb{R}^p)\) を通じて,事後分布 \[ \pi(\xi)\,\propto\,p_0(\xi)\prod_{i=1}^n\frac{\exp(y^i(x^i)^\top\xi)}{1+\exp((x^i)^\top\xi)}=e^{-U(\xi)} \] を計算することを考える.ただし, \[\begin{align*} U(\xi)&:=-\log p_0(\xi)-\sum_{i=1}^n\log\left(\frac{\exp(y^i(x^i)^\top\xi)}{1+\exp((x^i)^\top\xi)}\right)\\ &=:U_0(\xi)+U_1(\xi) \end{align*}\] と定めた.

0.2 ロジットモデルの事後分布サンプラー

ロジットリンクによる変換が複雑であるため,ロジスティック回帰は(完全な)ベイズ推定を実行することが難しいモデルとして知られてきた.

一方で,リンク関数 \(g\) を標準正規分布 \(\mathrm{N}(0,1)\) の分布関数の逆関数に取り替えた プロビットモデル は,Gaussian data augmentation (Albert and Chib, 1993) と呼ばれるデータ拡張に基づく Gibbs サンプラーが早くから提案されており,これにより効率的なベイズ推論が可能となっていた.

プロビットモデルはロジットモデルに似ており,実用上はただ裾の重さが違うのみであると言って良い (Gelman et al., 2014, p. 407).そのこともあり,プロビットモデルのベイズ推論は計量経済学や政治科学で広く使われている手法となったが,ロジットモデルのベイズ推論の応用は遅れた (Polson et al., 2013)

しかし実は,ロジットモデルの事後分布 \(\pi\) も正規分布の Pólya-Gamma 混合として表すことができ,データ拡張によって効率的な Gibbs サンプラーを構成することができる (Polson et al., 2013).現在ではこのデータ拡張 Gibbs サンプラーが,標準的な事後分布サンプラーとなっている.

0.3 Gibbs サンプラーの課題:不均衡データ

データもモデルも大規模になっていく現代では,このようなデータ拡張に基づく Gibbs サンプラーは,特定の条件が揃うと極めて収束が遅くなる場面が少なくないことが明らかになってきている.

そのうちの1つのパターンが大規模な 不均衡データ (Johndrow et al., 2019),すなわち,特定のラベルが極めて稀少なカテゴリカルデータである.このようなデータに対しては,プロビットモデルやロジットモデルに限らず,ほとんど全てのデータ拡張に基づく Gibbs サンプラーが低速化することが報告されている:

We have found that this behavior occurs routinely, essentially regardless of the type and complexity of the statistical model, if the data are large and imbalanced. (Johndrow et al., 2019, p. 1395)

0.4 実験:不均衡データでの収束鈍化

ここでは問題を簡単にし,カテゴリーが2つの場合,即ち2値のスパースデータ \(y^i\in2=\{0,1\}\) の場合を考える.

この下で, \[ \sum_{i=1}^n y^i\,\bigg|\,n\sim\mathrm{Bin}(n,g^{-1}(\theta)),\qquad\theta\sim\mathrm{N}(0,B), \] すなわち,モデル (1) において \[ p=1,\qquad X=1,\qquad p_0(\xi)d\xi=\mathrm{N}(a,B), \] とした,説明変数なしの切片項のみでの回帰分析の場合を考える.この場合,ポテンシャルは次のように表される: \[ -U(\xi)=\xi\sum_{i=1}^ny^i-n\log(1+e^{\xi})-\frac{(\xi-a)^2}{2B}-\frac{1}{2}\log2\pi B. \]

ここまで単純化した設定でも,前述の Gibbs サンプラーの収束鈍化が見られることを検証する.ここでは

(a,B) = (0,100.0)

そして \[ \sum_{i=1}^ny^i=1 \] を保ちながら \(n\to\infty\) として実験するが,(Johndrow et al., 2019) では, \[ \sum_{i=1}^ny^i\ll n \] である大規模不均衡データである限り,\((a,B)\) の値に依らず同様の結果が得られることが報告されている.

Metropolis-Hastings 法は,Turing Institute による Julia の AdvancedMH.jl パッケージなどを通じて実装することができる:

Article Image

Metropolis-Hastings サンプラー

Julia と Turing エコシステムを用いて

using AdvancedMH
using Distributions
using MCMCChains
using ForwardDiff
using StructArrays
using LinearAlgebra
using LogDensityProblems
using LogDensityProblemsAD

# Define the components of a basic model.
struct LogTargetDensity_Logistic
    a::Float64
    B::Float64
    n::Int64
end

LogDensityProblems.logdensity(p::LogTargetDensity_Logistic, ξ) = -log(2π * p.B) - (ξ[1] - p.a)^2/(2 * p.B) + ξ[1] - p.n * log(1 + exp(ξ[1]))
LogDensityProblems.dimension(p::LogTargetDensity_Logistic) = 1
LogDensityProblems.capabilities(::Type{LogTargetDensity_Logistic}) = LogDensityProblems.LogDensityOrder{0}()

function MHSampler(n::Int64; discard_initial=30000)

    model_with_ad = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), LogTargetDensity_Logistic(a, B, n))

    spl = RWMH(MvNormal(zeros(1), I))

    chain = sample(model_with_ad, spl, 50000; chain_type=Chains, param_names=["ξ"])

    return chain
end

# ξ_vector = MHSampler(10000)
# plot(ξ_vector, title="Plot of \$\\xi\$ values", xlabel="Index", ylabel="ξ", legend=false, color="#78C2AD")
MHSampler (generic function with 1 method)
using DataFrames
using Plots

n_list = [10, 100, 1000, 10000]

elapsed_time_Metropolis = @elapsed begin
    chains = [MHSampler(n) for n in n_list]
end

autos = [DataFrame(autocor(chain, lags=1:100)) for chain in chains]

MHChain = chains

combined_df = vcat(autos..., source=:chain)

lag_columns = names(combined_df)[2:101]
lags = 1:100

p_Metropolis = plot(
    title = "Metropolis",
    xlabel = "Lag",
    ylabel = "Autocorrelation",
    legend = :topright,
    #background_color = "#F0F1EB"
)

for (i, n) in zip(1:4, n_list)
    plot!(
        p_Metropolis,
        lags,
        Array(combined_df[i, lag_columns]),
        label = "n = $n",
        linewidth = 2
    )
end
Sampling:   1%|▍                                        |  ETA: 0:00:10Sampling: 100%|█████████████████████████████████████████| Time: 0:00:00

パッケージ PolyaGammaSamplers は現在,過去のバージョンの依存関係を必要とするので,グローバルの環境とは分離しておくのが良い.

ここでは,Pólya-Gamma 分布のサンプラーの実装 PolyaGammaSamplers を参考にして,直接次のように定義する.

using Random
using StatsFuns

struct PolyaGammaPSWSampler{T <: Real} <: Sampleable{Univariate, Continuous}
    b::Int
    z::T
end

struct JStarPSWSampler{T <: Real} <: Sampleable{Univariate, Continuous}
    z::T
end

function Base.rand(rng::AbstractRNG, s::PolyaGammaPSWSampler)
    out = 0.0
    s_aux = JStarPSWSampler(s.z / 2)
    for _ in 1:s.b
        out += rand(rng, s_aux) / 4
    end
    return out
end

function Base.rand(rng::AbstractRNG, s::JStarPSWSampler)
    z = abs(s.z)  # modified to avoid negative z
    t = 0.64
    μ = 1 / z
    k = π^2 / 8 + z^2 / 2
    p = (π / 2 / k) * exp(- k * t) 
    q = 2 * exp( - z) * cdf(InverseGaussian(μ, 1.0), t)
    while true
        # Simulate a candidate x
        u = rand(rng)
        v = rand(rng)
        if (u < p / (p + q))
            # (Truncated Exponential)
            e = randexp(rng)
            x = t + e / k
        else
            # (Truncated Inverse Gaussian)
            x = randtigauss(rng, z, t)
        end
        # Evaluate if the candidate should be accepted
        s = a_xnt(x, 0, t)
        y = v * s
        n = 0
        while true
            n += 1
            if (n % 2 == 1)
                s += a_xnt(x, n, t)
                y > s && break
            else
                s -= a_xnt(x, n, t)
                y < s && return x
            end
        end
    end
end

# Return ``a_n(x)`` for a given t, see [1], eqs. (12)-(13)
# Equations (12)-(13) in [1]
# Note: 
# This is a literal transcription from the article's formula
# except for the letter case
function a_xnt(x::Real, n::Int, t::Real)
    x  t ? a_xnt_left(x, n, t) : a_xnt_right(x, n, t)
end

# Return ``a_n(x)^L`` for a given t
# Equation (12) in [1]
# Note: 
# This is a literal transcription from the article's formula
# except for the letter case
function a_xnt_left(x::Real, n::Int, t::Real)
    π * (n + 0.5) * (2 / π / x)^(3 / 2) * exp(- 2 * (n + 0.5)^2 / x)
end

# Return ``a_n(x)^R`` for a given t, see [1], eq. (13)
# Equation (13) in [1]
# Note: 
# This is a literal transcription from the article's formula
# except for the letter case
function a_xnt_right(x::Real, n::Int, t::Real)
    π * (n + 0.5) * exp(- (n + 0.5)^2 * π^2 * x / 2)
end

# Simulate from an IG(μ, 1) distribution
# Algorithms 2-3 in [1]'s supplementary material
# Note: 
# This is a literal transcription from the article's pseudo code
# except for the letter case
function randtigauss(rng::AbstractRNG, z::Real, t::Real)
    1 / z > t ? randtigauss_v1(rng, z, t) : randtigauss_v2(rng, z, t)
end

# Simulate from an IG(μ, 1) distribution, for μ := 1 / z > t;
# Algorithms 2 in [1]'s supplementary material
# Note:
# This is a literal transcription from the article's pseudo code
# except for the letter case and one little a detail: the 
# original condition  `x > R` must be replaced by `x > t`
function randtigauss_v1(rng::AbstractRNG, z::Real, t::Real)
    x = t + one(t)
    α = zero(t)
    while rand(rng) > α
        e = randexp(rng) # In [1]: E 
        é = randexp(rng) # In [1]: E'
        while e^2 > (2 * é / t)
            e = randexp(rng)
            é = randexp(rng)
        end
        x = t / (1 + t * e)^2 
        α = exp(- z^2 * x / 2)
    end
    return x
end

# Simulate from an IG(μ, 1) distribution, for μ := 1 / z ≤ t
# Algorithms 3 in [1]'s supplementary material
# Note: This is a literal transcription from the article's pseudo code
function randtigauss_v2(rng::AbstractRNG, z::Real, t::Real)
    x = t + one(t)
    μ = 1 / z
    while x > t 
        y = randn(rng)^2
        x = μ + μ^2 * y / 2 - μ * (4 * μ * y +* y)^2) / 2
        if rand(rng) > μ /+ x)
            x = μ^2 / x
        end
    end
    return x
end
randtigauss_v2 (generic function with 1 method)
# using PolyaGammaSamplers

function PGSampler(n::Int64; discard_initial=30000, iter_number=50000, initial_ξ=0.0, B=100)

    λ = 1 - n/2

    ξ_list = [initial_ξ]
    ω_list = []

    while length(ξ_list) < iter_number
        ξ = ξ_list[end]
        ω_sampler = PolyaGammaPSWSampler(n, ξ)
        ω_new = rand(ω_sampler)
        push!(ω_list, ω_new)
        ξ_sampler = Normal((ω_new + B^(-1))^(-1) * λ, (ω_new + B^(-1))^(-1))
        ξ_new = rand(ξ_sampler)
        push!(ξ_list, ξ_new)
    end

    return Chains(ξ_list[discard_initial+1:end])
end

function Distributions.mean(s::PolyaGammaPSWSampler)
    s.b * inv(2.0 * s.z) * tanh(s.z / 2.0)
end

function Distributions.var(s::PolyaGammaPSWSampler)
    s.b * inv(4 * s.z^3) * (sinh(s.z) - s.z) * (sech(s.z / 2)^2)
end
elapsed_time_PolyaGamma = @elapsed begin
    chains = [PGSampler(n) for n in n_list]
end
autos = [DataFrame(autocor(chain, lags=1:100)) for chain in chains]

PGChain = chains

combined_df = vcat(autos..., source=:chain)

lag_columns = names(combined_df)[2:101]
lags = 1:100

p_PolyaGamma = plot(
    title = "Pólya-Gamma",
    xlabel = "Lag",
    ylabel = "Autocorrelation",
    legend = (0.65, 0.35),
    #background_color = "#F0F1EB"
)

for (i, n) in zip(1:4, n_list)
    plot!(
        p_PolyaGamma,
        lags,
        Array(combined_df[i, lag_columns]),
        label = "n = $n",
        linewidth = 2,
    )
end
println("Elapsed time: $elapsed_time_Metropolis seconds v.s. $elapsed_time_PolyaGamma seconds")
Elapsed time: 2.158250875 seconds v.s. 77.111793166 seconds

PG サンプラーは MH 法に比べ恐ろしいほどに時間がかかる.これは,Turing のパッケージの最適化が優秀であるのか,Pólya-Gamma サンプラーの宿命であるのか,引き続き調べる必要がある.

Code
plot(p_Metropolis, p_PolyaGamma, layout=(1,2), #background_color = "#F0F1EB"
)
#savefig("Logistic_WhiteBackground.svg")

0.5 理論:収束鈍化の理由

前節の設定の下で,計算複雑性のオーダーが,Metropolis 法では最悪で \((\log n)^3\) であるのに対し,Gibbs サンプラーでは最高でも \(n^{3/2}(\log n)^{2.5}\) のオーダーになることが (Johndrow et al., 2019) で示されている.

その理由は,提案分布と対象分布のズレに由来することも (Johndrow et al., 2019) は明らかにしている.

\(\sum_{i=1}^ny^i\) の値を固定して \(n\to\infty\) の極限を取った場合,事後分布は次のように,負方向にスライドしながら,幅が狭まっていく.その幅の縮小レートは \(n^{-1/2}\) ではなく,約 \((\log n)^{-1}\) になる.

using StatsPlots
using LaTeXStrings

plot(
    plot(MHChain[1], title=L"n=10", color="#78C2AD"),
    plot(MHChain[2], title=L"n=100", color="#78C2AD"),
    plot(MHChain[3], title=L"n=1000", color="#78C2AD"),
    plot(MHChain[4], title=L"n=10000", color="#78C2AD"),
    layout=(4,1),
    size=(1000, 800),
    #background_color = "#F0F1EB"
)

一方で,提案分布は \(\xi_t\) をモードとした場合,\(\xi_{t+1}\) もモードの周りに幅 \(\frac{(\log n)^{3/2}}{n^{1/2}}\) で集中してしまう.すなわち,提案のステップサイズが事後分布のスケールに比べて極めて小さくなってしまう.

plot(
    plot(PGChain[1], title=L"n=10", color="#78C2AD"),
    plot(PGChain[2], title=L"n=100", color="#78C2AD"),
    plot(PGChain[3], title=L"n=1000", color="#78C2AD"),
    plot(PGChain[4], title=L"n=10000", color="#78C2AD"),
    layout=(4,1),
    size=(1000, 800),
    #background_color = "#F0F1EB"
)

MH 法でサンプルした事後分布に比べて,より鋭くなっていることがわかるだろう(\(y\) 軸のスケールに注目).

これにより,分布の十分な探索が阻害され,サンプル間の自己相関が高くなってしまうという問題が起こるようである.

0.6 比例的高次元では MAP にバイアスが残る

ロジスティック回帰において,\(p\ll n\) が満たされない設定下では,最尤推定量(MAP 推定量)がバイアスを持つ (Sur and Candès, 2019)

後編に続く:

References

Albert, J. H., and Chib, S. (1993). Bayesian analysis of binary and polychotomous response data. Journal of the American Statistical Association, 88(422), 669–679.
Gelman, A., Carlin, J. B., Stern, H. S., Dunson, D. B., Vehtari, A., and Rubin, D. B. (2014). Bayesian data analysis. Boca Raton : CRC Press.
Johndrow, J. E., Smith, A., Pillai, N., and Dunson, D. B. (2019). MCMC for imbalanced categorical data. Journal of the American Statistical Association, 114(527), 1394–1403.
Polson, N. G., Scott, J. G., and Windle, J. (2013). Bayesian inference for logistic models using pólya–gamma latent variables. Journal of the American Statistical Association, 108(504), 1339–1349.
Sur, P., and Candès, E. J. (2019). A Modern Maximum-Likelihood Theory for High-Dimensional Logistic Regression. Proceedings of the National Academy of Sciences, 116(29), 14516–14525.

Footnotes

  1. \(x\) が確率を表すとき \(\frac{x}{1-x}\) という量はオッズともいい,それゆえ \(p\) のロジットは対数オッズ比ともいう.↩︎