From 27d6e138b98679b4f5175cabf0d282fb179f351f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 11 Nov 2023 19:26:24 +0100 Subject: [PATCH] Stuff --- Project.toml | 1 - docs/Project.toml | 1 - docs/src/api.md | 27 +++++++- ext/HiddenMarkovModelsChainRulesCoreExt.jl | 16 ++--- src/HiddenMarkovModels.jl | 7 +- src/inference/baum_welch.jl | 57 +++++++++++----- src/inference/forward.jl | 62 ++++++----------- src/inference/forward_backward.jl | 78 +++++++++------------- src/inference/viterbi.jl | 42 ++++-------- src/types/abstract_hmm.jl | 3 +- src/utils/check.jl | 6 -- src/utils/linalg.jl | 8 +-- test/allocations.jl | 36 +++++----- test/arrays.jl | 15 ++--- test/correctness.jl | 48 +++++-------- 15 files changed, 187 insertions(+), 220 deletions(-) diff --git a/Project.toml b/Project.toml index 7f01013f..ca5f9c42 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,6 @@ DocStringExtensions = "0.9" LinearAlgebra = "1.6" PrecompileTools = "1.1" Random = "1.6" -RequiredInterfaces = "0.1.3" Requires = "1.3" SimpleUnPack = "1.1" StatsAPI = "1.6" diff --git a/docs/Project.toml b/docs/Project.toml index ea2baba9..68e4008f 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,7 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -RequiredInterfaces = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" diff --git a/docs/src/api.md b/docs/src/api.md index a1c78398..7454a86e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -40,17 +40,38 @@ fit! ```@docs rand_prob_vec rand_trans_mat +HiddenMarkovModels.fit_element_from_sequence! +HiddenMarkovModels.LightDiagNormal ``` -## Internals +## In-place algorithms (internals) + +### Storage types ```@docs HiddenMarkovModels.ForwardStorage HiddenMarkovModels.ViterbiStorage HiddenMarkovModels.ForwardBackwardStorage HiddenMarkovModels.BaumWelchStorage -HiddenMarkovModels.fit_element_from_sequence! -HiddenMarkovModels.LightDiagNormal +``` + +### Initializing storage + +```@docs +HiddenMarkovModels.initialize_forward +HiddenMarkovModels.initialize_viterbi +HiddenMarkovModels.initialize_forward_backward +HiddenMarkovModels.initialize_baum_welch +HiddenMarkovModels.initialize_logL_evolution +``` + +### Modifying storage + +```@docs +HiddenMarkovModels.forward! +HiddenMarkovModels.viterbi! +HiddenMarkovModels.forward_backward! +HiddenMarkovModels.baum_welch! ``` ## Notations diff --git a/ext/HiddenMarkovModelsChainRulesCoreExt.jl b/ext/HiddenMarkovModelsChainRulesCoreExt.jl index 23b925b4..1e716ed1 100644 --- a/ext/HiddenMarkovModelsChainRulesCoreExt.jl +++ b/ext/HiddenMarkovModelsChainRulesCoreExt.jl @@ -1,7 +1,6 @@ module HiddenMarkovModelsChainRulesCoreExt -using ChainRulesCore: - ChainRulesCore, NoTangent, ZeroTangent, RuleConfig, rrule_via_ad, @not_implemented +using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, rrule_via_ad using DensityInterface: logdensityof using HiddenMarkovModels import HiddenMarkovModels as HMMs @@ -13,7 +12,7 @@ function obs_logdensities_matrix(hmm::AbstractHMM, obs_seq::Vector) return logB end -function _params_and_loglikelihoods(hmm::AbstractHMM, obs_seq) +function _params_and_loglikelihoods(hmm::AbstractHMM, obs_seq::Vector) p = initialization(hmm) A = transition_matrix(hmm) logB = obs_logdensities_matrix(hmm, obs_seq) @@ -21,12 +20,13 @@ function _params_and_loglikelihoods(hmm::AbstractHMM, obs_seq) end function ChainRulesCore.rrule( - rc::RuleConfig, ::typeof(logdensityof), hmm::AbstractHMM, obs_seq + rc::RuleConfig, ::typeof(logdensityof), hmm::AbstractHMM, obs_seq::Vector ) + @info "Chain rule used" (p, A, logB), pullback = rrule_via_ad(rc, _params_and_loglikelihoods, hmm, obs_seq) - fb = HMMs.initialize_forward_backward(hmm, obs_seq) - HMMs.forward_backward!(fb, hmm, obs_seq) - @unpack α, β, γ, c, Bβ = fb + storage = HMMs.initialize_forward_backward(hmm, obs_seq) + HMMs.forward_backward!(storage, hmm, obs_seq) + @unpack logL, α, β, γ, c, Bβ = storage T = length(obs_seq) function logdensityof_hmm_pullback(ΔlogL) @@ -42,7 +42,7 @@ function ChainRulesCore.rrule( return Δlogdensityof, Δhmm, Δobs_seq end - return fb.logL[], logdensityof_hmm_pullback + return logL[], logdensityof_hmm_pullback end end diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index 83ff430e..58fb17dd 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -2,17 +2,12 @@ HiddenMarkovModels A Julia package for HMM modeling, simulation, inference and learning. - -# Exports - -$(EXPORTS) """ module HiddenMarkovModels using Base: RefValue using Base.Threads: @threads -using DensityInterface: - DensityInterface, DensityKind, HasDensity, NoDensity, densityof, logdensityof +using DensityInterface: DensityInterface, DensityKind, HasDensity, NoDensity, logdensityof using Distributions: Distributions, Categorical, diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index b72d5c44..ae4de322 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -22,6 +22,15 @@ struct BaumWelchStorage{R,M<:AbstractMatrix{R}} limits::Vector{Int} end +function check_nb_seqs(obs_seqs::Vector{<:Vector}, nb_seqs::Integer) + if nb_seqs != length(obs_seqs) + throw(ArgumentError("Incoherent sizes provided: `nb_seqs != length(obs_seqs)`")) + end +end + +""" + initialize_baum_welch(hmm, obs_seqs, nb_seqs) +""" function initialize_baum_welch( hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer ) @@ -36,6 +45,9 @@ function initialize_baum_welch( return BaumWelchStorage(init_count, trans_count, state_marginals_concat, limits) end +""" + initialize_logL_evolution(hmm, obs_seqs, nb_seqs; max_iterations) +""" function initialize_logL_evolution( hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer; max_iterations::Integer ) @@ -47,16 +59,19 @@ function initialize_logL_evolution( end function update_sufficient_statistics!( - bw::BaumWelchStorage{R}, fbs::Vector{<:ForwardBackwardStorage} + bw_storage::BaumWelchStorage{R}, fb_storages::Vector{<:ForwardBackwardStorage} ) where {R} - @unpack init_count, trans_count, state_marginals_concat, limits = bw + @unpack init_count, trans_count, state_marginals_concat, limits = bw_storage init_count .= zero(R) trans_count .= zero(R) state_marginals_concat .= zero(R) - for k in eachindex(fbs) # TODO: ThreadsX? - init_count .+= fbs[k].init_count - mynonzeros(trans_count) .+= mynonzeros(fbs[k].trans_count) - state_marginals_concat[:, (limits[k] + 1):limits[k + 1]] .= fbs[k].γ + for k in eachindex(fb_storages) # TODO: ThreadsX? + @unpack γ, ξ = fb_storages[k] + init_count .+= @view γ[:, 1] + for t in eachindex(ξ) + mynonzeros(trans_count) .+= mynonzeros(ξ[t]) + end + state_marginals_concat[:, (limits[k] + 1):limits[k + 1]] .= γ end return nothing end @@ -76,15 +91,23 @@ function baum_welch_has_converged( return false end -function StatsAPI.fit!(hmm::AbstractHMM, bw::BaumWelchStorage, obs_seqs_concat::Vector) - return fit!( - hmm, bw.init_count, bw.trans_count, obs_seqs_concat, bw.state_marginals_concat - ) +function StatsAPI.fit!( + hmm::AbstractHMM, bw_storage::BaumWelchStorage, obs_seqs_concat::Vector +) + @unpack init_count, trans_count, state_marginals_concat = bw_storage + return fit!(hmm, init_count, trans_count, obs_seqs_concat, state_marginals_concat) end +""" + baum_welch!( + fb_storages, bw_storage, logL_evolution, + hmm, obs_seqs, obs_seqs_concat; + atol, max_iterations, loglikelihood_increasing + ) +""" function baum_welch!( - fbs::Vector{<:ForwardBackwardStorage}, - bw::BaumWelchStorage, + fb_storages::Vector{<:ForwardBackwardStorage}, + bw_storage::BaumWelchStorage, logL_evolution::Vector, hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, @@ -94,12 +117,12 @@ function baum_welch!( loglikelihood_increasing::Bool, ) for _ in 1:max_iterations - @threads for k in eachindex(obs_seqs, fbs) - forward_backward!(fbs[k], hmm, obs_seqs[k]) + for k in eachindex(obs_seqs, fb_storages) + forward_backward!(fb_storages[k], hmm, obs_seqs[k]) end - update_sufficient_statistics!(bw, fbs) - push!(logL_evolution, sum(fb.logL[] for fb in fbs)) - fit!(hmm, bw, obs_seqs_concat) + update_sufficient_statistics!(bw_storage, fb_storages) + push!(logL_evolution, sum(fb.logL[] for fb in fb_storages)) + fit!(hmm, bw_storage, obs_seqs_concat) check_hmm(hmm) if baum_welch_has_converged(logL_evolution; atol, loglikelihood_increasing) break diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 2cba895a..9e71144f 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -7,7 +7,7 @@ This storage is relative to a single sequence. # Fields -The only fields useful outside of the algorithm are `αₜ` and `logL`. +The only fields useful outside of the algorithm are `α` and `logL`, the rest does not belong to the public API. $(TYPEDFIELDS) """ @@ -22,6 +22,9 @@ struct ForwardStorage{R} α_next::Vector{R} end +""" + initialize_forward(hmm, obs_seq) +""" function initialize_forward(hmm::AbstractHMM, obs_seq::Vector) N = length(hmm) R = eltype(hmm, obs_seq[1]) @@ -30,15 +33,18 @@ function initialize_forward(hmm::AbstractHMM, obs_seq::Vector) logb = Vector{R}(undef, N) α = Vector{R}(undef, N) α_next = Vector{R}(undef, N) - f = ForwardStorage(logL, logb, α, α_next) - return f + storage = ForwardStorage(logL, logb, α, α_next) + return storage end -function forward!(f::ForwardStorage, hmm::AbstractHMM, obs_seq::Vector) +""" + forward!(storage, hmm, obs_seq) +""" +function forward!(storage::ForwardStorage, hmm::AbstractHMM, obs_seq::Vector) T = length(obs_seq) p = initialization(hmm) A = transition_matrix(hmm) - @unpack logL, logb, α, α_next = f + @unpack logL, logb, α, α_next = storage obs_logdensities!(logb, hmm, obs_seq[1]) check_right_finite(logb) @@ -63,58 +69,30 @@ function forward!(f::ForwardStorage, hmm::AbstractHMM, obs_seq::Vector) return nothing end -function forward!( - fs::Vector{<:ForwardStorage}, - hmm::AbstractHMM, - obs_seqs::Vector{<:Vector}, - nb_seqs::Integer, -) - check_nb_seqs(obs_seqs, nb_seqs) - @threads for k in eachindex(fs, obs_seqs) - forward!(fs[k], hmm, obs_seqs[k]) - end - return nothing -end - """ forward(hmm, obs_seq) - forward(hmm, obs_seqs, nb_seqs) -Run the forward algorithm to infer the current state of an HMM. +Run the forward algorithm to infer the current state of `hmm` after sequence `obs_seq`. -When applied on a single sequence, this function returns a tuple `(α, logL)` where +This function returns a tuple `(α, logL)` where - `α[i]` is the posterior probability of state `i` at the end of the sequence - `logL` is the loglikelihood of the sequence - -When applied on multiple sequences, this function returns a vector of tuples. """ -function forward(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer) - check_nb_seqs(obs_seqs, nb_seqs) - fs = [initialize_forward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] - forward!(fs, hmm, obs_seqs, nb_seqs) - return [(f.α, f.logL[]) for f in fs] -end - function forward(hmm::AbstractHMM, obs_seq::Vector) - return only(forward(hmm, [obs_seq], 1)) + storage = initialize_forward(hmm, obs_seq) + forward!(storage, hmm, obs_seq) + return storage.α, storage.logL[] end """ logdensityof(hmm, obs_seq) - logdensityof(hmm, obs_seqs, nb_seqs) -Run the forward algorithm to compute the posterior loglikelihood of observations for an HMM. +Run the forward algorithm to compute the posterior loglikelihood of sequence `obs_seq` for `hmm`. -Whether it is applied on one or multiple sequences, this function returns a number. +This function returns a number. """ -function DensityInterface.logdensityof( - hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer -) - logαs_and_logLs = forward(hmm, obs_seqs, nb_seqs) - return sum(last, logαs_and_logLs) -end - function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seq::Vector) - return logdensityof(hmm, [obs_seq], 1) + _, logL = forward(hmm, obs_seq) + return logL end diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index 16e802ae..5cfda512 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -7,7 +7,7 @@ This storage is relative to a single sequence. # Fields -The only fields useful outside of the algorithm are `γ`, `logL`, `init_count` and `trans_count`. +The only fields useful outside of the algorithm are `γ`, `logL`, `init_count` and `trans_count`, the rest does not belong to the public API. $(TYPEDFIELDS) """ @@ -20,6 +20,8 @@ struct ForwardBackwardStorage{R,M<:AbstractMatrix{R}} β::Matrix{R} "posterior state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`" γ::Matrix{R} + "posterior transition marginals `ξ[t][i,j] = ℙ(X[t]=i, X[t+1]=j | Y[1:T])`" + ξ::Vector{M} "forward message inverse normalizations `c[t] = 1 / sum(α[:,t])`" c::Vector{R} "observation loglikelihoods `logB[i,t] = ℙ(Y[t] | X[t]=i)`" @@ -30,39 +32,43 @@ struct ForwardBackwardStorage{R,M<:AbstractMatrix{R}} B::Matrix{R} "product `Bβ[i,t] = B[i,t] * β[i,t]`" Bβ::Matrix{R} - "posterior initialization count" - init_count::Vector{R} - "posterior transition count" - trans_count::M end +""" + initialize_forward_backward(hmm, obs_seq) +""" function initialize_forward_backward(hmm::AbstractHMM, obs_seq::Vector) N, T = length(hmm), length(obs_seq) A = transition_matrix(hmm) R = eltype(hmm, obs_seq[1]) + M = typeof(similar(A, R)) logL = RefValue{R}(zero(R)) α = Matrix{R}(undef, N, T) β = Matrix{R}(undef, N, T) γ = Matrix{R}(undef, N, T) + ξ = Vector{M}(undef, T - 1) + for t in 1:(T - 1) + ξ[t] = similar(A, R) + end c = Vector{R}(undef, T) logB = Matrix{R}(undef, N, T) logm = Vector{R}(undef, T) B = Matrix{R}(undef, N, T) Bβ = Matrix{R}(undef, N, T) - init_count = Vector{R}(undef, N) - trans_count = similar(A, R) - M = typeof(trans_count) - return ForwardBackwardStorage{R,M}( - logL, α, β, γ, c, logB, logm, B, Bβ, init_count, trans_count - ) + return ForwardBackwardStorage{R,M}(logL, α, β, γ, ξ, c, logB, logm, B, Bβ) end -function forward_backward!(fb::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq::Vector) +""" + forward_backward!(storage, hmm, obs_seq) +""" +function forward_backward!( + storage::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq::Vector +) p = initialization(hmm) A = transition_matrix(hmm) T = length(obs_seq) - @unpack α, β, c, γ, logB, logm, B, Bβ, init_count, trans_count = fb + @unpack logL, α, β, c, γ, ξ, logB, logm, B, Bβ = storage # Observation loglikelihoods for (logb, obs) in zip(eachcol(logB), obs_seq) @@ -97,53 +103,33 @@ function forward_backward!(fb::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq # Marginals γ .= α .* β ./ c' check_finite(γ) - - # Sufficient stats - init_count .= @view γ[:, 1] - trans_count .= zero(eltype(trans_count)) - @views for t in 1:(T - 1) - add_mul_rows_cols!(trans_count, α[:, t], A, Bβ[:, t + 1]) + for t in 1:(T - 1) + mul_rows_cols!(ξ[t], view(α, :, t), A, view(Bβ, :, t + 1)) end # Loglikelihood - fb.logL[] = -sum(log, fb.c) + sum(fb.logm) + logL[] = -sum(log, c) + sum(logm) return nothing end -function forward_backward!( - fbs::Vector{<:ForwardBackwardStorage}, - hmm::AbstractHMM, - obs_seqs::Vector{<:Vector}, - nb_seqs::Integer, -) - check_nb_seqs(obs_seqs, nb_seqs) - @threads for k in eachindex(fbs, obs_seqs) - forward_backward!(fbs[k], hmm, obs_seqs[k]) - end - return nothing -end - """ forward_backward(hmm, obs_seq) - forward_backward(hmm, obs_seqs, nb_seqs) -Run the forward-backward algorithm to infer the posterior state and transition marginals of an HMM. +Run the forward-backward algorithm to infer the posterior state and transition marginals of `hmm` on the sequence `obs_seq`. -When applied on a single sequence, this function returns a tuple `(γ, ξ, logL)` where +This function returns a tuple `(γ, ξ, logL)` where -- `γ` is a matrix containing the posterior state marginals `γ[i, t]` +- `γ` is a matrix containing the posterior state marginals `γ[i,t]` +- `ξ` is a vector of matrices containing the posterior transition marginals `ξ[t][i,j]` - `logL` is the loglikelihood of the sequence -WHen applied on multiple sequences, it returns a vector of tuples. -""" -function forward_backward(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer) - check_nb_seqs(obs_seqs, nb_seqs) - fbs = [initialize_forward_backward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] - forward_backward!(fbs, hmm, obs_seqs, nb_seqs) - return [(fb.γ, fb.logL[]) for fb in fbs] -end +# See also +- [`ForwardBackwardStorage`](@ref) +""" function forward_backward(hmm::AbstractHMM, obs_seq::Vector) - return only(forward_backward(hmm, [obs_seq], 1)) + storage = initialize_forward_backward(hmm, obs_seq) + forward_backward!(storage, hmm, obs_seq) + return storage.γ, storage.ξ, storage.logL[] end diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index 4e5a4b05..13322f43 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -7,7 +7,7 @@ This storage is relative to a single sequence. # Fields -The only field useful outside of the algorithm is `q`. +The only field useful outside of the algorithm is `q`, the rest does not belong to the public API. $(TYPEDFIELDS) """ @@ -26,6 +26,9 @@ struct ViterbiStorage{R} scratch::Vector{R} end +""" + initialize_viterbi(hmm, obs_seq) +""" function initialize_viterbi(hmm::AbstractHMM, obs_seq::Vector) T, N = length(obs_seq), length(hmm) R = eltype(hmm, obs_seq[1]) @@ -39,11 +42,14 @@ function initialize_viterbi(hmm::AbstractHMM, obs_seq::Vector) return ViterbiStorage(logb, δ, δ_prev, ψ, q, scratch) end -function viterbi!(v::ViterbiStorage, hmm::AbstractHMM, obs_seq::Vector) +""" + viterbi!(storage, hmm, obs_seq) +""" +function viterbi!(storage::ViterbiStorage, hmm::AbstractHMM, obs_seq::Vector) N, T = length(hmm), length(obs_seq) p = initialization(hmm) A = transition_matrix(hmm) - @unpack logb, δ, δ_prev, ψ, q, scratch = v + @unpack logb, δ, δ_prev, ψ, q, scratch = storage obs_logdensities!(logb, hmm, obs_seq[1]) check_right_finite(logb) @@ -72,35 +78,15 @@ function viterbi!(v::ViterbiStorage, hmm::AbstractHMM, obs_seq::Vector) return nothing end -function viterbi!( - vs::Vector{<:ViterbiStorage}, - hmm::AbstractHMM, - obs_seqs::Vector{<:Vector}, - nb_seqs::Integer, -) - check_nb_seqs(obs_seqs, nb_seqs) - @threads for k in eachindex(vs, obs_seqs) - viterbi!(vs[k], hmm, obs_seqs[k]) - end - return nothing -end - """ viterbi(hmm, obs_seq) - viterbi(hmm, obs_seqs, nb_seqs) -Apply the Viterbi algorithm to infer the most likely state sequence of an HMM. +Apply the Viterbi algorithm to infer the most likely state sequence corresponding to `obs_seq` for `hmm`. -When applied on a single sequence, this function returns a vector of integers. -When applied on multiple sequences, it returns a vector of vectors of integers. +This function returns a vector of integers. """ -function viterbi(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer) - check_nb_seqs(obs_seqs, nb_seqs) - vs = [initialize_viterbi(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] - viterbi!(vs, hmm, obs_seqs, nb_seqs) - return [v.q for v in vs] -end - function viterbi(hmm::AbstractHMM, obs_seq::Vector) - return only(viterbi(hmm, [obs_seq], 1)) + storage = initialize_viterbi(hmm, obs_seq) + viterbi!(storage, hmm, obs_seq) + return storage.q end diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index fa124e01..1df562c8 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -11,8 +11,7 @@ To create your own subtype of `AbstractHiddenMarkovModel`, you need to implement - [`eltype(hmm, obs)`](@ref) - [`initialization(hmm)`](@ref) - [`transition_matrix(hmm)`](@ref) -- [`obs_logdensities!(logb, hmm, obs)`](@ref) -- [`obs_sample(rng, hmm, i)`](@ref) (optional) +- [`obs_distributions(hmm)`](@ref) - [`fit!(hmm, init_count, trans_count, obs_seq, state_marginals)`](@ref) (optional) # Applicable functions diff --git a/src/utils/check.jl b/src/utils/check.jl index 7e4bcff0..efd4c216 100644 --- a/src/utils/check.jl +++ b/src/utils/check.jl @@ -59,12 +59,6 @@ function check_hmm_sizes(p::AbstractVector, A::AbstractMatrix, d::AbstractVector end end -function check_nb_seqs(obs_seqs::Vector{<:Vector}, nb_seqs::Integer) - if nb_seqs != length(obs_seqs) - throw(ArgumentError("Incoherent sizes provided: `nb_seqs != length(obs_seqs)`")) - end -end - function check_hmm(hmm::AbstractHMM) p = initialization(hmm) A = transition_matrix(hmm) diff --git a/src/utils/linalg.jl b/src/utils/linalg.jl index 6fafbcd8..4783a4df 100644 --- a/src/utils/linalg.jl +++ b/src/utils/linalg.jl @@ -3,14 +3,14 @@ sum_to_one!(x) = ldiv!(sum(x), x) mynonzeros(x::AbstractArray) = x mynonzeros(x::AbstractSparseArray) = nonzeros(x) -function add_mul_rows_cols!( +function mul_rows_cols!( B::AbstractMatrix, l::AbstractVector, A::AbstractMatrix, r::AbstractVector ) - B .+= l .* A .* r' + B .= l .* A .* r' return nothing end -function add_mul_rows_cols!( +function mul_rows_cols!( B::SparseMatrixCSC, l::AbstractVector, A::SparseMatrixCSC, r::AbstractVector ) @assert size(B) == size(A) == (length(l), length(r)) @@ -18,7 +18,7 @@ function add_mul_rows_cols!( for j in axes(B, 2) for k in nzrange(B, j) i = B.rowval[k] - B.nzval[k] += l[i] * A.nzval[k] * r[j] + B.nzval[k] = l[i] * A.nzval[k] * r[j] end end return nothing diff --git a/test/allocations.jl b/test/allocations.jl index e099cd6e..40ba2657 100644 --- a/test/allocations.jl +++ b/test/allocations.jl @@ -11,32 +11,36 @@ function test_allocations(hmm; T) obs_seqs = [rand(hmm, T).obs_seq for _ in 1:nb_seqs] ## Forward - f = HMMs.initialize_forward(hmm, obs_seq) - HiddenMarkovModels.forward!(f, hmm, obs_seq) - allocs = @allocated HiddenMarkovModels.forward!(f, hmm, obs_seq) + f_storage = HMMs.initialize_forward(hmm, obs_seq) + HiddenMarkovModels.forward!(f_storage, hmm, obs_seq) + allocs = @allocated HiddenMarkovModels.forward!(f_storage, hmm, obs_seq) @test allocs == 0 ## Viterbi - v = HMMs.initialize_viterbi(hmm, obs_seq) - HMMs.viterbi!(v, hmm, obs_seq) - allocs = @allocated HMMs.viterbi!(v, hmm, obs_seq) + v_storage = HMMs.initialize_viterbi(hmm, obs_seq) + HMMs.viterbi!(v_storage, hmm, obs_seq) + allocs = @allocated HMMs.viterbi!(v_storage, hmm, obs_seq) @test allocs == 0 ## Forward-backward - fb = HMMs.initialize_forward_backward(hmm, obs_seq) - HMMs.forward_backward!(fb, hmm, obs_seq) - allocs = @allocated HMMs.forward_backward!(fb, hmm, obs_seq) + fb_storage = HMMs.initialize_forward_backward(hmm, obs_seq) + HMMs.forward_backward!(fb_storage, hmm, obs_seq) + allocs = @allocated HMMs.forward_backward!(fb_storage, hmm, obs_seq) @test allocs == 0 ## Baum-Welch - fbs = [HMMs.initialize_forward_backward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] - bw = HMMs.initialize_baum_welch(hmm, obs_seqs, nb_seqs) + fb_storages = [ + HMMs.initialize_forward_backward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs) + ] + bw_storage = HMMs.initialize_baum_welch(hmm, obs_seqs, nb_seqs) obs_seqs_concat = reduce(vcat, obs_seqs) - HMMs.forward_backward!(fbs, hmm, obs_seqs, nb_seqs) - HMMs.update_sufficient_statistics!(bw, fbs) - fit!(hmm, bw, obs_seqs_concat) - allocs1 = @allocated HMMs.update_sufficient_statistics!(bw, fbs) - allocs2 = @allocated fit!(hmm, bw, obs_seqs_concat) + for k in eachindex(fb_storages, obs_seqs) + HMMs.forward_backward!(fb_storages[k], hmm, obs_seqs[k]) + end + HMMs.update_sufficient_statistics!(bw_storage, fb_storages) + fit!(hmm, bw_storage, obs_seqs_concat) + allocs1 = @allocated HMMs.update_sufficient_statistics!(bw_storage, fb_storages) + allocs2 = @allocated fit!(hmm, bw_storage, obs_seqs_concat) @test allocs1 == 0 @test allocs2 == 0 end diff --git a/test/arrays.jl b/test/arrays.jl index 316f019a..a50290c5 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -19,14 +19,12 @@ d_init = [Normal(i + randn(), 1.0) for i in 1:N]; hmm = HMM(p, A, d); hmm_init = HMM(p, A, d_init); - obs_seq = rand(hmm, T).obs_seq; -γ, logL = forward_backward(hmm, obs_seq); -hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq); - @testset "Sparse" begin - @test typeof(hmm_est) == typeof(hmm_init) + γ, ξ, logL = @inferred forward_backward(hmm, obs_seq) + hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq) + @test eltype(ξ) == typeof(transition_matrix(hmm)) @test nnz(transition_matrix(hmm_est)) <= nnz(transition_matrix(hmm)) end @@ -41,9 +39,8 @@ hmm = HMM(p, A, d); hmm_init = HMM(p, A, d_init); obs_seq = rand(hmm, T).obs_seq; -γ, logL = forward_backward(hmm, obs_seq); -hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq); - @testset "Static" begin - @test typeof(hmm_est) == typeof(hmm_init) + γ, ξ, logL = @inferred forward_backward(hmm, obs_seq) + hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq) + @test eltype(ξ) == typeof(transition_matrix(hmm)) end diff --git a/test/correctness.jl b/test/correctness.jl index 0aacbcd7..41b7e4f2 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -6,59 +6,45 @@ using SimpleUnPack using Test function test_correctness(hmm, hmm_init; T) - obs_seq1 = rand(hmm, T).obs_seq - obs_seq2 = rand(hmm, T).obs_seq - obs_mat1 = collect(reduce(hcat, obs_seq1)') - obs_mat2 = collect(reduce(hcat, obs_seq2)') - - nb_seqs = 2 - obs_seqs = [obs_seq1, obs_seq2] + obs_seq = rand(hmm, T).obs_seq + obs_mat = collect(reduce(hcat, obs_seq)') hmm_base = HMMBase.HMM(deepcopy(hmm)) hmm_init_base = HMMBase.HMM(deepcopy(hmm_init)) @testset "Logdensity" begin - logL1_base = HMMBase.forward(hmm_base, obs_mat1)[2] - logL2_base = HMMBase.forward(hmm_base, obs_mat2)[2] - logL = logdensityof(hmm, obs_seqs, nb_seqs) - @test logL ≈ logL1_base + logL2_base + logL_base = HMMBase.forward(hmm_base, obs_mat)[2] + logL = logdensityof(hmm, obs_seq) + @test logL ≈ logL_base end @testset "Forward" begin - (α1_base, logL1_base), (α2_base, logL2_base) = [ - HMMBase.forward(hmm_base, obs_mat1), HMMBase.forward(hmm_base, obs_mat2) - ] - (α1, logL1), (α2, logL2) = forward(hmm, obs_seqs, nb_seqs) - @test isapprox(α1, α1_base[end, :]) - @test isapprox(α2, α2_base[end, :]) - @test logL1 ≈ logL1_base - @test logL2 ≈ logL2_base + α_base, logL_base = HMMBase.forward(hmm_base, obs_mat) + α, logL = forward(hmm, obs_seq) + @test isapprox(α, α_base[end, :]) + @test logL ≈ logL_base end @testset "Viterbi" begin - q1_base = HMMBase.viterbi(hmm_base, obs_mat1) - q2_base = HMMBase.viterbi(hmm_base, obs_mat2) - q1, q2 = viterbi(hmm, obs_seqs, nb_seqs) + q_base = HMMBase.viterbi(hmm_base, obs_mat) + q = viterbi(hmm, obs_seq) # Viterbi decoding can vary in case of (infrequent) ties - @test mean(q1 .== q1_base) > 0.9 - @test mean(q2 .== q2_base) > 0.9 + @test mean(q .== q_base) > 0.9 end @testset "Forward-backward" begin - γ1_base = HMMBase.posteriors(hmm_base, obs_mat1) - γ2_base = HMMBase.posteriors(hmm_base, obs_mat2) - (γ1, _), (γ2, _) = forward_backward(hmm, obs_seqs, nb_seqs) - @test isapprox(γ1, γ1_base') - @test isapprox(γ2, γ2_base') + γ_base = HMMBase.posteriors(hmm_base, obs_mat) + γ, _, _ = forward_backward(hmm, obs_seq) + @test isapprox(γ, γ_base') end @testset "Baum-Welch" begin hmm_est_base, hist_base = HMMBase.fit_mle( - hmm_init_base, obs_mat1; maxiter=10, tol=-Inf + hmm_init_base, obs_mat; maxiter=10, tol=-Inf ) logL_evolution_base = hist_base.logtots hmm_est, logL_evolution = baum_welch( - hmm_init, [obs_seq1, obs_seq1], 2; max_iterations=10, atol=-Inf + hmm_init, [obs_seq, obs_seq], 2; max_iterations=10, atol=-Inf ) @test isapprox( logL_evolution[(begin + 1):end], 2 * logL_evolution_base[begin:(end - 1)]