大規模な不均衡データに対するロジスティック回帰(前編)

離散時間 MCMC から連続時間 MCMC へ

Bayesian
Computation
Python
MCMC
Statistics
Author

司馬博文

Published

7/12/2024

Modified

10/07/2024

概要
ロジットモデルやプロビットモデルの事後分布からのサンプリングには,その混合構造を利用したデータ拡張による Gibbs サンプラーが考案されている.しかし,このような Gibbs サンプラーは不明な理由で極めて収束が遅くなることがよく見られ,そのうちの1つのパターンが 大規模な不均衡データ である.この記事では,この現象がなぜ起こるかに関する考察を与え,次稿で代替手法として Zig-Zag サンプラーがうまくいくことをみる.

1 二項回帰モデル

1.1 はじめに

独立標本 \([n]:=\{1,\cdots,n\}\) 上に,2値データ \(y_i\in\{0,1\}=2\) が観測されているとする.\(y_i\) は例えば症状の有無などのように,2つのカテゴリーのみを持つデータである.

\(x_i\in\mathbb{R}^p\)\(p\) 次元の実ベクトルとし,これを説明変数として \(y_i\) をよく予測するモデルを作りたいとする.

予測確率を最大限上げることよりも,\(x_i\)\(p\) 個の成分のうちどの成分が \(y_i\) に影響を与えているかを知りたい場合,機械学習的な方法よりも,モデルを立てて推定する統計的な方法が適している (Breiman, 2001)

よく使われる統計モデルは,\(y_i=1\) となる確率を予測するという 二項回帰モデル (binary regression model) である: \[ \operatorname{P}[Y=1|X=x]=F(\xi^\top x). \tag{1}\]

ただし,\(F\) は分布関数とする.

これは \(Y\sim\mathrm{Ber}(\mu)\) というモデルの平均パラメータ \(\mu\) に対して,リンク関数 \(G:=F^{-1}\) を通じて線型予測をする 一般化線型モデル の一種である.

二項モデルの尤度

尤度は次のように表せる: \[ p(\{x_i,y_i\}|\xi)=\prod_{i=1}^n\biggr(y_iF(\xi^\top x_i)+(1-y_i)\left(1-F(\xi^\top x_i)\right)\biggl). \]

仮に \(y_i\in\{\pm1\}\) であった場合は,次のような簡単な表記もできる: \[ p(\{x_i,y_i\}|\xi)=\prod_{i=1}^nF(y_i\xi^\top x_i). \]

1.2 潜在変数モデルとしての解釈

二項回帰模型 (1) は,\(F\) を分布関数にもつ誤差項 \(\epsilon\) を通じて, \[ Y=1_{\left\{Y^*>0\right\}},\qquad Y^*=\xi^\top X+\epsilon,\;\epsilon\sim F, \tag{2}\] によって \(Y\) が観測されていると解釈することもできる.

この解釈は計量経済学において 二項選択モデル とも呼ばれる (Hansen, 2022, p. 804)

この \(Y^*\) という潜在変数を考慮することはベイズ計算上も大変有意義である.

\((\xi,Y^*)\) 上の拡張されたパラメータ空間上で事後分布推定をし,その後 \(\xi\) 上の周辺分布を求める方法は データ拡張 と呼ばれる.

問題の次元を上げているため一見非効率に思えるが,式 (2) が人間にとって読みやすいように,その分条件付き確率の構造が単純になっており,Gibbs サンプラーの構成が可能になる.

1.3 ロジットモデル

\(F\) を標準ロジスティック分布関数 \[ F(x)=\frac{1}{1+e^{-x}},\qquad G(x)=F^{-1}(x)=\log\frac{x}{1-x}, \] と取った場合,\(G\)ロジット,このモデルは ロジットモデル と呼ばれる: \[ \operatorname{P}[Y=1\,|\,X=x]=F(\xi^\top x)=\frac{\exp(\xi^\top x)}{1+\exp(\xi^\top x)}. \tag{3}\]

