From a19450d4d81a5d3fc2c6a2801551ba8c1c58af26 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 1 Oct 2024 12:42:15 +0200 Subject: [PATCH] Update `logL` in the inference routines --- Project.toml | 2 +- src/inference/forward.jl | 23 +++++++++-------------- src/inference/forward_backward.jl | 23 +++++++++-------------- src/inference/viterbi.jl | 21 ++++++++------------- 4 files changed, 27 insertions(+), 42 deletions(-) diff --git a/Project.toml b/Project.toml index 3d5d5c6c..4119a543 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "HiddenMarkovModels" uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47" authors = ["Guillaume Dalle"] -version = "0.5.4" +version = "0.6.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 9b2ee492..74818c00 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -63,18 +63,16 @@ function initialize_forward( return ForwardStorage(α, logL, B, c) end -""" -$(SIGNATURES) -""" function forward!( storage::ForwardOrForwardBackwardStorage, hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector, - t1::Integer, - t2::Integer; + seq_ends::AbstractVectorOrNTuple{Int}, + k::Integer, ) (; α, B, c) = storage + t1, t2 = seq_limits(seq_ends, k) # Initialization Bₜ₁ = view(B, :, t1) @@ -88,7 +86,7 @@ function forward!( c[t1] = inv(sum(αₜ₁)) lmul!(c[t1], αₜ₁) - logL = -log(c[t1]) + logm + logL[k] = -log(c[t1]) + logm # Loop for t in t1:(t2 - 1) @@ -104,11 +102,11 @@ function forward!( c[t + 1] = inv(sum(αₜ₊₁)) lmul!(c[t + 1], αₜ₊₁) - logL += -log(c[t + 1]) + logm + logL[k] += -log(c[t + 1]) + logm end - @argcheck isfinite(logL) - return logL + @argcheck isfinite(logL[k]) + return nothing end """ @@ -121,16 +119,13 @@ function forward!( control_seq::AbstractVector; seq_ends::AbstractVectorOrNTuple{Int}, ) - (; logL) = storage if seq_ends isa NTuple for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;) + forward!(storage, hmm, obs_seq, control_seq, seq_ends, k) end else @threads for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;) + forward!(storage, hmm, obs_seq, control_seq, seq_ends, k) end end return nothing diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index 2f64ab0d..5aa4b7bf 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -29,22 +29,20 @@ function initialize_forward_backward( return ForwardBackwardStorage{R,M}(γ, ξ, logL, B, α, c, β, Bβ) end -""" -$(SIGNATURES) -""" function forward_backward!( storage::ForwardBackwardStorage{R}, hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector, - t1::Integer, - t2::Integer; + seq_ends::AbstractVectorOrNTuple{Int}, + k::Integer, transition_marginals::Bool=true, ) where {R} (; α, β, c, γ, ξ, B, Bβ) = storage + t1, t2 = seq_limits(seq_ends, k) # Forward (fill B, α, c and logL) - logL = forward!(storage, hmm, obs_seq, control_seq, t1, t2) + forward!(storage, hmm, obs_seq, control_seq, t1, t2) # Backward β[:, t2] .= c[t2] @@ -68,7 +66,7 @@ function forward_backward!( ξ[t2] .= zero(R) end - return logL + return nothing end """ @@ -82,19 +80,16 @@ function forward_backward!( seq_ends::AbstractVectorOrNTuple{Int}, transition_marginals::Bool=true, ) - (; logL) = storage if seq_ends isa NTuple for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = forward_backward!( - storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals + forward_backward!( + storage, hmm, obs_seq, control_seq, seq_ends, k; transition_marginals ) end else @threads for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = forward_backward!( - storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals + forward_backward!( + storage, hmm, obs_seq, control_seq, seq_ends, k; transition_marginals ) end end diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index f1758097..9ef829cb 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -36,18 +36,16 @@ function initialize_viterbi( return ViterbiStorage(q, logL, logB, ϕ, ψ) end -""" -$(SIGNATURES) -""" function viterbi!( storage::ViterbiStorage{R}, hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector, - t1::Integer, - t2::Integer; + seq_ends::AbstractVectorOrNTuple{Int}, + k::Integer, ) where {R} (; q, logB, ϕ, ψ) = storage + t1, t2 = seq_limits(seq_ends, k) logBₜ₁ = view(logB, :, t1) obs_logdensities!(logBₜ₁, hmm, obs_seq[t1], control_seq[t1]) @@ -66,13 +64,13 @@ function viterbi!( ϕₜ₂ = view(ϕ, :, t2) q[t2] = argmax(ϕₜ₂) - logL = ϕ[q[t2], t2] + logL[k] = ϕ[q[t2], t2] for t in (t2 - 1):-1:t1 q[t] = ψ[q[t + 1], t + 1] end - @argcheck isfinite(logL) - return logL + @argcheck isfinite(logL[k]) + return nothing end """ @@ -85,16 +83,13 @@ function viterbi!( control_seq::AbstractVector; seq_ends::AbstractVectorOrNTuple{Int}, ) where {R} - (; logL) = storage if seq_ends isa NTuple for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;) + viterbi!(storage, hmm, obs_seq, control_seq, seq_ends, k) end else @threads for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;) + viterbi!(storage, hmm, obs_seq, control_seq, seq_ends, k) end end return nothing