Skip to content

Commit

Permalink
abcdesmc! | early exit mcmc, for issue #5
Browse files Browse the repository at this point in the history
  • Loading branch information
mauricelanghinrichs committed Apr 10, 2023
1 parent 95b1a83 commit 7f619db
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions src/abcdez_smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,15 @@ The particles have to be weighted (via `r.Wns`) for valid posterior samples.
- `nsims_max::Int=10^7`: maximal number of `dist!` evaluations (not counting initial samples from prior);
algorithm stops if `ϵ_target` or `nsims_max` is reached.
- `Kmcmc::Int=3`: number of MCMC (Markov chain Monte Carlo) steps at each sequential
target distribution specified by current ϵ and ABC kernel type.
target distribution specified by current ϵ and ABC kernel type; actual number can be
lower due to early exits as specified by `Kmcmc_min`.
- `Kmcmc_min=1.0`: minimal value for accumulated acceptances per alive particle; if this value
is reached earlier (before completing the total `Kmcmc` MCMC steps) the loop is early-exited;
set `Kmcmc_min=Inf` to compute exactly `Kmcmc` MCMC steps at each ϵ target.
- `ABCk=ABCdeZ.Indicator0toϵ`: ABC kernel to be specified by ϵ widths that receives distance values.
- `facc_min=0.25`: if the fraction of accepted MCMC proposals drops below `facc_min`, diffential evolution
- `facc_min=0.15`: if the fraction of accepted MCMC proposals drops below `facc_min`, diffential evolution
proposals are reduced by a factor of `facc_tune`.
- `facc_tune=0.95`: factor to reduce the jump distance of the diffential evolution
- `facc_tune=0.975`: factor to reduce the jump distance of the diffential evolution
proposals in the MCMC step (used if `facc_min` is reached).
- `verbose::Bool=true`: if set to `true`, enables verbosity (printout to REPL).
- `verboseout::Bool=true`: if set to `true`, algorithm returns a more detailed inference output.
Expand All @@ -209,9 +213,9 @@ julia> evidence = exp(r.logZ);
```
"""
function abcdesmc!(prior, dist!, ϵ_target, varexternal;
nparticles::Int=100, α=0.95,
δess=0.5, nsims_max::Int=10^7, Kmcmc::Int=3,
ABCk=Indicator0toϵ, facc_min=0.25, facc_tune=0.95,
nparticles::Int=100, α=0.95, δess=0.5,
nsims_max::Int=10^7, Kmcmc::Int=3, Kmcmc_min=1.0,
ABCk=Indicator0toϵ, facc_min=0.15, facc_tune=0.975,
verbose::Bool=true, verboseout::Bool=true,
rng=Random.GLOBAL_RNG, parallel::Bool=false)

Expand All @@ -222,14 +226,16 @@ function abcdesmc!(prior, dist!, ϵ_target, varexternal;
0.0 facc_tune 1.0 || error("facc_tune must be in 0 <= facc_tune <= 1")
0.0 ϵ_target || error("ϵ_target must be non-negative") # TODO/NOTE: like this or adaptive termination?
1 Kmcmc || error("Kmcmc must be at least 1")
0.0 Kmcmc_min Inf || error("Kmcmc_min must be in 0 <= Kmcmc_min <= Inf")
1 nsims_max || error("nsims_max must be at least 1")
Kmcmc_min > facc_min || @warn("Kmcmc_min should be larger than facc_min")

nparticles_min = ceil(Int, 3 * length(prior) / (min(α, δess)))
nparticles_min nparticles || error("nparticles must be at least $(nparticles_min)")

parallel ? ex=ThreadedEx() : ex=SequentialEx()
verbose && (@info("Running abcdesmc! with executor ($(Threads.nthreads()) threads available) ", typeof(ex)))
verbose && (@info "Running abcdesmc! with" ϵ_target nparticles α δess nsims_max Kmcmc ABCk facc_min facc_tune rng parallel verboseout)
verbose && (@info "Running abcdesmc! with" ϵ_target nparticles α δess nsims_max Kmcmc Kmcmc_min ABCk facc_min facc_tune rng parallel verboseout)

# draw prior parameters for each particle, and calculate logprior values
θs = [op(float, Particle(rand(rng, prior))) for i in 1:nparticles]
Expand Down Expand Up @@ -267,6 +273,7 @@ function abcdesmc!(prior, dist!, ϵ_target, varexternal;
nsims = zeros(Int, nparticles)
naccs = zeros(Int, nparticles)
facc = 1.0
Ki = Kmcmc

# parameters for DE move
γ0 = 2.38 / sqrt(2 * length(prior))
Expand All @@ -280,6 +287,7 @@ function abcdesmc!(prior, dist!, ϵ_target, varexternal;
esss = [get_ess(Wns)]
faccs = [facc]
γ0s = [γ0]
Kmcmcs = [Ki]
end

iters = 0
Expand All @@ -305,8 +313,9 @@ function abcdesmc!(prior, dist!, ϵ_target, varexternal;
# (NOTE: log-space may make this estimate biased (Jensen ineq.), but ok...)
logZ += log(wnorm)

# reset naccs and tune proposal if it dropped below facc_min in previous step
# reset naccs, Ki and tune proposal if it dropped below facc_min in previous step
naccs .= 0
Ki = Kmcmc
facc < facc_min && (γ0 *= facc_tune)

# resample if effective sample size too low
Expand All @@ -323,7 +332,7 @@ function abcdesmc!(prior, dist!, ϵ_target, varexternal;
# see KissABC; however (also for Z estimate) it may be better in general
# to equilibrate better to current target distribution, so Kmcmc>1 in every
# step may be wanted in some cases)
for __ in 1:Kmcmc
for i in 1:Kmcmc
nθs = identity.(θs) # vector of particles, where θs[i].x are parameters (as tuple)
nΔs = identity.(Δs) # vector of floats with distance values (model/data)
nlogπ = identity.(logπ) # vector of floats with log prior values of above particles
Expand All @@ -338,11 +347,13 @@ function abcdesmc!(prior, dist!, ϵ_target, varexternal;
Δs = nΔs
logπ = nlogπ
blobs = nblobs

(sum(naccs)/sum(alive) Kmcmc_min) && (Ki = i; break)
end

# compute acceptance fraction of the last Kmcmc Markov steps
# among live particles
facc = sum(naccs)/(sum(alive)*Kmcmc)
facc = sum(naccs)/(sum(alive)*Ki)

# update kernel
ϵ_k = ϵ_k_new
Expand All @@ -354,6 +365,7 @@ function abcdesmc!(prior, dist!, ϵ_target, varexternal;
push!(esss, ess)
push!(faccs, facc)
push!(γ0s, γ0)
push!(Kmcmcs, Ki)
end

verbose && (@info "Finished run:" iteration = iters nsim = sum(nsims) ϵ = ϵ range_ϵ = extrema(Δs) ess = ess facc = facc logZ = logZ)
Expand All @@ -373,7 +385,7 @@ function abcdesmc!(prior, dist!, ϵ_target, varexternal;

if verboseout
(P = θs, Wns = Wns, C = Δs, ϵ = ϵ, logZ = logZ, blobs = blobs,
ϵs = ϵs, ranges_ϵ = ranges_ϵ, logZs = logZs, esss = esss, faccs = faccs, γ0s = γ0s)
ϵs = ϵs, ranges_ϵ = ranges_ϵ, logZs = logZs, esss = esss, faccs = faccs, γ0s = γ0s, Kmcmcs = Kmcmcs)
else
(P = θs, Wns = Wns, C = Δs, ϵ = ϵ, logZ = logZ, blobs = blobs)
end
Expand Down

0 comments on commit 7f619db

Please sign in to comment.