ロジット \(G(\mu)\)\(\mu\)対数オッズ ともいい,ロジットモデルは対数オッズに線型モデルをおいているとも解釈される.

このとき,\(e^0=1\) であることを利用すれば尤度は次のように表示できる: \[ p(\{x_i,y_i\}|\xi)=\prod_{i=1}^n\frac{\exp(y_i\xi^\top x_i)}{1+\exp(\xi^\top x_i)}. \]

(Rasch, 1960) のモデル

\[ \operatorname{P}[Y=1\,|\,X=x,B=b,A=a]=F(ax-b)=\biggr(1+e^{b-ax}\biggl)^{-1} \] とした場合を 2母数応答モデル (2PLM: two-parameter logistic model) という.

このようなモデルは項目反応理論で考えられる.

ロジットモデルとの違いは,\(x\) も所与ではなく推定するパラメータである点である.\(x\) を能力パラメータ,\(a\) を項目識別パラメータ,\(b\) を難易度パラメータなどという.詳しくは次稿参照:

2 ロジットモデルのベイズ推定法

2.1 はじめに

ロジットモデルは比較的単純であるが,多くの深い性質を持っている.

まず \(p\ll n\) の仮定が満たされない場合,最尤推定量はバイアスを持つ (Sur and Candès, 2019)

比例的高次元極限では従来の漸近論は成り立たず,尤度比検定も実行不可能になる.

このように \(p\) の数が大きい場合などはベイズ推論が志向される理由がある (Firth, 1993), (Gelman et al., 2008)

2.2 ロジットモデルの事後分布

ロジットモデルではパラメータ \(\xi\) をベイズ推定することを考える.

即ち,データ \(\{(y_i,x_i)\}_{i=1}^n\subset2\times\mathbb{R}^p\) と事前分布 \(p(\xi)d\xi\in\mathcal{P}(\mathbb{R}^p)\) を通じて,事後分布 \[\begin{align*} p(\xi|\{y_i,x_i\})&\,\propto\,p(\xi)p(\{y_i,x_i\}_{i=1}^n|\xi)\\ &= p(\xi)\prod_{i=1}^n\frac{\exp(y_i(x_i)^\top\xi)}{1+\exp((x_i)^\top\xi)}=:e^{-U(\xi)} \end{align*}\] を計算することを考える.

事後分布は正規化定数を除いて \[ p(\xi|\{y_i,x_i\})\,\propto\,e^{-U(\xi)} \] と得られたことになる.

\(U\)ポテンシャル と呼び,次のように表せる: \[\begin{align*} U(\xi)&:=-\log p_0(\xi)-\sum_{i=1}^n\log\left(\frac{\exp(y_i(x_i)^\top\xi)}{1+\exp((x_i)^\top\xi)}\right)\\ &=:U_0(\xi)+U_1(\xi). \end{align*}\]

2.3 ロジットモデルの事後分布サンプラー

ロジットリンク \(G\) による変換が複雑であるため,ロジスティック回帰は(完全な)ベイズ推定を実行することが難しいモデルとして知られてきた.

一方でリンク関数 \(g\) を標準正規分布 \(\mathrm{N}(0,1)\) の分布関数 \(\Phi\) の逆関数に取り替えた プロビットモデル にはデータ拡張に基づく Gibbs サンプラーが早くから提案されており (Albert and Chib, 1993),これにより効率的なベイズ推論が可能となっていた.

プロビットモデルはロジットモデルに極めて似ており,ただ裾の重さが違うのみであると言って良い (Gelman et al., 2014, p. 407).そのこともあり,プロビットモデルのベイズ推論は計量経済学や政治科学で広く使われている手法となったが,ロジットモデルのベイズ推論の応用は遅れていた (Polson et al., 2013)

