using ZigZagBoomerang
using Distributions
using Random
using LinearAlgebra
using Statistics  # just for sure
using StatsFuns
"""
    ∇U(i,j,ξ,x,y)
        i ∈ [d]: 次元を表すインデックス
        j ∈ [n]: サンプル番号を表すインデックス
        ξ: パラメータ空間 R^d 上の位置
        他,観測 (x,y) を引数にとる.
    この関数を実装する際,log の中身をそのまま計算しようとすると大変大きくなり,数値的に不安定になる(除算の後は 1 近くになるはずだが,Inf になってしまう)
"""
∇U(i::Int64, j::Int64, ξ, x::Matrix{Float64}, y::Vector{Float64}) = length(y) * x[i,j] * (logistic(dot(x[:,j],ξ)) - y[j])
"""
    ∇U(i,ξ,x,y):∇U(i,j,ξ,x,y) を全データ j ∈ [n] について足し合わせたもの
        i ∈ [d]: 次元を表すインデックス
        ξ: パラメータ空間 R^d 上の位置
        他,観測 (x,y) を引数にとる.
"""
function ∇U(i::Int64, ξ, x::Matrix{Float64}, y::Vector{Float64})
    n = length(y)
    U_list = []
    for j in 1:n
        push!(U_list, ∇U(i, j, ξ, x, y))
    end
    return mean(U_list)
end
function  ∇U(ξ, x::Matrix{Float64}, y::Vector{Float64})  # 1次元の場合のショートカット
    return ∇U(1, ξ, x, y)
end
pos(x) = max(zero(x), x)
"""
    λ(i, ξ, θ, ∇U, x, y):第 i ∈ [d] 次元のレート関数
        i ∈ [d]: 次元を表すインデックス
        (ξ,θ): E 上の座標
        ∇U
        (x,y): 観測
"""
λ(i::Int64, ξ, θ, ∇U, x, y) = pos(θ[i] * ∇U(i, ξ, x, y))
λ(ξ, θ, ∇U, x, y) = pos(θ * ∇U(ξ, x, y))  # 1次元の場合のショートカット
"""
    λ(τ, a, b):代理レート関数の時刻 τ における値
        τ: 時間
        a,b: 1次関数の係数
"""
λ_bar(τ, a, b) = pos(a + b*τ)
"""
`x`: current location, `θ`: current velocity, `t`: current time,
"""
function move_forward(τ, t, ξ, θ, ::ZigZag1d)
    τ + t, ξ + θ*τ , θ
end
"""
    ZZ1d(∇U, ξ, θ, T, x, y, Flow; rng=Random.GLOBAL_RNG, ab=ab_Global):ZigZag sampler without subsampling
        `∇U`: gradient of the negative log-density
        `(ξ,θ)`: initial state
        `T`: Time Horizon
        `(x,y)`: observation
        `Flow`: continuous dynamics
        `a+bt`: computational bound for intensity m(t)
        `num`: ポアソン時刻に到着した回数
        `acc`: 受容回数.`acc/num` は acceptance rate
"""
function ZZ1d(∇U, ξ, θ, T::Float64, x::Matrix{Float64}, y::Vector{Float64}, Flow::ZigZagBoomerang.ContinuousDynamics; rng=Random.GLOBAL_RNG, ab=ab_Global)
    t = zero(T)
    Ξ = [(t, ξ, θ)]
    num = acc = 0
    epoch_list = [num]
    a, b = ab(ξ, θ, x, y, Flow)
    t′ =  t + poisson_time(a, b, rand())  # イベントは a,b が定める affine proxy に従って生成する
    while t < T
        τ = t′ - t
        t, ξ, θ = move_forward(τ, t, ξ, θ, Flow)
        l, lb = λ(ξ, θ, ∇U, x, y), λ_bar(τ, a, b)  # λ が真のレート, λ_bar が affine proxy
        num += 1
        if rand()*lb < l
            acc += 1
            if l > lb + 0.01
                println(l-lb)
                println(l)
            end
            θ = -θ
            push!(Ξ, (t, ξ, θ))
            push!(epoch_list, num)
        end
        a, b = ab(ξ, θ, x, y, Flow)
        t′ = t + poisson_time(a, b, rand())
    end
    return Ξ, epoch_list, acc/num
end