Skip to content

Commit

Permalink
add max_stretch
Browse files Browse the repository at this point in the history
  • Loading branch information
francescoalemanno committed Sep 7, 2020
1 parent 9b71195 commit 4527abd
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
20 changes: 17 additions & 3 deletions src/smc.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

macro cthreads(condition::Symbol, loop) #does not work well because of #15276, but seems to work on Julia v0.7
return esc(quote
($condition) && (Threads.@threads $loop; true) || ($loop; true)
Expand Down Expand Up @@ -65,6 +64,7 @@ function smc(
- `mcmc_tol` - stopping condition for SMC, if the fraction of accepted particles drops below `mcmc_tol` the algorithm terminates.
- `epstol` - stopping condition for SMC, if the adaptive cost threshold drops below `epstol` the algorithm has converged and thus it terminates.
- `min_r_ess` - whenever the fractional effective sample size drops below `min_r_ess`, a systematic resampling step is performed.
- `max_stretch` - the proposal distribution of `smc` is the stretch move of Foreman-Mackey et al. 2013, the larger the parameters the wider becomes the proposal distribution.
- `verbose` - if set to `true`, enables verbosity.
- `parallel` - if set to `true`, threaded parallelism is enabled, keep in mind that the cost function must be Thread-safe in such case.
Expand Down Expand Up @@ -96,6 +96,7 @@ function smc(
epstol = 0.0,
r_epstol = (1 - alpha) / 50,
min_r_ess = 0.55,
max_stretch = 2.0,
verbose::Bool = false,
parallel::Bool = false,
) where {Tprior<:Distribution}
Expand Down Expand Up @@ -183,7 +184,8 @@ function smc(
for (A,B) in ((s1,s2),(s2,s1))
new_p = map(A) do i
a = rand(rng,B)
Z = sample_g(rng, 2.0)
u = rand(rng)
Z = (u * (sqrt(max_stretch) - sqrt(1 / max_stretch)) + sqrt(1 / max_stretch))^2
W = op(*, op(-, θs[i], θs[a]), Z)
(log(rand(rng)), op(+, θs[a], W), (Np - 1) * log(Z))
end
Expand Down Expand Up @@ -219,13 +221,21 @@ function smc(

l = length(prior)
P = map(x -> Particles(x), getindex.(θs, i) for i = 1:l)
length(P)==1 && (P=first(P))
W = Particles(Ws[filter])
(P = P, W = W, ϵ = ϵ)
end

export smc

#=
using KissABC
pp=Normal(0,5)
cc(x) = 50*(x+randn()*0.01-1)^2
R=smc(pp,cc,verbose=true,alpha=0.95,nparticles=500).P
using KissABC
pp=Factored(Normal(0,5), Normal(0,5))
cc((x,y)) = 50*(x+randn()*0.01-y^2)^2+(y-1+randn()*0.01)^2
Expand Down Expand Up @@ -257,7 +267,11 @@ function costfun((u1, p1); raw=false)
sqrt(sum(abs2,[std(x)-2.2, median(x)-0.4]./[2.2,0.4]))
end
@time R=smc(Factored(Uniform(0,1), Uniform(0.5,1)), costfun, nparticles=100, M=1, verbose=true, alpha=0.6,epstol=0.01,parallel=true)
@time R=smc(Factored(Uniform(0,1), Uniform(0.5,1)), costfun, nparticles=100, M=1, verbose=true,epstol=0.01,alpha=0.55,parallel=true)
plan=ApproxPosterior(Factored(Uniform(0,1), Uniform(0.5,1)), costfun, 0.01)
@time res = sample(plan, AIS(25),MCMCThreads(),25,4,discard_initial=2500)
using PyPlot
pygui(true)
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end
res = sample(abc, AIS(12), 500, discard_initial = 1000, progress = false)

@test sim(res) 1.5
@test smc(pri, cost, epstol = 0.1).P[1] 0.707
@test smc(pri, cost, epstol = 0.1).P 0.707
end

@testset "Normal dist -> Dirac Delta inference, MCMCThreads" begin
Expand Down Expand Up @@ -162,7 +162,7 @@ end
discard_initial = 5000,
progress = false,
)
ressmc = smc(prior, cost, nparticles = 2000, alpha = 0.99, epstol = 0.01).P[1]
ressmc = smc(prior, cost, nparticles = 2000, alpha = 0.99, epstol = 0.01).P
testst(alg, r) = begin
m = mean(abs, st(r) - st_n)
println(":", alg, ": testing m = ", m)
Expand Down

0 comments on commit 4527abd

Please sign in to comment.