しかし実はロジットモデルの事後分布 \(\pi\) も正規分布の Pólya-Gamma 混合として表すことができ,データ拡張によって効率的な Gibbs サンプラーを構成することができる (Polson et al., 2013).現在ではこのデータ拡張 Gibbs サンプラーが,標準的な事後分布サンプラーとなっている.

2.4 Gibbs サンプラーの課題:不均衡データ

データもモデルも大規模になっていく現代では,このようなデータ拡張に基づく Gibbs サンプラーは特定の条件が揃うと極めて収束が遅くなる場面が少なくないことが明らかになってきている.

そのうちの1つのパターンが大規模な 不均衡データ (Johndrow et al., 2019),すなわち,特定のラベルが極めて稀少なカテゴリカルデータである.

このようなデータに対しては,プロビットモデルやロジットモデルに限らず,ほとんど全てのデータ拡張に基づく Gibbs サンプラーが低速化することが報告されている:

We have found that this behavior occurs routinely, essentially regardless of the type and complexity of the statistical model, if the data are large and imbalanced. (Johndrow et al., 2019, p. 1395)

2.5 実験:不均衡データでの収束鈍化

\(p=1\) 次元のプロビットモデル \[ \sum_{i=1}^n y_i\,\bigg|\,(n,\theta)\sim\mathrm{Bin}(n,F(\theta)),\qquad\theta\sim\mathrm{N}_1(0,B),B=100, \] しかも説明変数なしの切片項のみの状況を考える: \[ x_i\equiv1,\qquad p_0(\xi)d\xi=\mathrm{N}(a,B),\qquad a=0, \]

この場合,ポテンシャルは次のように表される: \[ -U(\xi)=\xi\sum_{i=1}^ny_i-n\log(1+e^{\xi})-\frac{(\xi-a)^2}{2B}-\frac{1}{2}\log2\pi B. \]

ここまで単純化した設定でも,観測 \(y_i\) のカテゴリが不均衡ならば,前述の Gibbs サンプラーの収束鈍化が見られることを検証する: \[ \sum_{i=1}^ny_i=1. \]

Metropolis-Hastings 法は,Turing Institute による Julia の AdvancedMH.jl パッケージなどを通じて実装することができる:

using AdvancedMH
using Distributions
using MCMCChains
using ForwardDiff
using StructArrays
using LinearAlgebra
using LogDensityProblems
using LogDensityProblemsAD

(a,B) = (0,100.0)

# Define the components of a basic model.
struct LogTargetDensity_Logistic
    a::Float64
    B::Float64
    n::Int64
end

LogDensityProblems.logdensity(p::LogTargetDensity_Logistic, ξ) = -log(2π * p.B) - (ξ[1] - p.a)^2/(2 * p.B) + ξ[1] - p.n * log(1 + exp(ξ[1]))
LogDensityProblems.dimension(p::LogTargetDensity_Logistic) = 1
LogDensityProblems.capabilities(::Type{LogTargetDensity_Logistic}) = LogDensityProblems.LogDensityOrder{0}()

function MHSampler(n::Int64; discard_initial=30000)

    model_with_ad = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), LogTargetDensity_Logistic(a, B, n))

    spl = RWMH(MvNormal(zeros(1), I))

    chain = sample(model_with_ad, spl, 50000; chain_type=Chains, param_names=["ξ"])

    return chain
end

# ξ_vector = MHSampler(10000)
# plot(ξ_vector, title="Plot of \$\\xi\$ values", xlabel="Index", ylabel="ξ", legend=false, color="#78C2AD")
MHSampler (generic function with 1 method)
using DataFrames
using Plots

n_list = [10, 100, 1000, 10000]

elapsed_time_Metropolis = @elapsed begin
    chains = [MHSampler(n) for n in n_list]
end

autos = [DataFrame(autocor(chain, lags=1:100)) for chain in chains]

MHChain = chains

