using ZigZagBoomerang
using Distributions
using Random
using LinearAlgebra
using Statistics # just for sure
using StatsFuns
"""
∇U(i,j,ξ,x,y)
i ∈ [d]: 次元を表すインデックス
j ∈ [n]: サンプル番号を表すインデックス
ξ: パラメータ空間 R^d 上の位置
他,観測 (x,y) を引数にとる.
この関数を実装する際,log の中身をそのまま計算しようとすると大変大きくなり,数値的に不安定になる(除算の後は 1 近くになるはずだが,Inf になってしまう)
"""
∇U(i::Int64, j::Int64, ξ, x::Matrix{Float64}, y::Vector{Float64}) = length(y) * x[i,j] * (logistic(dot(x[:,j],ξ)) - y[j])
"""
∇U(i,ξ,x,y):∇U(i,j,ξ,x,y) を全データ j ∈ [n] について足し合わせたもの
i ∈ [d]: 次元を表すインデックス
ξ: パラメータ空間 R^d 上の位置
他,観測 (x,y) を引数にとる.
"""
function ∇U(i::Int64, ξ, x::Matrix{Float64}, y::Vector{Float64})
n = length(y)
U_list = []
for j in 1:n
push!(U_list, ∇U(i, j, ξ, x, y))
end
return mean(U_list)
end
function ∇U(ξ, x::Matrix{Float64}, y::Vector{Float64}) # 1次元の場合のショートカット
return ∇U(1, ξ, x, y)
end
pos(x) = max(zero(x), x)
"""
λ(i, ξ, θ, ∇U, x, y):第 i ∈ [d] 次元のレート関数
i ∈ [d]: 次元を表すインデックス
(ξ,θ): E 上の座標
∇U
(x,y): 観測
"""
λ(i::Int64, ξ, θ, ∇U, x, y) = pos(θ[i] * ∇U(i, ξ, x, y))
λ(ξ, θ, ∇U, x, y) = pos(θ * ∇U(ξ, x, y)) # 1次元の場合のショートカット
"""
λ(τ, a, b):代理レート関数の時刻 τ における値
τ: 時間
a,b: 1次関数の係数
"""
λ_bar(τ, a, b) = pos(a + b*τ)
"""
`x`: current location, `θ`: current velocity, `t`: current time,
"""
function move_forward(τ, t, ξ, θ, ::ZigZag1d)
τ + t, ξ + θ*τ , θ
end
"""
ZZ1d(∇U, ξ, θ, T, x, y, Flow; rng=Random.GLOBAL_RNG, ab=ab_Global):ZigZag sampler without subsampling
`∇U`: gradient of the negative log-density
`(ξ,θ)`: initial state
`T`: Time Horizon
`(x,y)`: observation
`Flow`: continuous dynamics
`a+bt`: computational bound for intensity m(t)
`num`: ポアソン時刻に到着した回数
`acc`: 受容回数.`acc/num` は acceptance rate
"""
function ZZ1d(∇U, ξ, θ, T::Float64, x::Matrix{Float64}, y::Vector{Float64}, Flow::ZigZagBoomerang.ContinuousDynamics; rng=Random.GLOBAL_RNG, ab=ab_Global)
t = zero(T)
Ξ = [(t, ξ, θ)]
num = acc = 0
epoch_list = [num]
a, b = ab(ξ, θ, x, y, Flow)
t′ = t + poisson_time(a, b, rand()) # イベントは a,b が定める affine proxy に従って生成する
while t < T
τ = t′ - t
t, ξ, θ = move_forward(τ, t, ξ, θ, Flow)
l, lb = λ(ξ, θ, ∇U, x, y), λ_bar(τ, a, b) # λ が真のレート, λ_bar が affine proxy
num += 1
if rand()*lb < l
acc += 1
if l > lb + 0.01
println(l-lb)
println(l)
end
θ = -θ
push!(Ξ, (t, ξ, θ))
push!(epoch_list, num)
end
a, b = ab(ξ, θ, x, y, Flow)
t′ = t + poisson_time(a, b, rand())
end
return Ξ, epoch_list, acc/num
end