Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update inference routines #116

Merged
merged 6 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HiddenMarkovModels"
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
authors = ["Guillaume Dalle"]
version = "0.5.4"
version = "0.6.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
12 changes: 8 additions & 4 deletions libs/HMMTest/src/allocations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ function test_allocations(
seq_ends::AbstractVectorOrNTuple{Int},
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
)
# making seq_ends a tuple disables multithreading
seq_ends = ntuple(k -> seq_ends[k], Val(min(2, length(seq_ends))))
control_seq = control_seq[1:last(seq_ends)]

@testset "Allocations" begin
obs_seq = mapreduce(vcat, eachindex(seq_ends)) do k
t1, t2 = seq_limits(seq_ends, k)
Expand All @@ -18,23 +22,23 @@ function test_allocations(

f_storage = HMMs.initialize_forward(hmm, obs_seq, control_seq; seq_ends)
allocs_f = @ballocated HMMs.forward!(
$f_storage, $hmm, $obs_seq, $control_seq, $t1, $t2
$f_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends
) evals = 1 samples = 1
@test allocs_f == 0

## Viterbi

v_storage = HMMs.initialize_viterbi(hmm, obs_seq, control_seq; seq_ends)
allocs_v = @ballocated HMMs.viterbi!(
$v_storage, $hmm, $obs_seq, $control_seq, $t1, $t2
$v_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends
) evals = 1 samples = 1
@test allocs_v == 0

## Forward-backward

fb_storage = HMMs.initialize_forward_backward(hmm, obs_seq, control_seq; seq_ends)
allocs_fb = @ballocated HMMs.forward_backward!(
$fb_storage, $hmm, $obs_seq, $control_seq, $t1, $t2
$fb_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends
) evals = 1 samples = 1
@test allocs_fb == 0

Expand All @@ -48,7 +52,7 @@ function test_allocations(
allocs_bw = @ballocated fit!(
hmm_guess_copy, $fb_storage, $obs_seq, $control_seq; seq_ends=$seq_ends
) evals = 1 samples = 1 setup = (hmm_guess_copy = deepcopy($hmm_guess))
@test_broken allocs_bw == 0
@test allocs_bw == 0
end
end
end
27 changes: 11 additions & 16 deletions src/inference/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,16 @@
return ForwardStorage(α, logL, B, c)
end

"""
$(SIGNATURES)
"""
function forward!(
function _forward!(
storage::ForwardOrForwardBackwardStorage,
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector,
t1::Integer,
t2::Integer;
seq_ends::AbstractVectorOrNTuple{Int},
k::Integer,
)
(; α, B, c) = storage
(; α, B, c, logL) = storage
t1, t2 = seq_limits(seq_ends, k)

# Initialization
Bₜ₁ = view(B, :, t1)
Expand All @@ -88,7 +86,7 @@
c[t1] = inv(sum(αₜ₁))
lmul!(c[t1], αₜ₁)

logL = -log(c[t1]) + logm
logL[k] = -log(c[t1]) + logm

# Loop
for t in t1:(t2 - 1)
Expand All @@ -104,11 +102,11 @@
c[t + 1] = inv(sum(αₜ₊₁))
lmul!(c[t + 1], αₜ₊₁)

logL += -log(c[t + 1]) + logm
logL[k] += -log(c[t + 1]) + logm
end

@argcheck isfinite(logL)
return logL
@argcheck isfinite(logL[k])
return nothing
end

"""
Expand All @@ -121,16 +119,13 @@
control_seq::AbstractVector;
seq_ends::AbstractVectorOrNTuple{Int},
)
(; logL) = storage
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;)
_forward!(storage, hmm, obs_seq, control_seq, seq_ends, k)

Check warning on line 124 in src/inference/forward.jl

View check run for this annotation

Codecov / codecov/patch

src/inference/forward.jl#L124

Added line #L124 was not covered by tests
end
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;)
_forward!(storage, hmm, obs_seq, control_seq, seq_ends, k)
end
end
return nothing
Expand Down
25 changes: 10 additions & 15 deletions src/inference/forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,20 @@
return ForwardBackwardStorage{R,M}(γ, ξ, logL, B, α, c, β, Bβ)
end

"""
$(SIGNATURES)
"""
function forward_backward!(
function _forward_backward!(
storage::ForwardBackwardStorage{R},
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector,
t1::Integer,
t2::Integer;
seq_ends::AbstractVectorOrNTuple{Int},
k::Integer;
transition_marginals::Bool=true,
) where {R}
(; α, β, c, γ, ξ, B, Bβ) = storage
t1, t2 = seq_limits(seq_ends, k)

# Forward (fill B, α, c and logL)
logL = forward!(storage, hmm, obs_seq, control_seq, t1, t2)
_forward!(storage, hmm, obs_seq, control_seq, seq_ends, k)

# Backward
β[:, t2] .= c[t2]
Expand All @@ -68,7 +66,7 @@
ξ[t2] .= zero(R)
end

return logL
return nothing
end

"""
Expand All @@ -82,19 +80,16 @@
seq_ends::AbstractVectorOrNTuple{Int},
transition_marginals::Bool=true,
)
(; logL) = storage
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward_backward!(
storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals
_forward_backward!(

Check warning on line 85 in src/inference/forward_backward.jl

View check run for this annotation

Codecov / codecov/patch

src/inference/forward_backward.jl#L85

Added line #L85 was not covered by tests
storage, hmm, obs_seq, control_seq, seq_ends, k; transition_marginals
)
end
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward_backward!(
storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals
_forward_backward!(
storage, hmm, obs_seq, control_seq, seq_ends, k; transition_marginals
)
end
end
Expand Down
25 changes: 10 additions & 15 deletions src/inference/viterbi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,16 @@
return ViterbiStorage(q, logL, logB, ϕ, ψ)
end

"""
$(SIGNATURES)
"""
function viterbi!(
function _viterbi!(
storage::ViterbiStorage{R},
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector,
t1::Integer,
t2::Integer;
seq_ends::AbstractVectorOrNTuple{Int},
k::Integer,
) where {R}
(; q, logB, ϕ, ψ) = storage
(; q, logB, ϕ, ψ, logL) = storage
t1, t2 = seq_limits(seq_ends, k)

logBₜ₁ = view(logB, :, t1)
obs_logdensities!(logBₜ₁, hmm, obs_seq[t1], control_seq[t1])
Expand All @@ -66,13 +64,13 @@

ϕₜ₂ = view(ϕ, :, t2)
q[t2] = argmax(ϕₜ₂)
logL = ϕ[q[t2], t2]
logL[k] = ϕ[q[t2], t2]
for t in (t2 - 1):-1:t1
q[t] = ψ[q[t + 1], t + 1]
end

@argcheck isfinite(logL)
return logL
@argcheck isfinite(logL[k])
return nothing
end

"""
Expand All @@ -85,16 +83,13 @@
control_seq::AbstractVector;
seq_ends::AbstractVectorOrNTuple{Int},
) where {R}
(; logL) = storage
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;)
_viterbi!(storage, hmm, obs_seq, control_seq, seq_ends, k)

Check warning on line 88 in src/inference/viterbi.jl

View check run for this annotation

Codecov / codecov/patch

src/inference/viterbi.jl#L88

Added line #L88 was not covered by tests
end
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;)
_viterbi!(storage, hmm, obs_seq, control_seq, seq_ends, k)
end
end
return nothing
Expand Down
24 changes: 17 additions & 7 deletions src/types/hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,23 @@
)
(; γ, ξ) = fb_storage
# Fit states
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
# use ξ[t2] as scratch space since it is zero anyway
scratch = ξ[t2]
fill!(scratch, zero(eltype(scratch)))
for t in t1:(t2 - 1)
scratch .+= ξ[t]
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
scratch = ξ[t2] # use ξ[t2] as scratch space since it is zero anyway
fill!(scratch, zero(eltype(scratch)))
for t in t1:(t2 - 1)
scratch .+= ξ[t]
end
end

Check warning on line 72 in src/types/hmm.jl

View check run for this annotation

Codecov / codecov/patch

src/types/hmm.jl#L65-L72

Added lines #L65 - L72 were not covered by tests
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
scratch = ξ[t2] # use ξ[t2] as scratch space since it is zero anyway
fill!(scratch, zero(eltype(scratch)))
for t in t1:(t2 - 1)
scratch .+= ξ[t]
end
end
end
fill!(hmm.init, zero(eltype(hmm.init)))
Expand Down
Loading