diff --git a/Project.toml b/Project.toml index 3d5d5c6c..4119a543 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "HiddenMarkovModels" uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47" authors = ["Guillaume Dalle"] -version = "0.5.4" +version = "0.6.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/libs/HMMTest/src/allocations.jl b/libs/HMMTest/src/allocations.jl index 38361975..8deedafa 100644 --- a/libs/HMMTest/src/allocations.jl +++ b/libs/HMMTest/src/allocations.jl @@ -6,6 +6,10 @@ function test_allocations( seq_ends::AbstractVectorOrNTuple{Int}, hmm_guess::Union{Nothing,AbstractHMM}=nothing, ) + # making seq_ends a tuple disables multithreading + seq_ends = ntuple(k -> seq_ends[k], Val(min(2, length(seq_ends)))) + control_seq = control_seq[1:last(seq_ends)] + @testset "Allocations" begin obs_seq = mapreduce(vcat, eachindex(seq_ends)) do k t1, t2 = seq_limits(seq_ends, k) @@ -18,7 +22,7 @@ function test_allocations( f_storage = HMMs.initialize_forward(hmm, obs_seq, control_seq; seq_ends) allocs_f = @ballocated HMMs.forward!( - $f_storage, $hmm, $obs_seq, $control_seq, $t1, $t2 + $f_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends ) evals = 1 samples = 1 @test allocs_f == 0 @@ -26,7 +30,7 @@ function test_allocations( v_storage = HMMs.initialize_viterbi(hmm, obs_seq, control_seq; seq_ends) allocs_v = @ballocated HMMs.viterbi!( - $v_storage, $hmm, $obs_seq, $control_seq, $t1, $t2 + $v_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends ) evals = 1 samples = 1 @test allocs_v == 0 @@ -34,7 +38,7 @@ function test_allocations( fb_storage = HMMs.initialize_forward_backward(hmm, obs_seq, control_seq; seq_ends) allocs_fb = @ballocated HMMs.forward_backward!( - $fb_storage, $hmm, $obs_seq, $control_seq, $t1, $t2 + $fb_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends ) evals = 1 samples = 1 @test allocs_fb == 0 @@ -48,7 +52,7 @@ function test_allocations( allocs_bw = @ballocated fit!( hmm_guess_copy, $fb_storage, $obs_seq, $control_seq; seq_ends=$seq_ends ) evals = 1 samples = 1 setup = (hmm_guess_copy = deepcopy($hmm_guess)) - @test_broken allocs_bw == 0 + @test allocs_bw == 0 end end end diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 9b2ee492..27a22006 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -63,18 +63,16 @@ function initialize_forward( return ForwardStorage(α, logL, B, c) end -""" -$(SIGNATURES) -""" -function forward!( +function _forward!( storage::ForwardOrForwardBackwardStorage, hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector, - t1::Integer, - t2::Integer; + seq_ends::AbstractVectorOrNTuple{Int}, + k::Integer, ) - (; α, B, c) = storage + (; α, B, c, logL) = storage + t1, t2 = seq_limits(seq_ends, k) # Initialization Bₜ₁ = view(B, :, t1) @@ -88,7 +86,7 @@ function forward!( c[t1] = inv(sum(αₜ₁)) lmul!(c[t1], αₜ₁) - logL = -log(c[t1]) + logm + logL[k] = -log(c[t1]) + logm # Loop for t in t1:(t2 - 1) @@ -104,11 +102,11 @@ function forward!( c[t + 1] = inv(sum(αₜ₊₁)) lmul!(c[t + 1], αₜ₊₁) - logL += -log(c[t + 1]) + logm + logL[k] += -log(c[t + 1]) + logm end - @argcheck isfinite(logL) - return logL + @argcheck isfinite(logL[k]) + return nothing end """ @@ -121,16 +119,13 @@ function forward!( control_seq::AbstractVector; seq_ends::AbstractVectorOrNTuple{Int}, ) - (; logL) = storage if seq_ends isa NTuple for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;) + _forward!(storage, hmm, obs_seq, control_seq, seq_ends, k) end else @threads for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;) + _forward!(storage, hmm, obs_seq, control_seq, seq_ends, k) end end return nothing diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index 2f64ab0d..4051e30a 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -29,22 +29,20 @@ function initialize_forward_backward( return ForwardBackwardStorage{R,M}(γ, ξ, logL, B, α, c, β, Bβ) end -""" -$(SIGNATURES) -""" -function forward_backward!( +function _forward_backward!( storage::ForwardBackwardStorage{R}, hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector, - t1::Integer, - t2::Integer; + seq_ends::AbstractVectorOrNTuple{Int}, + k::Integer; transition_marginals::Bool=true, ) where {R} (; α, β, c, γ, ξ, B, Bβ) = storage + t1, t2 = seq_limits(seq_ends, k) # Forward (fill B, α, c and logL) - logL = forward!(storage, hmm, obs_seq, control_seq, t1, t2) + _forward!(storage, hmm, obs_seq, control_seq, seq_ends, k) # Backward β[:, t2] .= c[t2] @@ -68,7 +66,7 @@ function forward_backward!( ξ[t2] .= zero(R) end - return logL + return nothing end """ @@ -82,19 +80,16 @@ function forward_backward!( seq_ends::AbstractVectorOrNTuple{Int}, transition_marginals::Bool=true, ) - (; logL) = storage if seq_ends isa NTuple for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = forward_backward!( - storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals + _forward_backward!( + storage, hmm, obs_seq, control_seq, seq_ends, k; transition_marginals ) end else @threads for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = forward_backward!( - storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals + _forward_backward!( + storage, hmm, obs_seq, control_seq, seq_ends, k; transition_marginals ) end end diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index f1758097..c20632cd 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -36,18 +36,16 @@ function initialize_viterbi( return ViterbiStorage(q, logL, logB, ϕ, ψ) end -""" -$(SIGNATURES) -""" -function viterbi!( +function _viterbi!( storage::ViterbiStorage{R}, hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector, - t1::Integer, - t2::Integer; + seq_ends::AbstractVectorOrNTuple{Int}, + k::Integer, ) where {R} - (; q, logB, ϕ, ψ) = storage + (; q, logB, ϕ, ψ, logL) = storage + t1, t2 = seq_limits(seq_ends, k) logBₜ₁ = view(logB, :, t1) obs_logdensities!(logBₜ₁, hmm, obs_seq[t1], control_seq[t1]) @@ -66,13 +64,13 @@ function viterbi!( ϕₜ₂ = view(ϕ, :, t2) q[t2] = argmax(ϕₜ₂) - logL = ϕ[q[t2], t2] + logL[k] = ϕ[q[t2], t2] for t in (t2 - 1):-1:t1 q[t] = ψ[q[t + 1], t + 1] end - @argcheck isfinite(logL) - return logL + @argcheck isfinite(logL[k]) + return nothing end """ @@ -85,16 +83,13 @@ function viterbi!( control_seq::AbstractVector; seq_ends::AbstractVectorOrNTuple{Int}, ) where {R} - (; logL) = storage if seq_ends isa NTuple for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;) + _viterbi!(storage, hmm, obs_seq, control_seq, seq_ends, k) end else @threads for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;) + _viterbi!(storage, hmm, obs_seq, control_seq, seq_ends, k) end end return nothing diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 60d7bf12..9c1f47b4 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -61,13 +61,23 @@ function StatsAPI.fit!( ) (; γ, ξ) = fb_storage # Fit states - @threads for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - # use ξ[t2] as scratch space since it is zero anyway - scratch = ξ[t2] - fill!(scratch, zero(eltype(scratch))) - for t in t1:(t2 - 1) - scratch .+= ξ[t] + if seq_ends isa NTuple + for k in eachindex(seq_ends) + t1, t2 = seq_limits(seq_ends, k) + scratch = ξ[t2] # use ξ[t2] as scratch space since it is zero anyway + fill!(scratch, zero(eltype(scratch))) + for t in t1:(t2 - 1) + scratch .+= ξ[t] + end + end + else + @threads for k in eachindex(seq_ends) + t1, t2 = seq_limits(seq_ends, k) + scratch = ξ[t2] # use ξ[t2] as scratch space since it is zero anyway + fill!(scratch, zero(eltype(scratch))) + for t in t1:(t2 - 1) + scratch .+= ξ[t] + end end end fill!(hmm.init, zero(eltype(hmm.init)))