combined_df = vcat(autos..., source=:chain)

lag_columns = names(combined_df)[2:101]
lags = 1:100

p_Metropolis = plot(
    title = "Metropolis",
    xlabel = "Lag",
    ylabel = "Autocorrelation",
    legend = :topright,
    #background_color = "#F0F1EB"
)

for (i, n) in zip(1:4, n_list)
    plot!(
        p_Metropolis,
        lags,
        Array(combined_df[i, lag_columns]),
        label = "n = $n",
        linewidth = 2
    )
end

パッケージ PolyaGammaSamplers は現在,過去のバージョンの依存関係を必要とするので,グローバルの環境とは分離しておくのが良い.

ここでは,Pólya-Gamma 分布のサンプラーの実装 PolyaGammaSamplers を参考にして,直接次のように定義する.

using Random
using StatsFuns

struct PolyaGammaPSWSampler{T <: Real} <: Sampleable{Univariate, Continuous}
    b::Int
    z::T
end

struct JStarPSWSampler{T <: Real} <: Sampleable{Univariate, Continuous}
    z::T
end

function Base.rand(rng::AbstractRNG, s::PolyaGammaPSWSampler)
    out = 0.0
    s_aux = JStarPSWSampler(s.z / 2)
    for _ in 1:s.b
        out += rand(rng, s_aux) / 4
    end
    return out
end

function Base.rand(rng::AbstractRNG, s::JStarPSWSampler)
    z = abs(s.z)  # modified to avoid negative z
    t = 0.64
    μ = 1 / z
    k = π^2 / 8 + z^2 / 2
    p = (π / 2 / k) * exp(- k * t) 
    q = 2 * exp( - z) * cdf(InverseGaussian(μ, 1.0), t)
    while true
        # Simulate a candidate x
        u = rand(rng)
        v = rand(rng)
        if (u < p / (p + q))
            # (Truncated Exponential)
            e = randexp(rng)
            x = t + e / k
        else
            # (Truncated Inverse Gaussian)
            x = randtigauss(rng, z, t)
        end
        # Evaluate if the candidate should be accepted
        s = a_xnt(x, 0, t)
        y = v * s
        n = 0
        while true
            n += 1
            if (n % 2 == 1)
                s += a_xnt(x, n, t)
                y > s && break
            else
                s -= a_xnt(x, n, t)
                y < s && return x
            end
        end
    end
end

# Return ``a_n(x)`` for a given t, see [1], eqs. (12)-(13)
# Equations (12)-(13) in [1]
# Note: 
# This is a literal transcription from the article's formula
# except for the letter case
function a_xnt(x::Real, n::Int, t::Real)
    x  t ? a_xnt_left(x, n, t) : a_xnt_right(x, n, t)
end

# Return ``a_n(x)^L`` for a given t
# Equation (12) in [1]
# Note: 
# This is a literal transcription from the article's formula
# except for the letter case
function a_xnt_left(x::Real, n::Int, t::Real)
    π * (n + 0.5) * (2 / π / x)^(3 / 2) * exp(- 2 * (n + 0.5)^2 / x)
end

# Return ``a_n(x)^R`` for a given t, see [1], eq. (13)
# Equation (13) in [1]
# Note: 
# This is a literal transcription from the article's formula
# except for the letter case
function a_xnt_right(x::Real, n::Int, t::Real)
    π * (n + 0.5) * exp(- (n + 0.5)^2 * π^2 * x / 2)
end

# Simulate from an IG(μ, 1) distribution
# Algorithms 2-3 in [1]'s supplementary material
# Note: 
# This is a literal transcription from the article's pseudo code
# except for the letter case
function randtigauss(rng::AbstractRNG, z::Real, t::Real)
    1 / z > t ? randtigauss_v1(rng, z, t) : randtigauss_v2(rng, z, t)
end

