Skip to content

Commit

Permalink
Decompose forward function into initialize, predict, update
Browse files Browse the repository at this point in the history
  • Loading branch information
THargreaves committed Sep 27, 2024
1 parent a8b048a commit 82fd465
Showing 1 changed file with 41 additions and 29 deletions.
70 changes: 41 additions & 29 deletions src/inference/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,43 +47,55 @@ function forward!(
t1::Integer,
t2::Integer;
)
(; α, B, c) = storage

# Initialization
Bₜ₁ = view(B, :, t1)
obs_logdensities!(Bₜ₁, hmm, obs_seq[t1], control_seq[t1])
logm = maximum(Bₜ₁)
Bₜ₁ .= exp.(Bₜ₁ .- logm)
_initialize!(storage, hmm, t1)
logL = zero(eltype(storage))

init = initialization(hmm)
αₜ₁ = view(α, :, t1)
αₜ₁ .= init .* Bₜ₁
c[t1] = inv(sum(αₜ₁))
lmul!(c[t1], αₜ₁)

logL = -log(c[t1]) + logm

# Loop
for t in t1:(t2 - 1)
Bₜ₊₁ = view(B, :, t + 1)
obs_logdensities!(Bₜ₊₁, hmm, obs_seq[t + 1], control_seq[t + 1])
logm = maximum(Bₜ₊₁)
Bₜ₊₁ .= exp.(Bₜ₊₁ .- logm)

trans = transition_matrix(hmm, control_seq[t])
αₜ, αₜ₊₁ = view(α, :, t), view(α, :, t + 1)
mul!(αₜ₊₁, transpose(trans), αₜ)
αₜ₊₁ .*= Bₜ₊₁
c[t + 1] = inv(sum(αₜ₊₁))
lmul!(c[t + 1], αₜ₊₁)

logL += -log(c[t + 1]) + logm
# Filter step loop
for t in t1:t2
t > t1 && _predict!(storage, hmm, control_seq, t)
logL = _update!(storage, logL, hmm, obs_seq, control_seq, t)
end

@argcheck isfinite(logL)
return logL
end

function _initialize!(storage, hmm, t1)
(; α) = storage
αₜ₁ = view(α, :, t1)
αₜ₁ .= initialization(hmm)
return nothing
end

function _predict!(storage, hmm, control_seq, t)
(; α) = storage
αₜ₋₁, αₜ = view(α, :, t - 1), view(α, :, t)

trans = transition_matrix(hmm, control_seq[t])
mul!(αₜ, transpose(trans), αₜ₋₁)

return nothing
end

function _update!(storage, logL, hmm, obs_seq, control_seq, t)
(; α, B, c) = storage
Bₜ = view(B, :, t)
αₜ = view(α, :, t)

obs_logdensities!(Bₜ, hmm, obs_seq[t], control_seq[t])
logm = maximum(Bₜ)
Bₜ .= exp.(Bₜ .- logm)

αₜ .*= Bₜ
c[t] = inv(sum(αₜ))
lmul!(c[t], αₜ)

logL += -log(c[t]) + logm

return logL
end

"""
$(SIGNATURES)
"""
Expand Down

0 comments on commit 82fd465

Please sign in to comment.