Skip to content

Commit

Permalink
Update inference routines (#116)
Browse files Browse the repository at this point in the history
* Update `logL` in the inference routines

* Redocument

* Fix

* Bloody semicolon

* Fixes

* Remove last allocations
  • Loading branch information
gdalle authored Oct 1, 2024
1 parent 882f2f7 commit 38eb996
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 58 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HiddenMarkovModels"
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
authors = ["Guillaume Dalle"]
version = "0.5.4"
version = "0.6.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
12 changes: 8 additions & 4 deletions libs/HMMTest/src/allocations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ function test_allocations(
seq_ends::AbstractVectorOrNTuple{Int},
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
)
# making seq_ends a tuple disables multithreading
seq_ends = ntuple(k -> seq_ends[k], Val(min(2, length(seq_ends))))
control_seq = control_seq[1:last(seq_ends)]

@testset "Allocations" begin
obs_seq = mapreduce(vcat, eachindex(seq_ends)) do k
t1, t2 = seq_limits(seq_ends, k)
Expand All @@ -18,23 +22,23 @@ function test_allocations(

f_storage = HMMs.initialize_forward(hmm, obs_seq, control_seq; seq_ends)
allocs_f = @ballocated HMMs.forward!(
$f_storage, $hmm, $obs_seq, $control_seq, $t1, $t2
$f_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends
) evals = 1 samples = 1
@test allocs_f == 0

## Viterbi

v_storage = HMMs.initialize_viterbi(hmm, obs_seq, control_seq; seq_ends)
allocs_v = @ballocated HMMs.viterbi!(
$v_storage, $hmm, $obs_seq, $control_seq, $t1, $t2
$v_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends
) evals = 1 samples = 1
@test allocs_v == 0

## Forward-backward

fb_storage = HMMs.initialize_forward_backward(hmm, obs_seq, control_seq; seq_ends)
allocs_fb = @ballocated HMMs.forward_backward!(
$fb_storage, $hmm, $obs_seq, $control_seq, $t1, $t2
$fb_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends
) evals = 1 samples = 1
@test allocs_fb == 0

Expand All @@ -48,7 +52,7 @@ function test_allocations(
allocs_bw = @ballocated fit!(
hmm_guess_copy, $fb_storage, $obs_seq, $control_seq; seq_ends=$seq_ends
) evals = 1 samples = 1 setup = (hmm_guess_copy = deepcopy($hmm_guess))
@test_broken allocs_bw == 0
@test allocs_bw == 0
end
end
end
27 changes: 11 additions & 16 deletions src/inference/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,16 @@ function initialize_forward(
return ForwardStorage(α, logL, B, c)
end

"""
$(SIGNATURES)
"""
function forward!(
function _forward!(
storage::ForwardOrForwardBackwardStorage,
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector,
t1::Integer,
t2::Integer;
seq_ends::AbstractVectorOrNTuple{Int},
k::Integer,
)
(; α, B, c) = storage
(; α, B, c, logL) = storage
t1, t2 = seq_limits(seq_ends, k)

# Initialization
Bₜ₁ = view(B, :, t1)
Expand All @@ -88,7 +86,7 @@ function forward!(
c[t1] = inv(sum(αₜ₁))
lmul!(c[t1], αₜ₁)

logL = -log(c[t1]) + logm
logL[k] = -log(c[t1]) + logm

# Loop
for t in t1:(t2 - 1)
Expand All @@ -104,11 +102,11 @@ function forward!(
c[t + 1] = inv(sum(αₜ₊₁))
lmul!(c[t + 1], αₜ₊₁)

logL += -log(c[t + 1]) + logm
logL[k] += -log(c[t + 1]) + logm
end

@argcheck isfinite(logL)
return logL
@argcheck isfinite(logL[k])
return nothing
end

"""
Expand All @@ -121,16 +119,13 @@ function forward!(
control_seq::AbstractVector;
seq_ends::AbstractVectorOrNTuple{Int},
)
(; logL) = storage
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;)
_forward!(storage, hmm, obs_seq, control_seq, seq_ends, k)
end
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;)
_forward!(storage, hmm, obs_seq, control_seq, seq_ends, k)
end
end
return nothing
Expand Down
25 changes: 10 additions & 15 deletions src/inference/forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,20 @@ function initialize_forward_backward(
return ForwardBackwardStorage{R,M}(γ, ξ, logL, B, α, c, β, Bβ)
end

"""
$(SIGNATURES)
"""
function forward_backward!(
function _forward_backward!(
storage::ForwardBackwardStorage{R},
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector,
t1::Integer,
t2::Integer;
seq_ends::AbstractVectorOrNTuple{Int},
k::Integer;
transition_marginals::Bool=true,
) where {R}
(; α, β, c, γ, ξ, B, Bβ) = storage
t1, t2 = seq_limits(seq_ends, k)

# Forward (fill B, α, c and logL)
logL = forward!(storage, hmm, obs_seq, control_seq, t1, t2)
_forward!(storage, hmm, obs_seq, control_seq, seq_ends, k)

# Backward
β[:, t2] .= c[t2]
Expand All @@ -68,7 +66,7 @@ function forward_backward!(
ξ[t2] .= zero(R)
end

return logL
return nothing
end

"""
Expand All @@ -82,19 +80,16 @@ function forward_backward!(
seq_ends::AbstractVectorOrNTuple{Int},
transition_marginals::Bool=true,
)
(; logL) = storage
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward_backward!(
storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals
_forward_backward!(
storage, hmm, obs_seq, control_seq, seq_ends, k; transition_marginals
)
end
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward_backward!(
storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals
_forward_backward!(
storage, hmm, obs_seq, control_seq, seq_ends, k; transition_marginals
)
end
end
Expand Down
25 changes: 10 additions & 15 deletions src/inference/viterbi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,16 @@ function initialize_viterbi(
return ViterbiStorage(q, logL, logB, ϕ, ψ)
end

"""
$(SIGNATURES)
"""
function viterbi!(
function _viterbi!(
storage::ViterbiStorage{R},
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector,
t1::Integer,
t2::Integer;
seq_ends::AbstractVectorOrNTuple{Int},
k::Integer,
) where {R}
(; q, logB, ϕ, ψ) = storage
(; q, logB, ϕ, ψ, logL) = storage
t1, t2 = seq_limits(seq_ends, k)

logBₜ₁ = view(logB, :, t1)
obs_logdensities!(logBₜ₁, hmm, obs_seq[t1], control_seq[t1])
Expand All @@ -66,13 +64,13 @@ function viterbi!(

ϕₜ₂ = view(ϕ, :, t2)
q[t2] = argmax(ϕₜ₂)
logL = ϕ[q[t2], t2]
logL[k] = ϕ[q[t2], t2]
for t in (t2 - 1):-1:t1
q[t] = ψ[q[t + 1], t + 1]
end

@argcheck isfinite(logL)
return logL
@argcheck isfinite(logL[k])
return nothing
end

"""
Expand All @@ -85,16 +83,13 @@ function viterbi!(
control_seq::AbstractVector;
seq_ends::AbstractVectorOrNTuple{Int},
) where {R}
(; logL) = storage
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;)
_viterbi!(storage, hmm, obs_seq, control_seq, seq_ends, k)
end
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;)
_viterbi!(storage, hmm, obs_seq, control_seq, seq_ends, k)
end
end
return nothing
Expand Down
24 changes: 17 additions & 7 deletions src/types/hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,23 @@ function StatsAPI.fit!(
)
(; γ, ξ) = fb_storage
# Fit states
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
# use ξ[t2] as scratch space since it is zero anyway
scratch = ξ[t2]
fill!(scratch, zero(eltype(scratch)))
for t in t1:(t2 - 1)
scratch .+= ξ[t]
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
scratch = ξ[t2] # use ξ[t2] as scratch space since it is zero anyway
fill!(scratch, zero(eltype(scratch)))
for t in t1:(t2 - 1)
scratch .+= ξ[t]
end
end
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
scratch = ξ[t2] # use ξ[t2] as scratch space since it is zero anyway
fill!(scratch, zero(eltype(scratch)))
for t in t1:(t2 - 1)
scratch .+= ξ[t]
end
end
end
fill!(hmm.init, zero(eltype(hmm.init)))
Expand Down

0 comments on commit 38eb996

Please sign in to comment.