# Simulate from an IG(μ, 1) distribution, for μ := 1 / z > t;
# Algorithms 2 in [1]'s supplementary material
# Note:
# This is a literal transcription from the article's pseudo code
# except for the letter case and one little a detail: the 
# original condition  `x > R` must be replaced by `x > t`
function randtigauss_v1(rng::AbstractRNG, z::Real, t::Real)
    x = t + one(t)
    α = zero(t)
    while rand(rng) > α
        e = randexp(rng) # In [1]: E 
        é = randexp(rng) # In [1]: E'
        while e^2 > (2 * é / t)
            e = randexp(rng)
            é = randexp(rng)
        end
        x = t / (1 + t * e)^2 
        α = exp(- z^2 * x / 2)
    end
    return x
end

# Simulate from an IG(μ, 1) distribution, for μ := 1 / z ≤ t
# Algorithms 3 in [1]'s supplementary material
# Note: This is a literal transcription from the article's pseudo code
function randtigauss_v2(rng::AbstractRNG, z::Real, t::Real)
    x = t + one(t)
    μ = 1 / z
    while x > t 
        y = randn(rng)^2
        x = μ + μ^2 * y / 2 - μ * (4 * μ * y +* y)^2) / 2
        if rand(rng) > μ /+ x)
            x = μ^2 / x
        end
    end
    return x
end
randtigauss_v2 (generic function with 1 method)
# using PolyaGammaSamplers

function PGSampler(n::Int64; discard_initial=30000, iter_number=50000, initial_ξ=0.0, B=100)

    λ = 1 - n/2

    ξ_list = [initial_ξ]
    ω_list = []

    while length(ξ_list) < iter_number
        ξ = ξ_list[end]
        ω_sampler = PolyaGammaPSWSampler(n, ξ)
        ω_new = rand(ω_sampler)
        push!(ω_list, ω_new)
        ξ_sampler = Normal((ω_new + B^(-1))^(-1) * λ, (ω_new + B^(-1))^(-1))
        ξ_new = rand(ξ_sampler)
        push!(ξ_list, ξ_new)
    end

    return Chains(ξ_list[discard_initial+1:end])
end

function Distributions.mean(s::PolyaGammaPSWSampler)
    s.b * inv(2.0 * s.z) * tanh(s.z / 2.0)
end

function Distributions.var(s::PolyaGammaPSWSampler)
    s.b * inv(4 * s.z^3) * (sinh(s.z) - s.z) * (sech(s.z / 2)^2)
end
elapsed_time_PolyaGamma = @elapsed begin
    chains = [PGSampler(n) for n in n_list]
end
autos = [DataFrame(autocor(chain, lags=1:100)) for chain in chains]

PGChain = chains

combined_df = vcat(autos..., source=:chain)

lag_columns = names(combined_df)[2:101]
lags = 1:100

p_PolyaGamma = plot(
    title = "Pólya-Gamma",
    xlabel = "Lag",
    ylabel = "Autocorrelation",
    legend = (0.65, 0.35),
    #background_color = "#F0F1EB"
)

for (i, n) in zip(1:4, n_list)
    plot!(
        p_PolyaGamma,
        lags,
        Array(combined_df[i, lag_columns]),
        label = "n = $n",
        linewidth = 2,
    )
end
println("Elapsed time: $elapsed_time_Metropolis seconds v.s. $elapsed_time_PolyaGamma seconds")
Elapsed time: 2.0604135 seconds v.s. 79.162613875 seconds

PG サンプラーは MH 法に比べ恐ろしいほどに時間がかかる.これは,Turing のパッケージの最適化が優秀であるのか,Pólya-Gamma サンプラーの宿命であるのか,引き続き調べる必要がある.

2.6 理論:収束鈍化の理由

前節の \(p=1\) の不均衡プロビットモデルの下での計算複雑性のオーダーは,

  • Metropolis 法では最悪で \((\log n)^3\)
  • Gibbs サンプラーでは最高でも \(n^{3/2}(\log n)^{2.5}\)

