Skip to content

Commit

Permalink
Update logL in the inference routines
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Oct 1, 2024
1 parent 882f2f7 commit a19450d
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HiddenMarkovModels"
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
authors = ["Guillaume Dalle"]
version = "0.5.4"
version = "0.6.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
23 changes: 9 additions & 14 deletions src/inference/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,16 @@ function initialize_forward(
return ForwardStorage(α, logL, B, c)
end

"""
$(SIGNATURES)
"""
function forward!(
storage::ForwardOrForwardBackwardStorage,
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector,
t1::Integer,
t2::Integer;
seq_ends::AbstractVectorOrNTuple{Int},
k::Integer,
)
(; α, B, c) = storage
t1, t2 = seq_limits(seq_ends, k)

# Initialization
Bₜ₁ = view(B, :, t1)
Expand All @@ -88,7 +86,7 @@ function forward!(
c[t1] = inv(sum(αₜ₁))
lmul!(c[t1], αₜ₁)

logL = -log(c[t1]) + logm
logL[k] = -log(c[t1]) + logm

# Loop
for t in t1:(t2 - 1)
Expand All @@ -104,11 +102,11 @@ function forward!(
c[t + 1] = inv(sum(αₜ₊₁))
lmul!(c[t + 1], αₜ₊₁)

logL += -log(c[t + 1]) + logm
logL[k] += -log(c[t + 1]) + logm
end

@argcheck isfinite(logL)
return logL
@argcheck isfinite(logL[k])
return nothing
end

"""
Expand All @@ -121,16 +119,13 @@ function forward!(
control_seq::AbstractVector;
seq_ends::AbstractVectorOrNTuple{Int},
)
(; logL) = storage
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;)
forward!(storage, hmm, obs_seq, control_seq, seq_ends, k)
end
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;)
forward!(storage, hmm, obs_seq, control_seq, seq_ends, k)
end
end
return nothing
Expand Down
23 changes: 9 additions & 14 deletions src/inference/forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,20 @@ function initialize_forward_backward(
return ForwardBackwardStorage{R,M}(γ, ξ, logL, B, α, c, β, Bβ)
end

"""
$(SIGNATURES)
"""
function forward_backward!(
storage::ForwardBackwardStorage{R},
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector,
t1::Integer,
t2::Integer;
seq_ends::AbstractVectorOrNTuple{Int},
k::Integer,
transition_marginals::Bool=true,
) where {R}
(; α, β, c, γ, ξ, B, Bβ) = storage
t1, t2 = seq_limits(seq_ends, k)

# Forward (fill B, α, c and logL)
logL = forward!(storage, hmm, obs_seq, control_seq, t1, t2)
forward!(storage, hmm, obs_seq, control_seq, t1, t2)

# Backward
β[:, t2] .= c[t2]
Expand All @@ -68,7 +66,7 @@ function forward_backward!(
ξ[t2] .= zero(R)
end

return logL
return nothing
end

"""
Expand All @@ -82,19 +80,16 @@ function forward_backward!(
seq_ends::AbstractVectorOrNTuple{Int},
transition_marginals::Bool=true,
)
(; logL) = storage
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward_backward!(
storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals
forward_backward!(
storage, hmm, obs_seq, control_seq, seq_ends, k; transition_marginals
)
end
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward_backward!(
storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals
forward_backward!(
storage, hmm, obs_seq, control_seq, seq_ends, k; transition_marginals
)
end
end
Expand Down
21 changes: 8 additions & 13 deletions src/inference/viterbi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,16 @@ function initialize_viterbi(
return ViterbiStorage(q, logL, logB, ϕ, ψ)
end

"""
$(SIGNATURES)
"""
function viterbi!(
storage::ViterbiStorage{R},
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector,
t1::Integer,
t2::Integer;
seq_ends::AbstractVectorOrNTuple{Int},
k::Integer,
) where {R}
(; q, logB, ϕ, ψ) = storage
t1, t2 = seq_limits(seq_ends, k)

logBₜ₁ = view(logB, :, t1)
obs_logdensities!(logBₜ₁, hmm, obs_seq[t1], control_seq[t1])
Expand All @@ -66,13 +64,13 @@ function viterbi!(

ϕₜ₂ = view(ϕ, :, t2)
q[t2] = argmax(ϕₜ₂)
logL = ϕ[q[t2], t2]
logL[k] = ϕ[q[t2], t2]
for t in (t2 - 1):-1:t1
q[t] = ψ[q[t + 1], t + 1]
end

@argcheck isfinite(logL)
return logL
@argcheck isfinite(logL[k])
return nothing
end

"""
Expand All @@ -85,16 +83,13 @@ function viterbi!(
control_seq::AbstractVector;
seq_ends::AbstractVectorOrNTuple{Int},
) where {R}
(; logL) = storage
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;)
viterbi!(storage, hmm, obs_seq, control_seq, seq_ends, k)
end
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;)
viterbi!(storage, hmm, obs_seq, control_seq, seq_ends, k)
end
end
return nothing
Expand Down

0 comments on commit a19450d

Please sign in to comment.