From 1f6b1e8bc2383ac646f192f4c0307730fb4ab444 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 3 Nov 2023 10:02:17 +0100 Subject: [PATCH 01/14] Revamp --- Project.toml | 3 +- benchmark/utils/hmms.jl | 18 ++-- docs/src/api.md | 19 +--- docs/src/debugging.md | 2 +- docs/src/tuto_custom.md | 6 +- ext/HiddenMarkovModelsChainRulesCoreExt.jl | 10 +- src/HiddenMarkovModels.jl | 14 +-- src/inference/baum_welch.jl | 24 ++--- src/inference/forward_backward.jl | 106 +++++++++--------- src/inference/loglikelihoods.jl | 13 ++- src/types/abstract_hmm.jl | 77 ++++++-------- src/types/abstract_mc.jl | 118 --------------------- src/types/mc.jl | 44 -------- src/types/permuted_hmm.jl | 28 +++++ src/utils/check.jl | 19 +--- test/allocations.jl | 15 +-- test/dna.jl | 23 ++-- test/doctests.jl | 4 - test/formatting.jl | 5 - test/interface.jl | 17 --- test/linting.jl | 4 - test/mc.jl | 22 ---- test/quality.jl | 4 - test/runtests.jl | 19 ++-- 24 files changed, 186 insertions(+), 428 deletions(-) delete mode 100644 src/types/abstract_mc.jl delete mode 100644 src/types/mc.jl create mode 100644 src/types/permuted_hmm.jl delete mode 100644 test/doctests.jl delete mode 100644 test/formatting.jl delete mode 100644 test/interface.jl delete mode 100644 test/linting.jl delete mode 100644 test/mc.jl delete mode 100644 test/quality.jl diff --git a/Project.toml b/Project.toml index 29a61cce..62ccac7a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "HiddenMarkovModels" uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47" authors = ["Guillaume Dalle"] -version = "0.3.1" +version = "0.4.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -10,7 +10,6 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -RequiredInterfaces = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" diff --git a/benchmark/utils/hmms.jl b/benchmark/utils/hmms.jl index 027505c1..972e1981 100644 --- a/benchmark/utils/hmms.jl +++ b/benchmark/utils/hmms.jl @@ -1,5 +1,5 @@ using BenchmarkTools -using HiddenMarkovModels: HMMs +using HiddenMarkovModels using SimpleUnPack function rand_params_hmms(; N, D) @@ -15,9 +15,9 @@ function rand_model_hmms(; N, D) if D == 1 dists = [Normal(μ[n, 1], σ[n, 1]) for n in 1:N] else - dists = [HMMs.LightDiagNormal(μ[n, :], σ[n, :]) for n in 1:N] + dists = [HiddenMarkovModels.LightDiagNormal(μ[n, :], σ[n, :]) for n in 1:N] end - model = HMMs.HMM(p, A, dists) + model = HiddenMarkovModels.HMM(p, A, dists) return model end @@ -30,22 +30,22 @@ function benchmarkables_hmms(; algos, N, D, T, K, I) end benchs = Dict() if "logdensity" in algos - benchs["logdensity"] = @benchmarkable HMMs.logdensityof(model, $obs_seqs, $K) setup = ( - model = rand_model_hmms(; N=$N, D=$D) - ) + benchs["logdensity"] = @benchmarkable HiddenMarkovModels.logdensityof( + model, $obs_seqs, $K + ) setup = (model = rand_model_hmms(; N=$N, D=$D)) end if "viterbi" in algos - benchs["viterbi"] = @benchmarkable HMMs.viterbi(model, $obs_seqs, $K) setup = ( + benchs["viterbi"] = @benchmarkable HiddenMarkovModels.viterbi(model, $obs_seqs, $K) setup = ( model = rand_model_hmms(; N=$N, D=$D) ) end if "forward_backward" in algos - benchs["forward_backward"] = @benchmarkable HMMs.forward_backward( + benchs["forward_backward"] = @benchmarkable HiddenMarkovModels.forward_backward( model, $obs_seqs, $K ) setup = (model = rand_model_hmms(; N=$N, D=$D)) end if "baum_welch" in algos - benchs["baum_welch"] = @benchmarkable HMMs.baum_welch( + benchs["baum_welch"] = @benchmarkable HiddenMarkovModels.baum_welch( model, $obs_seqs, $K; max_iterations=$I, atol=-Inf ) setup = (model = rand_model_hmms(; N=$N, D=$D)) end diff --git a/docs/src/api.md b/docs/src/api.md index 3f58652e..1376a27e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -6,22 +6,11 @@ HiddenMarkovModels ## Types -### Markov chains - -```@docs -AbstractMarkovChain -MarkovChain -AbstractMC -MC -``` - -### Hidden Markov Models - ```@docs AbstractHiddenMarkovModel HiddenMarkovModel AbstractHMM -HiddenMarkovModels.PermutedHMM +PermutedHMM HMM ``` @@ -55,9 +44,9 @@ baum_welch ## Internals ```@docs -HMMs.ForwardBackwardStorage -HMMs.fit_element_from_sequence! -HMMs.LightDiagNormal +HiddenMarkovModels.ForwardBackwardStorage +HiddenMarkovModels.fit_element_from_sequence! +HiddenMarkovModels.LightDiagNormal ``` ## Notations diff --git a/docs/src/debugging.md b/docs/src/debugging.md index 88a3e66d..0f759552 100644 --- a/docs/src/debugging.md +++ b/docs/src/debugging.md @@ -9,4 +9,4 @@ This can happen for a variety of reasons, so here are a few leads worth investig * Reduce the number of states (to make every one of them useful) * Add a prior to your transition matrix / observation distributions (to avoid degenerate behavior like zero variance in a Gaussian) * Pick a better initialization (to start closer to the supposed ground truth) -* Use [LogarithmicNumbers.jl](https://github.com/cjdoris/LogarithmicNumbers.jl) in strategic places (to guarantee numerical stability). Note that these numbers don't play nicely with Distributions.jl, so you may have to roll out your own observation distribution. \ No newline at end of file +* Use [LogarithmicNumbers.jl](https://github.com/cjdoris/LogarithmicNumbers.jl) in strategic places (to guarantee numerical stability). Note that these numbers don't play nicely with Distributions.jl, so you may have to roll out your own observation distribution. diff --git a/docs/src/tuto_custom.md b/docs/src/tuto_custom.md index 7a6edc95..61f70416 100644 --- a/docs/src/tuto_custom.md +++ b/docs/src/tuto_custom.md @@ -61,9 +61,9 @@ The interface is only different as far as the initialization is concerned. ```@example tuto Base.length(hmm::EquilibriumHMM) = length(hmm.dists) -HMMs.initial_distribution(hmm::EquilibriumHMM) = markov_equilibrium(hmm.trans) # this is new -HMMs.transition_matrix(hmm::EquilibriumHMM) = hmm.trans -HMMs.obs_distribution(hmm::EquilibriumHMM, i::Integer) = hmm.dists[i] +HiddenMarkovModels.initial_distribution(hmm::EquilibriumHMM) = markov_equilibrium(hmm.trans) # this is new +HiddenMarkovModels.transition_matrix(hmm::EquilibriumHMM) = hmm.trans +HiddenMarkovModels.obs_distribution(hmm::EquilibriumHMM, i::Integer) = hmm.dists[i] ``` As for fitting, we simply ignore the initialization count and copy the rest of the original code (with a few simplifications): diff --git a/ext/HiddenMarkovModelsChainRulesCoreExt.jl b/ext/HiddenMarkovModelsChainRulesCoreExt.jl index c5a8c451..134a0ab7 100644 --- a/ext/HiddenMarkovModelsChainRulesCoreExt.jl +++ b/ext/HiddenMarkovModelsChainRulesCoreExt.jl @@ -17,16 +17,16 @@ function ChainRulesCore.rrule( rc::RuleConfig, ::typeof(logdensityof), hmm::AbstractHMM, obs_seq ) (p, A, logB), pullback = rrule_via_ad(rc, _params_and_loglikelihoods, hmm, obs_seq) - fb = forward_backward(p, A, logB) + fb = forward_backward(hmm, obs_seq) logL = HiddenMarkovModels.loglikelihood(fb) - @unpack α, β, γ, c, Bscaled, Bβscaled = fb + @unpack α, β, γ, c, B̃β = fb T = length(obs_seq) function logdensityof_hmm_pullback(ΔlogL) - Δp = ΔlogL .* Bβscaled[:, 1] - ΔA = ΔlogL .* α[:, 1] .* Bβscaled[:, 2]' + Δp = ΔlogL .* B̃β[:, 1] + ΔA = ΔlogL .* α[:, 1] .* B̃β[:, 2]' @views for t in 2:(T - 1) - ΔA .+= ΔlogL .* α[:, t] .* Bβscaled[:, t + 1]' + ΔA .+= ΔlogL .* α[:, t] .* B̃β[:, t + 1]' end ΔlogB = ΔlogL .* γ diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index c40acf3e..9babda50 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -7,8 +7,6 @@ The alias `HMMs` is exported for the package name. """ module HiddenMarkovModels -const HMMs = HiddenMarkovModels - using Base.Threads: @threads using DensityInterface: DensityInterface, DensityKind, HasDensity, NoDensity, densityof, logdensityof @@ -21,26 +19,20 @@ using Distributions: MatrixDistribution using LinearAlgebra: Diagonal, dot, mul! using PrecompileTools: @compile_workload, @setup_workload -using Random: AbstractRNG, default_rng -using RequiredInterfaces: @required +using Random: Random, AbstractRNG, default_rng using Requires: @require using SimpleUnPack: @unpack using StatsAPI: StatsAPI, fit, fit! -export HMMs -export AbstractMarkovChain, AbstractMC -export MarkovChain, MC export AbstractHiddenMarkovModel, AbstractHMM, PermutedHMM export HiddenMarkovModel, HMM export rand_prob_vec, rand_trans_mat export initial_distribution, transition_matrix, obs_distribution export logdensityof, viterbi, forward, forward_backward, baum_welch export fit, fit! -export LightDiagNormal -include("types/abstract_mc.jl") -include("types/mc.jl") include("types/abstract_hmm.jl") +include("types/permuted_hmm.jl") include("types/hmm.jl") include("utils/check.jl") @@ -67,6 +59,7 @@ if !isdefined(Base, :get_extension) end end +#= @compile_workload begin N, D, T = 5, 3, 100 p = rand_prob_vec(N) @@ -82,5 +75,6 @@ end forward_backward(hmm, obs_seqs, nb_seqs) baum_welch(hmm, obs_seqs, nb_seqs; max_iterations=2, atol=-Inf) end +=# end diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 87bc8ac3..1cc2e1d1 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -2,32 +2,24 @@ function baum_welch!( hmm::AbstractHMM, obs_seqs; atol, max_iterations, check_loglikelihood_increasing ) # Pre-allocate nearly all necessary memory - logB = loglikelihoods(hmm, obs_seqs[1]) - fb = initialize_forward_backward(hmm, logB) - - logBs = Vector{typeof(logB)}(undef, length(obs_seqs)) + fb = initialize_forward_backward(hmm, obs_seqs[1]) fbs = Vector{typeof(fb)}(undef, length(obs_seqs)) - @threads for k in eachindex(obs_seqs) - logBs[k] = loglikelihoods(hmm, obs_seqs[k]) - fbs[k] = forward_backward(hmm, logBs[k]) + @threads for k in eachindex(obs_seqs, fbs) + fbs[k] = initialize_forward_backward(hmm, obs_seqs[k]) end init_count, trans_count = initialize_states_stats(fbs) state_marginals_concat = initialize_observations_stats(fbs) obs_seqs_concat = reduce(vcat, obs_seqs) - logL = loglikelihood(fbs) - logL_evolution = [logL] + logL_evolution = eltype(fbs[1])[] for iteration in 1:max_iterations # E step - if iteration > 1 - @threads for k in eachindex(obs_seqs, logBs, fbs) - loglikelihoods!(logBs[k], hmm, obs_seqs[k]) - forward_backward!(fbs[k], hmm, logBs[k]) - end - logL = loglikelihood(fbs) - push!(logL_evolution, logL) + @threads for k in eachindex(obs_seqs, fbs) + forward_backward!(fbs[k], hmm, obs_seqs[k]) end + logL = loglikelihood(fbs) + push!(logL_evolution, logL) # M step update_states_stats!(init_count, trans_count, fbs) diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index 165b7b54..b23b9ed7 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -15,9 +15,10 @@ The following fields are internals and subject to change: - `α::Matrix{R}`: scaled forward variables `α[i,t]` proportional to `ℙ(Y[1:t], X[t]=i)` (up to a function of `t`) - `β::Matrix{R}`: scaled backward variables `β[i,t]` proportional to `ℙ(Y[t+1:T] | X[t]=i)` (up to a function of `t`) - `c::Vector{R}`: forward variable inverse normalizations `c[t] = 1 / sum(α[:,t])` +- `logB::Matrix{R}`: observation loglikelihoods `logB[i, t]` - `logm::Vector{R}`: maximum of the observation loglikelihoods `logm[t] = maximum(logB[:, t])` -- `Bscaled::Matrix{R}`: numerically stabilized observation likelihoods `Bscaled[i,t] = exp.(logB[i,t] - logm[t])` -- `Bβscaled::Matrix{R}`: numerically stabilized product `Bβscaled[i,t] = Bscaled[i,t] * β[i,t]` +- `B̃::Matrix{R}`: numerically stabilized observation likelihoods `B̃[i,t] = exp.(logB[i,t] - logm[t])` +- `B̃β::Matrix{R}`: numerically stabilized product `B̃β[i,t] = B̃[i,t] * β[i,t]` """ struct ForwardBackwardStorage{R} α::Matrix{R} @@ -25,11 +26,13 @@ struct ForwardBackwardStorage{R} γ::Matrix{R} ξ::Array{R,3} c::Vector{R} + logB::Matrix{R} logm::Vector{R} - Bscaled::Matrix{R} - Bβscaled::Matrix{R} + B̃::Matrix{R} + B̃β::Matrix{R} end +Base.eltype(::ForwardBackwardStorage{R}) where {R} = R Base.length(fb::ForwardBackwardStorage) = size(fb.α, 1) duration(fb::ForwardBackwardStorage) = size(fb.α, 2) @@ -46,39 +49,45 @@ function loglikelihood(fbs::Vector{ForwardBackwardStorage{R}}) where {R} return logL end -function initialize_forward_backward(p, A, logB) - N, T = size(logB) - R = promote_type(eltype(p), eltype(A), eltype(logB)) +function initialize_forward_backward(hmm::AbstractHMM, obs_seq) + p = initial_distribution(hmm) + A = transition_matrix(hmm) + testval = logdensityof(obs_distribution(hmm, 1), obs_seq[1]) + R = promote_type(eltype(p), eltype(A), typeof(testval)) + + N, T = length(hmm), length(obs_seq) α = Matrix{R}(undef, N, T) β = Matrix{R}(undef, N, T) γ = Matrix{R}(undef, N, T) ξ = Array{R,3}(undef, N, N, T - 1) c = Vector{R}(undef, T) + logB = Matrix{R}(undef, N, T) logm = Vector{R}(undef, T) - Bscaled = Matrix{R}(undef, N, T) - Bβscaled = Matrix{R}(undef, N, T) - return ForwardBackwardStorage(α, β, γ, ξ, c, logm, Bscaled, Bβscaled) + B̃ = Matrix{R}(undef, N, T) + B̃β = Matrix{R}(undef, N, T) + return ForwardBackwardStorage(α, β, γ, ξ, c, logB, logm, B̃, B̃β) end -function initialize_forward_backward(hmm::AbstractHMM, logB) - p = initial_distribution(hmm) - A = transition_matrix(hmm) - return initialize_forward_backward(p, A, logB) +function scale_likelihoods!(fb::ForwardBackwardStorage) + @unpack logB, logm, B̃ = fb + maximum!(logm', logB) + B̃ .= exp.(logB .- logm') + return nothing end -function forward!(fb::ForwardBackwardStorage, p, A, logB) - @unpack α, c, logm, Bscaled = fb +function forward!(fb::ForwardBackwardStorage, hmm::AbstractHMM) + p = initial_distribution(hmm) + A = transition_matrix(hmm) + @unpack α, c, B̃ = fb T = size(α, 2) - maximum!(logm', logB) - Bscaled .= exp.(logB .- logm') @views begin - α[:, 1] .= p .* Bscaled[:, 1] + α[:, 1] .= p .* B̃[:, 1] c[1] = inv(sum(α[:, 1])) α[:, 1] .*= c[1] end @views for t in 1:(T - 1) mul!(α[:, t + 1], A', α[:, t]) - α[:, t + 1] .*= Bscaled[:, t + 1] + α[:, t + 1] .*= B̃[:, t + 1] c[t + 1] = inv(sum(α[:, t + 1])) α[:, t + 1] .*= c[t + 1] end @@ -86,57 +95,43 @@ function forward!(fb::ForwardBackwardStorage, p, A, logB) return nothing end -function backward!(fb::ForwardBackwardStorage{R}, A, logB) where {R} - @unpack β, c, Bscaled, Bβscaled = fb +function backward!(fb::ForwardBackwardStorage{R}, hmm::AbstractHMM) where {R} + A = transition_matrix(hmm) + @unpack β, c, B̃, B̃β = fb T = size(β, 2) β[:, T] .= c[T] @views for t in (T - 1):-1:1 - Bβscaled[:, t + 1] .= Bscaled[:, t + 1] .* β[:, t + 1] - mul!(β[:, t], A, Bβscaled[:, t + 1]) + B̃β[:, t + 1] .= B̃[:, t + 1] .* β[:, t + 1] + mul!(β[:, t], A, B̃β[:, t + 1]) β[:, t] .*= c[t] end - @views Bβscaled[:, 1] .= Bscaled[:, 1] .* β[:, 1] + @views B̃β[:, 1] .= B̃[:, 1] .* β[:, 1] check_no_nan(β) return nothing end -function marginals!(fb::ForwardBackwardStorage, A) - @unpack α, β, c, Bβscaled, γ, ξ = fb +function marginals!(fb::ForwardBackwardStorage, hmm::AbstractHMM) + A = transition_matrix(hmm) + @unpack α, β, c, B̃β, γ, ξ = fb N, T = size(γ) γ .= α .* β ./ c' check_no_nan(γ) @views for t in 1:(T - 1) - ξ[:, :, t] .= α[:, t] .* A .* Bβscaled[:, t + 1]' + ξ[:, :, t] .= α[:, t] .* A .* B̃β[:, t + 1]' end check_no_nan(ξ) return nothing end -function forward_backward!(fb::ForwardBackwardStorage, p, A, logB) - forward!(fb, p, A, logB) - backward!(fb, A, logB) - marginals!(fb, A) +function forward_backward!(fb::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq) + loglikelihoods!(fb.logB, hmm, obs_seq) + scale_likelihoods!(fb) + forward!(fb, hmm) + backward!(fb, hmm) + marginals!(fb, hmm) return nothing end -function forward_backward!(fb::ForwardBackwardStorage, hmm::AbstractHMM, logB) - p = initial_distribution(hmm) - A = transition_matrix(hmm) - return forward_backward!(fb, p, A, logB) -end - -function forward_backward(p, A, logB) - fb = initialize_forward_backward(p, A, logB) - forward_backward!(fb, p, A, logB) - return fb -end - -function forward_backward(hmm::AbstractHMM, logB::Matrix) - p = initial_distribution(hmm) - A = transition_matrix(hmm) - return forward_backward(p, A, logB) -end - """ forward_backward(hmm, obs_seq) @@ -145,8 +140,9 @@ Apply the forward-backward algorithm to estimate the posterior state marginals o Return a [`ForwardBackwardStorage`](@ref). """ function forward_backward(hmm::AbstractHMM, obs_seq) - logB = loglikelihoods(hmm, obs_seq) - return forward_backward(hmm, logB) + fb = initialize_forward_backward(hmm, obs_seq) + forward_backward!(fb, hmm, obs_seq) + return fb end """ @@ -163,9 +159,9 @@ function forward_backward(hmm::AbstractHMM, obs_seqs, nb_seqs::Integer) if nb_seqs != length(obs_seqs) throw(ArgumentError("nb_seqs != length(obs_seqs)")) end - fb1 = forward_backward(hmm, first(obs_seqs)) - fbs = Vector{typeof(fb1)}(undef, nb_seqs) - fbs[1] = fb1 + fb = forward_backward(hmm, first(obs_seqs)) + fbs = Vector{typeof(fb)}(undef, nb_seqs) + fbs[1] = fb @threads for k in 2:nb_seqs fbs[k] = forward_backward(hmm, obs_seqs[k]) end diff --git a/src/inference/loglikelihoods.jl b/src/inference/loglikelihoods.jl index 26592b73..2c09b764 100644 --- a/src/inference/loglikelihoods.jl +++ b/src/inference/loglikelihoods.jl @@ -10,9 +10,9 @@ function loglikelihoods_vec!(logb, hmm::AbstractHMM, obs) end function loglikelihoods_vec(hmm::AbstractHMM, obs) - logb = [logdensityof(obs_distribution(hmm, i), obs) for i in 1:length(hmm)] - check_no_nan(logb) - check_no_inf(logb) + testval = logdensityof(obs_distribution(hmm, 1), obs) + logb = Vector{typeof(testval)}(undef, length(hmm)) + loglikelihoods_vec!(logb, hmm, obs) return logb end @@ -29,10 +29,9 @@ function loglikelihoods!(logB, hmm::AbstractHMM, obs_seq) end function loglikelihoods(hmm::AbstractHMM, obs_seq) + testval = logdensityof(obs_distribution(hmm, 1), obs_seq[1]) T, N = length(obs_seq), length(hmm) - dists = obs_distribution.(Ref(hmm), 1:N) - logB = [logdensityof(dists[i], obs_seq[t]) for i in 1:N, t in 1:T] - check_no_nan(logB) - check_no_inf(logB) + logB = Matrix{typeof(testval)}(undef, N, T) + loglikelihoods!(logB, hmm, obs_seq) return logB end diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index 18c2f558..2ea35f5b 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -1,5 +1,5 @@ """ - AbstractHiddenMarkovModel <: AbstractMarkovChain + AbstractHiddenMarkovModel Abstract supertype for an HMM amenable to simulation, inference and learning. @@ -19,7 +19,7 @@ Abstract supertype for an HMM amenable to simulation, inference and learning. - `forward_backward(hmm, obs_seq)` / `forward_backward(hmm, obs_seqs, nb_seqs)` - `baum_welch(hmm, obs_seq)` / `baum_welch(hmm, obs_seqs, nb_seqs)` if `fit!` is implemented """ -abstract type AbstractHiddenMarkovModel <: AbstractMarkovChain end +abstract type AbstractHiddenMarkovModel end """ AbstractHMM @@ -30,12 +30,28 @@ const AbstractHMM = AbstractHiddenMarkovModel @inline DensityInterface.DensityKind(::AbstractHMM) = HasDensity() -@required AbstractHMM begin - Base.length(::AbstractHMM) - initial_distribution(::AbstractHMM) - transition_matrix(::AbstractHMM) - obs_distribution(::AbstractHMM, ::Integer) -end +## Interface + +""" + length(hmm::AbstractHMM) + +Return the number of states of `hmm`. +""" +Base.length + +""" + initial_distribution(hmm::AbstractHMM) + +Return the initial state probabilities of `hmm`. +""" +function initial_distribution end + +""" + transition_matrix(hmm::AbstractHMM) + +Return the state transition probabilities of `hmm`. +""" +function transition_matrix end """ obs_distribution(hmm::AbstractHMM, i) @@ -48,9 +64,17 @@ The returned object `dist` must implement """ function obs_distribution end +## Sampling + function Base.rand(rng::AbstractRNG, hmm::AbstractHMM, T::Integer) - mc = MarkovChain(hmm) - state_seq = rand(rng, mc, T) + init = initial_distribution(hmm) + trans = transition_matrix(hmm) + first_state = rand(rng, Categorical(init; check_args=false)) + state_seq = Vector{Int}(undef, T) + state_seq[1] = first_state + @views for t in 2:T + state_seq[t] = rand(rng, Categorical(trans[state_seq[t - 1], :]; check_args=false)) + end first_obs = rand(rng, obs_distribution(hmm, first(state_seq))) obs_seq = Vector{typeof(first_obs)}(undef, T) obs_seq[1] = first_obs @@ -60,35 +84,4 @@ function Base.rand(rng::AbstractRNG, hmm::AbstractHMM, T::Integer) return (; state_seq=state_seq, obs_seq=obs_seq) end -function MarkovChain(hmm::AbstractHMM) - return MarkovChain(initial_distribution(hmm), transition_matrix(hmm)) -end - -""" - PermutedHMM{H<:AbstractHMM} - -Wrapper around an `AbstractHMM` that permutes its states. - -This is computationally inefficient and mostly useful for evaluation. - -# Fields - -- `hmm:H`: the old HMM -- `perm::Vector{Int}`: a permutation such that state `i` in the new HMM corresponds to state `perm[i]` in the old. -""" -struct PermutedHMM{H<:AbstractHMM} <: AbstractHMM - hmm::H - perm::Vector{Int} -end - -Base.length(p::PermutedHMM) = length(p.hmm) - -HMMs.initial_distribution(p::PermutedHMM) = initial_distribution(p.hmm)[p.perm] - -function HMMs.transition_matrix(p::PermutedHMM) - return transition_matrix(p.hmm)[p.perm, :][:, p.perm] -end - -function HMMs.obs_distribution(p::PermutedHMM, i::Integer) - return obs_distribution(p.hmm, p.perm[i]) -end +Base.rand(hmm::AbstractHMM, T::Integer) = rand(default_rng(), hmm, T) diff --git a/src/types/abstract_mc.jl b/src/types/abstract_mc.jl deleted file mode 100644 index afebdc54..00000000 --- a/src/types/abstract_mc.jl +++ /dev/null @@ -1,118 +0,0 @@ -""" - AbstractMarkovChain - -Abstract supertype for a Markov chain amenable to simulation, inference and learning. - -# Required interface - -- `initial_distribution(mc)` -- `transition_matrix(mc)` -- `fit!(mc, init_count, trans_count)` (optional) - -# Applicable methods - -- `rand([rng,] mc, T)` -- `logdensityof(mc, state_seq)` -- `fit(mc, state_seq_or_seqs)` (if `fit!` is implemented) -""" -abstract type AbstractMarkovChain end - -""" - AbstractMC - -Alias for the type `AbstractMarkovChain`. -""" -const AbstractMC = AbstractMarkovChain - -@inline DensityInterface.DensityKind(::AbstractMC) = HasDensity() - -@required AbstractMC begin - Base.length(::AbstractMC) - initial_distribution(::AbstractMC) - transition_matrix(::AbstractMC) -end - -""" - length(mc::AbstractMarkovChain) - -Return the number of states of `model`. -""" -Base.length - -""" - initial_distribution(mc::AbstractMarkovChain) - -Return the initial state probabilities of `mc`. -""" -function initial_distribution end - -""" - transition_matrix(mc::AbstractMarkovChain) - -Return the state transition probabilities of `mc`. -""" -function transition_matrix end - -""" - rand([rng=default_rng(),] mc::AbstractMarkovChain, T) - -Simulate `mc` for `T` time steps with a specified `rng`. -""" -Base.rand(mc::AbstractMarkovChain, T::Integer) = rand(default_rng(), mc, T) - -function StatsAPI.fit!(mc::AbstractMC, state_seq::Vector{<:Integer}) - return fit!(mc, [state_seq]) -end - -function StatsAPI.fit!(mc::AbstractMC, state_seqs::Vector{<:Vector{<:Integer}}) - N = length(mc) - init_count = zeros(Int, N) - trans_count = zeros(Int, N, N) - for state_seq in state_seqs - init_count[first(state_seq)] += 1 - for t in 1:(length(state_seq) - 1) - trans_count[state_seq[t], state_seq[t + 1]] += 1 - end - end - return fit!(mc, init_count, trans_count) -end - -""" - fit(mc, state_seq_or_seqs) - -Fit a Markov chain of the same type as `mc` to one or several state sequence(s). - -Beware that `mc` must be an actual object of type `MarkovChain`, and not the type itself as is usually done eg. in Distributions.jl. -""" -function StatsAPI.fit(mc::AbstractMC, state_seq_or_seqs) - mc_est = deepcopy(mc) - fit!(mc_est, state_seq_or_seqs) - return mc_est -end - -function Base.rand(rng::AbstractRNG, mc::AbstractMC, T::Integer) - init = initial_distribution(mc) - trans = transition_matrix(mc) - first_state = rand(rng, Categorical(init; check_args=false)) - state_seq = Vector{Int}(undef, T) - state_seq[1] = first_state - @views for t in 2:T - state_seq[t] = rand(rng, Categorical(trans[state_seq[t - 1], :]; check_args=false)) - end - return state_seq -end - -""" - logdensityof(mc, state_seq) - -Compute the loglikelihood of a single state sequence for a Markov chain. -""" -function DensityInterface.logdensityof(mc::AbstractMC, state_seq::Vector{<:Integer}) - init = initial_distribution(mc) - trans = transition_matrix(mc) - logL = log(init[first(state_seq)]) - for t in 1:(length(state_seq) - 1) - logL += log(trans[state_seq[t], state_seq[t + 1]]) - end - return logL -end diff --git a/src/types/mc.jl b/src/types/mc.jl deleted file mode 100644 index 9c35bd56..00000000 --- a/src/types/mc.jl +++ /dev/null @@ -1,44 +0,0 @@ -""" - MarkovChain <: AbstractMarkovChain - -Basic implementation of a discrete-state Markov chain. - -# Fields - -- `init::AbstractVector`: initial state probabilities -- `trans::AbstractMatrix`: state transition matrix -""" -struct MarkovChain{U<:AbstractVector,M<:AbstractMatrix} <: AbstractMarkovChain - init::U - trans::M - - function MarkovChain(init::U, trans::M) where {U<:AbstractVector,M<:AbstractMatrix} - mc = new{U,M}(init, trans) - check_mc(mc) - return mc - end -end - -""" - MC - -Alias for the type `MarkovChain`. -""" -const MC = MarkovChain - -Base.length(mc::MC) = length(mc.init) -initial_distribution(mc::MC) = mc.init -transition_matrix(mc::MC) = mc.trans - -""" - fit!(mc::MC, init_count, trans_count) - -Update `mc` in-place based on information generated from a state sequence. -""" -function StatsAPI.fit!(mc::MC, init_count, trans_count) - mc.init .= init_count - sum_to_one!(mc.init) - mc.trans .= trans_count - foreach(sum_to_one!, eachrow(mc.trans)) - return nothing -end diff --git a/src/types/permuted_hmm.jl b/src/types/permuted_hmm.jl new file mode 100644 index 00000000..63f8033a --- /dev/null +++ b/src/types/permuted_hmm.jl @@ -0,0 +1,28 @@ +""" + PermutedHMM{H<:AbstractHMM} + +Wrapper around an `AbstractHMM` that permutes its states. + +This is computationally inefficient and mostly useful for evaluation. + +# Fields + +- `hmm:H`: the old HMM +- `perm::Vector{Int}`: a permutation such that state `i` in the new HMM corresponds to state `perm[i]` in the old. +""" +struct PermutedHMM{H<:AbstractHMM} <: AbstractHMM + hmm::H + perm::Vector{Int} +end + +Base.length(p::PermutedHMM) = length(p.hmm) + +initial_distribution(p::PermutedHMM) = initial_distribution(p.hmm)[p.perm] + +function transition_matrix(p::PermutedHMM) + return transition_matrix(p.hmm)[p.perm, :][:, p.perm] +end + +function obs_distribution(p::PermutedHMM, i::Integer) + return obs_distribution(p.hmm, p.perm[i]) +end diff --git a/src/utils/check.jl b/src/utils/check.jl index c72bad29..47487d0e 100644 --- a/src/utils/check.jl +++ b/src/utils/check.jl @@ -48,24 +48,15 @@ function check_dists(dists) end end -function check_mc(mc::MarkovChain) - init = initial_distribution(mc) - trans = transition_matrix(mc) - if !(length(init) == size(trans, 1) == size(trans, 2)) - throw(DimensionMismatch("Incoherent sizes")) - end - check_prob_vec(init) - check_trans_mat(trans) - return nothing -end - function check_hmm(hmm::AbstractHMM) - mc = MarkovChain(hmm) + init = initial_distribution(hmm) + trans = transition_matrix(hmm) dists = [obs_distribution(hmm, i) for i in 1:length(hmm)] - if length(mc) != length(dists) + if !all(==(length(hmm)), [length(init), size(trans, 1), size(trans, 2), length(dists)]) throw(DimensionMismatch("Incoherent sizes")) end - check_mc(mc) + check_prob_vec(init) + check_trans_mat(trans) check_dists(dists) return nothing end diff --git a/test/allocations.jl b/test/allocations.jl index 5e75177f..caf6fcba 100644 --- a/test/allocations.jl +++ b/test/allocations.jl @@ -11,28 +11,29 @@ function test_allocations(hmm; T) @unpack state_seq, obs_seq = rand(hmm, T) ## Forward - logb = HMMs.loglikelihoods_vec(hmm, obs_seq[1]) + logb = HiddenMarkovModels.loglikelihoods_vec(hmm, obs_seq[1]) αₜ = zeros(N) αₜ₊₁ = zeros(N) - allocs = @ballocated HMMs.forward!($αₜ, $αₜ₊₁, $logb, $p, $A, $hmm, $obs_seq) + allocs = @ballocated HiddenMarkovModels.forward!( + $αₜ, $αₜ₊₁, $logb, $p, $A, $hmm, $obs_seq + ) @test allocs == 0 ## Viterbi - logb = HMMs.loglikelihoods_vec(hmm, obs_seq[1]) + logb = HiddenMarkovModels.loglikelihoods_vec(hmm, obs_seq[1]) δₜ = zeros(N) δₜ₋₁ = zeros(N) δA_tmp = zeros(N) ψ = zeros(Int, N, T) q = zeros(Int, T) - allocs = @ballocated HMMs.viterbi!( + allocs = @ballocated HiddenMarkovModels.viterbi!( $q, $δₜ, $δₜ₋₁, $δA_tmp, $ψ, $logb, $p, $A, $hmm, $obs_seq ) @test allocs == 0 ## Forward-backward - logB = HMMs.loglikelihoods(hmm, obs_seq) - fb = HMMs.initialize_forward_backward(p, A, logB) - allocs = @ballocated HMMs.forward_backward!($fb, $p, $A, $logB) + fb = HiddenMarkovModels.initialize_forward_backward(hmm, obs_seq) + allocs = @ballocated HiddenMarkovModels.forward_backward!($fb, $hmm, $obs_seq) @test allocs == 0 end diff --git a/test/dna.jl b/test/dna.jl index b6956ed5..6ed79f78 100644 --- a/test/dna.jl +++ b/test/dna.jl @@ -1,7 +1,6 @@ using DensityInterface using HiddenMarkovModels using Random: AbstractRNG -using RequiredInterfaces: check_interface_implemented using SimpleUnPack using StatsAPI using Test @@ -74,11 +73,11 @@ get_state(coding, nucleotide) = 4(coding - 1) + nucleotide Base.length(dchmm::DNACodingHMM) = 8 -function HMMs.initial_distribution(dchmm::DNACodingHMM) +function HiddenMarkovModels.initial_distribution(dchmm::DNACodingHMM) return repeat(dchmm.cod_init; inner=4) .* repeat(dchmm.nuc_init; outer=2) end -function HMMs.transition_matrix(dchmm::DNACodingHMM) +function HiddenMarkovModels.transition_matrix(dchmm::DNACodingHMM) @unpack cod_trans, nuc_trans = dchmm A = Matrix{Float64}(undef, 8, 8) for c1 in 1:2, n1 in 1:4, c2 in 1:2, n2 in 1:4 @@ -88,7 +87,7 @@ function HMMs.transition_matrix(dchmm::DNACodingHMM) return A end -function HMMs.obs_distribution(::DNACodingHMM, s::Integer) +function HiddenMarkovModels.obs_distribution(::DNACodingHMM, s::Integer) return Dirac(get_nucleotide(s)) end @@ -102,8 +101,8 @@ function StatsAPI.fit!( for n in 1:4 dchmm.nuc_init[n] = sum(init_count[get_state(c, n)] for c in 1:2) end - HMMs.sum_to_one!(dchmm.cod_init) - HMMs.sum_to_one!(dchmm.nuc_init) + HiddenMarkovModels.sum_to_one!(dchmm.cod_init) + HiddenMarkovModels.sum_to_one!(dchmm.nuc_init) # Transitions for c1 in 1:2, c2 in 1:2 @@ -116,15 +115,13 @@ function StatsAPI.fit!( trans_count[get_state(c1, n1), get_state(c2, n2)] for c2 in 1:2 ) end - foreach(HMMs.sum_to_one!, eachrow(dchmm.cod_trans)) - foreach(HMMs.sum_to_one!, eachrow(@view dchmm.nuc_trans[1, :, :])) - foreach(HMMs.sum_to_one!, eachrow(@view dchmm.nuc_trans[2, :, :])) + foreach(HiddenMarkovModels.sum_to_one!, eachrow(dchmm.cod_trans)) + foreach(HiddenMarkovModels.sum_to_one!, eachrow(@view dchmm.nuc_trans[1, :, :])) + foreach(HiddenMarkovModels.sum_to_one!, eachrow(@view dchmm.nuc_trans[2, :, :])) return nothing end -@test check_interface_implemented(AbstractHMM, DNACodingHMM) - dchmm = DNACodingHMM(; cod_init=rand_prob_vec(2), nuc_init=rand_prob_vec(4), @@ -136,14 +133,10 @@ dchmm = DNACodingHMM(; most_likely_coding_seq = get_coding.(viterbi(dchmm, obs_seq)); -mc = MarkovChain(ones(4) ./ 4, rand_trans_mat(4)) -fit!(mc, obs_seq) - dchmm_init = DNACodingHMM(; cod_init=rand(2), nuc_init=rand(4), cod_trans=rand_trans_mat(2), - # using transition_matrix(mc) as initialization below seems to be worse nuc_trans=permutedims(cat(rand_trans_mat(4), rand_trans_mat(4); dims=3), (3, 1, 2)), ); diff --git a/test/doctests.jl b/test/doctests.jl deleted file mode 100644 index 15594db4..00000000 --- a/test/doctests.jl +++ /dev/null @@ -1,4 +0,0 @@ -using Documenter -using HiddenMarkovModels - -doctest(HiddenMarkovModels) diff --git a/test/formatting.jl b/test/formatting.jl deleted file mode 100644 index b3fe015d..00000000 --- a/test/formatting.jl +++ /dev/null @@ -1,5 +0,0 @@ -using HiddenMarkovModels -using JuliaFormatter: format -using Test - -@test format(HiddenMarkovModels; verbose=false, overwrite=false) diff --git a/test/interface.jl b/test/interface.jl deleted file mode 100644 index 0681ca94..00000000 --- a/test/interface.jl +++ /dev/null @@ -1,17 +0,0 @@ -using HiddenMarkovModels -using RequiredInterfaces: check_interface_implemented -using Suppressor -using Test - -struct Empty end - -@test check_interface_implemented(AbstractMC, MarkovChain) -@test check_interface_implemented(AbstractHMM, HMM) -@test check_interface_implemented(AbstractMC, HMM) -@test check_interface_implemented(AbstractHMM, PermutedHMM) - -@suppress begin - @test check_interface_implemented(AbstractMC, Empty) != true - @test check_interface_implemented(AbstractHMM, Empty) != true - @test check_interface_implemented(AbstractHMM, MarkovChain) != true -end diff --git a/test/linting.jl b/test/linting.jl deleted file mode 100644 index aaa7986a..00000000 --- a/test/linting.jl +++ /dev/null @@ -1,4 +0,0 @@ -using HiddenMarkovModels -using JET - -JET.test_package(HiddenMarkovModels; target_defined_modules=true) diff --git a/test/mc.jl b/test/mc.jl deleted file mode 100644 index 2773e087..00000000 --- a/test/mc.jl +++ /dev/null @@ -1,22 +0,0 @@ -using HiddenMarkovModels -using Test - -N = 5 -T = 100 - -p = rand_prob_vec(N) -p_rand = rand_prob_vec(N) - -A = rand_trans_mat(N) -A_rand = rand_trans_mat(N) - -mc = MC(p, A) -mc_rand = MC(p_rand, A_rand) - -state_seq = rand(mc, T) - -mc_est = fit(mc_rand, state_seq) - -@test logdensityof(mc_est, state_seq) > - logdensityof(mc, state_seq) > - logdensityof(mc_rand, state_seq) diff --git a/test/quality.jl b/test/quality.jl deleted file mode 100644 index 0a6a547a..00000000 --- a/test/quality.jl +++ /dev/null @@ -1,4 +0,0 @@ -using Aqua: test_all -using HiddenMarkovModels - -test_all(HiddenMarkovModels; ambiguities=false) diff --git a/test/runtests.jl b/test/runtests.jl index 31c4c16a..6635dd91 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,17 +1,22 @@ +using Aqua: Aqua +using Documenter: Documenter +using HiddenMarkovModels +using JuliaFormatter: JuliaFormatter +using JET: JET using Test @testset verbose = true "HiddenMarkovModels.jl" begin @testset "Code formatting" begin - include("formatting.jl") + @test JuliaFormatter.format(HiddenMarkovModels; verbose=false, overwrite=false) end if VERSION >= v"1.9" @testset "Code quality" begin - include("quality.jl") + Aqua.test_all(HiddenMarkovModels; ambiguities=false) end @testset "Code linting" begin - include("linting.jl") + JET.test_package(HiddenMarkovModels; target_defined_modules=true) end @testset verbose = true "Type stability" begin @@ -24,11 +29,7 @@ using Test end @testset "Interface" begin - include("interface.jl") - end - - @testset "Markov chain" begin - include("mc.jl") + nothing end @testset verbose = true "Correctness" begin @@ -60,6 +61,6 @@ using Test end @testset "Doctests" begin - include("doctests.jl") + Documenter.doctest(HiddenMarkovModels) end end From cf7640252e4df68d23352d6f7906b30503937bbf Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 3 Nov 2023 13:26:03 +0100 Subject: [PATCH 02/14] More changes --- Project.toml | 1 + docs/src/api.md | 4 +- docs/src/tuto_custom.md | 2 +- ext/HiddenMarkovModelsChainRulesCoreExt.jl | 2 +- ext/HiddenMarkovModelsHMMBaseExt.jl | 2 +- src/HiddenMarkovModels.jl | 10 +-- src/inference/baum_welch.jl | 36 ++++++++--- src/inference/forward.jl | 56 ++++++++++++---- src/inference/forward_backward.jl | 52 ++++++++------- src/inference/loglikelihoods.jl | 37 ----------- src/inference/sufficient_stats.jl | 40 ------------ src/inference/viterbi.jl | 74 +++++++++++++++------- src/types/abstract_hmm.jl | 47 ++++++++------ src/types/hmm.jl | 29 +++++---- src/types/permuted_hmm.jl | 13 ++-- src/utils/check.jl | 25 +++++--- src/utils/lightdiagnormal.jl | 9 ++- test/allocations.jl | 2 +- test/correctness.jl | 2 +- test/dna.jl | 2 +- test/permuted.jl | 2 +- 21 files changed, 243 insertions(+), 204 deletions(-) delete mode 100644 src/inference/loglikelihoods.jl delete mode 100644 src/inference/sufficient_stats.jl diff --git a/Project.toml b/Project.toml index 62ccac7a..41874a56 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.4.0" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/docs/src/api.md b/docs/src/api.md index 1376a27e..3488f111 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -19,7 +19,7 @@ HMM ```@docs rand length -initial_distribution +initialization transition_matrix obs_distribution ``` @@ -60,7 +60,7 @@ HiddenMarkovModels.LightDiagNormal ### Models and simulations -- `p` or `init`: initial_distribution (vector of state probabilities) +- `p` or `init`: initialization (vector of state probabilities) - `A` or `trans`: transition_matrix (matrix of transition probabilities) - `dists`: observation distribution (vector of `rand`-able and `logdensityof`-able objects) - `state_seq`: a sequence of states (vector of integers) diff --git a/docs/src/tuto_custom.md b/docs/src/tuto_custom.md index 61f70416..a4a67b92 100644 --- a/docs/src/tuto_custom.md +++ b/docs/src/tuto_custom.md @@ -61,7 +61,7 @@ The interface is only different as far as the initialization is concerned. ```@example tuto Base.length(hmm::EquilibriumHMM) = length(hmm.dists) -HiddenMarkovModels.initial_distribution(hmm::EquilibriumHMM) = markov_equilibrium(hmm.trans) # this is new +HiddenMarkovModels.initialization(hmm::EquilibriumHMM) = markov_equilibrium(hmm.trans) # this is new HiddenMarkovModels.transition_matrix(hmm::EquilibriumHMM) = hmm.trans HiddenMarkovModels.obs_distribution(hmm::EquilibriumHMM, i::Integer) = hmm.dists[i] ``` diff --git a/ext/HiddenMarkovModelsChainRulesCoreExt.jl b/ext/HiddenMarkovModelsChainRulesCoreExt.jl index 134a0ab7..0aab915c 100644 --- a/ext/HiddenMarkovModelsChainRulesCoreExt.jl +++ b/ext/HiddenMarkovModelsChainRulesCoreExt.jl @@ -7,7 +7,7 @@ using HiddenMarkovModels using SimpleUnPack function _params_and_loglikelihoods(hmm::AbstractHMM, obs_seq) - p = initial_distribution(hmm) + p = initialization(hmm) A = transition_matrix(hmm) logB = HiddenMarkovModels.loglikelihoods(hmm, obs_seq) return p, A, logB diff --git a/ext/HiddenMarkovModelsHMMBaseExt.jl b/ext/HiddenMarkovModelsHMMBaseExt.jl index 397ba3bc..ea559a32 100644 --- a/ext/HiddenMarkovModelsHMMBaseExt.jl +++ b/ext/HiddenMarkovModelsHMMBaseExt.jl @@ -11,7 +11,7 @@ function HiddenMarkovModels.HMM(hmm_base::HMMBase.HMM) end function HMMBase.HMM(hmm::HiddenMarkovModels.HMM) - a = deepcopy(initial_distribution(hmm)) + a = deepcopy(initialization(hmm)) A = deepcopy(transition_matrix(hmm)) B = deepcopy(obs_distribution.(Ref(hmm), 1:length(hmm))) return HMMBase.HMM(a, A, B) diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index 9babda50..ab71cfb1 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -3,7 +3,9 @@ A Julia package for HMM modeling, simulation, inference and learning. -The alias `HMMs` is exported for the package name. +# Exports + +$(EXPORTS) """ module HiddenMarkovModels @@ -17,6 +19,7 @@ using Distributions: UnivariateDistribution, MultivariateDistribution, MatrixDistribution +using DocStringExtensions using LinearAlgebra: Diagonal, dot, mul! using PrecompileTools: @compile_workload, @setup_workload using Random: Random, AbstractRNG, default_rng @@ -24,10 +27,10 @@ using Requires: @require using SimpleUnPack: @unpack using StatsAPI: StatsAPI, fit, fit! -export AbstractHiddenMarkovModel, AbstractHMM, PermutedHMM +export AbstractHiddenMarkovModel, AbstractHMM export HiddenMarkovModel, HMM export rand_prob_vec, rand_trans_mat -export initial_distribution, transition_matrix, obs_distribution +export initialization, transition_matrix, obs_distribution export logdensityof, viterbi, forward, forward_backward, baum_welch export fit, fit! @@ -45,7 +48,6 @@ include("inference/loglikelihoods.jl") include("inference/forward.jl") include("inference/viterbi.jl") include("inference/forward_backward.jl") -include("inference/sufficient_stats.jl") include("inference/baum_welch.jl") if !isdefined(Base, :get_extension) diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 1cc2e1d1..45b4ff9c 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -1,16 +1,37 @@ +function initialize_state_marginals(fbs::Vector{ForwardBackwardStorage{R}}) where {R} + N = length(first(fbs)) + T_total = sum(duration, fbs) + state_marginals_concat = Matrix{R}(undef, N, T_total) + return state_marginals_concat +end + +function update_state_marginals!( + state_marginals_concat, fbs::Vector{ForwardBackwardStorage{R}} +) where {R} + T = 1 + for k in eachindex(fbs) + Tk = duration(fbs[k]) + @views state_marginals_concat[:, T:(T + Tk - 1)] .= fbs[k].γ + T += Tk + end + return nothing +end + function baum_welch!( hmm::AbstractHMM, obs_seqs; atol, max_iterations, check_loglikelihood_increasing ) # Pre-allocate nearly all necessary memory - fb = initialize_forward_backward(hmm, obs_seqs[1]) - fbs = Vector{typeof(fb)}(undef, length(obs_seqs)) + fb1 = initialize_forward_backward(hmm, obs_seqs[1]) + fbs = Vector{typeof(fb1)}(undef, length(obs_seqs)) + fbs[1] = fb1 @threads for k in eachindex(obs_seqs, fbs) - fbs[k] = initialize_forward_backward(hmm, obs_seqs[k]) + if k > 2 + fbs[k] = initialize_forward_backward(hmm, obs_seqs[k]) + end end - init_count, trans_count = initialize_states_stats(fbs) - state_marginals_concat = initialize_observations_stats(fbs) obs_seqs_concat = reduce(vcat, obs_seqs) + state_marginals_concat = initialize_state_marginals(fbs) logL_evolution = eltype(fbs[1])[] for iteration in 1:max_iterations @@ -18,13 +39,12 @@ function baum_welch!( @threads for k in eachindex(obs_seqs, fbs) forward_backward!(fbs[k], hmm, obs_seqs[k]) end + update_state_marginals!(state_marginals_concat, fbs) logL = loglikelihood(fbs) push!(logL_evolution, logL) # M step - update_states_stats!(init_count, trans_count, fbs) - update_observations_stats!(state_marginals_concat, fbs) - fit!(hmm, init_count, trans_count, obs_seqs_concat, state_marginals_concat) + fit!(hmm, fbs, obs_seqs_concat, state_marginals_concat) # Stopping criterion if iteration > 1 diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 3815c77d..1c23fd09 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -1,13 +1,50 @@ -function forward!(αₜ, αₜ₊₁, logb, p, A, hmm::AbstractHMM, obs_seq) +""" +$(TYPEDEF) + +Store forward quantities with element type `R`. + +# Fields + +Let `X` denote the vector of hidden states and `Y` denote the vector of observations. + +$(TYPEDFIELDS) +""" +struct ForwardStorage{R} + "vector of observation loglikelihoods `logb[i]`" + logb::Vector{R} + "scaled forward variables `α[t]` proportional to `ℙ(Y[1:t], X[t]=i)` (up to a function of `t`)" + αₜ::Vector{R} + "scaled forward variables `α[t+1]`" + αₜ₊₁::Vector{R} +end + +function initialize_forward(hmm::AbstractHMM, obs_seq) + N = length(hmm) + p = initialization(hmm) + A = transition_matrix(hmm) + d = obs_distributions(hmm) + logb = logdensityof.(d, Ref(obs_seq[1])) + + R = promote_type(eltype(p), eltype(A), eltype(logb)) + αₜ = Vector{R}(undef, N) + αₜ₊₁ = Vector{R}(undef, N) + f = ForwardStorage(logb, αₜ, αₜ₊₁) + return f +end + +function forward!(f::ForwardStorage, hmm::AbstractHMM, obs_seq) + @unpack logb, αₜ, αₜ₊₁ = f + p = initialization(hmm) + A = transition_matrix(hmm) T = length(obs_seq) - loglikelihoods_vec!(logb, hmm, obs_seq[1]) + update_loglikelihoods!(logb, hmm, obs_seq[1]) logm = maximum(logb) αₜ .= p .* exp.(logb .- logm) c = inv(sum(αₜ)) αₜ .*= c logL = -log(c) + logm for t in 1:(T - 1) - loglikelihoods_vec!(logb, hmm, obs_seq[t + 1]) + update_loglikelihoods!(logb, hmm, obs_seq[t + 1]) logm = maximum(logb) mul!(αₜ₊₁, A', αₜ) αₜ₊₁ .*= exp.(logb .- logm) @@ -30,16 +67,9 @@ Return a tuple `(α, logL)` where - `α[i]` is the posterior probability of state `i` at the end of the sequence. """ function forward(hmm::AbstractHMM, obs_seq) - N = length(hmm) - p = initial_distribution(hmm) - A = transition_matrix(hmm) - logb = loglikelihoods_vec(hmm, obs_seq[1]) - - R = promote_type(eltype(p), eltype(A), eltype(logb)) - αₜ = Vector{R}(undef, N) - αₜ₊₁ = Vector{R}(undef, N) - logL = forward!(αₜ, αₜ₊₁, logb, p, A, hmm, obs_seq) - return αₜ, logL + f = initialize_forward(hmm, obs_seq) + logL = forward!(f, hmm, obs_seq) + return f.αₜ, logL end """ diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index b23b9ed7..32a18ba1 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -1,34 +1,34 @@ """ - ForwardBackwardStorage{R} +$(TYPEDEF) Store forward-backward quantities with element type `R`. # Fields -Let `X` denote the vector of hidden states and `Y` denote the vector of observations. The following fields are part of the API: +Let `X` denote the vector of hidden states and `Y` denote the vector of observations. -- `γ::Matrix{R}`: posterior one-state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])` -- `ξ::Array{R,3}`: posterior two-state marginals `ξ[i,j,t] = ℙ(X[t:t+1]=(i,j) | Y[1:T])` +$(TYPEDFIELDS) -The following fields are internals and subject to change: - -- `α::Matrix{R}`: scaled forward variables `α[i,t]` proportional to `ℙ(Y[1:t], X[t]=i)` (up to a function of `t`) -- `β::Matrix{R}`: scaled backward variables `β[i,t]` proportional to `ℙ(Y[t+1:T] | X[t]=i)` (up to a function of `t`) -- `c::Vector{R}`: forward variable inverse normalizations `c[t] = 1 / sum(α[:,t])` -- `logB::Matrix{R}`: observation loglikelihoods `logB[i, t]` -- `logm::Vector{R}`: maximum of the observation loglikelihoods `logm[t] = maximum(logB[:, t])` -- `B̃::Matrix{R}`: numerically stabilized observation likelihoods `B̃[i,t] = exp.(logB[i,t] - logm[t])` -- `B̃β::Matrix{R}`: numerically stabilized product `B̃β[i,t] = B̃[i,t] * β[i,t]` +Only the `γ` and `ξ` fields are part of the public API. """ struct ForwardBackwardStorage{R} + "scaled forward variables `α[i,t]` proportional to `ℙ(Y[1:t], X[t]=i)` (up to a function of `t`)" α::Matrix{R} + "scaled backward variables `β[i,t]` proportional to `ℙ(Y[t+1:T] | X[t]=i)` (up to a function of `t`)" β::Matrix{R} + "posterior one-state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`" γ::Matrix{R} + "posterior two-state marginals `ξ[i,j,t] = ℙ(X[t:t+1]=(i,j) | Y[1:T])`" ξ::Array{R,3} + "forward variable inverse normalizations `c[t] = 1 / sum(α[:,t])`" c::Vector{R} + "observation loglikelihoods `logB[i, t]`" logB::Matrix{R} + "maximum of the observation loglikelihoods `logm[t] = maximum(logB[:, t])`" logm::Vector{R} + "numerically stabilized observation likelihoods `B̃[i,t] = exp.(logB[i,t] - logm[t])`" B̃::Matrix{R} + "numerically stabilized product `B̃β[i,t] = B̃[i,t] * β[i,t]`" B̃β::Matrix{R} end @@ -50,7 +50,7 @@ function loglikelihood(fbs::Vector{ForwardBackwardStorage{R}}) where {R} end function initialize_forward_backward(hmm::AbstractHMM, obs_seq) - p = initial_distribution(hmm) + p = initialization(hmm) A = transition_matrix(hmm) testval = logdensityof(obs_distribution(hmm, 1), obs_seq[1]) R = promote_type(eltype(p), eltype(A), typeof(testval)) @@ -65,18 +65,23 @@ function initialize_forward_backward(hmm::AbstractHMM, obs_seq) logm = Vector{R}(undef, T) B̃ = Matrix{R}(undef, N, T) B̃β = Matrix{R}(undef, N, T) + return ForwardBackwardStorage(α, β, γ, ξ, c, logB, logm, B̃, B̃β) end -function scale_likelihoods!(fb::ForwardBackwardStorage) +function update_likelihoods!(fb::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq) @unpack logB, logm, B̃ = fb + d = obs_distributions(hmm) + for (logb, obs) in zip(eachcol(logB), obs_seq) + logb .= logdensityof(d, Ref(obs)) + end maximum!(logm', logB) B̃ .= exp.(logB .- logm') return nothing end function forward!(fb::ForwardBackwardStorage, hmm::AbstractHMM) - p = initial_distribution(hmm) + p = initialization(hmm) A = transition_matrix(hmm) @unpack α, c, B̃ = fb T = size(α, 2) @@ -124,8 +129,7 @@ function marginals!(fb::ForwardBackwardStorage, hmm::AbstractHMM) end function forward_backward!(fb::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq) - loglikelihoods!(fb.logB, hmm, obs_seq) - scale_likelihoods!(fb) + update_likelihoods!(fb, hmm, obs_seq) forward!(fb, hmm) backward!(fb, hmm) marginals!(fb, hmm) @@ -159,11 +163,13 @@ function forward_backward(hmm::AbstractHMM, obs_seqs, nb_seqs::Integer) if nb_seqs != length(obs_seqs) throw(ArgumentError("nb_seqs != length(obs_seqs)")) end - fb = forward_backward(hmm, first(obs_seqs)) - fbs = Vector{typeof(fb)}(undef, nb_seqs) - fbs[1] = fb - @threads for k in 2:nb_seqs - fbs[k] = forward_backward(hmm, obs_seqs[k]) + fb1 = forward_backward(hmm, first(obs_seqs)) + fbs = Vector{typeof(fb1)}(undef, nb_seqs) + fbs[1] = fb1 + @threads for k in eachindex(obs_seqs, fbs) + if k > 2 + fbs[k] = forward_backward(hmm, obs_seqs[k]) + end end return fbs end diff --git a/src/inference/loglikelihoods.jl b/src/inference/loglikelihoods.jl deleted file mode 100644 index 2c09b764..00000000 --- a/src/inference/loglikelihoods.jl +++ /dev/null @@ -1,37 +0,0 @@ -## Vector - -function loglikelihoods_vec!(logb, hmm::AbstractHMM, obs) - for i in 1:length(hmm) - logb[i] = logdensityof(obs_distribution(hmm, i), obs) - end - check_no_nan(logb) - check_no_inf(logb) - return nothing -end - -function loglikelihoods_vec(hmm::AbstractHMM, obs) - testval = logdensityof(obs_distribution(hmm, 1), obs) - logb = Vector{typeof(testval)}(undef, length(hmm)) - loglikelihoods_vec!(logb, hmm, obs) - return logb -end - -## Matrix - -function loglikelihoods!(logB, hmm::AbstractHMM, obs_seq) - T, N = length(obs_seq), length(hmm) - for t in 1:T, i in 1:N - logB[i, t] = logdensityof(obs_distribution(hmm, i), obs_seq[t]) - end - check_no_nan(logB) - check_no_inf(logB) - return nothing -end - -function loglikelihoods(hmm::AbstractHMM, obs_seq) - testval = logdensityof(obs_distribution(hmm, 1), obs_seq[1]) - T, N = length(obs_seq), length(hmm) - logB = Matrix{typeof(testval)}(undef, N, T) - loglikelihoods!(logB, hmm, obs_seq) - return logB -end diff --git a/src/inference/sufficient_stats.jl b/src/inference/sufficient_stats.jl deleted file mode 100644 index 9bf55c58..00000000 --- a/src/inference/sufficient_stats.jl +++ /dev/null @@ -1,40 +0,0 @@ - -function initialize_states_stats(fbs::Vector{ForwardBackwardStorage{R}}) where {R} - N = length(first(fbs)) - init_count = Vector{R}(undef, N) - trans_count = Matrix{R}(undef, N, N) - return init_count, trans_count -end - -function initialize_observations_stats(fbs::Vector{ForwardBackwardStorage{R}}) where {R} - N = length(first(fbs)) - T_total = sum(duration, fbs) - state_marginals_concat = Matrix{R}(undef, N, T_total) - return state_marginals_concat -end - -function update_states_stats!( - init_count, trans_count, fbs::Vector{ForwardBackwardStorage{R}} -) where {R} - init_count .= zero(R) - for k in eachindex(fbs) - @views init_count .+= fbs[k].γ[:, 1] - end - trans_count .= zero(R) - for k in eachindex(fbs) - sum!(trans_count, fbs[k].ξ; init=false) - end - return nothing -end - -function update_observations_stats!( - state_marginals_concat, fbs::Vector{ForwardBackwardStorage{R}} -) where {R} - T = 1 - for k in eachindex(fbs) - Tk = duration(fbs[k]) - @views state_marginals_concat[:, T:(T + Tk - 1)] .= fbs[k].γ - T += Tk - end - return nothing -end diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index 06b744cf..43d049be 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -1,22 +1,63 @@ -function viterbi!(q, δₜ, δₜ₋₁, δA_tmp, ψ, logb, p, A, hmm::AbstractHMM, obs_seq) +""" +$(TYPEDEF) + +Store Viterbi quantities with element type `R`. + +# Fields + +Let `X` denote the vector of hidden states and `Y` denote the vector of observations. + +$(TYPEDFIELDS) +""" +struct ViterbiStorage{R} + logb::Vector{R} + δₜ::Vector{R} + δₜ₋₁::Vector{R} + δₜ₋₁Aⱼ::Vector{R} + ψ::Matrix{Int} + q::Vector{Int} +end + +function initialize_viterbi(hmm::AbstractHMM, obs_seq) + T, N = length(obs_seq), length(hmm) + p = initialization(hmm) + A = transition_matrix(hmm) + d = obs_distributions(hmm) + logb = logdensityof.(d, Ref(obs_seq[1])) + + R = promote_type(eltype(p), eltype(A), eltype(logb)) + δₜ = Vector{R}(undef, N) + δₜ₋₁ = Vector{R}(undef, N) + δₜ₋₁Aⱼ = Vector{R}(undef, N) + ψ = Matrix{Int}(undef, N, T) + q = Vector{Int}(undef, T) + return ViterbiStorage(logb, δₜ, δₜ₋₁, δₜ₋₁Aⱼ, ψ, q) +end + +function viterbi!(v::ViterbiStorage, hmm::AbstractHMM, obs_seq) + @unpack logb, δₜ, δₜ₋₁, δₜ₋₁Aⱼ, ψ, q = v + p = initialization(hmm) + A = transition_matrix(hmm) + d = obs_distributions(hmm) N, T = length(hmm), length(obs_seq) - loglikelihoods_vec!(logb, hmm, obs_seq[1]) + + logb .= logdensityof(d, Ref(obs_seq[1])) logm = maximum(logb) δₜ .= p .* exp.(logb .- logm) δₜ₋₁ .= δₜ @views ψ[:, 1] .= zero(eltype(ψ)) for t in 2:T - loglikelihoods_vec!(logb, hmm, obs_seq[t]) + logb .= logdensityof(d, Ref(obs_seq[t])) logm = maximum(logb) for j in 1:N - @views δA_tmp .= δₜ₋₁ .* A[:, j] - i_max = argmax(δA_tmp) + @views δₜ₋₁Aⱼ .= δₜ₋₁ .* A[:, j] + i_max = argmax(δₜ₋₁Aⱼ) ψ[j, t] = i_max - δₜ[j] = δA_tmp[i_max] * exp(logb[j] - logm) + δₜ[j] = δₜ₋₁Aⱼ[i_max] * exp(logb[j] - logm) end δₜ₋₁ .= δₜ end - @views q[T] = argmax(δₜ) + q[T] = argmax(δₜ) for t in (T - 1):-1:1 q[t] = ψ[q[t + 1], t + 1] end @@ -31,20 +72,9 @@ Apply the Viterbi algorithm to compute the most likely state sequence of an HMM. Return a vector of integers. """ function viterbi(hmm::AbstractHMM, obs_seq) - T, N = length(obs_seq), length(hmm) - p = initial_distribution(hmm) - A = transition_matrix(hmm) - logb = loglikelihoods_vec(hmm, obs_seq[1]) - - R = promote_type(eltype(p), eltype(A), eltype(logb)) - δₜ = Vector{R}(undef, N) - δₜ₋₁ = Vector{R}(undef, N) - δA_tmp = Vector{R}(undef, N) - ψ = Matrix{Int}(undef, N, T) - q = Vector{Int}(undef, T) - - viterbi!(q, δₜ, δₜ₋₁, δA_tmp, ψ, logb, p, A, hmm, obs_seq) - return q + v = initialize_viterbi(hmm, obs_seq) + viterbi!(v, hmm, obs_seq) + return v.q end """ @@ -62,7 +92,7 @@ function viterbi(hmm::AbstractHMM, obs_seqs, nb_seqs::Integer) throw(ArgumentError("nb_seqs != length(obs_seqs)")) end qs = Vector{Vector{Int}}(undef, nb_seqs) - @threads for k in 1:nb_seqs + @threads for k in eachindex(qs, obs_seqs) qs[k] = viterbi(hmm, obs_seqs[k]) end return qs diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index 2ea35f5b..d80e8c8b 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -3,12 +3,12 @@ Abstract supertype for an HMM amenable to simulation, inference and learning. -# Required interface +# Interface -- `initial_distribution(hmm)` -- `transition_matrix(hmm)` -- `obs_distribution(hmm, i)` -- `fit!(hmm, init_count, trans_count, obs_seq, state_marginals)` (optional) +- [`initialization`](@ref) +- [`transition_matrix`](@ref) +- [`obs_distributions`](@ref) +- [`fit!](@ref) # Applicable methods @@ -33,53 +33,60 @@ const AbstractHMM = AbstractHiddenMarkovModel ## Interface """ - length(hmm::AbstractHMM) + length(hmm::AbstractHMM) Return the number of states of `hmm`. """ Base.length """ - initial_distribution(hmm::AbstractHMM) + initialization(hmm::AbstractHMM) -Return the initial state probabilities of `hmm`. +Return the vector of initial state probabilities for `hmm`. """ -function initial_distribution end +function initialization end """ transition_matrix(hmm::AbstractHMM) -Return the state transition probabilities of `hmm`. +Return the matrix of state transition probabilities for `hmm`. """ function transition_matrix end """ - obs_distribution(hmm::AbstractHMM, i) + obs_distributions(hmm::AbstractHMM) -Return the observation distribution of `hmm` associated with state `i`. +Return a vector of observation distributions for `hmm`. -The returned object `dist` must implement +Each element `dist` of this vector must implement - `rand(rng, dist)` -- `DensityInterface.logdensityof(dist, x)` +- `DensityInterface.logdensityof(dist, obs)` """ -function obs_distribution end +function obs_distributions end ## Sampling +""" + rand([rng,] hmm, T) + +Simulate `hmm` for `T` time steps. +""" function Base.rand(rng::AbstractRNG, hmm::AbstractHMM, T::Integer) - init = initial_distribution(hmm) - trans = transition_matrix(hmm) - first_state = rand(rng, Categorical(init; check_args=false)) + p = initialization(hmm) + A = transition_matrix(hmm) + d = obs_distributions(hmm) + + first_state = rand(rng, Categorical(p; check_args=false)) state_seq = Vector{Int}(undef, T) state_seq[1] = first_state @views for t in 2:T - state_seq[t] = rand(rng, Categorical(trans[state_seq[t - 1], :]; check_args=false)) + state_seq[t] = rand(rng, Categorical(A[state_seq[t - 1], :]; check_args=false)) end first_obs = rand(rng, obs_distribution(hmm, first(state_seq))) obs_seq = Vector{typeof(first_obs)}(undef, T) obs_seq[1] = first_obs for t in 2:T - obs_seq[t] = rand(rng, obs_distribution(hmm, state_seq[t])) + obs_seq[t] = rand(rng, d[state_seq[t]]) end return (; state_seq=state_seq, obs_seq=obs_seq) end diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 17845bb9..ddc0718c 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -1,18 +1,19 @@ """ - HiddenMarkovModel{D} <: AbstractHiddenMarkovModel +$(TYPEDEF) Basic implementation of an HMM. # Fields -- `init::AbstractVector`: initial state probabilities -- `trans::AbstractMatrix`: state transition matrix -- `dists::AbstractVector{D}`: observation distributions +$(TYPEDFIELDS) """ struct HiddenMarkovModel{D,U<:AbstractVector,M<:AbstractMatrix,V<:AbstractVector{D}} <: AbstractHMM + "initial state probabilities" init::U + "state transition matrix" trans::M + "observation distributions" dists::V function HiddenMarkovModel( @@ -36,22 +37,28 @@ function Base.copy(hmm::HMM) end Base.length(hmm::HMM) = length(hmm.init) -initial_distribution(hmm::HMM) = hmm.init +initialization(hmm::HMM) = hmm.init transition_matrix(hmm::HMM) = hmm.trans -obs_distribution(hmm::HMM, i::Integer) = hmm.dists[i] +obs_distributions(hmm::HMM) = hmm.dists """ - fit!(hmm::HMM, init_count, trans_count, obs_seq, state_marginals) + fit!(hmm::HMM, obs_seq, fbs, obs_seqs_concat, state_marginals_concat) Update `hmm` in-place based on information generated during forward-backward. """ -function StatsAPI.fit!(hmm::HMM, init_count, trans_count, obs_seq, state_marginals) - hmm.init .= init_count +function StatsAPI.fit!(hmm::HMM, fbs, obs_seqs_concat, state_marginals_concat) + hmm.init .= zero(eltype(hmm.init)) + for k in eachindex(fbs) + @views hmm.init .+= fbs[k].γ[:, 1] + end sum_to_one!(hmm.init) - hmm.trans .= trans_count + hmm.trans .= zero(eltype(hmm.trans)) + for k in eachindex(fbs) + sum!(hmm.trans, fbs[k].ξ; init=false) + end foreach(sum_to_one!, eachrow(hmm.trans)) @views for i in eachindex(hmm.dists) - fit_element_from_sequence!(hmm.dists, i, obs_seq, state_marginals[i, :]) + fit_element_from_sequence!(hmm.dists, i, obs_seqs_concat, state_marginals_concat[i, :]) end check_hmm(hmm) return nothing diff --git a/src/types/permuted_hmm.jl b/src/types/permuted_hmm.jl index 63f8033a..4d1f53de 100644 --- a/src/types/permuted_hmm.jl +++ b/src/types/permuted_hmm.jl @@ -1,5 +1,5 @@ """ - PermutedHMM{H<:AbstractHMM} +$(TYPEDEF) Wrapper around an `AbstractHMM` that permutes its states. @@ -7,22 +7,23 @@ This is computationally inefficient and mostly useful for evaluation. # Fields -- `hmm:H`: the old HMM -- `perm::Vector{Int}`: a permutation such that state `i` in the new HMM corresponds to state `perm[i]` in the old. +$(TYPEDFIELDS) """ struct PermutedHMM{H<:AbstractHMM} <: AbstractHMM + "the old HMM" hmm::H + "a permutation such that state `i` in the new HMM corresponds to state `perm[i]` in the old" perm::Vector{Int} end Base.length(p::PermutedHMM) = length(p.hmm) -initial_distribution(p::PermutedHMM) = initial_distribution(p.hmm)[p.perm] +initialization(p::PermutedHMM) = initialization(p.hmm)[p.perm] function transition_matrix(p::PermutedHMM) return transition_matrix(p.hmm)[p.perm, :][:, p.perm] end -function obs_distribution(p::PermutedHMM, i::Integer) - return obs_distribution(p.hmm, p.perm[i]) +function obs_distributions(p::PermutedHMM) + return obs_distributions(p.hmm)[p.perm] end diff --git a/src/utils/check.jl b/src/utils/check.jl index 47487d0e..76438d9f 100644 --- a/src/utils/check.jl +++ b/src/utils/check.jl @@ -40,23 +40,28 @@ function check_coherent_sizes(p::AbstractVector, A::AbstractMatrix) end end -function check_dists(dists) - for i in eachindex(dists) - if DensityKind(dists[i]) == NoDensity() +function check_dists(d) + for i in eachindex(d) + if DensityKind(d[i]) == NoDensity() throw(ArgumentError("Observation is not a density")) end end end +""" + check_hmm(hmm::AbstractHMM) + +Verify that `hmm` satisfies basic assumptions. +""" function check_hmm(hmm::AbstractHMM) - init = initial_distribution(hmm) - trans = transition_matrix(hmm) - dists = [obs_distribution(hmm, i) for i in 1:length(hmm)] - if !all(==(length(hmm)), [length(init), size(trans, 1), size(trans, 2), length(dists)]) + p = initialization(hmm) + A = transition_matrix(hmm) + d = [obs_distribution(hmm, i) for i in 1:length(hmm)] + if !all(==(length(hmm)), [length(p), size(A, 1), size(A, 2), length(d)]) throw(DimensionMismatch("Incoherent sizes")) end - check_prob_vec(init) - check_trans_mat(trans) - check_dists(dists) + check_prob_vec(p) + check_trans_mat(A) + check_dists(d) return nothing end diff --git a/src/utils/lightdiagnormal.jl b/src/utils/lightdiagnormal.jl index 850f0089..0974b69a 100644 --- a/src/utils/lightdiagnormal.jl +++ b/src/utils/lightdiagnormal.jl @@ -1,15 +1,22 @@ """ - LightDiagNormal +$(TYPEDEF) An HMMs-compatible implementation of a multivariate normal distribution with diagonal covariance, enabling allocation-free estimation. This is not part of the public API and is expected to change. + +# Fields + +$(TYPEDFIELDS) """ struct LightDiagNormal{ T1,T2,T3,V1<:AbstractVector{T1},V2<:AbstractVector{T2},V3<:AbstractVector{T3} } + "vector of means" μ::V1 + "vector of standard deviations" σ::V2 + "vector of log standard deviations" logσ::V3 end diff --git a/test/allocations.jl b/test/allocations.jl index caf6fcba..c0f6e8a5 100644 --- a/test/allocations.jl +++ b/test/allocations.jl @@ -6,7 +6,7 @@ using SimpleUnPack using Test function test_allocations(hmm; T) - p = initial_distribution(hmm) + p = initialization(hmm) A = transition_matrix(hmm) @unpack state_seq, obs_seq = rand(hmm, T) diff --git a/test/correctness.jl b/test/correctness.jl index 1fd9fa7e..03185530 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -48,7 +48,7 @@ function test_correctness(hmm, hmm_init; T) @test isapprox( logL_evolution[(begin + 1):end], logL_evolution_base[begin:(end - 1)] ) - @test isapprox(initial_distribution(hmm_est), hmm_est_base.a) + @test isapprox(initialization(hmm_est), hmm_est_base.a) @test isapprox(transition_matrix(hmm_est), hmm_est_base.A) for (dist, dist_base) in zip(hmm.dists, hmm_base.B) diff --git a/test/dna.jl b/test/dna.jl index 6ed79f78..08932e8a 100644 --- a/test/dna.jl +++ b/test/dna.jl @@ -73,7 +73,7 @@ get_state(coding, nucleotide) = 4(coding - 1) + nucleotide Base.length(dchmm::DNACodingHMM) = 8 -function HiddenMarkovModels.initial_distribution(dchmm::DNACodingHMM) +function HiddenMarkovModels.initialization(dchmm::DNACodingHMM) return repeat(dchmm.cod_init; inner=4) .* repeat(dchmm.nuc_init; outer=2) end diff --git a/test/permuted.jl b/test/permuted.jl index b8056704..053e2441 100644 --- a/test/permuted.jl +++ b/test/permuted.jl @@ -10,7 +10,7 @@ hmm = HMM(p, A, dists) perm = [3, 1, 2] hmm_perm = PermutedHMM(hmm, perm) -p_perm = initial_distribution(hmm_perm) +p_perm = initialization(hmm_perm) A_perm = transition_matrix(hmm_perm) dists_perm = [obs_distribution(hmm_perm, i) for i in 1:3] From 75ca6f12a9ddd95f2cb1a9f55548e6f01534d774 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 3 Nov 2023 14:23:26 +0100 Subject: [PATCH 03/14] Correctness and type stability --- ext/HiddenMarkovModelsHMMBaseExt.jl | 2 +- src/HiddenMarkovModels.jl | 3 +-- src/inference/baum_welch.jl | 28 +++++++++++----------------- src/inference/forward.jl | 27 +++++++++++++++++---------- src/inference/forward_backward.jl | 11 ++++++++--- src/inference/viterbi.jl | 13 +++++++------ src/types/abstract_hmm.jl | 2 +- src/types/hmm.jl | 4 +++- src/utils/check.jl | 2 +- test/logarithmic.jl | 1 + test/permuted.jl | 2 +- 11 files changed, 52 insertions(+), 43 deletions(-) diff --git a/ext/HiddenMarkovModelsHMMBaseExt.jl b/ext/HiddenMarkovModelsHMMBaseExt.jl index ea559a32..9488107f 100644 --- a/ext/HiddenMarkovModelsHMMBaseExt.jl +++ b/ext/HiddenMarkovModelsHMMBaseExt.jl @@ -13,7 +13,7 @@ end function HMMBase.HMM(hmm::HiddenMarkovModels.HMM) a = deepcopy(initialization(hmm)) A = deepcopy(transition_matrix(hmm)) - B = deepcopy(obs_distribution.(Ref(hmm), 1:length(hmm))) + B = deepcopy(obs_distributions(hmm)) return HMMBase.HMM(a, A, B) end diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index ab71cfb1..88796d65 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -30,7 +30,7 @@ using StatsAPI: StatsAPI, fit, fit! export AbstractHiddenMarkovModel, AbstractHMM export HiddenMarkovModel, HMM export rand_prob_vec, rand_trans_mat -export initialization, transition_matrix, obs_distribution +export initialization, transition_matrix, obs_distributions export logdensityof, viterbi, forward, forward_backward, baum_welch export fit, fit! @@ -44,7 +44,6 @@ include("utils/transmat.jl") include("utils/fit.jl") include("utils/lightdiagnormal.jl") -include("inference/loglikelihoods.jl") include("inference/forward.jl") include("inference/viterbi.jl") include("inference/forward_backward.jl") diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 45b4ff9c..912f08cc 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -1,10 +1,3 @@ -function initialize_state_marginals(fbs::Vector{ForwardBackwardStorage{R}}) where {R} - N = length(first(fbs)) - T_total = sum(duration, fbs) - state_marginals_concat = Matrix{R}(undef, N, T_total) - return state_marginals_concat -end - function update_state_marginals!( state_marginals_concat, fbs::Vector{ForwardBackwardStorage{R}} ) where {R} @@ -21,27 +14,28 @@ function baum_welch!( hmm::AbstractHMM, obs_seqs; atol, max_iterations, check_loglikelihood_increasing ) # Pre-allocate nearly all necessary memory - fb1 = initialize_forward_backward(hmm, obs_seqs[1]) + fb1 = forward_backward(hmm, obs_seqs[1]) fbs = Vector{typeof(fb1)}(undef, length(obs_seqs)) fbs[1] = fb1 @threads for k in eachindex(obs_seqs, fbs) - if k > 2 - fbs[k] = initialize_forward_backward(hmm, obs_seqs[k]) + if k > 1 + fbs[k] = forward_backward(hmm, obs_seqs[k]) end end obs_seqs_concat = reduce(vcat, obs_seqs) - state_marginals_concat = initialize_state_marginals(fbs) - logL_evolution = eltype(fbs[1])[] + state_marginals_concat = reduce(hcat, fb.γ for fb in fbs) + logL_evolution = [loglikelihood(fbs)] for iteration in 1:max_iterations # E step - @threads for k in eachindex(obs_seqs, fbs) - forward_backward!(fbs[k], hmm, obs_seqs[k]) + if iteration > 1 + @threads for k in eachindex(obs_seqs, fbs) + forward_backward!(fbs[k], hmm, obs_seqs[k]) + end + update_state_marginals!(state_marginals_concat, fbs) + push!(logL_evolution, loglikelihood(fbs)) end - update_state_marginals!(state_marginals_concat, fbs) - logL = loglikelihood(fbs) - push!(logL_evolution, logL) # M step fit!(hmm, fbs, obs_seqs_concat, state_marginals_concat) diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 1c23fd09..f4f068ff 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -23,9 +23,10 @@ function initialize_forward(hmm::AbstractHMM, obs_seq) p = initialization(hmm) A = transition_matrix(hmm) d = obs_distributions(hmm) - logb = logdensityof.(d, Ref(obs_seq[1])) + testval = logdensityof(d[1], obs_seq[1]) + R = promote_type(eltype(p), eltype(A), typeof(testval)) - R = promote_type(eltype(p), eltype(A), eltype(logb)) + logb = Vector{R}(undef, N) αₜ = Vector{R}(undef, N) αₜ₊₁ = Vector{R}(undef, N) f = ForwardStorage(logb, αₜ, αₜ₊₁) @@ -33,18 +34,20 @@ function initialize_forward(hmm::AbstractHMM, obs_seq) end function forward!(f::ForwardStorage, hmm::AbstractHMM, obs_seq) - @unpack logb, αₜ, αₜ₊₁ = f + T = length(obs_seq) p = initialization(hmm) A = transition_matrix(hmm) - T = length(obs_seq) - update_loglikelihoods!(logb, hmm, obs_seq[1]) + d = obs_distributions(hmm) + @unpack logb, αₜ, αₜ₊₁ = f + + logb .= logdensityof.(d, Ref(obs_seq[1])) logm = maximum(logb) αₜ .= p .* exp.(logb .- logm) c = inv(sum(αₜ)) αₜ .*= c logL = -log(c) + logm for t in 1:(T - 1) - update_loglikelihoods!(logb, hmm, obs_seq[t + 1]) + logb .= logdensityof.(d, Ref(obs_seq[t + 1])) logm = maximum(logb) mul!(αₜ₊₁, A', αₜ) αₜ₊₁ .*= exp.(logb .- logm) @@ -92,8 +95,10 @@ function forward(hmm::AbstractHMM, obs_seqs, nb_seqs::Integer) f1 = forward(hmm, first(obs_seqs)) fs = Vector{typeof(f1)}(undef, nb_seqs) fs[1] = f1 - @threads for k in 2:nb_seqs - fs[k] = forward(hmm, obs_seqs[k]) + @threads for k in eachindex(fs, obs_seqs) + if k > 1 + fs[k] = forward(hmm, obs_seqs[k]) + end end return fs end @@ -126,8 +131,10 @@ function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seqs, nb_seqs::Inte logL1 = logdensityof(hmm, first(obs_seqs)) logLs = Vector{typeof(logL1)}(undef, nb_seqs) logLs[1] = logL1 - @threads for k in 2:nb_seqs - logLs[k] = logdensityof(hmm, obs_seqs[k]) + @threads for k in eachindex(logLs, obs_seqs) + if k > 1 + logLs[k] = logdensityof(hmm, obs_seqs[k]) + end end return sum(logLs) end diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index 32a18ba1..80462d05 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -52,7 +52,8 @@ end function initialize_forward_backward(hmm::AbstractHMM, obs_seq) p = initialization(hmm) A = transition_matrix(hmm) - testval = logdensityof(obs_distribution(hmm, 1), obs_seq[1]) + d = obs_distributions(hmm) + testval = logdensityof(d[1], obs_seq[1]) R = promote_type(eltype(p), eltype(A), typeof(testval)) N, T = length(hmm), length(obs_seq) @@ -70,10 +71,11 @@ function initialize_forward_backward(hmm::AbstractHMM, obs_seq) end function update_likelihoods!(fb::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq) - @unpack logB, logm, B̃ = fb d = obs_distributions(hmm) + @unpack logB, logm, B̃ = fb + for (logb, obs) in zip(eachcol(logB), obs_seq) - logb .= logdensityof(d, Ref(obs)) + logb .= logdensityof.(d, Ref(obs)) end maximum!(logm', logB) B̃ .= exp.(logB .- logm') @@ -85,6 +87,7 @@ function forward!(fb::ForwardBackwardStorage, hmm::AbstractHMM) A = transition_matrix(hmm) @unpack α, c, B̃ = fb T = size(α, 2) + @views begin α[:, 1] .= p .* B̃[:, 1] c[1] = inv(sum(α[:, 1])) @@ -104,6 +107,7 @@ function backward!(fb::ForwardBackwardStorage{R}, hmm::AbstractHMM) where {R} A = transition_matrix(hmm) @unpack β, c, B̃, B̃β = fb T = size(β, 2) + β[:, T] .= c[T] @views for t in (T - 1):-1:1 B̃β[:, t + 1] .= B̃[:, t + 1] .* β[:, t + 1] @@ -119,6 +123,7 @@ function marginals!(fb::ForwardBackwardStorage, hmm::AbstractHMM) A = transition_matrix(hmm) @unpack α, β, c, B̃β, γ, ξ = fb N, T = size(γ) + γ .= α .* β ./ c' check_no_nan(γ) @views for t in 1:(T - 1) diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index 43d049be..30e9b910 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -23,9 +23,10 @@ function initialize_viterbi(hmm::AbstractHMM, obs_seq) p = initialization(hmm) A = transition_matrix(hmm) d = obs_distributions(hmm) - logb = logdensityof.(d, Ref(obs_seq[1])) + testval = logdensityof(d[1], obs_seq[1]) + R = promote_type(eltype(p), eltype(A), typeof(testval)) - R = promote_type(eltype(p), eltype(A), eltype(logb)) + logb = Vector{R}(undef, N) δₜ = Vector{R}(undef, N) δₜ₋₁ = Vector{R}(undef, N) δₜ₋₁Aⱼ = Vector{R}(undef, N) @@ -35,19 +36,19 @@ function initialize_viterbi(hmm::AbstractHMM, obs_seq) end function viterbi!(v::ViterbiStorage, hmm::AbstractHMM, obs_seq) - @unpack logb, δₜ, δₜ₋₁, δₜ₋₁Aⱼ, ψ, q = v + N, T = length(hmm), length(obs_seq) p = initialization(hmm) A = transition_matrix(hmm) d = obs_distributions(hmm) - N, T = length(hmm), length(obs_seq) + @unpack logb, δₜ, δₜ₋₁, δₜ₋₁Aⱼ, ψ, q = v - logb .= logdensityof(d, Ref(obs_seq[1])) + logb .= logdensityof.(d, Ref(obs_seq[1])) logm = maximum(logb) δₜ .= p .* exp.(logb .- logm) δₜ₋₁ .= δₜ @views ψ[:, 1] .= zero(eltype(ψ)) for t in 2:T - logb .= logdensityof(d, Ref(obs_seq[t])) + logb .= logdensityof.(d, Ref(obs_seq[t])) logm = maximum(logb) for j in 1:N @views δₜ₋₁Aⱼ .= δₜ₋₁ .* A[:, j] diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index d80e8c8b..bc06449a 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -82,7 +82,7 @@ function Base.rand(rng::AbstractRNG, hmm::AbstractHMM, T::Integer) @views for t in 2:T state_seq[t] = rand(rng, Categorical(A[state_seq[t - 1], :]; check_args=false)) end - first_obs = rand(rng, obs_distribution(hmm, first(state_seq))) + first_obs = rand(rng, d[state_seq[1]]) obs_seq = Vector{typeof(first_obs)}(undef, T) obs_seq[1] = first_obs for t in 2:T diff --git a/src/types/hmm.jl b/src/types/hmm.jl index ddc0718c..51967fee 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -58,7 +58,9 @@ function StatsAPI.fit!(hmm::HMM, fbs, obs_seqs_concat, state_marginals_concat) end foreach(sum_to_one!, eachrow(hmm.trans)) @views for i in eachindex(hmm.dists) - fit_element_from_sequence!(hmm.dists, i, obs_seqs_concat, state_marginals_concat[i, :]) + fit_element_from_sequence!( + hmm.dists, i, obs_seqs_concat, state_marginals_concat[i, :] + ) end check_hmm(hmm) return nothing diff --git a/src/utils/check.jl b/src/utils/check.jl index 76438d9f..2497ea56 100644 --- a/src/utils/check.jl +++ b/src/utils/check.jl @@ -56,7 +56,7 @@ Verify that `hmm` satisfies basic assumptions. function check_hmm(hmm::AbstractHMM) p = initialization(hmm) A = transition_matrix(hmm) - d = [obs_distribution(hmm, i) for i in 1:length(hmm)] + d = obs_distributions(hmm) if !all(==(length(hmm)), [length(p), size(A, 1), size(A, 2), length(d)]) throw(DimensionMismatch("Incoherent sizes")) end diff --git a/test/logarithmic.jl b/test/logarithmic.jl index 4129e355..bfef42dc 100644 --- a/test/logarithmic.jl +++ b/test/logarithmic.jl @@ -1,5 +1,6 @@ using Distributions using HiddenMarkovModels +using HiddenMarkovModels: LightDiagNormal using LinearAlgebra using LogarithmicNumbers using SimpleUnPack diff --git a/test/permuted.jl b/test/permuted.jl index 053e2441..6f088910 100644 --- a/test/permuted.jl +++ b/test/permuted.jl @@ -12,7 +12,7 @@ perm = [3, 1, 2] hmm_perm = PermutedHMM(hmm, perm) p_perm = initialization(hmm_perm) A_perm = transition_matrix(hmm_perm) -dists_perm = [obs_distribution(hmm_perm, i) for i in 1:3] +dists_perm = obs_distributions(hmm_perm) for i in 1:3 @test p_perm[i] ≈ p[perm[i]] From d8e78e31b8677af4d166dd0f5973d4c69cbd5505 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 3 Nov 2023 14:42:59 +0100 Subject: [PATCH 04/14] Allocs --- .github/workflows/docs-benchmark.yml | 3 +-- Project.toml | 1 + docs/make.jl | 9 ++++++++- ext/HiddenMarkovModelsChainRulesCoreExt.jl | 3 ++- test/allocations.jl | 19 ++++--------------- 5 files changed, 16 insertions(+), 19 deletions(-) diff --git a/.github/workflows/docs-benchmark.yml b/.github/workflows/docs-benchmark.yml index 8f654c37..07435238 100644 --- a/.github/workflows/docs-benchmark.yml +++ b/.github/workflows/docs-benchmark.yml @@ -26,8 +26,7 @@ jobs: using Pkg Pkg.instantiate() include("benchmark/run_benchmarks.jl") - include("benchmark/process_benchmarks.jl") - ENV["HMM_BENCHMARKS_DONE"] = true' + include("benchmark/process_benchmarks.jl")' - uses: julia-actions/julia-docdeploy@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/Project.toml b/Project.toml index 41874a56..4c752a6e 100644 --- a/Project.toml +++ b/Project.toml @@ -27,6 +27,7 @@ HiddenMarkovModelsHMMBaseExt = "HMMBase" ChainRulesCore = "1.16" DensityInterface = "0.4" Distributions = "0.25" +DocStringExtensions = "0.9" LinearAlgebra = "1.6" PrecompileTools = "1.1" Random = "1.6" diff --git a/docs/make.jl b/docs/make.jl index e738e8bc..1fff0700 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -22,7 +22,14 @@ open(joinpath(joinpath(@__DIR__, "src"), "index.md"), "w") do io end end -alt_pages = if get(ENV, "HMM_BENCHMARKS_DONE", false) +benchmarks_done = ( + length(readdir(joinpath(@__DIR__, "src", "assets", "benchmark", "plots"))) > + 1 & # plots present + length(readdir(joinpath(@__DIR__, "src", "assets", "benchmark", "results"))) > + 1 # results present +) + +alt_pages = if benchmarks_done ["Features" => "alt_features.md", "Performance" => "alt_performance.md"] else ["Features" => "alt_features.md"] diff --git a/ext/HiddenMarkovModelsChainRulesCoreExt.jl b/ext/HiddenMarkovModelsChainRulesCoreExt.jl index 0aab915c..7feaebb9 100644 --- a/ext/HiddenMarkovModelsChainRulesCoreExt.jl +++ b/ext/HiddenMarkovModelsChainRulesCoreExt.jl @@ -9,7 +9,8 @@ using SimpleUnPack function _params_and_loglikelihoods(hmm::AbstractHMM, obs_seq) p = initialization(hmm) A = transition_matrix(hmm) - logB = HiddenMarkovModels.loglikelihoods(hmm, obs_seq) + d = obs_distributions(hmm) + logB = reduce(hcat, logdensityof.(d, Ref(obs)) for obs in obs_seq) return p, A, logB end diff --git a/test/allocations.jl b/test/allocations.jl index c0f6e8a5..6a21e2ca 100644 --- a/test/allocations.jl +++ b/test/allocations.jl @@ -11,24 +11,13 @@ function test_allocations(hmm; T) @unpack state_seq, obs_seq = rand(hmm, T) ## Forward - logb = HiddenMarkovModels.loglikelihoods_vec(hmm, obs_seq[1]) - αₜ = zeros(N) - αₜ₊₁ = zeros(N) - allocs = @ballocated HiddenMarkovModels.forward!( - $αₜ, $αₜ₊₁, $logb, $p, $A, $hmm, $obs_seq - ) + f = HiddenMarkovModels.initialize_forward(hmm, obs_seq) + allocs = @ballocated HiddenMarkovModels.forward!($f, $hmm, $obs_seq) @test allocs == 0 ## Viterbi - logb = HiddenMarkovModels.loglikelihoods_vec(hmm, obs_seq[1]) - δₜ = zeros(N) - δₜ₋₁ = zeros(N) - δA_tmp = zeros(N) - ψ = zeros(Int, N, T) - q = zeros(Int, T) - allocs = @ballocated HiddenMarkovModels.viterbi!( - $q, $δₜ, $δₜ₋₁, $δA_tmp, $ψ, $logb, $p, $A, $hmm, $obs_seq - ) + v = HiddenMarkovModels.initialize_viterbi(hmm, obs_seq) + allocs = @ballocated HiddenMarkovModels.viterbi!($v, $hmm, $obs_seq) @test allocs == 0 ## Forward-backward From 02a6a3799966d33203c32749e11eff81736d8982 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 3 Nov 2023 18:48:03 +0100 Subject: [PATCH 05/14] BaumWelch storage --- docs/src/api.md | 3 + src/HiddenMarkovModels.jl | 1 + src/inference/baum_welch.jl | 120 ++++++++++++++++-------------- src/inference/forward.jl | 54 +++++++------- src/inference/forward_backward.jl | 38 ++++------ src/inference/viterbi.jl | 10 +-- src/types/abstract_hmm.jl | 17 ++--- src/types/hmm.jl | 5 +- 8 files changed, 123 insertions(+), 125 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 0581ab5c..aa0fef2a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -45,7 +45,10 @@ baum_welch ## Internals ```@docs +HiddenMarkovModels.ForwardStorage +HiddenMarkovModels.ViterbiStorage HiddenMarkovModels.ForwardBackwardStorage +HiddenMarkovModels.BaumWelchStorage HiddenMarkovModels.fit_element_from_sequence! HiddenMarkovModels.LightDiagNormal ``` diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index d8e626d3..7b540bf1 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -9,6 +9,7 @@ $(EXPORTS) """ module HiddenMarkovModels +using Base: RefValue using Base.Threads: @threads using DensityInterface: DensityInterface, DensityKind, HasDensity, NoDensity, densityof, logdensityof diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 912f08cc..56c11ba5 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -1,48 +1,62 @@ -function update_state_marginals!( - state_marginals_concat, fbs::Vector{ForwardBackwardStorage{R}} -) where {R} - T = 1 - for k in eachindex(fbs) - Tk = duration(fbs[k]) - @views state_marginals_concat[:, T:(T + Tk - 1)] .= fbs[k].γ - T += Tk - end - return nothing +""" +$(TYPEDEF) + +Store Baum-Welch quantities with element type `R` and observation type `O`. + +# Fields + +$(TYPEDFIELDS) +""" +struct BaumWelchStorage{R,O} + fbs::Vector{ForwardBackwardStorage{R}} + logL_evolution::Vector{R} + state_marginals_concat::Matrix{R} + obs_seqs_concat::Vector{O} + limits::Vector{Int} end -function baum_welch!( - hmm::AbstractHMM, obs_seqs; atol, max_iterations, check_loglikelihood_increasing -) - # Pre-allocate nearly all necessary memory - fb1 = forward_backward(hmm, obs_seqs[1]) - fbs = Vector{typeof(fb1)}(undef, length(obs_seqs)) - fbs[1] = fb1 +function initialize_baum_welch(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}; max_iterations) + N, T = length(hmm), sum(length, obs_seqs) + R = eltype(hmm, obs_seqs[1][1]) + fbs = Vector{ForwardBackwardStorage{R}}(undef, length(obs_seqs)) @threads for k in eachindex(obs_seqs, fbs) - if k > 1 - fbs[k] = forward_backward(hmm, obs_seqs[k]) - end + fbs[k] = initialize_forward_backward(hmm, obs_seqs[k]) end - + logL_evolution = Vector{R}(undef, max_iterations) + state_marginals_concat = Matrix{R}(undef, N, T) obs_seqs_concat = reduce(vcat, obs_seqs) - state_marginals_concat = reduce(hcat, fb.γ for fb in fbs) - logL_evolution = [loglikelihood(fbs)] + limits = vcat(0, cumsum(length.(obs_seqs))) + return BaumWelchStorage( + fbs, logL_evolution, state_marginals_concat, obs_seqs_concat, limits + ) +end + +function baum_welch!( + hmm::AbstractHMM, + bw::BaumWelchStorage, + obs_seqs::Vector{<:Vector}; + atol::Real, + max_iterations::Integer, + check_loglikelihood_increasing::Bool, +) + @unpack fbs, logL_evolution, state_marginals_concat, obs_seqs_concat, limits = bw - for iteration in 1:max_iterations + iteration = 0 + while iteration < max_iterations # E step - if iteration > 1 - @threads for k in eachindex(obs_seqs, fbs) - forward_backward!(fbs[k], hmm, obs_seqs[k]) - end - update_state_marginals!(state_marginals_concat, fbs) - push!(logL_evolution, loglikelihood(fbs)) + @threads for k in eachindex(obs_seqs, fbs) + forward_backward!(fbs[k], hmm, obs_seqs[k]) + @views state_marginals_concat[:, (limits[k] + 1):limits[k + 1]] .= fbs[k].γ end + logL_evolution[iteration] = sum(fb.logL[] for fb in fbs) # M step - fit!(hmm, fbs, obs_seqs_concat, state_marginals_concat) + fit!(hmm, bw, obs_seqs) # Stopping criterion + iteration += 1 if iteration > 1 - progress = logL_evolution[end] - logL_evolution[end - 1] + progress = logL_evolution[iteration] - logL_evolution[iteration - 1] if check_loglikelihood_increasing && progress < 0 error("Loglikelihood decreased in Baum-Welch") elseif progress < atol @@ -51,19 +65,23 @@ function baum_welch!( end end - return logL_evolution + logL_evolution = logL_evolution[1:iteration] + return nothing end """ baum_welch( - hmm_init, obs_seq; + hmm_init, obs_seqs, nb_seqs; atol, max_iterations, check_loglikelihood_increasing ) -Apply the Baum-Welch algorithm to estimate the parameters of an HMM starting from `hmm_init`. +Apply the Baum-Welch algorithm to estimate the parameters of an HMM starting from `hmm_init`, based on `nb_seqs` observation sequences. Return a tuple `(hmm_est, logL_evolution)`. +!!! warning "Multithreading" + This function is parallelized across sequences. + # Keyword arguments - `atol`: Minimum loglikelihood increase at an iteration of the algorithm (otherwise the algorithm is deemed to have converged) @@ -72,31 +90,31 @@ Return a tuple `(hmm_est, logL_evolution)`. """ function baum_welch( hmm_init::AbstractHMM, - obs_seq; + obs_seqs::Vector{<:Vector}, + nb_seqs::Integer; atol=1e-5, max_iterations=100, check_loglikelihood_increasing=true, ) + if nb_seqs != length(obs_seqs) + throw(ArgumentError("nb_seqs != length(obs_seqs)")) + end hmm = deepcopy(hmm_init) - logL_evolution = baum_welch!( - hmm, [obs_seq]; atol, max_iterations, check_loglikelihood_increasing - ) - return hmm, logL_evolution + bw = initialize_baum_welch(hmm, obs_seqs; max_iterations) + baum_welch!(hmm, bw, obs_seqs; atol, max_iterations, check_loglikelihood_increasing) + return hmm, bw.logL_evolution end """ baum_welch( - hmm_init, obs_seqs, nb_seqs; + hmm_init, obs_seq; atol, max_iterations, check_loglikelihood_increasing ) -Apply the Baum-Welch algorithm to estimate the parameters of an HMM starting from `hmm_init`, based on `nb_seqs` observation sequences. +Apply the Baum-Welch algorithm to estimate the parameters of an HMM starting from `hmm_init`. Return a tuple `(hmm_est, logL_evolution)`. -!!! warning "Multithreading" - This function is parallelized across sequences. - # Keyword arguments - `atol`: Minimum loglikelihood increase at an iteration of the algorithm (otherwise the algorithm is deemed to have converged) @@ -105,18 +123,12 @@ Return a tuple `(hmm_est, logL_evolution)`. """ function baum_welch( hmm_init::AbstractHMM, - obs_seqs, - nb_seqs::Integer; + obs_seq; atol=1e-5, max_iterations=100, check_loglikelihood_increasing=true, ) - if nb_seqs != length(obs_seqs) - throw(ArgumentError("nb_seqs != length(obs_seqs)")) - end - hmm = deepcopy(hmm_init) - logL_evolution = baum_welch!( - hmm, obs_seqs; atol, max_iterations, check_loglikelihood_increasing + return baum_welch( + hmm_init, [obs_seq]; atol, max_iterations, check_loglikelihood_increasing ) - return hmm, logL_evolution end diff --git a/src/inference/forward.jl b/src/inference/forward.jl index d7118f51..131c073a 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -5,31 +5,32 @@ Store forward quantities with element type `R`. # Fields -Let `X` denote the vector of hidden states and `Y` denote the vector of observations. - $(TYPEDFIELDS) """ struct ForwardStorage{R} + "total loglikelihood" + logL::RefValue{R} "vector of observation loglikelihoods `logb[i]`" logb::Vector{R} - "scaled forward variables `α[t]` proportional to `ℙ(Y[1:t], X[t]=i)` (up to a function of `t`)" + "scaled forward variables `α[t]`" αₜ::Vector{R} "scaled forward variables `α[t+1]`" αₜ₊₁::Vector{R} end -function initialize_forward(hmm::AbstractHMM, obs_seq) +function initialize_forward(hmm::AbstractHMM, obs_seq::Vector) N = length(hmm) R = eltype(hmm, obs_seq[1]) + logL = RefValue{R}(zero(R)) logb = Vector{R}(undef, N) αₜ = Vector{R}(undef, N) αₜ₊₁ = Vector{R}(undef, N) - f = ForwardStorage(logb, αₜ, αₜ₊₁) + f = ForwardStorage(logL, logb, αₜ, αₜ₊₁) return f end -function forward!(f::ForwardStorage, hmm::AbstractHMM, obs_seq) +function forward!(f::ForwardStorage, hmm::AbstractHMM, obs_seq::Vector) T = length(obs_seq) p = initialization(hmm) A = transition_matrix(hmm) @@ -41,7 +42,7 @@ function forward!(f::ForwardStorage, hmm::AbstractHMM, obs_seq) αₜ .= p .* exp.(logb .- logm) c = inv(sum(αₜ)) αₜ .*= c - logL = -log(c) + logm + logL[] = -log(c) + logm for t in 1:(T - 1) logb .= logdensityof.(d, (obs_seq[t + 1],)) logm = maximum(logb) @@ -50,9 +51,9 @@ function forward!(f::ForwardStorage, hmm::AbstractHMM, obs_seq) c = inv(sum(αₜ₊₁)) αₜ₊₁ .*= c αₜ .= αₜ₊₁ - logL += -log(c) + logm + logL[] += -log(c) + logm end - return logL + return nothing end """ @@ -65,10 +66,10 @@ Return a tuple `(α, logL)` where - `logL` is the loglikelihood of the sequence - `α[i]` is the posterior probability of state `i` at the end of the sequence. """ -function forward(hmm::AbstractHMM, obs_seq) +function forward(hmm::AbstractHMM, obs_seq::Vector) f = initialize_forward(hmm, obs_seq) - logL = forward!(f, hmm, obs_seq) - return f.αₜ, logL + forward!(f, hmm, obs_seq) + return f.αₜ, f.logL[] end """ @@ -84,17 +85,14 @@ Return a vector of tuples `(αₖ, logLₖ)`, where !!! warning "Multithreading" This function is parallelized across sequences. """ -function forward(hmm::AbstractHMM, obs_seqs, nb_seqs::Integer) +function forward(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer) if nb_seqs != length(obs_seqs) throw(ArgumentError("nb_seqs != length(obs_seqs)")) end - f1 = forward(hmm, first(obs_seqs)) - fs = Vector{typeof(f1)}(undef, nb_seqs) - fs[1] = f1 + R = eltype(hmm, obs_seqs[1][1]) + fs = Vector{ForwardStorage{R}}(undef, nb_seqs) @threads for k in eachindex(fs, obs_seqs) - if k > 1 - fs[k] = forward(hmm, obs_seqs[k]) - end + fs[k] = forward(hmm, obs_seqs[k]) end return fs end @@ -106,8 +104,9 @@ Apply the forward algorithm to compute the loglikelihood of a single observation Return a number. """ -function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seq) - return last(forward(hmm, obs_seq)) +function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seq::Vector) + _, logL = forward(hmm, obs_seq) + return logL end """ @@ -120,17 +119,16 @@ Return a number. !!! warning "Multithreading" This function is parallelized across sequences. """ -function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seqs, nb_seqs::Integer) +function DensityInterface.logdensityof( + hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer +) if nb_seqs != length(obs_seqs) throw(ArgumentError("nb_seqs != length(obs_seqs)")) end - logL1 = logdensityof(hmm, first(obs_seqs)) - logLs = Vector{typeof(logL1)}(undef, nb_seqs) - logLs[1] = logL1 + R = eltype(hmm, obs_seqs[1][1]) + logLs = Vector{R}(undef, nb_seqs) @threads for k in eachindex(logLs, obs_seqs) - if k > 1 - logLs[k] = logdensityof(hmm, obs_seqs[k]) - end + logLs[k] = logdensityof(hmm, obs_seqs[k]) end return sum(logLs) end diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index 29770579..b4e1b8b9 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -12,6 +12,8 @@ $(TYPEDFIELDS) Only the `γ` and `ξ` fields are part of the public API. """ struct ForwardBackwardStorage{R} + "total loglikelihood" + logL::RefValue{R} "scaled forward variables `α[i,t]` proportional to `ℙ(Y[1:t], X[t]=i)` (up to a function of `t`)" α::Matrix{R} "scaled backward variables `β[i,t]` proportional to `ℙ(Y[t+1:T] | X[t]=i)` (up to a function of `t`)" @@ -36,23 +38,11 @@ Base.eltype(::ForwardBackwardStorage{R}) where {R} = R Base.length(fb::ForwardBackwardStorage) = size(fb.α, 1) duration(fb::ForwardBackwardStorage) = size(fb.α, 2) -function loglikelihood(fb::ForwardBackwardStorage{R}) where {R} - logL = -sum(log, fb.c) + sum(fb.logm) - return logL -end - -function loglikelihood(fbs::Vector{ForwardBackwardStorage{R}}) where {R} - logL = zero(R) - for fb in fbs - logL += loglikelihood(fb) - end - return logL -end - -function initialize_forward_backward(hmm::AbstractHMM, obs_seq) +function initialize_forward_backward(hmm::AbstractHMM, obs_seq::Vector) N, T = length(hmm), length(obs_seq) R = eltype(hmm, obs_seq[1]) + logL = RefValue{R}(zero(R)) α = Matrix{R}(undef, N, T) β = Matrix{R}(undef, N, T) γ = Matrix{R}(undef, N, T) @@ -63,10 +53,10 @@ function initialize_forward_backward(hmm::AbstractHMM, obs_seq) B̃ = Matrix{R}(undef, N, T) B̃β = Matrix{R}(undef, N, T) - return ForwardBackwardStorage(α, β, γ, ξ, c, logB, logm, B̃, B̃β) + return ForwardBackwardStorage(logL, α, β, γ, ξ, c, logB, logm, B̃, B̃β) end -function update_likelihoods!(fb::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq) +function update_likelihoods!(fb::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq::Vector) d = obs_distributions(hmm) @unpack logB, logm, B̃ = fb @@ -96,6 +86,7 @@ function forward!(fb::ForwardBackwardStorage, hmm::AbstractHMM) α[:, t + 1] .*= c[t + 1] end check_no_nan(α) + fb.logL[] = -sum(log, fb.c) + sum(fb.logm) return nothing end @@ -129,7 +120,7 @@ function marginals!(fb::ForwardBackwardStorage, hmm::AbstractHMM) return nothing end -function forward_backward!(fb::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq) +function forward_backward!(fb::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq::Vector) update_likelihoods!(fb, hmm, obs_seq) forward!(fb, hmm) backward!(fb, hmm) @@ -144,7 +135,7 @@ Apply the forward-backward algorithm to estimate the posterior state marginals o Return a [`ForwardBackwardStorage`](@ref). """ -function forward_backward(hmm::AbstractHMM, obs_seq) +function forward_backward(hmm::AbstractHMM, obs_seq::Vector) fb = initialize_forward_backward(hmm, obs_seq) forward_backward!(fb, hmm, obs_seq) return fb @@ -160,17 +151,14 @@ Return a vector of [`ForwardBackwardStorage`](@ref) objects. !!! warning "Multithreading" This function is parallelized across sequences. """ -function forward_backward(hmm::AbstractHMM, obs_seqs, nb_seqs::Integer) +function forward_backward(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer) if nb_seqs != length(obs_seqs) throw(ArgumentError("nb_seqs != length(obs_seqs)")) end - fb1 = forward_backward(hmm, first(obs_seqs)) - fbs = Vector{typeof(fb1)}(undef, nb_seqs) - fbs[1] = fb1 + R = eltype(hmm, obs_seqs[1][1]) + fbs = Vector{ForwardBackwardStorage{R}}(undef, nb_seqs) @threads for k in eachindex(obs_seqs, fbs) - if k > 2 - fbs[k] = forward_backward(hmm, obs_seqs[k]) - end + fbs[k] = forward_backward(hmm, obs_seqs[k]) end return fbs end diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index 74393487..fc73388d 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -5,8 +5,6 @@ Store Viterbi quantities with element type `R`. # Fields -Let `X` denote the vector of hidden states and `Y` denote the vector of observations. - $(TYPEDFIELDS) """ struct ViterbiStorage{R} @@ -18,7 +16,7 @@ struct ViterbiStorage{R} q::Vector{Int} end -function initialize_viterbi(hmm::AbstractHMM, obs_seq) +function initialize_viterbi(hmm::AbstractHMM, obs_seq::Vector) T, N = length(obs_seq), length(hmm) R = eltype(hmm, obs_seq[1]) @@ -31,7 +29,7 @@ function initialize_viterbi(hmm::AbstractHMM, obs_seq) return ViterbiStorage(logb, δₜ, δₜ₋₁, δₜ₋₁Aⱼ, ψ, q) end -function viterbi!(v::ViterbiStorage, hmm::AbstractHMM, obs_seq) +function viterbi!(v::ViterbiStorage, hmm::AbstractHMM, obs_seq::Vector) N, T = length(hmm), length(obs_seq) p = initialization(hmm) A = transition_matrix(hmm) @@ -68,7 +66,7 @@ Apply the Viterbi algorithm to compute the most likely state sequence of an HMM. Return a vector of integers. """ -function viterbi(hmm::AbstractHMM, obs_seq) +function viterbi(hmm::AbstractHMM, obs_seq::Vector) v = initialize_viterbi(hmm, obs_seq) viterbi!(v, hmm, obs_seq) return v.q @@ -84,7 +82,7 @@ Return a vector of vectors of integers. !!! warning "Multithreading" This function is parallelized across sequences. """ -function viterbi(hmm::AbstractHMM, obs_seqs, nb_seqs::Integer) +function viterbi(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer) if nb_seqs != length(obs_seqs) throw(ArgumentError("nb_seqs != length(obs_seqs)")) end diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index a98bed18..c6cbd077 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -35,14 +35,14 @@ const AbstractHMM = AbstractHiddenMarkovModel ## Interface """ - length(hmm::AbstractHMM) + length(hmm) Return the number of states of `hmm`. """ Base.length """ - eltype(hmm::AbstractHMM, obs) + eltype(hmm, obs) Return a type that can accommodate forward-backward computations on observations similar to `obs`. It is typicall a promotion between the element type of the initialization, the element type of the transition matrix, and the type of an observation logdensity evaluated at `obs`. @@ -55,21 +55,21 @@ function Base.eltype(hmm::AbstractHMM, obs) end """ - initialization(hmm::AbstractHMM) + initialization(hmm) Return the vector of initial state probabilities for `hmm`. """ function initialization end """ - transition_matrix(hmm::AbstractHMM) + transition_matrix(hmm) Return the matrix of state transition probabilities for `hmm`. """ function transition_matrix end """ - obs_distributions(hmm::AbstractHMM) + obs_distributions(hmm) Return a vector of observation distributions for `hmm`. @@ -80,12 +80,7 @@ Each element `dist` of this vector must implement function obs_distributions end """ - fit!( - hmm::AbstractHMM, - fbs::Vector{<:ForwardBackwardStorage}, - obs_seqs_concat, - state_marginals_concat::Matrix - ) + fit!(hmm, bw::BaumWelchStorage, obs_seqs) Update `hmm` in-place based on information generated during forward-backward. """ diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 2d771a5c..44bb41d1 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -50,7 +50,10 @@ transition_matrix(hmm::HMM) = hmm.trans obs_distributions(hmm::HMM) = hmm.dists function StatsAPI.fit!( - hmm::HMM, fbs::Vector{<:ForwardBackwardStorage}, obs_seqs_concat, state_marginals_concat + hmm::HMM, + fbs::Vector{<:ForwardBackwardStorage}, + obs_seqs_concat::Vector, + state_marginals_concat::Matrix, ) # Initialization hmm.init .= zero(eltype(hmm.init)) From 2d343e27d28484fab25e2016c6518a91751d9f62 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 6 Nov 2023 14:21:35 +0100 Subject: [PATCH 06/14] Finish BW storage --- ext/HiddenMarkovModelsChainRulesCoreExt.jl | 7 +-- src/inference/baum_welch.jl | 52 +++++++++++++------- src/inference/forward.jl | 39 ++++++++------- src/inference/forward_backward.jl | 42 +++++++++++----- src/inference/viterbi.jl | 23 +++++++-- src/types/hmm.jl | 14 ++---- src/utils/check.jl | 2 +- test/allocations.jl | 33 +++++++++---- test/autodiff.jl | 2 +- test/correctness.jl | 25 +++++----- test/dna.jl | 4 +- test/logarithmic.jl | 4 +- test/sparse.jl | 4 +- test/static.jl | 2 +- test/type_stability.jl | 57 +++++++++++++--------- 15 files changed, 193 insertions(+), 117 deletions(-) diff --git a/ext/HiddenMarkovModelsChainRulesCoreExt.jl b/ext/HiddenMarkovModelsChainRulesCoreExt.jl index 9f13cd2a..a0bfd1b1 100644 --- a/ext/HiddenMarkovModelsChainRulesCoreExt.jl +++ b/ext/HiddenMarkovModelsChainRulesCoreExt.jl @@ -4,6 +4,7 @@ using ChainRulesCore: ChainRulesCore, NoTangent, ZeroTangent, RuleConfig, rrule_via_ad, @not_implemented using DensityInterface: logdensityof using HiddenMarkovModels +import HiddenMarkovModels as HMMs using SimpleUnPack function _params_and_loglikelihoods(hmm::AbstractHMM, obs_seq) @@ -18,8 +19,8 @@ function ChainRulesCore.rrule( rc::RuleConfig, ::typeof(logdensityof), hmm::AbstractHMM, obs_seq ) (p, A, logB), pullback = rrule_via_ad(rc, _params_and_loglikelihoods, hmm, obs_seq) - fb = forward_backward(hmm, obs_seq) - logL = HiddenMarkovModels.loglikelihood(fb) + fb = HMMs.initialize_forward_backward(hmm, obs_seq) + HMMs.forward_backward!(fb, hmm, obs_seq) @unpack α, β, γ, c, B̃β = fb T = length(obs_seq) @@ -36,7 +37,7 @@ function ChainRulesCore.rrule( return Δlogdensityof, Δhmm, Δobs_seq end - return logL, logdensityof_hmm_pullback + return fb.logL[], logdensityof_hmm_pullback end end diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 56c11ba5..eb94fd43 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -3,31 +3,46 @@ $(TYPEDEF) Store Baum-Welch quantities with element type `R` and observation type `O`. +Unlike the other storage types, this one is relative to multiple sequences. + # Fields $(TYPEDFIELDS) """ struct BaumWelchStorage{R,O} + "one `ForwardBackwardStorage` for each observation sequence" fbs::Vector{ForwardBackwardStorage{R}} + "number of iterations performed" + iteration::RefValue{Int} + "history of total loglikelihood values throughout the algorithm" logL_evolution::Vector{R} + "concatenation of `γ` matrices for all observation sequences (useful to avoid allocations in fitting)" state_marginals_concat::Matrix{R} + "concatenation of observation sequences (useful to avoid allocations in fitting)" obs_seqs_concat::Vector{O} + "temporal limits of each observation sequence in the concatenations" limits::Vector{Int} end -function initialize_baum_welch(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}; max_iterations) +function initialize_baum_welch( + hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer; max_iterations::Integer +) + if nb_seqs != length(obs_seqs) + throw(ArgumentError("nb_seqs != length(obs_seqs)")) + end N, T = length(hmm), sum(length, obs_seqs) R = eltype(hmm, obs_seqs[1][1]) fbs = Vector{ForwardBackwardStorage{R}}(undef, length(obs_seqs)) @threads for k in eachindex(obs_seqs, fbs) fbs[k] = initialize_forward_backward(hmm, obs_seqs[k]) end + iteration = Ref(0) logL_evolution = Vector{R}(undef, max_iterations) state_marginals_concat = Matrix{R}(undef, N, T) obs_seqs_concat = reduce(vcat, obs_seqs) limits = vcat(0, cumsum(length.(obs_seqs))) return BaumWelchStorage( - fbs, logL_evolution, state_marginals_concat, obs_seqs_concat, limits + fbs, iteration, logL_evolution, state_marginals_concat, obs_seqs_concat, limits ) end @@ -39,24 +54,28 @@ function baum_welch!( max_iterations::Integer, check_loglikelihood_increasing::Bool, ) - @unpack fbs, logL_evolution, state_marginals_concat, obs_seqs_concat, limits = bw + @unpack ( + fbs, iteration, logL_evolution, state_marginals_concat, obs_seqs_concat, limits + ) = bw + iteration[] = 0 - iteration = 0 - while iteration < max_iterations + while iteration[] < max_iterations # E step @threads for k in eachindex(obs_seqs, fbs) forward_backward!(fbs[k], hmm, obs_seqs[k]) @views state_marginals_concat[:, (limits[k] + 1):limits[k + 1]] .= fbs[k].γ end - logL_evolution[iteration] = sum(fb.logL[] for fb in fbs) # M step - fit!(hmm, bw, obs_seqs) + fit!(hmm, bw) + + # # Record likelihood + iteration[] += 1 + logL_evolution[iteration[]] = sum(fb.logL[] for fb in fbs) - # Stopping criterion - iteration += 1 - if iteration > 1 - progress = logL_evolution[iteration] - logL_evolution[iteration - 1] + # # Stopping criterion + if iteration[] > 1 + progress = logL_evolution[iteration[]] - logL_evolution[iteration[] - 1] if check_loglikelihood_increasing && progress < 0 error("Loglikelihood decreased in Baum-Welch") elseif progress < atol @@ -64,9 +83,6 @@ function baum_welch!( end end end - - logL_evolution = logL_evolution[1:iteration] - return nothing end """ @@ -100,9 +116,9 @@ function baum_welch( throw(ArgumentError("nb_seqs != length(obs_seqs)")) end hmm = deepcopy(hmm_init) - bw = initialize_baum_welch(hmm, obs_seqs; max_iterations) + bw = initialize_baum_welch(hmm, obs_seqs, nb_seqs; max_iterations) baum_welch!(hmm, bw, obs_seqs; atol, max_iterations, check_loglikelihood_increasing) - return hmm, bw.logL_evolution + return hmm, bw.logL_evolution[1:bw.iteration[]] end """ @@ -123,12 +139,12 @@ Return a tuple `(hmm_est, logL_evolution)`. """ function baum_welch( hmm_init::AbstractHMM, - obs_seq; + obs_seq::Vector; atol=1e-5, max_iterations=100, check_loglikelihood_increasing=true, ) return baum_welch( - hmm_init, [obs_seq]; atol, max_iterations, check_loglikelihood_increasing + hmm_init, [obs_seq], 1; atol, max_iterations, check_loglikelihood_increasing ) end diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 131c073a..c5017bb5 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -3,8 +3,12 @@ $(TYPEDEF) Store forward quantities with element type `R`. +This storage is relative to a single sequence. + # Fields +These fields are not part of the public API. + $(TYPEDFIELDS) """ struct ForwardStorage{R} @@ -35,7 +39,7 @@ function forward!(f::ForwardStorage, hmm::AbstractHMM, obs_seq::Vector) p = initialization(hmm) A = transition_matrix(hmm) d = obs_distributions(hmm) - @unpack logb, αₜ, αₜ₊₁ = f + @unpack logL, logb, αₜ, αₜ₊₁ = f logb .= logdensityof.(d, (obs_seq[1],)) logm = maximum(logb) @@ -56,6 +60,15 @@ function forward!(f::ForwardStorage, hmm::AbstractHMM, obs_seq::Vector) return nothing end +function forward!( + fs::Vector{<:ForwardStorage}, hmm::AbstractHMM, obs_seqs::Vector{<:Vector} +) + @threads for k in eachindex(fs, obs_seqs) + forward!(fs[k], hmm, obs_seqs[k]) + end + return nothing +end + """ forward(hmm, obs_seq) @@ -63,8 +76,8 @@ Apply the forward algorithm to an HMM. Return a tuple `(α, logL)` where +- `α[i]` is the posterior probability of state `i` at the end of the sequence - `logL` is the loglikelihood of the sequence -- `α[i]` is the posterior probability of state `i` at the end of the sequence. """ function forward(hmm::AbstractHMM, obs_seq::Vector) f = initialize_forward(hmm, obs_seq) @@ -79,8 +92,8 @@ Apply the forward algorithm to an HMM, based on multiple observation sequences. Return a vector of tuples `(αₖ, logLₖ)`, where -- `logLₖ` is the loglikelihood of sequence `k` - `αₖ[i]` is the posterior probability of state `i` at the end of sequence `k` +- `logLₖ` is the loglikelihood of sequence `k` !!! warning "Multithreading" This function is parallelized across sequences. @@ -89,12 +102,9 @@ function forward(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer) if nb_seqs != length(obs_seqs) throw(ArgumentError("nb_seqs != length(obs_seqs)")) end - R = eltype(hmm, obs_seqs[1][1]) - fs = Vector{ForwardStorage{R}}(undef, nb_seqs) - @threads for k in eachindex(fs, obs_seqs) - fs[k] = forward(hmm, obs_seqs[k]) - end - return fs + fs = [initialize_forward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] + forward!(fs, hmm, obs_seqs) + return [(f.αₜ, f.logL[]) for f in fs] end """ @@ -122,13 +132,6 @@ Return a number. function DensityInterface.logdensityof( hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer ) - if nb_seqs != length(obs_seqs) - throw(ArgumentError("nb_seqs != length(obs_seqs)")) - end - R = eltype(hmm, obs_seqs[1][1]) - logLs = Vector{R}(undef, nb_seqs) - @threads for k in eachindex(logLs, obs_seqs) - logLs[k] = logdensityof(hmm, obs_seqs[k]) - end - return sum(logLs) + logαs_and_logLs = forward(hmm, obs_seqs, nb_seqs) + return sum(last, logαs_and_logLs) end diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index b4e1b8b9..371a0ad7 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -3,13 +3,15 @@ $(TYPEDEF) Store forward-backward quantities with element type `R`. +This storage is relative to a single sequence. + # Fields +These fields are not part of the public API. + Let `X` denote the vector of hidden states and `Y` denote the vector of observations. $(TYPEDFIELDS) - -Only the `γ` and `ξ` fields are part of the public API. """ struct ForwardBackwardStorage{R} "total loglikelihood" @@ -18,9 +20,9 @@ struct ForwardBackwardStorage{R} α::Matrix{R} "scaled backward variables `β[i,t]` proportional to `ℙ(Y[t+1:T] | X[t]=i)` (up to a function of `t`)" β::Matrix{R} - "posterior one-state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`" + "posterior state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`" γ::Matrix{R} - "posterior two-state marginals `ξ[i,j,t] = ℙ(X[t:t+1]=(i,j) | Y[1:T])`" + "posterior transition marginals `ξ[i,j,t] = ℙ(X[t:t+1]=(i,j) | Y[1:T])`" ξ::Array{R,3} "forward variable inverse normalizations `c[t] = 1 / sum(α[:,t])`" c::Vector{R} @@ -128,17 +130,30 @@ function forward_backward!(fb::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq return nothing end +function forward_backward!( + fbs::Vector{<:ForwardBackwardStorage}, hmm::AbstractHMM, obs_seqs::Vector{<:Vector} +) + @threads for k in eachindex(fbs, obs_seqs) + forward_backward!(fbs[k], hmm, obs_seqs[k]) + end + return nothing +end + """ forward_backward(hmm, obs_seq) Apply the forward-backward algorithm to estimate the posterior state marginals of an HMM. -Return a [`ForwardBackwardStorage`](@ref). +Return a tuple `(γ, ξ, logL)` where + +- `γ` is a matrix containing the posterior state marginals `γ[i, t]` +- `ξ` is a 3-tensor containing the posterior transition marginals `ξ[i, j, t]` +- `logL` is the loglikelihood of the sequence """ function forward_backward(hmm::AbstractHMM, obs_seq::Vector) fb = initialize_forward_backward(hmm, obs_seq) forward_backward!(fb, hmm, obs_seq) - return fb + return (fb.γ, fb.ξ, fb.logL[]) end """ @@ -146,7 +161,11 @@ end Apply the forward-backward algorithm to estimate the posterior state marginals of an HMM, based on multiple observation sequences. -Return a vector of [`ForwardBackwardStorage`](@ref) objects. +Return a vector of tuples `(γₖ, ξₖ, logLₖ)` where + +- `γₖ` is a matrix containing the posterior state marginals `γₖ[i, t]` for sequence `k` +- `ξₖ` is a 3-tensor containing the posterior transition marginals `ξ[i, j, t]` for sequence `k` +- `logLₖ` is the loglikelihood of sequence `k` !!! warning "Multithreading" This function is parallelized across sequences. @@ -155,10 +174,7 @@ function forward_backward(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs: if nb_seqs != length(obs_seqs) throw(ArgumentError("nb_seqs != length(obs_seqs)")) end - R = eltype(hmm, obs_seqs[1][1]) - fbs = Vector{ForwardBackwardStorage{R}}(undef, nb_seqs) - @threads for k in eachindex(obs_seqs, fbs) - fbs[k] = forward_backward(hmm, obs_seqs[k]) - end - return fbs + fbs = [initialize_forward_backward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] + forward_backward!(fbs, hmm, obs_seqs) + return [(fb.γ, fb.ξ, fb.logL[]) for fb in fbs] end diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index fc73388d..b4ca33bc 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -3,16 +3,22 @@ $(TYPEDEF) Store Viterbi quantities with element type `R`. +This storage is relative to a single sequence. + # Fields +These fields are not part of the public API. + $(TYPEDFIELDS) """ struct ViterbiStorage{R} + "vector of loglikelihood values for each state" logb::Vector{R} δₜ::Vector{R} δₜ₋₁::Vector{R} δₜ₋₁Aⱼ::Vector{R} ψ::Matrix{Int} + "vector of most likely state at each time" q::Vector{Int} end @@ -59,6 +65,15 @@ function viterbi!(v::ViterbiStorage, hmm::AbstractHMM, obs_seq::Vector) return nothing end +function viterbi!( + vs::Vector{<:ViterbiStorage}, hmm::AbstractHMM, obs_seqs::Vector{<:Vector} +) + @threads for k in eachindex(vs, obs_seqs) + viterbi!(vs[k], hmm, obs_seqs[k]) + end + return nothing +end + """ viterbi(hmm, obs_seq) @@ -86,9 +101,7 @@ function viterbi(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer) if nb_seqs != length(obs_seqs) throw(ArgumentError("nb_seqs != length(obs_seqs)")) end - qs = Vector{Vector{Int}}(undef, nb_seqs) - @threads for k in eachindex(qs, obs_seqs) - qs[k] = viterbi(hmm, obs_seqs[k]) - end - return qs + vs = [initialize_viterbi(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] + viterbi!(vs, hmm, obs_seqs) + return [v.q for v in vs] end diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 44bb41d1..31a9fb1f 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -49,16 +49,12 @@ initialization(hmm::HMM) = hmm.init transition_matrix(hmm::HMM) = hmm.trans obs_distributions(hmm::HMM) = hmm.dists -function StatsAPI.fit!( - hmm::HMM, - fbs::Vector{<:ForwardBackwardStorage}, - obs_seqs_concat::Vector, - state_marginals_concat::Matrix, -) +function StatsAPI.fit!(hmm::HMM, bw::BaumWelchStorage) + @unpack fbs, state_marginals_concat, obs_seqs_concat = bw # Initialization hmm.init .= zero(eltype(hmm.init)) for k in eachindex(fbs) - @views hmm.init .+= fbs[k].γ[:, 1] + hmm.init .+= view(fbs[k].γ, :, 1) end sum_to_one!(hmm.init) # Transition matrix @@ -68,9 +64,9 @@ function StatsAPI.fit!( end foreach(sum_to_one!, eachrow(hmm.trans)) # Observation distributions - @views for i in eachindex(hmm.dists) + for i in eachindex(hmm.dists) fit_element_from_sequence!( - hmm.dists, i, obs_seqs_concat, state_marginals_concat[i, :] + hmm.dists, i, obs_seqs_concat, view(state_marginals_concat, i, :) ) end check_hmm(hmm) diff --git a/src/utils/check.jl b/src/utils/check.jl index 2497ea56..2e0a0377 100644 --- a/src/utils/check.jl +++ b/src/utils/check.jl @@ -57,7 +57,7 @@ function check_hmm(hmm::AbstractHMM) p = initialization(hmm) A = transition_matrix(hmm) d = obs_distributions(hmm) - if !all(==(length(hmm)), [length(p), size(A, 1), size(A, 2), length(d)]) + if !all(==(length(hmm)), (length(p), size(A, 1), size(A, 2), length(d))) throw(DimensionMismatch("Incoherent sizes")) end check_prob_vec(p) diff --git a/test/allocations.jl b/test/allocations.jl index f39d6de9..63566aa5 100644 --- a/test/allocations.jl +++ b/test/allocations.jl @@ -1,30 +1,42 @@ using BenchmarkTools using Distributions using HiddenMarkovModels +import HiddenMarkovModels as HMMs using HiddenMarkovModels: LightDiagNormal using SimpleUnPack using Test function test_allocations(hmm; T) - @unpack state_seq, obs_seq = rand(hmm, T) + obs_seq = rand(hmm, T).obs_seq ## Forward - f = HiddenMarkovModels.initialize_forward(hmm, obs_seq) - HiddenMarkovModels.forward!(f, hmm, obs_seq) - allocs = @allocated HiddenMarkovModels.forward!(f, hmm, obs_seq) + f = HMMs.initialize_forward(hmm, obs_seq) + allocs = @ballocated HiddenMarkovModels.forward!($f, $hmm, $obs_seq) @test allocs == 0 ## Viterbi - v = HiddenMarkovModels.initialize_viterbi(hmm, obs_seq) - HiddenMarkovModels.viterbi!(v, hmm, obs_seq) - allocs = @allocated HiddenMarkovModels.viterbi!(v, hmm, obs_seq) + v = HMMs.initialize_viterbi(hmm, obs_seq) + allocs = @ballocated HMMs.viterbi!($v, $hmm, $obs_seq) @test allocs == 0 ## Forward-backward - fb = HiddenMarkovModels.initialize_forward_backward(hmm, obs_seq) - HiddenMarkovModels.forward_backward!(fb, hmm, obs_seq) - allocs = @allocated HiddenMarkovModels.forward_backward!(fb, hmm, obs_seq) + fb = HMMs.initialize_forward_backward(hmm, obs_seq) + allocs = @ballocated HMMs.forward_backward!($fb, $hmm, $obs_seq) @test allocs == 0 + + ## Baum-Welch + nb_seqs = 2 + obs_seqs = [obs_seq for _ in 1:nb_seqs] + bw = HMMs.initialize_baum_welch(hmm, obs_seqs, nb_seqs; max_iterations=2) + allocs = @ballocated HMMs.baum_welch!( + $hmm, + $bw, + $obs_seqs; + atol=-Inf, + max_iterations=2, + check_loglikelihood_increasing=false, + ) + @test_broken allocs == 0 # @threads introduces type instability, see https://discourse.julialang.org/t/type-instability-because-of-threads-boxing-variables/78395/ end N = 5 @@ -36,5 +48,6 @@ A = rand_trans_mat(N) dists = [LightDiagNormal(randn(2), ones(2)) for i in 1:N] hmm = HMM(p, A, dists) +obs_seq = rand(hmm, T).obs_seq test_allocations(hmm; T) diff --git a/test/autodiff.jl b/test/autodiff.jl index e7ba7f7d..cccaf133 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -15,7 +15,7 @@ A = rand_trans_mat(N) dists = [Normal(μ[i], 1.0) for i in 1:N] hmm = HMM(p, A, dists) -@unpack state_seq, obs_seq = rand(hmm, T); +obs_seq = rand(hmm, T).obs_seq; function f_init(_p) hmm = HMM(_p, A, dists) diff --git a/test/correctness.jl b/test/correctness.jl index 03185530..aad72e6a 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -6,35 +6,38 @@ using SimpleUnPack using Test function test_correctness(hmm, hmm_init; T) - @unpack state_seq, obs_seq = rand(hmm, T) + obs_seq = rand(hmm, T).obs_seq obs_mat = collect(reduce(hcat, obs_seq)') + nb_seqs = 2 + obs_seqs = [obs_seq for _ in 1:nb_seqs] + hmm_base = HMMBase.HMM(deepcopy(hmm)) hmm_init_base = HMMBase.HMM(deepcopy(hmm_init)) @testset "Logdensity" begin _, logL_base = HMMBase.forward(hmm_base, obs_mat) - logL = @inferred logdensityof(hmm, obs_seq) - @test logL ≈ logL_base + logL = @inferred logdensityof(hmm, obs_seqs, nb_seqs) + @test logL ≈ logL_base * nb_seqs end @testset "Forward" begin α_base, logL_base = HMMBase.forward(hmm_base, obs_mat) - α, logL = @inferred forward(hmm, obs_seq) + α, logL = @inferred first(forward(hmm, obs_seqs, nb_seqs)) @test isapprox(α, α_base[end, :]) @test logL ≈ logL_base end @testset "Viterbi" begin - best_state_seq_base = HMMBase.viterbi(hmm_base, obs_mat) - best_state_seq = @inferred viterbi(hmm, obs_seq) - @test isequal(best_state_seq, best_state_seq_base) + q_base = HMMBase.viterbi(hmm_base, obs_mat) + q = @inferred first(viterbi(hmm, obs_seqs, nb_seqs)) + @test isequal(q, q_base) end @testset "Forward-backward" begin γ_base = HMMBase.posteriors(hmm_base, obs_mat) - fb = @inferred forward_backward(hmm, obs_seq) - @test isapprox(fb.γ, γ_base') + γ, ξ = @inferred first(forward_backward(hmm, obs_seqs, nb_seqs)) + @test isapprox(γ, γ_base') end @testset "Baum-Welch" begin @@ -43,10 +46,10 @@ function test_correctness(hmm, hmm_init; T) ) logL_evolution_base = hist_base.logtots hmm_est, logL_evolution = @inferred baum_welch( - hmm_init, obs_seq; max_iterations=10, atol=-Inf + hmm_init, obs_seqs, nb_seqs; max_iterations=10, atol=-Inf ) @test isapprox( - logL_evolution[(begin + 1):end], logL_evolution_base[begin:(end - 1)] + logL_evolution[(begin + 1):end], logL_evolution_base[begin:(end - 1)] .* nb_seqs ) @test isapprox(initialization(hmm_est), hmm_est_base.a) @test isapprox(transition_matrix(hmm_est), hmm_est_base.A) diff --git a/test/dna.jl b/test/dna.jl index 8d04efe4..9fa5cff5 100644 --- a/test/dna.jl +++ b/test/dna.jl @@ -91,7 +91,9 @@ function HiddenMarkovModels.obs_distributions(hmm::DNACodingHMM) return [Dirac(get_nucleotide(s)) for s in 1:length(hmm)] end -function StatsAPI.fit!(dchmm::DNACodingHMM, fbs, obs_seqs_concat, state_marginals_concat) +function StatsAPI.fit!(dchmm::DNACodingHMM, bw::HiddenMarkovModels.BaumWelchStorage) + @unpack fbs, obs_seqs_concat, state_marginals_concat = bw + init_count = zeros(eltype(initialization(dchmm)), 8) for k in eachindex(fbs) @views init_count .+= fbs[k].γ[:, 1] diff --git a/test/logarithmic.jl b/test/logarithmic.jl index bfef42dc..6fa297b5 100644 --- a/test/logarithmic.jl +++ b/test/logarithmic.jl @@ -8,7 +8,7 @@ using Test N = 3 D = 2 -T = 100 +T = 1000 p = ones(N) / N A = rand_trans_mat(N) @@ -17,7 +17,7 @@ dists_init = [LightDiagNormal(randn(D), ones(D)) for i in 1:N]; dists_init_log = [LightDiagNormal(randn(D), LogFloat64.(ones(D))) for i in 1:N]; hmm = HMM(p, A, dists); -@unpack state_seq, obs_seq = rand(hmm, T); +obs_seq = rand(hmm, T).obs_seq; hmm_init = HMM(LogFloat64.(p), A, dists_init); hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq); diff --git a/test/sparse.jl b/test/sparse.jl index da6e14df..dec543fa 100644 --- a/test/sparse.jl +++ b/test/sparse.jl @@ -6,7 +6,7 @@ using SparseArrays using SimpleUnPack using Test -N = 4 +N = 3 T = 2000 p = ones(N) / N @@ -18,7 +18,7 @@ dists_init = [Normal(i + randn(), 1) for i in 1:N] hmm = HMM(p, A, dists) hmm_init = HMM(p, A, dists_init) -@unpack state_seq, obs_seq = rand(hmm, T) +obs_seq = rand(hmm, T).obs_seq hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq) @test typeof(hmm_est) == typeof(hmm_init) diff --git a/test/static.jl b/test/static.jl index 0a30d67c..23976aa2 100644 --- a/test/static.jl +++ b/test/static.jl @@ -16,7 +16,7 @@ dists_init = MVector{N}([Normal(randn(), 1.0) for i in 1:N]) hmm = HMM(p, A, dists) hmm_init = HMM(p, A, dists_init) -@unpack state_seq, obs_seq = rand(hmm, T) +obs_seq = rand(hmm, T).obs_seq hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq) @test typeof(hmm_est) == typeof(hmm_init) diff --git a/test/type_stability.jl b/test/type_stability.jl index ed730606..4aecc646 100644 --- a/test/type_stability.jl +++ b/test/type_stability.jl @@ -6,44 +6,57 @@ using JET using SimpleUnPack using Test -function test_type_stability(hmm, hmm_init; T, K) - obs_seqs = [rand(hmm, T).obs_seq for k in 1:K] +function test_type_stability(hmm, hmm_init; T) + obs_seq = rand(hmm, T).obs_seq + nb_seqs = 2 + obs_seqs = [obs_seq for _ in 1:nb_seqs] @testset "Logdensity" begin - @inferred logdensityof(hmm, obs_seqs, K) - @test_opt target_modules = (HiddenMarkovModels,) logdensityof(hmm, obs_seqs, K) - @test_call target_modules = (HiddenMarkovModels,) logdensityof(hmm, obs_seqs, K) + @inferred logdensityof(hmm, obs_seqs, nb_seqs) + @test_opt target_modules = (HiddenMarkovModels,) logdensityof( + hmm, obs_seqs, nb_seqs + ) + @test_call target_modules = (HiddenMarkovModels,) logdensityof( + hmm, obs_seqs, nb_seqs + ) end @testset "Forward" begin - @inferred forward(hmm, obs_seqs, K) - @test_opt target_modules = (HiddenMarkovModels,) forward(hmm, obs_seqs, K) - @test_call target_modules = (HiddenMarkovModels,) forward(hmm, obs_seqs, K) + @inferred forward(hmm, obs_seqs, nb_seqs) + @test_opt target_modules = (HiddenMarkovModels,) forward(hmm, obs_seqs, nb_seqs) + @test_call target_modules = (HiddenMarkovModels,) forward(hmm, obs_seqs, nb_seqs) end @testset "Viterbi" begin - @inferred viterbi(hmm, obs_seqs, K) - @test_opt target_modules = (HiddenMarkovModels,) viterbi(hmm, obs_seqs, K) - @test_call target_modules = (HiddenMarkovModels,) viterbi(hmm, obs_seqs, K) + @inferred viterbi(hmm, obs_seqs, nb_seqs) + @test_opt target_modules = (HiddenMarkovModels,) viterbi(hmm, obs_seqs, nb_seqs) + @test_call target_modules = (HiddenMarkovModels,) viterbi(hmm, obs_seqs, nb_seqs) end @testset "Forward-backward" begin - @inferred forward_backward(hmm, obs_seqs, K) - @test_opt target_modules = (HiddenMarkovModels,) forward_backward(hmm, obs_seqs, K) - @test_call target_modules = (HiddenMarkovModels,) forward_backward(hmm, obs_seqs, K) + @inferred forward_backward(hmm, obs_seqs, nb_seqs) + @test_opt target_modules = (HiddenMarkovModels,) forward_backward( + hmm, obs_seqs, nb_seqs + ) + @test_call target_modules = (HiddenMarkovModels,) forward_backward( + hmm, obs_seqs, nb_seqs + ) end @testset "Baum-Welch" begin - @inferred baum_welch(hmm_init, obs_seqs, K) - @test_opt target_modules = (HiddenMarkovModels,) baum_welch(hmm_init, obs_seqs, K) - @test_call target_modules = (HiddenMarkovModels,) baum_welch(hmm_init, obs_seqs, K) + @inferred baum_welch(hmm_init, obs_seqs, nb_seqs) + @test_opt target_modules = (HiddenMarkovModels,) baum_welch( + hmm_init, obs_seqs, nb_seqs + ) + @test_call target_modules = (HiddenMarkovModels,) baum_welch( + hmm_init, obs_seqs, nb_seqs + ) end end -N = 5 +N = 2 D = 3 T = 100 -K = 4 p = rand_prob_vec(N) p_init = rand_prob_vec(N) @@ -60,7 +73,7 @@ hmm_norm = HMM(p, A, dists_norm) hmm_norm_init = HMM(p_init, A_init, dists_norm_init) @testset "Normal" begin - test_type_stability(hmm_norm, hmm_norm_init; T, K) + test_type_stability(hmm_norm, hmm_norm_init; T) end # DiagNormal @@ -72,7 +85,7 @@ hmm_diagnorm = HMM(p, A, dists_diagnorm) hmm_diagnorm_init = HMM(p, A, dists_diagnorm_init) @testset "DiagNormal" begin - test_type_stability(hmm_diagnorm, hmm_diagnorm_init; T, K) + test_type_stability(hmm_diagnorm, hmm_diagnorm_init; T) end ## LightDiagNormal @@ -84,5 +97,5 @@ hmm_lightdiagnorm = HMM(p, A, dists_lightdiagnorm) hmm_lightdiagnorm_init = HMM(p, A, dists_lightdiagnorm_init) @testset "LightDiagNormal" begin - test_type_stability(hmm_lightdiagnorm, hmm_lightdiagnorm_init; T, K) + test_type_stability(hmm_lightdiagnorm, hmm_lightdiagnorm_init; T) end From 17710fd1d3ea0973d12932b543df30b51bb8389b Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 6 Nov 2023 14:42:43 +0100 Subject: [PATCH 07/14] Fix docs --- docs/make.jl | 14 ++-- docs/src/api.md | 11 ++- docs/src/benchmarks.md | 4 -- docs/src/builtin.md | 8 +-- docs/src/custom.md | 135 ------------------------------------ src/HiddenMarkovModels.jl | 5 +- src/inference/baum_welch.jl | 5 +- src/types/abstract_hmm.jl | 11 ++- src/types/hmm.jl | 8 --- src/utils/probvec.jl | 5 ++ src/utils/transmat.jl | 5 ++ 11 files changed, 38 insertions(+), 173 deletions(-) delete mode 100644 docs/src/custom.md diff --git a/docs/make.jl b/docs/make.jl index b9e0ea3a..959e5c1d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -30,12 +30,12 @@ benchmarks_done = ( pages = [ "Home" => "index.md", "Essentials" => ["Background" => "background.md", "API reference" => "api.md"], - "Tutorials" => [ - "Built-in HMM" => "builtin.md", - "Custom HMM" => "custom.md", - "Debugging" => "debugging.md", - ], - "Alternatives" => ["Features" => "features.md", "Benchmarks" => "benchmarks.md"], + "Tutorials" => ["Built-in HMM" => "builtin.md", "Debugging" => "debugging.md"], + "Alternatives" => if benchmarks_done + ["Features" => "features.md", "Benchmarks" => "benchmarks.md"] + else + ["Features" => "features.md"] + end, "Advanced" => ["Formulas" => "formulas.md", "Roadmap" => "roadmap.md"], ] @@ -53,7 +53,7 @@ makedocs(; format=fmt, pages=pages, plugins=[bib], - warnonly=!benchmarks_done, + pagesonly=true, ) deploydocs(; repo="github.com/gdalle/HiddenMarkovModels.jl", devbranch="main") diff --git a/docs/src/api.md b/docs/src/api.md index aa0fef2a..4675127b 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -10,7 +10,6 @@ HiddenMarkovModels AbstractHiddenMarkovModel HiddenMarkovModel AbstractHMM -PermutedHMM HMM ``` @@ -38,10 +37,17 @@ forward_backward ```@docs fit! -fit baum_welch ``` +## Misc + +```@docs +check_hmm +rand_prob_vec +rand_trans_mat +``` + ## Internals ```@docs @@ -51,6 +57,7 @@ HiddenMarkovModels.ForwardBackwardStorage HiddenMarkovModels.BaumWelchStorage HiddenMarkovModels.fit_element_from_sequence! HiddenMarkovModels.LightDiagNormal +HiddenMarkovModels.PermutedHMM ``` ## Notations diff --git a/docs/src/benchmarks.md b/docs/src/benchmarks.md index 88165e1b..bcf618e3 100644 --- a/docs/src/benchmarks.md +++ b/docs/src/benchmarks.md @@ -15,10 +15,6 @@ The test case is an HMM with diagonal multivariate normal observations. - `K`: number of trajectories - `I`: number of Baum-Welch iterations -!!! danger "Missing benchmarks?" - The benchmarks are computationally expensive and we only run them once for each new release. - If you don't see any plots below and the links are broken, you are probably on the dev documentation: go to the [stable documentation](https://gdalle.github.io/HiddenMarkovModels.jl/stable/) instead. - ## Reproducibility These benchmarks were generated in the following environment: [`setup.txt`](./assets/benchmark/results/setup.txt). diff --git a/docs/src/builtin.md b/docs/src/builtin.md index d8801142..4ccf748d 100644 --- a/docs/src/builtin.md +++ b/docs/src/builtin.md @@ -29,7 +29,7 @@ transition_matrix(hmm) ``` ```@example tuto -[obs_distribution(hmm, i) for i in 1:N] +obs_distributions(hmm) ``` Simulating a sequence: @@ -70,15 +70,15 @@ first(logL_evolution), last(logL_evolution) Correcting state order because we know observation means are increasing in the true model: ```@example tuto -[obs_distribution(hmm_est, i) for i in 1:N] +d_est = obs_distributions(hmm_est) ``` ```@example tuto -perm = sortperm(1:3, by=i->obs_distribution(hmm_est, i).μ) +perm = sortperm(1:3, by=i->d_est[i].μ) ``` ```@example tuto -hmm_est = PermutedHMM(hmm_est, perm) +hmm_est = HiddenMarkovModels.PermutedHMM(hmm_est, perm) ``` Evaluating errors: diff --git a/docs/src/custom.md b/docs/src/custom.md deleted file mode 100644 index a4a67b92..00000000 --- a/docs/src/custom.md +++ /dev/null @@ -1,135 +0,0 @@ -# Tutorial - custom HMM - -The built-in HMM is perfect when the initial state distribution, transition matrix and emission distributions can be updated independently with Maximum Likelihood Estimation. -But some of these parameters might be correlated, or fixed. -Or they might come with a prior, which forces you to use Maximum A Posteriori instead. - -In such cases, it is necessary to implement a new subtype of [`AbstractHMM`](@ref) with all its required methods. - -```@example tuto -using Distributions -using HiddenMarkovModels -using LinearAlgebra -using RequiredInterfaces -using StatsAPI - -using Random; Random.seed!(63) -``` - -## Interface - -To ascertain that a type indeed satisfies the interface, you can use [RequiredInterfaces.jl](https://github.com/Seelengrab/RequiredInterfaces.jl). - -```@example tuto -RequiredInterfaces.check_interface_implemented(AbstractHMM, HMM) -``` - -If your implementation is insufficient, the test will list missing methods. - -```@example tuto -struct EmptyHMM end -RequiredInterfaces.check_interface_implemented(AbstractHMM, EmptyHMM) -``` - -Note that this test does not check the `StatsAPI.fit!` method. -Since it is only used in the Baum-Welch algorithm, it is an optional part of the `AbstractHMM` interface. - -## Example - -We show how to implement an HMM whose initial distribution is always the equilibrium distribution of the underlying Markov chain. -The code that follows is not efficient (it leads to a lot of allocations), but it would be fairly easy to optimize if needed. - -The equilibrium distribution of a Markov chain is the (only) left eigenvector associated with the left eigenvalue $1$. - -```@example tuto -function markov_equilibrium(A) - p = real.(eigvecs(A')[:, end]) - return p ./ sum(p) -end -``` - -We now define our custom HMM by taking inspiration from `src/types/hmm.jl` and making a few modifications: - -```@example tuto -struct EquilibriumHMM{R,D} <: AbstractHMM - trans::Matrix{R} - dists::Vector{D} -end -``` - -The interface is only different as far as the initialization is concerned. - -```@example tuto -Base.length(hmm::EquilibriumHMM) = length(hmm.dists) -HiddenMarkovModels.initialization(hmm::EquilibriumHMM) = markov_equilibrium(hmm.trans) # this is new -HiddenMarkovModels.transition_matrix(hmm::EquilibriumHMM) = hmm.trans -HiddenMarkovModels.obs_distribution(hmm::EquilibriumHMM, i::Integer) = hmm.dists[i] -``` - -As for fitting, we simply ignore the initialization count and copy the rest of the original code (with a few simplifications): - -```@example tuto -function StatsAPI.fit!( - hmm::EquilibriumHMM{R,D}, init_count, trans_count, obs_seq, state_marginals -) where {R,D} - hmm.trans .= trans_count ./ sum(trans_count, dims=2) - for i in 1:N - hmm.dists[i] = fit(D, obs_seq, state_marginals[i, :]) - end -end -``` - -Let's take our new model for a spin: - -```@example tuto -function gaussian_equilibrium_hmm(N; noise=0) - A = rand_trans_mat(N) - dists = [Normal(i + noise * randn(), 0.5) for i in 1:N] - return EquilibriumHMM(A, dists) -end; -``` - -```@example tuto -N = 3 -hmm = gaussian_equilibrium_hmm(N); -transition_matrix(hmm) -``` - -```@example tuto -[obs_distribution(hmm, i) for i in 1:N] -``` - -We can estimate parameters based on several observation sequences. -Note that as soon as we tamper with the re-estimation procedure, the loglikelihood is no longer guaranteed to increase during Baum-Welch, which is why we turn off the corresponding check. - -```@example tuto -T = 1000 -nb_seqs = 10 -obs_seqs = [rand(hmm, T).obs_seq for _ in 1:nb_seqs] - -hmm_init = gaussian_equilibrium_hmm(N; noise=1) -hmm_est, logL_evolution = baum_welch( - hmm_init, obs_seqs, nb_seqs; check_loglikelihood_increasing=false -); -first(logL_evolution), last(logL_evolution) -``` - -Let's correct the state order: - -```@example tuto -[obs_distribution(hmm_est, i) for i in 1:N] -``` - -```@example tuto -perm = sortperm(1:3, by=i->obs_distribution(hmm_est, i).μ) -``` - -```@example tuto -hmm_est = PermutedHMM(hmm_est, perm) -``` - -And finally evaluate the errors: - -```@example tuto -cat(transition_matrix(hmm_est), transition_matrix(hmm), dims=3) -``` \ No newline at end of file diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index 7b540bf1..ee0af5ba 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -33,7 +33,8 @@ export HiddenMarkovModel, HMM export rand_prob_vec, rand_trans_mat export initialization, transition_matrix, obs_distributions export logdensityof, viterbi, forward, forward_backward, baum_welch -export fit, fit! +export fit! +export check_hmm include("types/abstract_hmm.jl") include("types/permuted_hmm.jl") @@ -62,7 +63,6 @@ if !isdefined(Base, :get_extension) end end -#= @compile_workload begin N, D, T = 5, 3, 100 p = rand_prob_vec(N) @@ -78,6 +78,5 @@ end forward_backward(hmm, obs_seqs, nb_seqs) baum_welch(hmm, obs_seqs, nb_seqs; max_iterations=2, atol=-Inf) end -=# end diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index eb94fd43..7dc0aeda 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -32,10 +32,7 @@ function initialize_baum_welch( end N, T = length(hmm), sum(length, obs_seqs) R = eltype(hmm, obs_seqs[1][1]) - fbs = Vector{ForwardBackwardStorage{R}}(undef, length(obs_seqs)) - @threads for k in eachindex(obs_seqs, fbs) - fbs[k] = initialize_forward_backward(hmm, obs_seqs[k]) - end + fbs = [initialize_forward_backward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] iteration = Ref(0) logL_evolution = Vector{R}(undef, max_iterations) state_marginals_concat = Matrix{R}(undef, N, T) diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index c6cbd077..a5f240bd 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -5,12 +5,11 @@ Abstract supertype for an HMM amenable to simulation, inference and learning. # Interface -- [`length`](@ref) -- [`eltype`](@ref) -- [`initialization`](@ref) -- [`transition_matrix`](@ref) -- [`obs_distributions`](@ref) -- [`fit!](@ref) +- `length` +- `initialization` +- `transition_matrix` +- `obs_distributions` +- `fit` # Applicable methods diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 31a9fb1f..a59d1fa5 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -37,14 +37,6 @@ function Base.copy(hmm::HMM) end Base.length(hmm::HMM) = length(hmm.init) - -function Base.eltype(hmm::HMM, obs) - init_type = eltype(hmm.init) - trans_type = eltype(hmm.trans) - logdensity_type = typeof(logdensityof(hmm.dists[1], obs)) - return promote_type(init_type, trans_type, logdensity_type) -end - initialization(hmm::HMM) = hmm.init transition_matrix(hmm::HMM) = hmm.trans obs_distributions(hmm::HMM) = hmm.dists diff --git a/src/utils/probvec.jl b/src/utils/probvec.jl index 50c6d368..7456558f 100644 --- a/src/utils/probvec.jl +++ b/src/utils/probvec.jl @@ -4,6 +4,11 @@ end sum_to_one!(x) = x ./= sum(x) +""" + rand_prob_vec([rng,] N) + +Generate a random probability distribution of size `N`. +""" function rand_prob_vec(rng::AbstractRNG, N) p = rand(rng, N) sum_to_one!(p) diff --git a/src/utils/transmat.jl b/src/utils/transmat.jl index 71cb86d5..1e40831f 100644 --- a/src/utils/transmat.jl +++ b/src/utils/transmat.jl @@ -15,6 +15,11 @@ function is_trans_mat(A::AbstractMatrix; atol=1e-2) end end +""" + rand_trans_mat([rng,] N) + +Generate a random transition matrix of size `(N, N)`. +""" function rand_trans_mat(rng::AbstractRNG, N) A = rand(rng, N, N) foreach(sum_to_one!, eachrow(A)) From 51fca9c1f751b079a95e34150b9fe275d79815a1 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 6 Nov 2023 18:22:18 +0100 Subject: [PATCH 08/14] Separate single sequence case --- docs/src/api.md | 10 +- src/HiddenMarkovModels.jl | 41 ++++--- src/inference/baum_welch.jl | 188 +++++++++++++----------------- src/inference/forward.jl | 37 ++---- src/inference/forward_backward.jl | 62 +++++----- src/inference/viterbi.jl | 20 +--- src/types/abstract_hmm.jl | 24 ++-- src/types/hmm.jl | 13 +-- test/allocations.jl | 25 ++-- 9 files changed, 191 insertions(+), 229 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 4675127b..6c2d8aab 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -48,13 +48,17 @@ rand_prob_vec rand_trans_mat ``` -## Internals +## Storage ```@docs +HiddenMarkovModels.ForwardBackwardStorage HiddenMarkovModels.ForwardStorage HiddenMarkovModels.ViterbiStorage -HiddenMarkovModels.ForwardBackwardStorage -HiddenMarkovModels.BaumWelchStorage +``` + +## Internals + +```@docs HiddenMarkovModels.fit_element_from_sequence! HiddenMarkovModels.LightDiagNormal HiddenMarkovModels.PermutedHMM diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index ee0af5ba..7d6b7ee2 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -63,20 +63,31 @@ if !isdefined(Base, :get_extension) end end -@compile_workload begin - N, D, T = 5, 3, 100 - p = rand_prob_vec(N) - A = rand_trans_mat(N) - dists = [LightDiagNormal(randn(D), ones(D)) for i in 1:N] - hmm = HMM(p, A, dists) - - obs_seqs = [last(rand(hmm, T)) for _ in 1:3] - nb_seqs = 3 - logdensityof(hmm, obs_seqs, nb_seqs) - forward(hmm, obs_seqs, nb_seqs) - viterbi(hmm, obs_seqs, nb_seqs) - forward_backward(hmm, obs_seqs, nb_seqs) - baum_welch(hmm, obs_seqs, nb_seqs; max_iterations=2, atol=-Inf) -end +# @compile_workload begin +# N, D, T = 5, 3, 100 +# p = rand_prob_vec(N) +# A = rand_trans_mat(N) +# dists = [LightDiagNormal(randn(D), ones(D)) for i in 1:N] +# hmm = HMM(p, A, dists) + +# obs_seq = rand(hmm, T).obs_seq +# obs_seqs = [rand(hmm, T).obs_seq for _ in 1:3] +# nb_seqs = 3 + +# logdensityof(hmm, obs_seq) +# logdensityof(hmm, obs_seqs, nb_seqs) + +# forward(hmm, obs_seq) +# forward(hmm, obs_seqs, nb_seqs) + +# viterbi(hmm, obs_seq) +# viterbi(hmm, obs_seqs, nb_seqs) + +# forward_backward(hmm, obs_seq) +# forward_backward(hmm, obs_seqs, nb_seqs) + +# baum_welch(hmm, obs_seq; max_iterations=2, atol=-Inf) +# baum_welch(hmm, obs_seqs, nb_seqs; max_iterations=2, atol=-Inf) +# end end diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 7dc0aeda..d0cd2131 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -1,147 +1,123 @@ -""" -$(TYPEDEF) - -Store Baum-Welch quantities with element type `R` and observation type `O`. - -Unlike the other storage types, this one is relative to multiple sequences. - -# Fields - -$(TYPEDFIELDS) -""" -struct BaumWelchStorage{R,O} - "one `ForwardBackwardStorage` for each observation sequence" - fbs::Vector{ForwardBackwardStorage{R}} - "number of iterations performed" - iteration::RefValue{Int} - "history of total loglikelihood values throughout the algorithm" - logL_evolution::Vector{R} - "concatenation of `γ` matrices for all observation sequences (useful to avoid allocations in fitting)" - state_marginals_concat::Matrix{R} - "concatenation of observation sequences (useful to avoid allocations in fitting)" - obs_seqs_concat::Vector{O} - "temporal limits of each observation sequence in the concatenations" - limits::Vector{Int} +function baum_welch_has_converged( + logL_evolution::Vector; atol::Real, loglikelihood_increasing::Bool +) + if length(logL_evolution) >= 2 + logL, logL_prev = logL_evolution[end], logL_evolution[end - 1] + progress = logL - logL_prev + if loglikelihood_increasing && progress < 0 + error("Loglikelihood decreased in Baum-Welch") + elseif progress < atol + return true + end + end + return false end -function initialize_baum_welch( - hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer; max_iterations::Integer +function baum_welch!( + fb::ForwardBackwardStorage, + logL_evolution::Vector, + hmm::AbstractHMM, + obs_seq::Vector; + atol::Real, + max_iterations::Integer, + loglikelihood_increasing::Bool, ) - if nb_seqs != length(obs_seqs) - throw(ArgumentError("nb_seqs != length(obs_seqs)")) + for _ in 1:max_iterations + forward_backward!(fb, hmm, obs_seq) + push!(logL_evolution, fb.logL[]) + fit!(hmm, (fb,), (obs_seq,), fb, obs_seq) + if baum_welch_has_converged(logL_evolution; atol, loglikelihood_increasing) + break + end end - N, T = length(hmm), sum(length, obs_seqs) - R = eltype(hmm, obs_seqs[1][1]) - fbs = [initialize_forward_backward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] - iteration = Ref(0) - logL_evolution = Vector{R}(undef, max_iterations) - state_marginals_concat = Matrix{R}(undef, N, T) - obs_seqs_concat = reduce(vcat, obs_seqs) - limits = vcat(0, cumsum(length.(obs_seqs))) - return BaumWelchStorage( - fbs, iteration, logL_evolution, state_marginals_concat, obs_seqs_concat, limits - ) + return nothing end function baum_welch!( + fbs::Vector{<:ForwardBackwardStorage}, + fb_concat::ForwardBackwardStorage, + logL_evolution::Vector, hmm::AbstractHMM, - bw::BaumWelchStorage, - obs_seqs::Vector{<:Vector}; + obs_seqs::Vector{<:Vector}, + obs_seqs_concat::Vector; atol::Real, max_iterations::Integer, - check_loglikelihood_increasing::Bool, + loglikelihood_increasing::Bool, ) - @unpack ( - fbs, iteration, logL_evolution, state_marginals_concat, obs_seqs_concat, limits - ) = bw - iteration[] = 0 - - while iteration[] < max_iterations - # E step + for _ in 1:max_iterations @threads for k in eachindex(obs_seqs, fbs) forward_backward!(fbs[k], hmm, obs_seqs[k]) - @views state_marginals_concat[:, (limits[k] + 1):limits[k + 1]] .= fbs[k].γ end - - # M step - fit!(hmm, bw) - - # # Record likelihood - iteration[] += 1 - logL_evolution[iteration[]] = sum(fb.logL[] for fb in fbs) - - # # Stopping criterion - if iteration[] > 1 - progress = logL_evolution[iteration[]] - logL_evolution[iteration[] - 1] - if check_loglikelihood_increasing && progress < 0 - error("Loglikelihood decreased in Baum-Welch") - elseif progress < atol - break - end + push!(logL_evolution, sum(fb.logL[] for fb in fbs)) + fit!(hmm, fbs, obs_seqs, fb_concat, obs_seqs_concat) + if baum_welch_has_converged(logL_evolution; atol, loglikelihood_increasing) + break end end + return nothing end """ - baum_welch( - hmm_init, obs_seqs, nb_seqs; - atol, max_iterations, check_loglikelihood_increasing - ) + baum_welch(hmm_init, obs_seq; kwargs...) + baum_welch(hmm_init, obs_seqs, nb_seqs; kwargs...) -Apply the Baum-Welch algorithm to estimate the parameters of an HMM starting from `hmm_init`, based on `nb_seqs` observation sequences. +Apply the Baum-Welch algorithm to estimate the parameters of an HMM starting from `hmm_init`, based on one or several observation sequences. Return a tuple `(hmm_est, logL_evolution)`. -!!! warning "Multithreading" - This function is parallelized across sequences. - # Keyword arguments -- `atol`: Minimum loglikelihood increase at an iteration of the algorithm (otherwise the algorithm is deemed to have converged) -- `max_iterations`: Maximum number of iterations of the algorithm -- `check_loglikelihood_increasing`: Whether to throw an error if the loglikelihood decreases +- `atol`: minimum loglikelihood increase at an iteration of the algorithm (otherwise the algorithm is deemed to have converged) +- `max_iterations`: maximum number of iterations of the algorithm +- `loglikelihood_increasing`: whether to throw an error if the loglikelihood decreases """ function baum_welch( hmm_init::AbstractHMM, - obs_seqs::Vector{<:Vector}, - nb_seqs::Integer; + obs_seq::Vector; atol=1e-5, max_iterations=100, - check_loglikelihood_increasing=true, + loglikelihood_increasing=true, ) - if nb_seqs != length(obs_seqs) - throw(ArgumentError("nb_seqs != length(obs_seqs)")) - end hmm = deepcopy(hmm_init) - bw = initialize_baum_welch(hmm, obs_seqs, nb_seqs; max_iterations) - baum_welch!(hmm, bw, obs_seqs; atol, max_iterations, check_loglikelihood_increasing) - return hmm, bw.logL_evolution[1:bw.iteration[]] -end - -""" - baum_welch( - hmm_init, obs_seq; - atol, max_iterations, check_loglikelihood_increasing + fb = initialize_forward_backward(hmm, obs_seq) + R = eltype(hmm, obs_seq[1]) + logL_evolution = R[] + sizehint!(logL_evolution, max_iterations) + baum_welch!( + fb, logL_evolution, hmm, obs_seq; atol, max_iterations, loglikelihood_increasing ) + return hmm, logL_evolution +end -Apply the Baum-Welch algorithm to estimate the parameters of an HMM starting from `hmm_init`. - -Return a tuple `(hmm_est, logL_evolution)`. - -# Keyword arguments - -- `atol`: Minimum loglikelihood increase at an iteration of the algorithm (otherwise the algorithm is deemed to have converged) -- `max_iterations`: Maximum number of iterations of the algorithm -- `check_loglikelihood_increasing`: Whether to throw an error if the loglikelihood decreases -""" function baum_welch( hmm_init::AbstractHMM, - obs_seq::Vector; + obs_seqs::Vector{<:Vector}, + nb_seqs::Integer; atol=1e-5, max_iterations=100, - check_loglikelihood_increasing=true, + loglikelihood_increasing=true, ) - return baum_welch( - hmm_init, [obs_seq], 1; atol, max_iterations, check_loglikelihood_increasing + if nb_seqs != length(obs_seqs) + throw(ArgumentError("nb_seqs != length(obs_seqs)")) + end + hmm = deepcopy(hmm_init) + limits = vcat(0, cumsum(length.(obs_seqs))) + obs_seqs_concat = reduce(vcat, obs_seqs) + fb_concat = initialize_forward_backward(hmm, obs_seqs_concat) + fbs = [view(fb_concat, (limits[k] + 1):limits[k + 1]) for k in eachindex(obs_seqs)] + R = eltype(hmm, obs_seqs[1][1]) + logL_evolution = R[] + sizehint!(logL_evolution, max_iterations) + baum_welch!( + fbs, + fb_concat, + logL_evolution, + hmm, + obs_seqs, + obs_seqs_concat; + atol, + max_iterations, + loglikelihood_increasing, ) + return hmm, logL_evolution end diff --git a/src/inference/forward.jl b/src/inference/forward.jl index c5017bb5..eff7e06b 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -7,8 +7,6 @@ This storage is relative to a single sequence. # Fields -These fields are not part of the public API. - $(TYPEDFIELDS) """ struct ForwardStorage{R} @@ -71,13 +69,16 @@ end """ forward(hmm, obs_seq) + forward(hmm, obs_seqs, nb_seqs) -Apply the forward algorithm to an HMM. +Run the forward algorithm to infer the current state of an HMM. -Return a tuple `(α, logL)` where +When applied on a single sequence, this function returns a tuple `(α, logL)` where - `α[i]` is the posterior probability of state `i` at the end of the sequence - `logL` is the loglikelihood of the sequence + +When applied on multiple sequences, this function returns a vector of tuples. """ function forward(hmm::AbstractHMM, obs_seq::Vector) f = initialize_forward(hmm, obs_seq) @@ -85,19 +86,6 @@ function forward(hmm::AbstractHMM, obs_seq::Vector) return f.αₜ, f.logL[] end -""" - forward(hmm, obs_seqs, nb_seqs) - -Apply the forward algorithm to an HMM, based on multiple observation sequences. - -Return a vector of tuples `(αₖ, logLₖ)`, where - -- `αₖ[i]` is the posterior probability of state `i` at the end of sequence `k` -- `logLₖ` is the loglikelihood of sequence `k` - -!!! warning "Multithreading" - This function is parallelized across sequences. -""" function forward(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer) if nb_seqs != length(obs_seqs) throw(ArgumentError("nb_seqs != length(obs_seqs)")) @@ -109,26 +97,17 @@ end """ logdensityof(hmm, obs_seq) + logdensityof(hmm, obs_seqs, nb_seqs) -Apply the forward algorithm to compute the loglikelihood of a single observation sequence for an HMM. +Run the forward algorithm to compute the posterior loglikelihood of observations for an HMM. -Return a number. +Whether it is applied on one or multiple sequences, this function returns a number. """ function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seq::Vector) _, logL = forward(hmm, obs_seq) return logL end -""" - logdensityof(hmm, obs_seqs, nb_seqs) - -Apply the forward algorithm to compute the total loglikelihood of multiple observation sequences for an HMM. - -Return a number. - -!!! warning "Multithreading" - This function is parallelized across sequences. -""" function DensityInterface.logdensityof( hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer ) diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index 371a0ad7..5066828f 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -7,39 +7,53 @@ This storage is relative to a single sequence. # Fields -These fields are not part of the public API. - Let `X` denote the vector of hidden states and `Y` denote the vector of observations. $(TYPEDFIELDS) """ -struct ForwardBackwardStorage{R} +struct ForwardBackwardStorage{ + R,V<:AbstractVector{R},M<:AbstractMatrix{R},A3<:AbstractArray{R,3} +} "total loglikelihood" logL::RefValue{R} "scaled forward variables `α[i,t]` proportional to `ℙ(Y[1:t], X[t]=i)` (up to a function of `t`)" - α::Matrix{R} + α::M "scaled backward variables `β[i,t]` proportional to `ℙ(Y[t+1:T] | X[t]=i)` (up to a function of `t`)" - β::Matrix{R} + β::M "posterior state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`" - γ::Matrix{R} + γ::M "posterior transition marginals `ξ[i,j,t] = ℙ(X[t:t+1]=(i,j) | Y[1:T])`" - ξ::Array{R,3} + ξ::A3 "forward variable inverse normalizations `c[t] = 1 / sum(α[:,t])`" - c::Vector{R} + c::V "observation loglikelihoods `logB[i, t]`" - logB::Matrix{R} + logB::M "maximum of the observation loglikelihoods `logm[t] = maximum(logB[:, t])`" - logm::Vector{R} + logm::V "numerically stabilized observation likelihoods `B̃[i,t] = exp.(logB[i,t] - logm[t])`" - B̃::Matrix{R} + B̃::M "numerically stabilized product `B̃β[i,t] = B̃[i,t] * β[i,t]`" - B̃β::Matrix{R} + B̃β::M end Base.eltype(::ForwardBackwardStorage{R}) where {R} = R Base.length(fb::ForwardBackwardStorage) = size(fb.α, 1) duration(fb::ForwardBackwardStorage) = size(fb.α, 2) +function Base.view(fb::ForwardBackwardStorage{R}, r::AbstractUnitRange) where {R} + logL = Ref(zero(R)) + α = view(fb.α, :, r) + β = view(fb.β, :, r) + γ = view(fb.γ, :, r) + ξ = view(fb.ξ, :, :, r) + c = view(fb.c, r) + logB = view(fb.logB, :, r) + logm = view(fb.logm, r) + B̃ = view(fb.B̃, :, r) + B̃β = view(fb.B̃β, :, r) + return ForwardBackwardStorage(logL, α, β, γ, ξ, c, logB, logm, B̃, B̃β) +end + function initialize_forward_backward(hmm::AbstractHMM, obs_seq::Vector) N, T = length(hmm), length(obs_seq) R = eltype(hmm, obs_seq[1]) @@ -48,7 +62,7 @@ function initialize_forward_backward(hmm::AbstractHMM, obs_seq::Vector) α = Matrix{R}(undef, N, T) β = Matrix{R}(undef, N, T) γ = Matrix{R}(undef, N, T) - ξ = Array{R,3}(undef, N, N, T - 1) + ξ = Array{R,3}(undef, N, N, T) c = Vector{R}(undef, T) logB = Matrix{R}(undef, N, T) logm = Vector{R}(undef, T) @@ -118,6 +132,7 @@ function marginals!(fb::ForwardBackwardStorage, hmm::AbstractHMM) @views for t in 1:(T - 1) ξ[:, :, t] .= α[:, t] .* A .* B̃β[:, t + 1]' end + ξ[:, :, T] .= zero(eltype(ξ)) check_no_nan(ξ) return nothing end @@ -141,14 +156,17 @@ end """ forward_backward(hmm, obs_seq) + forward_backward(hmm, obs_seqs, nb_seqs) -Apply the forward-backward algorithm to estimate the posterior state marginals of an HMM. +Run the forward-backward algorithm to infer the posterior state and transition marginals of an HMM. -Return a tuple `(γ, ξ, logL)` where +When applied on a single sequence, this function returns a tuple `(γ, ξ, logL)` where - `γ` is a matrix containing the posterior state marginals `γ[i, t]` - `ξ` is a 3-tensor containing the posterior transition marginals `ξ[i, j, t]` - `logL` is the loglikelihood of the sequence + +WHen applied on multiple sequences, it returns a vector of tuples. """ function forward_backward(hmm::AbstractHMM, obs_seq::Vector) fb = initialize_forward_backward(hmm, obs_seq) @@ -156,20 +174,6 @@ function forward_backward(hmm::AbstractHMM, obs_seq::Vector) return (fb.γ, fb.ξ, fb.logL[]) end -""" - forward_backward(hmm, obs_seqs, nb_seqs) - -Apply the forward-backward algorithm to estimate the posterior state marginals of an HMM, based on multiple observation sequences. - -Return a vector of tuples `(γₖ, ξₖ, logLₖ)` where - -- `γₖ` is a matrix containing the posterior state marginals `γₖ[i, t]` for sequence `k` -- `ξₖ` is a 3-tensor containing the posterior transition marginals `ξ[i, j, t]` for sequence `k` -- `logLₖ` is the loglikelihood of sequence `k` - -!!! warning "Multithreading" - This function is parallelized across sequences. -""" function forward_backward(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer) if nb_seqs != length(obs_seqs) throw(ArgumentError("nb_seqs != length(obs_seqs)")) diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index b4ca33bc..ec4fde91 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -7,12 +7,10 @@ This storage is relative to a single sequence. # Fields -These fields are not part of the public API. - $(TYPEDFIELDS) """ struct ViterbiStorage{R} - "vector of loglikelihood values for each state" + "vector of observation loglikelihoods `logb[i]`" logb::Vector{R} δₜ::Vector{R} δₜ₋₁::Vector{R} @@ -76,10 +74,12 @@ end """ viterbi(hmm, obs_seq) + viterbi(hmm, obs_seqs, nb_seqs) -Apply the Viterbi algorithm to compute the most likely state sequence of an HMM. +Apply the Viterbi algorithm to infer the most likely state sequence of an HMM. -Return a vector of integers. +When applied on a single sequence, this function returns a vector of integers. +When applied on multiple sequences, it returns a vector of vectors of integers. """ function viterbi(hmm::AbstractHMM, obs_seq::Vector) v = initialize_viterbi(hmm, obs_seq) @@ -87,16 +87,6 @@ function viterbi(hmm::AbstractHMM, obs_seq::Vector) return v.q end -""" - viterbi(hmm, obs_seqs, nb_seqs) - -Apply the Viterbi algorithm to compute the most likely state sequences of an HMM, based on multiple observation sequences. - -Return a vector of vectors of integers. - -!!! warning "Multithreading" - This function is parallelized across sequences. -""" function viterbi(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer) if nb_seqs != length(obs_seqs) throw(ArgumentError("nb_seqs != length(obs_seqs)")) diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index a5f240bd..133ecb90 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -5,20 +5,20 @@ Abstract supertype for an HMM amenable to simulation, inference and learning. # Interface -- `length` -- `initialization` -- `transition_matrix` -- `obs_distributions` -- `fit` +- `length(hmm)` +- `initialization(hmm)` +- `transition_matrix(hmm)` +- `obs_distributions(hmm)` +- `fit!(hmm, fbs, obs_seqs, fb_concat, obs_seqs_concat)` # Applicable methods - `rand([rng,] hmm, T)` -- `logdensityof(hmm, obs_seq)` / `logdensityof(hmm, obs_seqs, nb_seqs)` -- `forward(hmm, obs_seq)` / `forward(hmm, obs_seqs, nb_seqs)` -- `viterbi(hmm, obs_seq)` / `viterbi(hmm, obs_seqs, nb_seqs)` -- `forward_backward(hmm, obs_seq)` / `forward_backward(hmm, obs_seqs, nb_seqs)` -- `baum_welch(hmm, obs_seq)` / `baum_welch(hmm, obs_seqs, nb_seqs)` if `fit!` is implemented +- `logdensityof(hmm, obs_seq)` +- `forward(hmm, obs_seq)` +- `viterbi(hmm, obs_seq)` +- `forward_backward(hmm, obs_seq)` +- `baum_welch(hmm, obs_seq)` if `fit!` is implemented """ abstract type AbstractHiddenMarkovModel end @@ -79,9 +79,9 @@ Each element `dist` of this vector must implement function obs_distributions end """ - fit!(hmm, bw::BaumWelchStorage, obs_seqs) + fit!(hmm, obs_seqs, fbs) -Update `hmm` in-place based on information generated during forward-backward. +Update `hmm` in-place based on several observation sequences `obs_seqs` as well as information `fbs` generated during forward-backward. """ StatsAPI.fit! diff --git a/src/types/hmm.jl b/src/types/hmm.jl index a59d1fa5..a2c05f6d 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -41,8 +41,9 @@ initialization(hmm::HMM) = hmm.init transition_matrix(hmm::HMM) = hmm.trans obs_distributions(hmm::HMM) = hmm.dists -function StatsAPI.fit!(hmm::HMM, bw::BaumWelchStorage) - @unpack fbs, state_marginals_concat, obs_seqs_concat = bw +function StatsAPI.fit!( + hmm::HMM, fbs, obs_seqs, fb_concat::ForwardBackwardStorage, obs_seqs_concat::Vector +) # Initialization hmm.init .= zero(eltype(hmm.init)) for k in eachindex(fbs) @@ -51,15 +52,11 @@ function StatsAPI.fit!(hmm::HMM, bw::BaumWelchStorage) sum_to_one!(hmm.init) # Transition matrix hmm.trans .= zero(eltype(hmm.trans)) - for k in eachindex(fbs) - sum!(hmm.trans, fbs[k].ξ; init=false) - end + sum!(hmm.trans, fb_concat.ξ; init=false) foreach(sum_to_one!, eachrow(hmm.trans)) # Observation distributions for i in eachindex(hmm.dists) - fit_element_from_sequence!( - hmm.dists, i, obs_seqs_concat, view(state_marginals_concat, i, :) - ) + fit_element_from_sequence!(hmm.dists, i, obs_seqs_concat, view(fb_concat.γ, i, :)) end check_hmm(hmm) return nothing diff --git a/test/allocations.jl b/test/allocations.jl index 63566aa5..d6fa829c 100644 --- a/test/allocations.jl +++ b/test/allocations.jl @@ -11,32 +11,34 @@ function test_allocations(hmm; T) ## Forward f = HMMs.initialize_forward(hmm, obs_seq) - allocs = @ballocated HiddenMarkovModels.forward!($f, $hmm, $obs_seq) + allocs = @ballocated HiddenMarkovModels.forward!($f, $hmm, $obs_seq) samples = 2 @test allocs == 0 ## Viterbi v = HMMs.initialize_viterbi(hmm, obs_seq) - allocs = @ballocated HMMs.viterbi!($v, $hmm, $obs_seq) + allocs = @ballocated HMMs.viterbi!($v, $hmm, $obs_seq) samples = 2 @test allocs == 0 ## Forward-backward fb = HMMs.initialize_forward_backward(hmm, obs_seq) - allocs = @ballocated HMMs.forward_backward!($fb, $hmm, $obs_seq) + allocs = @ballocated HMMs.forward_backward!($fb, $hmm, $obs_seq) samples = 2 @test allocs == 0 ## Baum-Welch - nb_seqs = 2 - obs_seqs = [obs_seq for _ in 1:nb_seqs] - bw = HMMs.initialize_baum_welch(hmm, obs_seqs, nb_seqs; max_iterations=2) + fb = HMMs.initialize_forward_backward(hmm, obs_seq) + R = eltype(hmm, obs_seq[1]) + logL_evolution = R[] + sizehint!(logL_evolution, 2) allocs = @ballocated HMMs.baum_welch!( + $fb, + $logL_evolution, $hmm, - $bw, - $obs_seqs; + $obs_seq; atol=-Inf, max_iterations=2, - check_loglikelihood_increasing=false, - ) - @test_broken allocs == 0 # @threads introduces type instability, see https://discourse.julialang.org/t/type-instability-because-of-threads-boxing-variables/78395/ + loglikelihood_increasing=false, + ) samples = 2 + @test allocs == 0 end N = 5 @@ -48,6 +50,5 @@ A = rand_trans_mat(N) dists = [LightDiagNormal(randn(2), ones(2)) for i in 1:N] hmm = HMM(p, A, dists) -obs_seq = rand(hmm, T).obs_seq test_allocations(hmm; T) From 2c582ceae188e8878abdccb7f6819cf469f17424 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 9 Nov 2023 16:18:05 +0100 Subject: [PATCH 09/14] Finish revamp except custom tuto --- Project.toml | 1 + docs/make.jl | 6 +- docs/src/api.md | 26 ++---- docs/src/builtin.md | 4 +- docs/src/custom.md | 13 +++ src/HMMTest.jl | 49 +++++++++++ src/HiddenMarkovModels.jl | 4 + src/inference/baum_welch.jl | 141 +++++++++++++++++++----------- src/inference/forward.jl | 69 ++++++++------- src/inference/forward_backward.jl | 101 +++++++++------------ src/inference/viterbi.jl | 62 +++++++------ src/types/abstract_hmm.jl | 52 +++++++---- src/types/hmm.jl | 30 +++---- src/utils/check.jl | 6 ++ src/utils/fit.jl | 8 +- src/utils/lightdiagnormal.jl | 4 + src/utils/mul.jl | 20 +++++ src/utils/probvec.jl | 3 +- src/utils/transmat.jl | 3 +- test/allocations.jl | 53 ++++++----- test/arrays.jl | 51 +++++++++++ test/autodiff.jl | 16 ++-- test/correctness.jl | 89 +++++++++---------- test/dna.jl | 21 ++--- test/logarithmic.jl | 32 ------- test/misc.jl | 29 ++++++ test/numbers.jl | 34 +++++++ test/permuted.jl | 25 ------ test/runtests.jl | 24 +++-- test/sparse.jl | 25 ------ test/static.jl | 22 ----- test/type_stability.jl | 99 +++++++-------------- 32 files changed, 614 insertions(+), 508 deletions(-) create mode 100644 docs/src/custom.md create mode 100644 src/HMMTest.jl create mode 100644 src/utils/mul.jl create mode 100644 test/arrays.jl delete mode 100644 test/logarithmic.jl create mode 100644 test/misc.jl create mode 100644 test/numbers.jl delete mode 100644 test/permuted.jl delete mode 100644 test/sparse.jl delete mode 100644 test/static.jl diff --git a/Project.toml b/Project.toml index 4c752a6e..690b4dc2 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" [weakdeps] diff --git a/docs/make.jl b/docs/make.jl index 959e5c1d..85085118 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -30,7 +30,11 @@ benchmarks_done = ( pages = [ "Home" => "index.md", "Essentials" => ["Background" => "background.md", "API reference" => "api.md"], - "Tutorials" => ["Built-in HMM" => "builtin.md", "Debugging" => "debugging.md"], + "Tutorials" => [ + "Built-in HMM" => "builtin.md", + "Custom HMM" => "custom.md", + "Debugging" => "debugging.md", + ], "Alternatives" => if benchmarks_done ["Features" => "features.md", "Benchmarks" => "benchmarks.md"] else diff --git a/docs/src/api.md b/docs/src/api.md index 6c2d8aab..c027e18c 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -31,13 +31,8 @@ logdensityof forward viterbi forward_backward -``` - -## Learning - -```@docs -fit! baum_welch +fit! ``` ## Misc @@ -48,17 +43,13 @@ rand_prob_vec rand_trans_mat ``` -## Storage +## Internals ```@docs -HiddenMarkovModels.ForwardBackwardStorage HiddenMarkovModels.ForwardStorage HiddenMarkovModels.ViterbiStorage -``` - -## Internals - -```@docs +HiddenMarkovModels.ForwardBackwardStorage +HiddenMarkovModels.BaumWelchStorage HiddenMarkovModels.fit_element_from_sequence! HiddenMarkovModels.LightDiagNormal HiddenMarkovModels.PermutedHMM @@ -77,10 +68,11 @@ HiddenMarkovModels.PermutedHMM - `p` or `init`: initialization (vector of state probabilities) - `A` or `trans`: transition_matrix (matrix of transition probabilities) -- `dists`: observation distribution (vector of `rand`-able and `logdensityof`-able objects) +- `d` or `dists`: observation distribution (vector of `rand`-able and `logdensityof`-able objects) - `state_seq`: a sequence of states (vector of integers) - `obs_seq`: a sequence of observations (vector of individual observations) - `obs_seqs`: several sequences of observations +- `nb_seqs`: number of observation sequences ### Forward backward @@ -88,9 +80,9 @@ HiddenMarkovModels.PermutedHMM - `(log)B`: matrix of observation (log)likelihoods by state for a sequence of observations - `α`: scaled forward variables - `β`: scaled backward variables -- `γ`: one-state marginals -- `ξ`: two-state marginals -- `logL`: loglikelihood of a sequence of observations +- `γ`: state marginals +- `ξ`: transition marginals +- `logL`: posterior loglikelihood of a sequence of observations ## Index diff --git a/docs/src/builtin.md b/docs/src/builtin.md index 4ccf748d..afa429f1 100644 --- a/docs/src/builtin.md +++ b/docs/src/builtin.md @@ -15,8 +15,8 @@ Creating a model: function gaussian_hmm(N; noise=0) p = ones(N) / N # initial distribution A = rand_trans_mat(N) # transition matrix - dists = [Normal(i + noise * randn(), 0.5) for i in 1:N] # observation distributions - return HMM(p, A, dists) + d = [Normal(i + noise * randn(), 0.5) for i in 1:N] # observation distributions + return HMM(p, A, d) end ``` diff --git a/docs/src/custom.md b/docs/src/custom.md new file mode 100644 index 00000000..c21154f1 --- /dev/null +++ b/docs/src/custom.md @@ -0,0 +1,13 @@ +# Tutorial - custom HMM + +```@example tuto +using Distributions +using HiddenMarkovModels + +using Random; Random.seed!(63) +``` + +Here we demonstrate how to build your own HMM structure satisfying the interface. + +!!! danger + Work in progress. \ No newline at end of file diff --git a/src/HMMTest.jl b/src/HMMTest.jl new file mode 100644 index 00000000..d218f1a5 --- /dev/null +++ b/src/HMMTest.jl @@ -0,0 +1,49 @@ +module HMMTest + +using Distributions +using Distributions: PDiagMat +using HiddenMarkovModels +using HiddenMarkovModels: LightDiagNormal, sum_to_one! +using LinearAlgebra +using SparseArrays + +export rand_categorical_hmm +export rand_gaussian_hmm_1d +export rand_gaussian_hmm_2d +export rand_gaussian_hmm_2d_light + +function sparse_trans_mat(N) + A = sparse(SymTridiagonal(ones(N), ones(N - 1))) + foreach(sum_to_one!, eachrow(A)) + return A +end + +function rand_categorical_hmm(N, D; sparse_trans=false) + p = ones(N) / N + A = sparse_trans ? sparse_trans_mat(N) : rand_trans_mat(N) + d = [Categorical(rand_prob_vec(D)) for i in 1:N] + return HMM(p, A, d) +end + +function rand_gaussian_hmm_1d(N; sparse_trans=false) + p = ones(N) / N + A = sparse_trans ? sparse_trans_mat(N) : rand_trans_mat(N) + d = [Normal(randn(), 1) for i in 1:N] + return HMM(p, A, d) +end + +function rand_gaussian_hmm_2d(N, D; sparse_trans=false) + p = ones(N) / N + A = sparse_trans ? sparse_trans_mat(N) : rand_trans_mat(N) + d = [DiagNormal(randn(D), PDiagMat(ones(D) .^ 2)) for i in 1:N] + return HMM(p, A, d) +end + +function rand_gaussian_hmm_2d_light(N, D; sparse_trans=false) + p = ones(N) / N + A = sparse_trans ? sparse_trans_mat(N) : rand_trans_mat(N) + d = [LightDiagNormal(randn(D), ones(D)) for i in 1:N] + return HMM(p, A, d) +end + +end diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index 7d6b7ee2..b39e850e 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -26,6 +26,7 @@ using PrecompileTools: @compile_workload, @setup_workload using Random: Random, AbstractRNG, default_rng using Requires: @require using SimpleUnPack: @unpack +using SparseArrays: SparseMatrixCSC, nzrange, nnz using StatsAPI: StatsAPI, fit, fit! export AbstractHiddenMarkovModel, AbstractHMM @@ -44,6 +45,7 @@ include("utils/probvec.jl") include("utils/transmat.jl") include("utils/fit.jl") include("utils/lightdiagnormal.jl") +include("utils/mul.jl") include("inference/forward.jl") include("inference/viterbi.jl") @@ -52,6 +54,8 @@ include("inference/baum_welch.jl") include("types/hmm.jl") +include("HMMTest.jl") + if !isdefined(Base, :get_extension) function __init__() @require HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7" include( diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index d0cd2131..c4ef69e8 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -1,3 +1,69 @@ +""" +$(TYPEDEF) + +Store Baum-Welch quantities with element type `R`. + +This storage is relative to a several sequences. + +# Fields + +These fields (except `limits`) are passed to the `fit!(hmm, ...)` method along with `obs_seqs_concat`. + +$(TYPEDFIELDS) +""" +struct BaumWelchStorage{R,M<:AbstractMatrix{R}} + "posterior initialization counts for each state" + init_count::Vector{R} + "posterior transition counts for each state" + trans_count::M + "concatenation along time of the state marginal matrices `γ[i,t] = ℙ(X[t]=i | Y[1:T])` for all observation sequences" + state_marginals_concat::Matrix{R} + "temporal separations between observation sequences: `state_marginals_concat[limits[k]+1:limits[k+1]]` refers to sequence `k`" + limits::Vector{Int} +end + +function initialize_baum_welch( + hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer +) + check_lengths(obs_seqs, nb_seqs) + N, T_concat = length(hmm), sum(length, obs_seqs) + A = transition_matrix(hmm) + R = eltype(hmm, obs_seqs[1][1]) + init_count = Vector{R}(undef, N) + trans_count = similar(A, R) + state_marginals_concat = Matrix{R}(undef, N, T_concat) + limits = vcat(0, cumsum(length.(obs_seqs))) + return BaumWelchStorage(init_count, trans_count, state_marginals_concat, limits) +end + +function initialize_logL_evolution( + hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer; max_iterations::Integer +) + check_lengths(obs_seqs, nb_seqs) + R = eltype(hmm, obs_seqs[1][1]) + logL_evolution = R[] + sizehint!(logL_evolution, max_iterations) + return logL_evolution +end + +function update_sufficient_statistics!( + bw::BaumWelchStorage{R}, fbs::Vector{<:ForwardBackwardStorage} +) where {R} + @unpack init_count, trans_count, state_marginals_concat, limits = bw + init_count .= zero(R) + trans_count .= zero(R) + state_marginals_concat .= zero(R) + for k in eachindex(fbs) + @unpack γ, ξ, B̃β = fbs[k] + init_count .+= view(γ, :, 1) + for t in eachindex(ξ) + trans_count .+= ξ[t] + end + state_marginals_concat[:, (limits[k] + 1):limits[k + 1]] .= γ + end + return nothing +end + function baum_welch_has_converged( logL_evolution::Vector; atol::Real, loglikelihood_increasing::Bool ) @@ -13,29 +79,15 @@ function baum_welch_has_converged( return false end -function baum_welch!( - fb::ForwardBackwardStorage, - logL_evolution::Vector, - hmm::AbstractHMM, - obs_seq::Vector; - atol::Real, - max_iterations::Integer, - loglikelihood_increasing::Bool, -) - for _ in 1:max_iterations - forward_backward!(fb, hmm, obs_seq) - push!(logL_evolution, fb.logL[]) - fit!(hmm, (fb,), (obs_seq,), fb, obs_seq) - if baum_welch_has_converged(logL_evolution; atol, loglikelihood_increasing) - break - end - end - return nothing +function StatsAPI.fit!(hmm::AbstractHMM, bw::BaumWelchStorage, obs_seqs_concat::Vector) + return fit!( + hmm, bw.init_count, bw.trans_count, obs_seqs_concat, bw.state_marginals_concat + ) end function baum_welch!( fbs::Vector{<:ForwardBackwardStorage}, - fb_concat::ForwardBackwardStorage, + bw::BaumWelchStorage, logL_evolution::Vector, hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, @@ -48,8 +100,10 @@ function baum_welch!( @threads for k in eachindex(obs_seqs, fbs) forward_backward!(fbs[k], hmm, obs_seqs[k]) end + update_sufficient_statistics!(bw, fbs) push!(logL_evolution, sum(fb.logL[] for fb in fbs)) - fit!(hmm, fbs, obs_seqs, fb_concat, obs_seqs_concat) + fit!(hmm, bw, obs_seqs_concat) + check_hmm(hmm) if baum_welch_has_converged(logL_evolution; atol, loglikelihood_increasing) break end @@ -61,7 +115,7 @@ end baum_welch(hmm_init, obs_seq; kwargs...) baum_welch(hmm_init, obs_seqs, nb_seqs; kwargs...) -Apply the Baum-Welch algorithm to estimate the parameters of an HMM starting from `hmm_init`, based on one or several observation sequences. +Apply the Baum-Welch algorithm to estimate the parameters of an HMM starting from `hmm_init`. Return a tuple `(hmm_est, logL_evolution)`. @@ -71,24 +125,6 @@ Return a tuple `(hmm_est, logL_evolution)`. - `max_iterations`: maximum number of iterations of the algorithm - `loglikelihood_increasing`: whether to throw an error if the loglikelihood decreases """ -function baum_welch( - hmm_init::AbstractHMM, - obs_seq::Vector; - atol=1e-5, - max_iterations=100, - loglikelihood_increasing=true, -) - hmm = deepcopy(hmm_init) - fb = initialize_forward_backward(hmm, obs_seq) - R = eltype(hmm, obs_seq[1]) - logL_evolution = R[] - sizehint!(logL_evolution, max_iterations) - baum_welch!( - fb, logL_evolution, hmm, obs_seq; atol, max_iterations, loglikelihood_increasing - ) - return hmm, logL_evolution -end - function baum_welch( hmm_init::AbstractHMM, obs_seqs::Vector{<:Vector}, @@ -97,20 +133,15 @@ function baum_welch( max_iterations=100, loglikelihood_increasing=true, ) - if nb_seqs != length(obs_seqs) - throw(ArgumentError("nb_seqs != length(obs_seqs)")) - end + check_lengths(obs_seqs, nb_seqs) hmm = deepcopy(hmm_init) - limits = vcat(0, cumsum(length.(obs_seqs))) + fbs = [initialize_forward_backward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] + bw = initialize_baum_welch(hmm, obs_seqs, nb_seqs) + logL_evolution = initialize_logL_evolution(hmm, obs_seqs, nb_seqs; max_iterations) obs_seqs_concat = reduce(vcat, obs_seqs) - fb_concat = initialize_forward_backward(hmm, obs_seqs_concat) - fbs = [view(fb_concat, (limits[k] + 1):limits[k + 1]) for k in eachindex(obs_seqs)] - R = eltype(hmm, obs_seqs[1][1]) - logL_evolution = R[] - sizehint!(logL_evolution, max_iterations) baum_welch!( fbs, - fb_concat, + bw, logL_evolution, hmm, obs_seqs, @@ -121,3 +152,15 @@ function baum_welch( ) return hmm, logL_evolution end + +function baum_welch( + hmm_init::AbstractHMM, + obs_seq::Vector; + atol=1e-5, + max_iterations=100, + loglikelihood_increasing=true, +) + return baum_welch( + hmm_init, [obs_seq], 1; atol, max_iterations, loglikelihood_increasing + ) +end diff --git a/src/inference/forward.jl b/src/inference/forward.jl index eff7e06b..02f1a3d3 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -7,17 +7,19 @@ This storage is relative to a single sequence. # Fields +The only fields useful outside of the algorithm are `αₜ` and `logL`. + $(TYPEDFIELDS) """ struct ForwardStorage{R} "total loglikelihood" logL::RefValue{R} - "vector of observation loglikelihoods `logb[i]`" + "observation loglikelihoods `logbₜ[i] = ℙ(Y[t] | X[t]=i)`" logb::Vector{R} - "scaled forward variables `α[t]`" - αₜ::Vector{R} - "scaled forward variables `α[t+1]`" - αₜ₊₁::Vector{R} + "scaled forward messsages for a given time step" + α::Vector{R} + "same as `α` but for the next time step" + α_next::Vector{R} end function initialize_forward(hmm::AbstractHMM, obs_seq::Vector) @@ -26,9 +28,9 @@ function initialize_forward(hmm::AbstractHMM, obs_seq::Vector) logL = RefValue{R}(zero(R)) logb = Vector{R}(undef, N) - αₜ = Vector{R}(undef, N) - αₜ₊₁ = Vector{R}(undef, N) - f = ForwardStorage(logL, logb, αₜ, αₜ₊₁) + α = Vector{R}(undef, N) + α_next = Vector{R}(undef, N) + f = ForwardStorage(logL, logb, α, α_next) return f end @@ -37,30 +39,34 @@ function forward!(f::ForwardStorage, hmm::AbstractHMM, obs_seq::Vector) p = initialization(hmm) A = transition_matrix(hmm) d = obs_distributions(hmm) - @unpack logL, logb, αₜ, αₜ₊₁ = f + @unpack logL, logb, α, α_next = f logb .= logdensityof.(d, (obs_seq[1],)) logm = maximum(logb) - αₜ .= p .* exp.(logb .- logm) - c = inv(sum(αₜ)) - αₜ .*= c + α .= p .* exp.(logb .- logm) + c = inv(sum(α)) + α .*= c logL[] = -log(c) + logm for t in 1:(T - 1) logb .= logdensityof.(d, (obs_seq[t + 1],)) logm = maximum(logb) - mul!(αₜ₊₁, A', αₜ) - αₜ₊₁ .*= exp.(logb .- logm) - c = inv(sum(αₜ₊₁)) - αₜ₊₁ .*= c - αₜ .= αₜ₊₁ + mul!(α_next, A', α) + α_next .*= exp.(logb .- logm) + c = inv(sum(α_next)) + α_next .*= c + α .= α_next logL[] += -log(c) + logm end return nothing end function forward!( - fs::Vector{<:ForwardStorage}, hmm::AbstractHMM, obs_seqs::Vector{<:Vector} + fs::Vector{<:ForwardStorage}, + hmm::AbstractHMM, + obs_seqs::Vector{<:Vector}, + nb_seqs::Integer, ) + check_lengths(obs_seqs, nb_seqs) @threads for k in eachindex(fs, obs_seqs) forward!(fs[k], hmm, obs_seqs[k]) end @@ -80,19 +86,15 @@ When applied on a single sequence, this function returns a tuple `(α, logL)` wh When applied on multiple sequences, this function returns a vector of tuples. """ -function forward(hmm::AbstractHMM, obs_seq::Vector) - f = initialize_forward(hmm, obs_seq) - forward!(f, hmm, obs_seq) - return f.αₜ, f.logL[] -end - function forward(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer) - if nb_seqs != length(obs_seqs) - throw(ArgumentError("nb_seqs != length(obs_seqs)")) - end + check_lengths(obs_seqs, nb_seqs) fs = [initialize_forward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] - forward!(fs, hmm, obs_seqs) - return [(f.αₜ, f.logL[]) for f in fs] + forward!(fs, hmm, obs_seqs, nb_seqs) + return [(f.α, f.logL[]) for f in fs] +end + +function forward(hmm::AbstractHMM, obs_seq::Vector) + return only(forward(hmm, [obs_seq], 1)) end """ @@ -103,14 +105,13 @@ Run the forward algorithm to compute the posterior loglikelihood of observations Whether it is applied on one or multiple sequences, this function returns a number. """ -function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seq::Vector) - _, logL = forward(hmm, obs_seq) - return logL -end - function DensityInterface.logdensityof( hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer ) logαs_and_logLs = forward(hmm, obs_seqs, nb_seqs) return sum(last, logαs_and_logLs) end + +function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seq::Vector) + return logdensityof(hmm, [obs_seq], 1) +end diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index 5066828f..3aefcdbf 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -7,78 +7,68 @@ This storage is relative to a single sequence. # Fields -Let `X` denote the vector of hidden states and `Y` denote the vector of observations. +The only fields useful outside of the algorithm are `γ`, `ξ` and `logL`. $(TYPEDFIELDS) """ -struct ForwardBackwardStorage{ - R,V<:AbstractVector{R},M<:AbstractMatrix{R},A3<:AbstractArray{R,3} -} +struct ForwardBackwardStorage{R,M<:AbstractMatrix{R}} "total loglikelihood" logL::RefValue{R} - "scaled forward variables `α[i,t]` proportional to `ℙ(Y[1:t], X[t]=i)` (up to a function of `t`)" - α::M - "scaled backward variables `β[i,t]` proportional to `ℙ(Y[t+1:T] | X[t]=i)` (up to a function of `t`)" - β::M + "scaled forward messsages `α[i,t]` proportional to `ℙ(Y[1:t], X[t]=i)` (up to a function of `t`)" + α::Matrix{R} + "scaled backward messsages `β[i,t]` proportional to `ℙ(Y[t+1:T] | X[t]=i)` (up to a function of `t`)" + β::Matrix{R} "posterior state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`" - γ::M - "posterior transition marginals `ξ[i,j,t] = ℙ(X[t:t+1]=(i,j) | Y[1:T])`" - ξ::A3 - "forward variable inverse normalizations `c[t] = 1 / sum(α[:,t])`" - c::V - "observation loglikelihoods `logB[i, t]`" - logB::M + γ::Matrix{R} + "posterior transition marginals `ξ[t][i,j] = ℙ(X[t:t+1]=(i,j) | Y[1:T])`" + ξ::Vector{M} + "forward message inverse normalizations `c[t] = 1 / sum(α[:,t])`" + c::Vector{R} + "observation loglikelihoods `logB[i,t] = ℙ(Y[t] | X[t]=i)`" + logB::Matrix{R} "maximum of the observation loglikelihoods `logm[t] = maximum(logB[:, t])`" - logm::V + logm::Vector{R} "numerically stabilized observation likelihoods `B̃[i,t] = exp.(logB[i,t] - logm[t])`" - B̃::M - "numerically stabilized product `B̃β[i,t] = B̃[i,t] * β[i,t]`" - B̃β::M + B̃::Matrix{R} + "product `B̃β[i,t] = B̃[i,t] * β[i,t]`" + B̃β::Matrix{R} end Base.eltype(::ForwardBackwardStorage{R}) where {R} = R -Base.length(fb::ForwardBackwardStorage) = size(fb.α, 1) duration(fb::ForwardBackwardStorage) = size(fb.α, 2) -function Base.view(fb::ForwardBackwardStorage{R}, r::AbstractUnitRange) where {R} - logL = Ref(zero(R)) - α = view(fb.α, :, r) - β = view(fb.β, :, r) - γ = view(fb.γ, :, r) - ξ = view(fb.ξ, :, :, r) - c = view(fb.c, r) - logB = view(fb.logB, :, r) - logm = view(fb.logm, r) - B̃ = view(fb.B̃, :, r) - B̃β = view(fb.B̃β, :, r) - return ForwardBackwardStorage(logL, α, β, γ, ξ, c, logB, logm, B̃, B̃β) -end - function initialize_forward_backward(hmm::AbstractHMM, obs_seq::Vector) N, T = length(hmm), length(obs_seq) + A = transition_matrix(hmm) R = eltype(hmm, obs_seq[1]) + M = typeof(similar(A, R)) logL = RefValue{R}(zero(R)) α = Matrix{R}(undef, N, T) β = Matrix{R}(undef, N, T) γ = Matrix{R}(undef, N, T) - ξ = Array{R,3}(undef, N, N, T) + ξ = Vector{M}(undef, T - 1) + for t in 1:(T - 1) + ξ[t] = similar(A, R) + end c = Vector{R}(undef, T) logB = Matrix{R}(undef, N, T) logm = Vector{R}(undef, T) B̃ = Matrix{R}(undef, N, T) B̃β = Matrix{R}(undef, N, T) - return ForwardBackwardStorage(logL, α, β, γ, ξ, c, logB, logm, B̃, B̃β) + return ForwardBackwardStorage{R,M}(logL, α, β, γ, ξ, c, logB, logm, B̃, B̃β) end function update_likelihoods!(fb::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq::Vector) d = obs_distributions(hmm) @unpack logB, logm, B̃ = fb + N, T = length(hmm), duration(fb) - for t in eachindex(axes(logB, 2), obs_seq) + for t in 1:T logB[:, t] .= logdensityof.(d, (obs_seq[t],)) end + check_no_nan(logB) maximum!(logm', logB) B̃ .= exp.(logB .- logm') return nothing @@ -88,7 +78,7 @@ function forward!(fb::ForwardBackwardStorage, hmm::AbstractHMM) p = initialization(hmm) A = transition_matrix(hmm) @unpack α, c, B̃ = fb - T = size(α, 2) + N, T = length(hmm), duration(fb) @views begin α[:, 1] .= p .* B̃[:, 1] @@ -101,7 +91,6 @@ function forward!(fb::ForwardBackwardStorage, hmm::AbstractHMM) c[t + 1] = inv(sum(α[:, t + 1])) α[:, t + 1] .*= c[t + 1] end - check_no_nan(α) fb.logL[] = -sum(log, fb.c) + sum(fb.logm) return nothing end @@ -109,7 +98,7 @@ end function backward!(fb::ForwardBackwardStorage{R}, hmm::AbstractHMM) where {R} A = transition_matrix(hmm) @unpack β, c, B̃, B̃β = fb - T = size(β, 2) + N, T = length(hmm), duration(fb) β[:, T] .= c[T] @views for t in (T - 1):-1:1 @@ -118,22 +107,19 @@ function backward!(fb::ForwardBackwardStorage{R}, hmm::AbstractHMM) where {R} β[:, t] .*= c[t] end @views B̃β[:, 1] .= B̃[:, 1] .* β[:, 1] - check_no_nan(β) return nothing end function marginals!(fb::ForwardBackwardStorage, hmm::AbstractHMM) A = transition_matrix(hmm) @unpack α, β, c, B̃β, γ, ξ = fb - N, T = size(γ) + N, T = length(hmm), duration(fb) γ .= α .* β ./ c' check_no_nan(γ) @views for t in 1:(T - 1) - ξ[:, :, t] .= α[:, t] .* A .* B̃β[:, t + 1]' + mul_rows_cols!(ξ[t], α[:, t], A, B̃β[:, t + 1]) end - ξ[:, :, T] .= zero(eltype(ξ)) - check_no_nan(ξ) return nothing end @@ -146,8 +132,12 @@ function forward_backward!(fb::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq end function forward_backward!( - fbs::Vector{<:ForwardBackwardStorage}, hmm::AbstractHMM, obs_seqs::Vector{<:Vector} + fbs::Vector{<:ForwardBackwardStorage}, + hmm::AbstractHMM, + obs_seqs::Vector{<:Vector}, + nb_seqs::Integer, ) + check_lengths(obs_seqs, nb_seqs) @threads for k in eachindex(fbs, obs_seqs) forward_backward!(fbs[k], hmm, obs_seqs[k]) end @@ -163,22 +153,17 @@ Run the forward-backward algorithm to infer the posterior state and transition m When applied on a single sequence, this function returns a tuple `(γ, ξ, logL)` where - `γ` is a matrix containing the posterior state marginals `γ[i, t]` -- `ξ` is a 3-tensor containing the posterior transition marginals `ξ[i, j, t]` - `logL` is the loglikelihood of the sequence WHen applied on multiple sequences, it returns a vector of tuples. """ -function forward_backward(hmm::AbstractHMM, obs_seq::Vector) - fb = initialize_forward_backward(hmm, obs_seq) - forward_backward!(fb, hmm, obs_seq) - return (fb.γ, fb.ξ, fb.logL[]) -end - function forward_backward(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer) - if nb_seqs != length(obs_seqs) - throw(ArgumentError("nb_seqs != length(obs_seqs)")) - end + check_lengths(obs_seqs, nb_seqs) fbs = [initialize_forward_backward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] - forward_backward!(fbs, hmm, obs_seqs) - return [(fb.γ, fb.ξ, fb.logL[]) for fb in fbs] + forward_backward!(fbs, hmm, obs_seqs, nb_seqs) + return [(fb.γ, fb.logL[]) for fb in fbs] +end + +function forward_backward(hmm::AbstractHMM, obs_seq::Vector) + return only(forward_backward(hmm, [obs_seq], 1)) end diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index ec4fde91..6a65f92c 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -7,16 +7,22 @@ This storage is relative to a single sequence. # Fields +The only field useful outside of the algorithm is `q`. + $(TYPEDFIELDS) """ struct ViterbiStorage{R} - "vector of observation loglikelihoods `logb[i]`" + "observation loglikelihoods at a given time step" logb::Vector{R} - δₜ::Vector{R} - δₜ₋₁::Vector{R} - δₜ₋₁Aⱼ::Vector{R} + "highest path scores when accounting for the first `t` observations and ending at a given state" + δ::Vector{R} + "same as `δ` but for the previous time step" + δ_prev::Vector{R} + "temporary variable used to store products `δ_prev .* A[:, j]`" + δA::Vector{R} + "penultimate state maximizing the path score" ψ::Matrix{Int} - "vector of most likely state at each time" + "most likely state at each time `q[t] = argmaxᵢ ℙ(X[t]=i | Y[1:T])`" q::Vector{Int} end @@ -25,12 +31,12 @@ function initialize_viterbi(hmm::AbstractHMM, obs_seq::Vector) R = eltype(hmm, obs_seq[1]) logb = Vector{R}(undef, N) - δₜ = Vector{R}(undef, N) - δₜ₋₁ = Vector{R}(undef, N) - δₜ₋₁Aⱼ = Vector{R}(undef, N) + δ = Vector{R}(undef, N) + δ_prev = Vector{R}(undef, N) + δA = Vector{R}(undef, N) ψ = Matrix{Int}(undef, N, T) q = Vector{Int}(undef, T) - return ViterbiStorage(logb, δₜ, δₜ₋₁, δₜ₋₁Aⱼ, ψ, q) + return ViterbiStorage(logb, δ, δ_prev, δA, ψ, q) end function viterbi!(v::ViterbiStorage, hmm::AbstractHMM, obs_seq::Vector) @@ -38,25 +44,25 @@ function viterbi!(v::ViterbiStorage, hmm::AbstractHMM, obs_seq::Vector) p = initialization(hmm) A = transition_matrix(hmm) d = obs_distributions(hmm) - @unpack logb, δₜ, δₜ₋₁, δₜ₋₁Aⱼ, ψ, q = v + @unpack logb, δ, δ_prev, δA, ψ, q = v logb .= logdensityof.(d, (obs_seq[1],)) logm = maximum(logb) - δₜ .= p .* exp.(logb .- logm) - δₜ₋₁ .= δₜ + δ .= p .* exp.(logb .- logm) + δ_prev .= δ @views ψ[:, 1] .= zero(eltype(ψ)) for t in 2:T logb .= logdensityof.(d, (obs_seq[t],)) logm = maximum(logb) for j in 1:N - @views δₜ₋₁Aⱼ .= δₜ₋₁ .* A[:, j] - i_max = argmax(δₜ₋₁Aⱼ) + @views δA .= δ_prev .* A[:, j] + i_max = argmax(δA) ψ[j, t] = i_max - δₜ[j] = δₜ₋₁Aⱼ[i_max] * exp(logb[j] - logm) + δ[j] = δA[i_max] * exp(logb[j] - logm) end - δₜ₋₁ .= δₜ + δ_prev .= δ end - q[T] = argmax(δₜ) + q[T] = argmax(δ) for t in (T - 1):-1:1 q[t] = ψ[q[t + 1], t + 1] end @@ -64,8 +70,12 @@ function viterbi!(v::ViterbiStorage, hmm::AbstractHMM, obs_seq::Vector) end function viterbi!( - vs::Vector{<:ViterbiStorage}, hmm::AbstractHMM, obs_seqs::Vector{<:Vector} + vs::Vector{<:ViterbiStorage}, + hmm::AbstractHMM, + obs_seqs::Vector{<:Vector}, + nb_seqs::Integer, ) + check_lengths(obs_seqs, nb_seqs) @threads for k in eachindex(vs, obs_seqs) viterbi!(vs[k], hmm, obs_seqs[k]) end @@ -81,17 +91,13 @@ Apply the Viterbi algorithm to infer the most likely state sequence of an HMM. When applied on a single sequence, this function returns a vector of integers. When applied on multiple sequences, it returns a vector of vectors of integers. """ -function viterbi(hmm::AbstractHMM, obs_seq::Vector) - v = initialize_viterbi(hmm, obs_seq) - viterbi!(v, hmm, obs_seq) - return v.q -end - function viterbi(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer) - if nb_seqs != length(obs_seqs) - throw(ArgumentError("nb_seqs != length(obs_seqs)")) - end + check_lengths(obs_seqs, nb_seqs) vs = [initialize_viterbi(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] - viterbi!(vs, hmm, obs_seqs) + viterbi!(vs, hmm, obs_seqs, nb_seqs) return [v.q for v in vs] end + +function viterbi(hmm::AbstractHMM, obs_seq::Vector) + return only(viterbi(hmm, [obs_seq], 1)) +end diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index 133ecb90..658dabbb 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -5,20 +5,25 @@ Abstract supertype for an HMM amenable to simulation, inference and learning. # Interface -- `length(hmm)` -- `initialization(hmm)` -- `transition_matrix(hmm)` -- `obs_distributions(hmm)` -- `fit!(hmm, fbs, obs_seqs, fb_concat, obs_seqs_concat)` +To create your own subtype of `AbstractHiddenMarkovModel`, you need to implement the following methods: -# Applicable methods +- [`length(hmm)`](@ref) +- [`eltype(hmm, obs)`](@ref) +- [`initialization(hmm)`](@ref) +- [`transition_matrix(hmm)`](@ref) +- [`obs_distributions(hmm)`](@ref) +- [`fit!(hmm, init_count, trans_count, obs_seq, state_marginals)`](@ref) (optional) -- `rand([rng,] hmm, T)` -- `logdensityof(hmm, obs_seq)` -- `forward(hmm, obs_seq)` -- `viterbi(hmm, obs_seq)` -- `forward_backward(hmm, obs_seq)` -- `baum_welch(hmm, obs_seq)` if `fit!` is implemented +# Applicable functions + +Any HMM object which satisfies the interface can be given as input to the following functions: + +- [`rand(rng, hmm, T)`](@ref) +- [`logdensityof(hmm, obs_seq)`](@ref) +- [`forward(hmm, obs_seq)`](@ref) +- [`viterbi(hmm, obs_seq)`](@ref) +- [`forward_backward(hmm, obs_seq)`](@ref) +- [`baum_welch(hmm, obs_seq)`](@ref) (if the optional `fit!` is implemented) """ abstract type AbstractHiddenMarkovModel end @@ -79,16 +84,31 @@ Each element `dist` of this vector must implement function obs_distributions end """ - fit!(hmm, obs_seqs, fbs) + fit!(hmm, init_count, trans_count, obs_seq, state_marginals) + +Update `hmm` in-place based on information generated during forward-backward. + +This method is only necessary for the Baum-Welch algorithm. + +# Arguments + +- `init_count::Vector`: posterior initialization counts for each state (size `N`) +- `trans_count::AbstractMatrix`: posterior transition counts for each state (size `(N, N)`) +- `obs_seq::Vector`: sequence of observation, possibly concatenated (size `T`) +- `state_marginals::Matrix`: posterior probabilities of being in each state at each time, to be used as weights during maximum likelihood fitting of the observation distributions (size `(N, T)`). + +# See also -Update `hmm` in-place based on several observation sequences `obs_seqs` as well as information `fbs` generated during forward-backward. +- [`BaumWelchStorage`](@ref) +- [`ForwardBackwardStorage`](@ref) """ -StatsAPI.fit! +StatsAPI.fit! # TODO: complete ## Sampling """ - rand([rng,] hmm, T) + rand(hmm, T) + rand(rng, hmm, T) Simulate `hmm` for `T` time steps. """ diff --git a/src/types/hmm.jl b/src/types/hmm.jl index a2c05f6d..2a27f0d7 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -7,19 +7,17 @@ Basic implementation of an HMM. $(TYPEDFIELDS) """ -struct HiddenMarkovModel{D,U<:AbstractVector,M<:AbstractMatrix,V<:AbstractVector{D}} <: +struct HiddenMarkovModel{I<:AbstractVector,T<:AbstractMatrix,D<:AbstractVector} <: AbstractHMM "initial state probabilities" - init::U + init::I "state transition matrix" - trans::M + trans::T "observation distributions" - dists::V + dists::D - function HiddenMarkovModel( - init::U, trans::M, dists::V - ) where {D,U<:AbstractVector,M<:AbstractMatrix,V<:AbstractVector{D}} - hmm = new{D,U,M,V}(init, trans, dists) + function HiddenMarkovModel(init::I, trans::T, dists::D) where {I,T,D} + hmm = new{I,T,D}(init, trans, dists) check_hmm(hmm) return hmm end @@ -42,21 +40,21 @@ transition_matrix(hmm::HMM) = hmm.trans obs_distributions(hmm::HMM) = hmm.dists function StatsAPI.fit!( - hmm::HMM, fbs, obs_seqs, fb_concat::ForwardBackwardStorage, obs_seqs_concat::Vector + hmm::HMM, + init_count::Vector, + trans_count::AbstractMatrix, + obs_seq::Vector, + state_marginals::Matrix, ) # Initialization - hmm.init .= zero(eltype(hmm.init)) - for k in eachindex(fbs) - hmm.init .+= view(fbs[k].γ, :, 1) - end + hmm.init .= init_count sum_to_one!(hmm.init) # Transition matrix - hmm.trans .= zero(eltype(hmm.trans)) - sum!(hmm.trans, fb_concat.ξ; init=false) + hmm.trans .= trans_count foreach(sum_to_one!, eachrow(hmm.trans)) # Observation distributions for i in eachindex(hmm.dists) - fit_element_from_sequence!(hmm.dists, i, obs_seqs_concat, view(fb_concat.γ, i, :)) + fit_element_from_sequence!(hmm.dists, i, obs_seq, view(state_marginals, i, :)) end check_hmm(hmm) return nothing diff --git a/src/utils/check.jl b/src/utils/check.jl index 2e0a0377..a0c34838 100644 --- a/src/utils/check.jl +++ b/src/utils/check.jl @@ -65,3 +65,9 @@ function check_hmm(hmm::AbstractHMM) check_dists(d) return nothing end + +function check_lengths(obs_seqs::Vector{<:Vector}, nb_seqs::Integer) + if nb_seqs != length(obs_seqs) + throw(ArgumentError("nb_seqs != length(obs_seqs)")) + end +end diff --git a/src/utils/fit.jl b/src/utils/fit.jl index 6e749b99..7321af9b 100644 --- a/src/utils/fit.jl +++ b/src/utils/fit.jl @@ -14,23 +14,23 @@ end function fit_element_from_sequence!( dists::AbstractVector{D}, i::Integer, x::AbstractVector, w::AbstractVector ) where {D<:Distribution} - dists[i] = _fit_from_sequence(D, x, w) + dists[i] = fit_from_sequence(D, x, w) return nothing end -function _fit_from_sequence( +function fit_from_sequence( ::Type{D}, x_nums::AbstractVector, w::AbstractVector ) where {D<:UnivariateDistribution} return fit(D, x_nums, w) end -function _fit_from_sequence( +function fit_from_sequence( ::Type{D}, x_vecs::AbstractVector{<:AbstractVector}, w::AbstractVector ) where {D<:MultivariateDistribution} return fit(D, reduce(hcat, x_vecs), w) end -function _fit_from_sequence( +function fit_from_sequence( ::Type{D}, x_mats::AbstractVector{<:AbstractMatrix}, w::AbstractVector ) where {D<:MatrixDistribution} return fit(D, reduce(dcat, x_mats), w) diff --git a/src/utils/lightdiagnormal.jl b/src/utils/lightdiagnormal.jl index 0974b69a..a2b36a31 100644 --- a/src/utils/lightdiagnormal.jl +++ b/src/utils/lightdiagnormal.jl @@ -26,6 +26,10 @@ function LightDiagNormal(μ, σ) return LightDiagNormal(μ, σ, log.(σ)) end +function Base.show(io::IO, dist::LightDiagNormal) + return print(io, "LightDiagNormal($(dist.μ), $(dist.σ))") +end + @inline DensityInterface.DensityKind(::LightDiagNormal) = HasDensity() Base.length(dist::LightDiagNormal) = length(dist.μ) diff --git a/src/utils/mul.jl b/src/utils/mul.jl new file mode 100644 index 00000000..5b66fc9e --- /dev/null +++ b/src/utils/mul.jl @@ -0,0 +1,20 @@ +function mul_rows_cols!( + B::AbstractMatrix, l::AbstractVector, A::AbstractMatrix, r::AbstractVector +) + B .= l .* A .* r' + return nothing +end + +function mul_rows_cols!( + B::SparseMatrixCSC, l::AbstractVector, A::SparseMatrixCSC, r::AbstractVector +) + @assert size(B) == size(A) == (length(l), length(r)) + B .= A + for j in axes(B, 2) + for k in nzrange(B, j) + i = B.rowval[k] + B.nzval[k] *= l[i] * r[j] + end + end + return nothing +end diff --git a/src/utils/probvec.jl b/src/utils/probvec.jl index 7456558f..7925b56c 100644 --- a/src/utils/probvec.jl +++ b/src/utils/probvec.jl @@ -5,7 +5,8 @@ end sum_to_one!(x) = x ./= sum(x) """ - rand_prob_vec([rng,] N) + rand_prob_vec(N) + rand_prob_vec(rng, N) Generate a random probability distribution of size `N`. """ diff --git a/src/utils/transmat.jl b/src/utils/transmat.jl index 1e40831f..1f30e5f8 100644 --- a/src/utils/transmat.jl +++ b/src/utils/transmat.jl @@ -16,7 +16,8 @@ function is_trans_mat(A::AbstractMatrix; atol=1e-2) end """ - rand_trans_mat([rng,] N) + rand_trans_mat(N) + rand_trans_mat(rng, N) Generate a random transition matrix of size `(N, N)`. """ diff --git a/test/allocations.jl b/test/allocations.jl index d6fa829c..667e4306 100644 --- a/test/allocations.jl +++ b/test/allocations.jl @@ -1,54 +1,53 @@ using BenchmarkTools -using Distributions using HiddenMarkovModels +using HiddenMarkovModels.HMMTest import HiddenMarkovModels as HMMs -using HiddenMarkovModels: LightDiagNormal -using SimpleUnPack using Test function test_allocations(hmm; T) obs_seq = rand(hmm, T).obs_seq + nb_seqs = 2 + obs_seqs = [rand(hmm, T).obs_seq for _ in 1:nb_seqs] ## Forward f = HMMs.initialize_forward(hmm, obs_seq) - allocs = @ballocated HiddenMarkovModels.forward!($f, $hmm, $obs_seq) samples = 2 + allocs = @ballocated HiddenMarkovModels.forward!($f, $hmm, $obs_seq) @test allocs == 0 ## Viterbi v = HMMs.initialize_viterbi(hmm, obs_seq) - allocs = @ballocated HMMs.viterbi!($v, $hmm, $obs_seq) samples = 2 + allocs = @ballocated HMMs.viterbi!($v, $hmm, $obs_seq) @test allocs == 0 ## Forward-backward fb = HMMs.initialize_forward_backward(hmm, obs_seq) - allocs = @ballocated HMMs.forward_backward!($fb, $hmm, $obs_seq) samples = 2 + allocs = @ballocated HMMs.forward_backward!($fb, $hmm, $obs_seq) @test allocs == 0 ## Baum-Welch - fb = HMMs.initialize_forward_backward(hmm, obs_seq) - R = eltype(hmm, obs_seq[1]) - logL_evolution = R[] - sizehint!(logL_evolution, 2) - allocs = @ballocated HMMs.baum_welch!( - $fb, - $logL_evolution, - $hmm, - $obs_seq; - atol=-Inf, - max_iterations=2, - loglikelihood_increasing=false, - ) samples = 2 + fbs = [HMMs.initialize_forward_backward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] + bw = HMMs.initialize_baum_welch(hmm, obs_seqs, nb_seqs) + obs_seqs_concat = reduce(vcat, obs_seqs) + HMMs.forward_backward!(fbs, hmm, obs_seqs, nb_seqs) + HMMs.update_sufficient_statistics!(bw, fbs) + fit!(hmm, bw, obs_seqs_concat) + allocs = @ballocated HMMs.update_sufficient_statistics!($bw, $fbs) + @test allocs == 0 + allocs = @ballocated fit!($hmm, $bw, $obs_seqs_concat) @test allocs == 0 end -N = 5 -D = 3 -T = 100 +N, D, T = 3, 2, 100 -p = rand_prob_vec(N) -A = rand_trans_mat(N) -dists = [LightDiagNormal(randn(2), ones(2)) for i in 1:N] +@testset "Normal" begin + test_allocations(rand_gaussian_hmm_1d(N); T) +end -hmm = HMM(p, A, dists) +@testset "Normal sparse" begin + # see https://discourse.julialang.org/t/why-does-mul-u-a-v-allocate-when-a-is-sparse-and-u-v-are-views/105995 + @test_skip test_allocations(rand_gaussian_hmm_1d(N; sparse_trans=true); T) +end -test_allocations(hmm; T) +@testset "LightDiagNormal" begin + test_allocations(rand_gaussian_hmm_2d_light(N, D); T) +end diff --git a/test/arrays.jl b/test/arrays.jl new file mode 100644 index 00000000..6f63794a --- /dev/null +++ b/test/arrays.jl @@ -0,0 +1,51 @@ +using Distributions +using HiddenMarkovModels +using HiddenMarkovModels: sum_to_one! +using LinearAlgebra +using SparseArrays +using StaticArrays +using SimpleUnPack +using Test + +N, T = 3, 1000 + +## Sparse + +p = ones(N) / N; +A = SparseMatrixCSC(SymTridiagonal(ones(N), ones(N - 1))); +foreach(sum_to_one!, eachrow(A)); +d = [Normal(i + randn(), 1.0) for i in 1:N]; +d_init = [Normal(i + randn(), 1.0) for i in 1:N]; + +hmm = HMM(p, A, d); +hmm_init = HMM(p, A, d_init); + +obs_seq = rand(hmm, T).obs_seq; + +γ, ξ, logL = forward_backward(hmm, obs_seq); +hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq); + +@testset "Sparse" begin + @test eltype(ξ) <: AbstractSparseArray + @test typeof(hmm_est) == typeof(hmm_init) + @test nnz(transition_matrix(hmm_est)) <= nnz(transition_matrix(hmm)) +end + +## Static + +p = MVector{N}(ones(N) / N); +A = MMatrix{N,N}(rand_trans_mat(N)); +d = MVector{N}([Normal(randn(), 1.0) for i in 1:N]); +d_init = MVector{N}([Normal(randn(), 1.0) for i in 1:N]); + +hmm = HMM(p, A, d); +hmm_init = HMM(p, A, d_init); +obs_seq = rand(hmm, T).obs_seq; + +γ, ξ, logL = forward_backward(hmm, obs_seq); +hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq); + +@testset "Static" begin + @test eltype(ξ) <: StaticArray + @test typeof(hmm_est) == typeof(hmm_init) +end diff --git a/test/autodiff.jl b/test/autodiff.jl index cccaf133..eca642ca 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -12,13 +12,13 @@ T = 100 p = rand_prob_vec(N) A = rand_trans_mat(N) μ = rand(N) -dists = [Normal(μ[i], 1.0) for i in 1:N] -hmm = HMM(p, A, dists) +d = [Normal(μ[i], 1.0) for i in 1:N] +hmm = HMM(p, A, d) obs_seq = rand(hmm, T).obs_seq; function f_init(_p) - hmm = HMM(_p, A, dists) + hmm = HMM(_p, A, d) return logdensityof(hmm, obs_seq) end @@ -27,7 +27,7 @@ g2 = Zygote.gradient(f_init, p)[1] @test isapprox(g1, g2) function f_trans(_A) - hmm = HMM(p, _A, dists) + hmm = HMM(p, _A, d) return logdensityof(hmm, obs_seq) end @@ -35,13 +35,13 @@ g1 = ForwardDiff.gradient(f_trans, A) g2 = Zygote.gradient(f_trans, A)[1] @test isapprox(g1, g2) -function f_dists(_μ) +function f_d(_μ) hmm = HMM(p, A, [Normal(_μ[i], 1.0) for i in 1:N]) return logdensityof(hmm, obs_seq) end -g0 = FiniteDifferences.grad(central_fdm(5, 1), f_dists, μ)[1] -g1 = ForwardDiff.gradient(f_dists, μ) -g2 = Zygote.gradient(f_dists, μ)[1] +g0 = FiniteDifferences.grad(central_fdm(5, 1), f_d, μ)[1] +g1 = ForwardDiff.gradient(f_d, μ) +g2 = Zygote.gradient(f_d, μ)[1] @test isapprox(g0, g1) @test isapprox(g0, g2) diff --git a/test/correctness.jl b/test/correctness.jl index aad72e6a..10aeb1e6 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -1,95 +1,88 @@ using Distributions -using Distributions: PDiagMat using HMMBase: HMMBase using HiddenMarkovModels +using HiddenMarkovModels.HMMTest using SimpleUnPack using Test function test_correctness(hmm, hmm_init; T) - obs_seq = rand(hmm, T).obs_seq - obs_mat = collect(reduce(hcat, obs_seq)') + obs_seq1 = rand(hmm, T).obs_seq + obs_seq2 = rand(hmm, T).obs_seq + obs_mat1 = collect(reduce(hcat, obs_seq1)') + obs_mat2 = collect(reduce(hcat, obs_seq2)') nb_seqs = 2 - obs_seqs = [obs_seq for _ in 1:nb_seqs] + obs_seqs = [obs_seq1, obs_seq2] hmm_base = HMMBase.HMM(deepcopy(hmm)) hmm_init_base = HMMBase.HMM(deepcopy(hmm_init)) @testset "Logdensity" begin - _, logL_base = HMMBase.forward(hmm_base, obs_mat) - logL = @inferred logdensityof(hmm, obs_seqs, nb_seqs) - @test logL ≈ logL_base * nb_seqs + logL1_base = HMMBase.forward(hmm_base, obs_mat1)[2] + logL2_base = HMMBase.forward(hmm_base, obs_mat2)[2] + logL = logdensityof(hmm, obs_seqs, nb_seqs) + @test logL ≈ logL1_base + logL2_base end @testset "Forward" begin - α_base, logL_base = HMMBase.forward(hmm_base, obs_mat) - α, logL = @inferred first(forward(hmm, obs_seqs, nb_seqs)) - @test isapprox(α, α_base[end, :]) - @test logL ≈ logL_base + (α1_base, logL1_base), (α2_base, logL2_base) = [ + HMMBase.forward(hmm_base, obs_mat1), HMMBase.forward(hmm_base, obs_mat2) + ] + (α1, logL1), (α2, logL2) = forward(hmm, obs_seqs, nb_seqs) + @test isapprox(α1, α1_base[end, :]) + @test isapprox(α2, α2_base[end, :]) + @test logL1 ≈ logL1_base + @test logL2 ≈ logL2_base end @testset "Viterbi" begin - q_base = HMMBase.viterbi(hmm_base, obs_mat) - q = @inferred first(viterbi(hmm, obs_seqs, nb_seqs)) - @test isequal(q, q_base) + q1_base = HMMBase.viterbi(hmm_base, obs_mat1) + q2_base = HMMBase.viterbi(hmm_base, obs_mat2) + q1, q2 = viterbi(hmm, obs_seqs, nb_seqs) + @test isequal(q1, q1_base) + @test isequal(q2, q2_base) end @testset "Forward-backward" begin - γ_base = HMMBase.posteriors(hmm_base, obs_mat) - γ, ξ = @inferred first(forward_backward(hmm, obs_seqs, nb_seqs)) - @test isapprox(γ, γ_base') + γ1_base = HMMBase.posteriors(hmm_base, obs_mat1) + γ2_base = HMMBase.posteriors(hmm_base, obs_mat2) + (γ1, _), (γ2, _) = forward_backward(hmm, obs_seqs, nb_seqs) + @test isapprox(γ1, γ1_base') + @test isapprox(γ2, γ2_base') end @testset "Baum-Welch" begin hmm_est_base, hist_base = HMMBase.fit_mle( - hmm_init_base, obs_mat; maxiter=10, tol=-Inf + hmm_init_base, obs_mat1; maxiter=10, tol=-Inf ) logL_evolution_base = hist_base.logtots - hmm_est, logL_evolution = @inferred baum_welch( - hmm_init, obs_seqs, nb_seqs; max_iterations=10, atol=-Inf + hmm_est, logL_evolution = baum_welch( + hmm_init, [obs_seq1, obs_seq1], 2; max_iterations=10, atol=-Inf ) @test isapprox( - logL_evolution[(begin + 1):end], logL_evolution_base[begin:(end - 1)] .* nb_seqs + logL_evolution[(begin + 1):end], 2 * logL_evolution_base[begin:(end - 1)] ) @test isapprox(initialization(hmm_est), hmm_est_base.a) @test isapprox(transition_matrix(hmm_est), hmm_est_base.A) for (dist, dist_base) in zip(hmm.dists, hmm_base.B) - @test isapprox(dist.μ, dist_base.μ) + for n in fieldnames(typeof(dist)) + @test isapprox(getfield(dist, n), getfield(dist_base, n)) + end end end end -N = 5 -D = 3 -T = 100 +N, D, T = 3, 2, 100 -p = rand_prob_vec(N) -p_init = rand_prob_vec(N) - -A = rand_trans_mat(N) -A_init = rand_trans_mat(N) - -# Normal - -dists_norm = [Normal(randn(), 1.0) for i in 1:N] -dists_norm_init = [Normal(randn(), 1) for i in 1:N] - -hmm_norm = HMM(p, A, dists_norm) -hmm_norm_init = HMM(p_init, A_init, dists_norm_init) +@testset "Categorical" begin + test_correctness(rand_categorical_hmm(N, 2D), rand_categorical_hmm(N, 2D); T) +end @testset "Normal" begin - test_correctness(hmm_norm, hmm_norm_init; T) + test_correctness(rand_gaussian_hmm_1d(N), rand_gaussian_hmm_1d(N); T) end -# DiagNormal - -dists_diagnorm = [DiagNormal(randn(D), PDiagMat(ones(D))) for i in 1:N] -dists_diagnorm_init = [DiagNormal(randn(D), PDiagMat(ones(D) .^ 2)) for i in 1:N] - -hmm_diagnorm = HMM(p, A, dists_diagnorm) -hmm_diagnorm_init = HMM(p, A, dists_diagnorm_init) - @testset "DiagNormal" begin - test_correctness(hmm_diagnorm, hmm_diagnorm_init; T) + test_correctness(rand_gaussian_hmm_2d(N, D), rand_gaussian_hmm_2d(N, D); T) end diff --git a/test/dna.jl b/test/dna.jl index 9fa5cff5..6279d122 100644 --- a/test/dna.jl +++ b/test/dna.jl @@ -91,20 +91,13 @@ function HiddenMarkovModels.obs_distributions(hmm::DNACodingHMM) return [Dirac(get_nucleotide(s)) for s in 1:length(hmm)] end -function StatsAPI.fit!(dchmm::DNACodingHMM, bw::HiddenMarkovModels.BaumWelchStorage) - @unpack fbs, obs_seqs_concat, state_marginals_concat = bw - - init_count = zeros(eltype(initialization(dchmm)), 8) - for k in eachindex(fbs) - @views init_count .+= fbs[k].γ[:, 1] - end - sum_to_one!(init_count) - trans_count = zeros(eltype(transition_matrix(dchmm)), 8, 8) - for k in eachindex(fbs) - sum!(trans_count, fbs[k].ξ; init=false) - end - foreach(sum_to_one!, eachrow(hmm.trans)) - +function StatsAPI.fit!( + dchmm::DNACodingHMM, + init_count::Vector, + trans_count::Matrix, + obs_seq::Vector, + state_marginals::Matrix, +) # Initializations for c in 1:2 dchmm.cod_init[c] = sum(init_count[get_state(c, n)] for n in 1:4) diff --git a/test/logarithmic.jl b/test/logarithmic.jl deleted file mode 100644 index 6fa297b5..00000000 --- a/test/logarithmic.jl +++ /dev/null @@ -1,32 +0,0 @@ -using Distributions -using HiddenMarkovModels -using HiddenMarkovModels: LightDiagNormal -using LinearAlgebra -using LogarithmicNumbers -using SimpleUnPack -using Test - -N = 3 -D = 2 -T = 1000 - -p = ones(N) / N -A = rand_trans_mat(N) -dists = [LightDiagNormal(randn(D), ones(D)) for i in 1:N]; -dists_init = [LightDiagNormal(randn(D), ones(D)) for i in 1:N]; -dists_init_log = [LightDiagNormal(randn(D), LogFloat64.(ones(D))) for i in 1:N]; - -hmm = HMM(p, A, dists); -obs_seq = rand(hmm, T).obs_seq; - -hmm_init = HMM(LogFloat64.(p), A, dists_init); -hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq); -@test typeof(hmm_est) == typeof(hmm_init) - -hmm_init = HMM(p, LogFloat64.(A), dists_init); -hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq); -@test typeof(hmm_est) == typeof(hmm_init) - -hmm_init = HMM(p, A, dists_init_log); -hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq); -@test typeof(hmm_est) == typeof(hmm_init) diff --git a/test/misc.jl b/test/misc.jl new file mode 100644 index 00000000..e035dc46 --- /dev/null +++ b/test/misc.jl @@ -0,0 +1,29 @@ +using HiddenMarkovModels +using HiddenMarkovModels.HMMTest +using HiddenMarkovModels: PermutedHMM +using Distributions +using Test + +## Permuted + +perm = [3, 1, 2] +hmm = rand_gaussian_hmm_1d(3) +hmm_perm = PermutedHMM(hmm, perm) + +p = initialization(hmm) +A = transition_matrix(hmm) +d = obs_distributions(hmm) + +p_perm = initialization(hmm_perm) +A_perm = transition_matrix(hmm_perm) +d_perm = obs_distributions(hmm_perm) + +@testset "PermutedHMM" begin + for i in 1:3 + @test p_perm[i] ≈ p[perm[i]] + @test d_perm[i] ≈ d[perm[i]] + end + for i in 1:3, j in 1:3 + @test A_perm[i, j] ≈ A[perm[i], perm[j]] + end +end diff --git a/test/numbers.jl b/test/numbers.jl new file mode 100644 index 00000000..35e906cc --- /dev/null +++ b/test/numbers.jl @@ -0,0 +1,34 @@ +using HiddenMarkovModels +using HiddenMarkovModels: LightDiagNormal +using LinearAlgebra +using LogarithmicNumbers +using SimpleUnPack +using Test + +N, D, T = 3, 2, 1000 + +## LogarithmicNumbers + +p = ones(N) / N; +A = rand_trans_mat(N); +d = [LightDiagNormal(randn(D), ones(D)) for i in 1:N]; +d_init = [LightDiagNormal(randn(D), ones(D)) for i in 1:N]; +d_init_log = [LightDiagNormal(randn(D), LogFloat64.(ones(D))) for i in 1:N]; + +hmm = HMM(p, A, d); +obs_seq = rand(hmm, T).obs_seq; + +hmm_init1 = HMM(LogFloat64.(p), A, d_init); +hmm_est1, logL_evolution1 = @inferred baum_welch(hmm_init1, obs_seq); + +hmm_init2 = HMM(p, LogFloat64.(A), d_init); +hmm_est2, logL_evolution2 = @inferred baum_welch(hmm_init2, obs_seq); + +hmm_init3 = HMM(p, A, d_init_log); +hmm_est3, logL_evolution3 = @inferred baum_welch(hmm_init3, obs_seq); + +@testset "Logarithmic" begin + @test typeof(hmm_est1) == typeof(hmm_init1) + @test typeof(hmm_est2) == typeof(hmm_init2) + @test typeof(hmm_est3) == typeof(hmm_init3) +end diff --git a/test/permuted.jl b/test/permuted.jl deleted file mode 100644 index 5ea38685..00000000 --- a/test/permuted.jl +++ /dev/null @@ -1,25 +0,0 @@ -using HiddenMarkovModels -using HiddenMarkovModels: PermutedHMM -using Distributions -using Test - -p = rand_prob_vec(3) -A = rand_trans_mat(3) -dists = [Normal(i) for i in 1:3] - -hmm = HMM(p, A, dists) -perm = [3, 1, 2] - -hmm_perm = PermutedHMM(hmm, perm) -p_perm = initialization(hmm_perm) -A_perm = transition_matrix(hmm_perm) -dists_perm = obs_distributions(hmm_perm) - -for i in 1:3 - @test p_perm[i] ≈ p[perm[i]] - @test dists_perm[i] ≈ dists[perm[i]] -end - -for i in 1:3, j in 1:3 - @test A_perm[i, j] ≈ A[perm[i], perm[j]] -end diff --git a/test/runtests.jl b/test/runtests.jl index 70a79d8a..a49bbe92 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,20 +28,20 @@ using Test end end - @testset "Correctness" begin - include("correctness.jl") + @testset "Doctests" begin + Documenter.doctest(HiddenMarkovModels) end - @testset "Sparse" begin - include("sparse.jl") + @testset "Correctness" begin + include("correctness.jl") end - @testset "Static" begin - include("static.jl") + @testset "Array types" begin + include("arrays.jl") end - @testset "Logarithmic" begin - include("logarithmic.jl") + @testset "Number types" begin + include("numbers.jl") end @testset "Autodiff" begin @@ -52,11 +52,7 @@ using Test include("dna.jl") end - @testset "Permuted" begin - include("permuted.jl") - end - - @testset "Doctests" begin - Documenter.doctest(HiddenMarkovModels) + @testset "Misc" begin + include("misc.jl") end end diff --git a/test/sparse.jl b/test/sparse.jl deleted file mode 100644 index dec543fa..00000000 --- a/test/sparse.jl +++ /dev/null @@ -1,25 +0,0 @@ -using Distributions -using HiddenMarkovModels -using HiddenMarkovModels: sum_to_one! -using LinearAlgebra -using SparseArrays -using SimpleUnPack -using Test - -N = 3 -T = 2000 - -p = ones(N) / N -A = SparseMatrixCSC(SymTridiagonal(ones(N), ones(N - 1))) -foreach(sum_to_one!, eachrow(A)) -dists = [Normal(i + randn(), 1) for i in 1:N] -dists_init = [Normal(i + randn(), 1) for i in 1:N] - -hmm = HMM(p, A, dists) -hmm_init = HMM(p, A, dists_init) - -obs_seq = rand(hmm, T).obs_seq -hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq) - -@test typeof(hmm_est) == typeof(hmm_init) -@test nnz(transition_matrix(hmm_est)) <= nnz(transition_matrix(hmm)) diff --git a/test/static.jl b/test/static.jl deleted file mode 100644 index 23976aa2..00000000 --- a/test/static.jl +++ /dev/null @@ -1,22 +0,0 @@ -using Distributions -using HiddenMarkovModels -using LinearAlgebra -using StaticArrays -using SimpleUnPack -using Test - -N = 5 -T = 1000 - -p = MVector{N}(ones(N) / N) -A = MMatrix{N,N}(rand_trans_mat(N)) -dists = MVector{N}([Normal(randn(), 1.0) for i in 1:N]) -dists_init = MVector{N}([Normal(randn(), 1.0) for i in 1:N]) - -hmm = HMM(p, A, dists) -hmm_init = HMM(p, A, dists_init) - -obs_seq = rand(hmm, T).obs_seq -hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq) - -@test typeof(hmm_est) == typeof(hmm_init) diff --git a/test/type_stability.jl b/test/type_stability.jl index 4aecc646..642ae0ef 100644 --- a/test/type_stability.jl +++ b/test/type_stability.jl @@ -1,101 +1,68 @@ -using Distributions -using Distributions: PDiagMat using HiddenMarkovModels -using HiddenMarkovModels: LightDiagNormal +using HiddenMarkovModels.HMMTest +import HiddenMarkovModels as HMMs using JET using SimpleUnPack using Test function test_type_stability(hmm, hmm_init; T) obs_seq = rand(hmm, T).obs_seq - nb_seqs = 2 - obs_seqs = [obs_seq for _ in 1:nb_seqs] @testset "Logdensity" begin - @inferred logdensityof(hmm, obs_seqs, nb_seqs) - @test_opt target_modules = (HiddenMarkovModels,) logdensityof( - hmm, obs_seqs, nb_seqs - ) - @test_call target_modules = (HiddenMarkovModels,) logdensityof( - hmm, obs_seqs, nb_seqs - ) + @inferred logdensityof(hmm, obs_seq) + @test_opt target_modules = (HMMs,) logdensityof(hmm, obs_seq) + @test_call target_modules = (HMMs,) logdensityof(hmm, obs_seq) end @testset "Forward" begin - @inferred forward(hmm, obs_seqs, nb_seqs) - @test_opt target_modules = (HiddenMarkovModels,) forward(hmm, obs_seqs, nb_seqs) - @test_call target_modules = (HiddenMarkovModels,) forward(hmm, obs_seqs, nb_seqs) + @inferred forward(hmm, obs_seq) + @test_opt target_modules = (HMMs,) forward(hmm, obs_seq) + @test_call target_modules = (HMMs,) forward(hmm, obs_seq) end @testset "Viterbi" begin - @inferred viterbi(hmm, obs_seqs, nb_seqs) - @test_opt target_modules = (HiddenMarkovModels,) viterbi(hmm, obs_seqs, nb_seqs) - @test_call target_modules = (HiddenMarkovModels,) viterbi(hmm, obs_seqs, nb_seqs) + @inferred viterbi(hmm, obs_seq) + @test_opt target_modules = (HMMs,) viterbi(hmm, obs_seq) + @test_call target_modules = (HMMs,) viterbi(hmm, obs_seq) end @testset "Forward-backward" begin - @inferred forward_backward(hmm, obs_seqs, nb_seqs) - @test_opt target_modules = (HiddenMarkovModels,) forward_backward( - hmm, obs_seqs, nb_seqs - ) - @test_call target_modules = (HiddenMarkovModels,) forward_backward( - hmm, obs_seqs, nb_seqs - ) + @inferred forward_backward(hmm, obs_seq) + @test_opt target_modules = (HMMs,) forward_backward(hmm, obs_seq) + @test_call target_modules = (HMMs,) forward_backward(hmm, obs_seq) end @testset "Baum-Welch" begin - @inferred baum_welch(hmm_init, obs_seqs, nb_seqs) - @test_opt target_modules = (HiddenMarkovModels,) baum_welch( - hmm_init, obs_seqs, nb_seqs - ) - @test_call target_modules = (HiddenMarkovModels,) baum_welch( - hmm_init, obs_seqs, nb_seqs - ) + @inferred baum_welch(hmm_init, obs_seq) + @test_opt target_modules = (HMMs,) baum_welch(hmm_init, obs_seq; max_iterations=2) + @test_call target_modules = (HMMs,) baum_welch(hmm_init, obs_seq; max_iterations=2) end end -N = 2 -D = 3 -T = 100 +N, D, T = 3, 2, 100 -p = rand_prob_vec(N) -p_init = rand_prob_vec(N) - -A = rand_trans_mat(N) -A_init = rand_trans_mat(N) - -# Normal - -dists_norm = [Normal(randn(), 1.0) for i in 1:N] -dists_norm_init = [Normal(randn(), 1) for i in 1:N] - -hmm_norm = HMM(p, A, dists_norm) -hmm_norm_init = HMM(p_init, A_init, dists_norm_init) +@testset "Categorical" begin + test_type_stability(rand_categorical_hmm(N, D), rand_categorical_hmm(N, D); T) +end @testset "Normal" begin - test_type_stability(hmm_norm, hmm_norm_init; T) + test_type_stability(rand_gaussian_hmm_1d(N), rand_gaussian_hmm_1d(N); T) end -# DiagNormal - -dists_diagnorm = [DiagNormal(randn(D), PDiagMat(ones(D) .^ 2)) for i in 1:N] -dists_diagnorm_init = [DiagNormal(randn(D), PDiagMat(ones(D) .^ 2)) for i in 1:N] - -hmm_diagnorm = HMM(p, A, dists_diagnorm) -hmm_diagnorm_init = HMM(p, A, dists_diagnorm_init) +@testset "Normal sparse" begin + test_type_stability( + rand_gaussian_hmm_1d(N; sparse_trans=true), + rand_gaussian_hmm_1d(N; sparse_trans=true); + T, + ) +end @testset "DiagNormal" begin - test_type_stability(hmm_diagnorm, hmm_diagnorm_init; T) + test_type_stability(rand_gaussian_hmm_2d(N, D), rand_gaussian_hmm_2d(N, D); T) end -## LightDiagNormal - -dists_lightdiagnorm = [LightDiagNormal(randn(D), ones(D)) for i in 1:N] -dists_lightdiagnorm_init = [LightDiagNormal(randn(D), ones(D)) for i in 1:N] - -hmm_lightdiagnorm = HMM(p, A, dists_lightdiagnorm) -hmm_lightdiagnorm_init = HMM(p, A, dists_lightdiagnorm_init) - @testset "LightDiagNormal" begin - test_type_stability(hmm_lightdiagnorm, hmm_lightdiagnorm_init; T) + test_type_stability( + rand_gaussian_hmm_2d_light(N, D), rand_gaussian_hmm_2d_light(N, D); T + ) end From d12ea1f1b50915ec40be0512f3f9606a35125612 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 9 Nov 2023 16:28:24 +0100 Subject: [PATCH 10/14] Reactivate precompile workload and fix viterbi --- src/HiddenMarkovModels.jl | 40 ++++++++++++++------------------------- test/correctness.jl | 2 +- 2 files changed, 15 insertions(+), 27 deletions(-) diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index b39e850e..1fd68ac7 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -67,31 +67,19 @@ if !isdefined(Base, :get_extension) end end -# @compile_workload begin -# N, D, T = 5, 3, 100 -# p = rand_prob_vec(N) -# A = rand_trans_mat(N) -# dists = [LightDiagNormal(randn(D), ones(D)) for i in 1:N] -# hmm = HMM(p, A, dists) - -# obs_seq = rand(hmm, T).obs_seq -# obs_seqs = [rand(hmm, T).obs_seq for _ in 1:3] -# nb_seqs = 3 - -# logdensityof(hmm, obs_seq) -# logdensityof(hmm, obs_seqs, nb_seqs) - -# forward(hmm, obs_seq) -# forward(hmm, obs_seqs, nb_seqs) - -# viterbi(hmm, obs_seq) -# viterbi(hmm, obs_seqs, nb_seqs) - -# forward_backward(hmm, obs_seq) -# forward_backward(hmm, obs_seqs, nb_seqs) - -# baum_welch(hmm, obs_seq; max_iterations=2, atol=-Inf) -# baum_welch(hmm, obs_seqs, nb_seqs; max_iterations=2, atol=-Inf) -# end +@compile_workload begin + N, D, T = 3, 2, 100 + p = rand_prob_vec(N) + A = rand_trans_mat(N) + dists = [LightDiagNormal(randn(D), ones(D)) for i in 1:N] + hmm = HMM(p, A, dists) + obs_seq = rand(hmm, T).obs_seq + + logdensityof(hmm, obs_seq) + forward(hmm, obs_seq) + viterbi(hmm, obs_seq) + forward_backward(hmm, obs_seq) + baum_welch(hmm, obs_seq; max_iterations=2, atol=-Inf) +end end diff --git a/test/correctness.jl b/test/correctness.jl index 10aeb1e6..868ef938 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -76,7 +76,7 @@ end N, D, T = 3, 2, 100 @testset "Categorical" begin - test_correctness(rand_categorical_hmm(N, 2D), rand_categorical_hmm(N, 2D); T) + test_correctness(rand_categorical_hmm(N, 10), rand_categorical_hmm(N, 10); T) end @testset "Normal" begin From 0142f6d81e0a6dec822ba564ba0878329b0c6651 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 9 Nov 2023 16:35:15 +0100 Subject: [PATCH 11/14] Reactivate sparse allocation test --- test/allocations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/allocations.jl b/test/allocations.jl index 667e4306..ece482f6 100644 --- a/test/allocations.jl +++ b/test/allocations.jl @@ -45,7 +45,7 @@ end @testset "Normal sparse" begin # see https://discourse.julialang.org/t/why-does-mul-u-a-v-allocate-when-a-is-sparse-and-u-v-are-views/105995 - @test_skip test_allocations(rand_gaussian_hmm_1d(N; sparse_trans=true); T) + test_allocations(rand_gaussian_hmm_1d(N; sparse_trans=true); T) end @testset "LightDiagNormal" begin From 0681ec2610b379f81369106feee3b2125d4537f6 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 9 Nov 2023 16:47:42 +0100 Subject: [PATCH 12/14] Fix FB output --- test/arrays.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/arrays.jl b/test/arrays.jl index 6f63794a..316f019a 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -22,11 +22,10 @@ hmm_init = HMM(p, A, d_init); obs_seq = rand(hmm, T).obs_seq; -γ, ξ, logL = forward_backward(hmm, obs_seq); +γ, logL = forward_backward(hmm, obs_seq); hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq); @testset "Sparse" begin - @test eltype(ξ) <: AbstractSparseArray @test typeof(hmm_est) == typeof(hmm_init) @test nnz(transition_matrix(hmm_est)) <= nnz(transition_matrix(hmm)) end @@ -42,10 +41,9 @@ hmm = HMM(p, A, d); hmm_init = HMM(p, A, d_init); obs_seq = rand(hmm, T).obs_seq; -γ, ξ, logL = forward_backward(hmm, obs_seq); +γ, logL = forward_backward(hmm, obs_seq); hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq); @testset "Static" begin - @test eltype(ξ) <: StaticArray @test typeof(hmm_est) == typeof(hmm_init) end From da964f51c1101ba8ece7f73e5a6a0fea35a1d408 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 10 Nov 2023 17:46:21 +0100 Subject: [PATCH 13/14] Skip allocations test sparse --- test/allocations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/allocations.jl b/test/allocations.jl index ece482f6..667e4306 100644 --- a/test/allocations.jl +++ b/test/allocations.jl @@ -45,7 +45,7 @@ end @testset "Normal sparse" begin # see https://discourse.julialang.org/t/why-does-mul-u-a-v-allocate-when-a-is-sparse-and-u-v-are-views/105995 - test_allocations(rand_gaussian_hmm_1d(N; sparse_trans=true); T) + @test_skip test_allocations(rand_gaussian_hmm_1d(N; sparse_trans=true); T) end @testset "LightDiagNormal" begin From c68f8e9dcc325d8c3116fac231855427534d2cec Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 10 Nov 2023 18:02:22 +0100 Subject: [PATCH 14/14] Fix tests for 1.6 --- test/correctness.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/correctness.jl b/test/correctness.jl index 868ef938..29b1b55a 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -66,8 +66,10 @@ function test_correctness(hmm, hmm_init; T) @test isapprox(transition_matrix(hmm_est), hmm_est_base.A) for (dist, dist_base) in zip(hmm.dists, hmm_base.B) - for n in fieldnames(typeof(dist)) - @test isapprox(getfield(dist, n), getfield(dist_base, n)) + if hasfield(typeof(dist), :μ) + @test isapprox(dist.μ, dist_base.μ) + elseif hasfield(typeof(dist), :p) + @test isapprox(dist.p, dist_base.p) end end end