になることが示されている (Section 4 Johndrow et al., 2019)

その理由は,提案分布と対象分布のズレに由来することも (Johndrow et al., 2019) は明らかにしている.

\(\sum_{i=1}^ny_i\) の値を固定して \(n\to\infty\) の極限を取った場合,事後分布は次のように負方向にスライドしながら,幅が狭まっていく.

その幅の縮小レートは \(n^{-1/2}\) ではなく,約 \((\log n)^{-1}\) になる.

一方で提案分布は \(\xi_t\) をモードとした場合,\(\xi_{t+1}\) もモードの周りに幅 \(\frac{(\log n)^{3/2}}{n^{1/2}}\) で集中してしまう.すなわち,提案のステップサイズが事後分布のスケールに比べて極めて小さくなってしまう.

MH 法でサンプルした事後分布に比べて,より鋭くなっていることがわかるだろう(\(y\) 軸のスケールに注目).

Gibbs サンプラーではステップサイズがあまりに小さくなってしまい,分布の十分な探索が阻害され,サンプル間の自己相関が高くなってしまうという問題が起こるようである.

3 Zig-Zag サンプラーによる解決

後編に続く:

4 文献紹介

ロジスティック回帰やプロビット回帰などの2項回帰モデルは,広くベイズ計算手法のベンチマークとして用いられるモデルである (Chopin and Ridgway, 2017)

二項回帰モデルにおいて観測 \(y_{i}\) の値が激しく偏っていた場合,Gibbs サンプラーの収束に問題が起きることは (Johndrow et al., 2019) が指摘している.

この問題が Zig-Zag サンプラーを用いれば,重点荷重+ミニバッチサブサンプリングのアイデアで比較的簡単に解決できることが (Sen et al., 2020) で実証された.

References

Albert, J. H., and Chib, S. (1993). Bayesian analysis of binary and polychotomous response data. Journal of the American Statistical Association, 88(422), 669–679.
Breiman, L. (2001). Statistical Modeling: The Two Cultures (with comments and a rejoinder by the author). Statistical Science, 16(3), 199–231.
Chopin, N., and Ridgway, J. (2017). Leave Pima Indians Alone: Binary Regression as a Benchmark for Bayesian Computation. Statistical Science, 32(1), 64–87.
Firth, D. (1993). Bias reduction of maximum likelihood estimates. Biometrika, 80(1), 27–38.
Gelman, A., Carlin, J. B., Stern, H. S., Dunson, D. B., Vehtari, A., and Rubin, D. B. (2014). Bayesian data analysis. Boca Raton : CRC Press.
Gelman, A., Jakulin, A., Pittau, M. G., and Su, Y.-S. (2008). A weakly informative default prior distribution for logistic and other regression models. The Annals of Applied Statistics, 2(4), 1360–1383.
Hansen, B. E. (2022). Econometrics. Princeton University Press.
Johndrow, J. E., Smith, A., Pillai, N., and Dunson, D. B. (2019). MCMC for imbalanced categorical data. Journal of the American Statistical Association, 114(527), 1394–1403.
Polson, N. G., Scott, J. G., and Windle, J. (2013). Bayesian inference for logistic models using pólya–gamma latent variables. Journal of the American Statistical Association, 108(504), 1339–1349.
Rasch, G. W. (1960). Studies in mathematical psychology: I. Probabilistic models for some intelligence and attainment tests. Nielsen & Lydiche.
Sen, D., Sachs, M., Lu, J., and Dunson, D. B. (2020). Efficient posterior sampling for high-dimensional imbalanced logistic regression. Biometrika, 107(4), 1005–1012.
Sur, P., and Candès, E. J. (2019). A Modern Maximum-Likelihood Theory for High-Dimensional Logistic Regression. Proceedings of the National Academy of Sciences, 116(29), 14516–14525.