diff --git a/examples/controlled.jl b/examples/controlled.jl index 28ef8ebc..ffd0ab98 100644 --- a/examples/controlled.jl +++ b/examples/controlled.jl @@ -94,7 +94,7 @@ function StatsAPI.fit!( fb_storage::HMMs.ForwardBackwardStorage, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends, ) where {T} (; γ, ξ) = fb_storage N = length(hmm) diff --git a/examples/interfaces.jl b/examples/interfaces.jl index ba306051..4498478f 100644 --- a/examples/interfaces.jl +++ b/examples/interfaces.jl @@ -186,7 +186,7 @@ function StatsAPI.fit!( hmm::PriorHMM, fb_storage::HiddenMarkovModels.ForwardBackwardStorage, obs_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends, ) ## initialize to defaults without observations hmm.init .= 0 diff --git a/examples/temporal.jl b/examples/temporal.jl index 80216dda..1cad38f6 100644 --- a/examples/temporal.jl +++ b/examples/temporal.jl @@ -109,7 +109,7 @@ function StatsAPI.fit!( fb_storage::HMMs.ForwardBackwardStorage, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends, ) where {T} (; γ, ξ) = fb_storage L, N = period(hmm), length(hmm) diff --git a/libs/HMMTest/src/HMMTest.jl b/libs/HMMTest/src/HMMTest.jl index aaeedb40..f342c238 100644 --- a/libs/HMMTest/src/HMMTest.jl +++ b/libs/HMMTest/src/HMMTest.jl @@ -2,6 +2,7 @@ module HMMTest using BenchmarkTools: @ballocated using HiddenMarkovModels +using HiddenMarkovModels: AbstractVectorOrNTuple import HiddenMarkovModels as HMMs using JET: @test_opt, @test_call using Random: AbstractRNG diff --git a/libs/HMMTest/src/allocations.jl b/libs/HMMTest/src/allocations.jl index ea3aeef9..38361975 100644 --- a/libs/HMMTest/src/allocations.jl +++ b/libs/HMMTest/src/allocations.jl @@ -3,7 +3,7 @@ function test_allocations( rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, hmm_guess::Union{Nothing,AbstractHMM}=nothing, ) @testset "Allocations" begin diff --git a/libs/HMMTest/src/coherence.jl b/libs/HMMTest/src/coherence.jl index 2a39e336..a3384c0f 100644 --- a/libs/HMMTest/src/coherence.jl +++ b/libs/HMMTest/src/coherence.jl @@ -55,7 +55,7 @@ function test_coherent_algorithms( rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, hmm_guess::Union{Nothing,AbstractHMM}=nothing, atol::Real=0.05, init::Bool=true, diff --git a/libs/HMMTest/src/jet.jl b/libs/HMMTest/src/jet.jl index 75820193..d6d29f4f 100644 --- a/libs/HMMTest/src/jet.jl +++ b/libs/HMMTest/src/jet.jl @@ -3,7 +3,7 @@ function test_type_stability( rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, hmm_guess::Union{Nothing,AbstractHMM}=nothing, ) @testset "Type stability" begin diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 3bec0ac2..3791071c 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -22,7 +22,7 @@ function baum_welch!( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, atol::Real, max_iterations::Integer, loglikelihood_increasing::Bool, @@ -55,7 +55,7 @@ function baum_welch( hmm_guess::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), atol=1e-5, max_iterations=100, loglikelihood_increasing=true, @@ -85,7 +85,7 @@ function StatsAPI.fit!( fb_storage::ForwardBackwardStorage, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, ) return fit!(hmm, fb_storage, obs_seq; seq_ends) end diff --git a/src/inference/chainrules.jl b/src/inference/chainrules.jl index 424236e1..8816120f 100644 --- a/src/inference/chainrules.jl +++ b/src/inference/chainrules.jl @@ -4,7 +4,7 @@ function _params_and_loglikelihoods( hmm::AbstractHMM, obs_seq::Vector, control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), ) init = initialization(hmm) trans_by_time = mapreduce(_dcat, eachindex(control_seq)) do t @@ -22,7 +22,7 @@ function ChainRulesCore.rrule( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), ) _, pullback = rrule_via_ad( rc, _params_and_loglikelihoods, hmm, obs_seq, control_seq; seq_ends diff --git a/src/inference/forward.jl b/src/inference/forward.jl index c7d4883c..290248fa 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -16,7 +16,35 @@ struct ForwardStorage{R} c::Vector{R} end +""" +$(TYPEDEF) + +# Fields + +Only the fields with a description are part of the public API. + +$(TYPEDFIELDS) +""" +struct ForwardBackwardStorage{R,M<:AbstractMatrix{R}} + "posterior state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`" + γ::Matrix{R} + "posterior transition marginals `ξ[t][i,j] = ℙ(X[t]=i, X[t+1]=j | Y[1:T])`" + ξ::Vector{M} + "one loglikelihood per observation sequence" + logL::Vector{R} + B::Matrix{R} + α::Matrix{R} + c::Vector{R} + β::Matrix{R} + Bβ::Matrix{R} +end + Base.eltype(::ForwardStorage{R}) where {R} = R +Base.eltype(::ForwardBackwardStorage{R}) where {R} = R + +const ForwardOrForwardBackwardStorage{R} = Union{ + ForwardStorage{R},ForwardBackwardStorage{R} +} """ $(SIGNATURES) @@ -25,7 +53,7 @@ function initialize_forward( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, ) N, T, K = length(hmm), length(obs_seq), length(seq_ends) R = eltype(hmm, obs_seq[1], control_seq[1]) @@ -40,7 +68,7 @@ end $(SIGNATURES) """ function forward!( - storage, + storage::ForwardOrForwardBackwardStorage, hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector, @@ -88,16 +116,23 @@ end $(SIGNATURES) """ function forward!( - storage, + storage::ForwardOrForwardBackwardStorage, hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, ) (; logL) = storage - @threads for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;) + if seq_ends isa NTuple + for k in eachindex(seq_ends) + t1, t2 = seq_limits(seq_ends, k) + logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;) + end + else + @threads for k in eachindex(seq_ends) + t1, t2 = seq_limits(seq_ends, k) + logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;) + end end return nothing end @@ -113,7 +148,7 @@ function forward( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), ) storage = initialize_forward(hmm, obs_seq, control_seq; seq_ends) forward!(storage, hmm, obs_seq, control_seq; seq_ends) diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index 63ad978b..2f64ab0d 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -1,28 +1,3 @@ -""" -$(TYPEDEF) - -# Fields - -Only the fields with a description are part of the public API. - -$(TYPEDFIELDS) -""" -struct ForwardBackwardStorage{R,M<:AbstractMatrix{R}} - "posterior state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`" - γ::Matrix{R} - "posterior transition marginals `ξ[t][i,j] = ℙ(X[t]=i, X[t+1]=j | Y[1:T])`" - ξ::Vector{M} - "one loglikelihood per observation sequence" - logL::Vector{R} - B::Matrix{R} - α::Matrix{R} - c::Vector{R} - β::Matrix{R} - Bβ::Matrix{R} -end - -Base.eltype(::ForwardBackwardStorage{R}) where {R} = R - """ $(SIGNATURES) """ @@ -30,7 +5,7 @@ function initialize_forward_backward( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, transition_marginals=true, ) N, T, K = length(hmm), length(obs_seq), length(seq_ends) @@ -100,19 +75,28 @@ end $(SIGNATURES) """ function forward_backward!( - storage::ForwardBackwardStorage{R}, + storage::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, transition_marginals::Bool=true, -) where {R} +) (; logL) = storage - @threads for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = forward_backward!( - storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals - ) + if seq_ends isa NTuple + for k in eachindex(seq_ends) + t1, t2 = seq_limits(seq_ends, k) + logL[k] = forward_backward!( + storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals + ) + end + else + @threads for k in eachindex(seq_ends) + t1, t2 = seq_limits(seq_ends, k) + logL[k] = forward_backward!( + storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals + ) + end end return nothing end @@ -128,7 +112,7 @@ function forward_backward( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), ) transition_marginals = false storage = initialize_forward_backward( diff --git a/src/inference/logdensity.jl b/src/inference/logdensity.jl index f43fb25c..ce153ff2 100644 --- a/src/inference/logdensity.jl +++ b/src/inference/logdensity.jl @@ -7,7 +7,7 @@ function DensityInterface.logdensityof( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), ) _, logL = forward(hmm, obs_seq, control_seq; seq_ends) return sum(logL) @@ -23,7 +23,7 @@ function joint_logdensityof( obs_seq::AbstractVector, state_seq::AbstractVector, control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), ) R = eltype(hmm, obs_seq[1], control_seq[1]) logL = zero(R) diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index 09e18a26..c212c9bb 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -26,7 +26,7 @@ function initialize_viterbi( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, ) N, T, K = length(hmm), length(obs_seq), length(seq_ends) R = eltype(hmm, obs_seq[1], control_seq[1]) @@ -85,12 +85,19 @@ function viterbi!( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, ) where {R} (; logL) = storage - @threads for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;) + if seq_ends isa NTuple + for k in eachindex(seq_ends) + t1, t2 = seq_limits(seq_ends, k) + logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;) + end + else + @threads for k in eachindex(seq_ends) + t1, t2 = seq_limits(seq_ends, k) + logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;) + end end return nothing end @@ -106,7 +113,7 @@ function viterbi( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), ) storage = initialize_viterbi(hmm, obs_seq, control_seq; seq_ends) viterbi!(storage, hmm, obs_seq, control_seq; seq_ends) diff --git a/src/types/hmm.jl b/src/types/hmm.jl index ca6d33c3..6a71135a 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -61,7 +61,7 @@ function StatsAPI.fit!( hmm::HMM, fb_storage::ForwardBackwardStorage, obs_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, ) (; γ, ξ) = fb_storage # Fit states diff --git a/src/utils/limits.jl b/src/utils/limits.jl index cbd40e50..f06c7b0f 100644 --- a/src/utils/limits.jl +++ b/src/utils/limits.jl @@ -3,7 +3,7 @@ $(SIGNATURES) Return a tuple `(t1, t2)` giving the begin and end indices of subsequence `k` within a set of sequences ending at `seq_ends`. """ -function seq_limits(seq_ends::AbstractVector{Int}, k::Integer) +function seq_limits(seq_ends::AbstractVectorOrNTuple{Int}, k::Integer) if k == 1 return 1, seq_ends[k] else diff --git a/src/utils/linalg.jl b/src/utils/linalg.jl index 0c18596d..9ed23157 100644 --- a/src/utils/linalg.jl +++ b/src/utils/linalg.jl @@ -1,3 +1,5 @@ +const AbstractVectorOrNTuple{T} = Union{AbstractVector{T},NTuple{N,T}} where {N} + sum_to_one!(x) = ldiv!(sum(x), x) mynonzeros(x::AbstractArray) = x