From 1dfdf90723034697527b31f615347d1fabe66a5e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 31 Oct 2023 20:19:48 +0100 Subject: [PATCH] Give more information to fit! --- Project.toml | 2 +- docs/src/tuto_custom.md | 6 +++-- src/HiddenMarkovModels.jl | 4 ++-- src/inference/baum_welch.jl | 16 ++++++------- src/inference/sufficient_stats.jl | 40 ------------------------------- src/types/abstract_hmm.jl | 36 +++++----------------------- src/types/hmm.jl | 22 +++++++++++++---- src/types/permuted_hmm.jl | 28 ++++++++++++++++++++++ src/utils/lightdiagnormal.jl | 4 ++++ 9 files changed, 70 insertions(+), 88 deletions(-) delete mode 100644 src/inference/sufficient_stats.jl create mode 100644 src/types/permuted_hmm.jl diff --git a/Project.toml b/Project.toml index 29a61cce..05ce7b5a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "HiddenMarkovModels" uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47" authors = ["Guillaume Dalle"] -version = "0.3.1" +version = "0.4.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/docs/src/tuto_custom.md b/docs/src/tuto_custom.md index 7a6edc95..d73a9b6f 100644 --- a/docs/src/tuto_custom.md +++ b/docs/src/tuto_custom.md @@ -70,11 +70,13 @@ As for fitting, we simply ignore the initialization count and copy the rest of t ```@example tuto function StatsAPI.fit!( - hmm::EquilibriumHMM{R,D}, init_count, trans_count, obs_seq, state_marginals + hmm::EquilibriumHMM{R,D}, obs_seqs, fbs ) where {R,D} hmm.trans .= trans_count ./ sum(trans_count, dims=2) + obs_seqs_concat = reduce(vcat, obs_seqs) + state_marginals_concat = reduce(hcat, fb.γ for fb in fbs) for i in 1:N - hmm.dists[i] = fit(D, obs_seq, state_marginals[i, :]) + hmm.dists[i] = fit(D, obs_seqs_concat, state_marginals_concat[i, :]) end end ``` diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index c40acf3e..bc1b75ef 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -41,6 +41,7 @@ export LightDiagNormal include("types/abstract_mc.jl") include("types/mc.jl") include("types/abstract_hmm.jl") +include("types/permuted_hmm.jl") include("types/hmm.jl") include("utils/check.jl") @@ -53,7 +54,6 @@ include("inference/loglikelihoods.jl") include("inference/forward.jl") include("inference/viterbi.jl") include("inference/forward_backward.jl") -include("inference/sufficient_stats.jl") include("inference/baum_welch.jl") if !isdefined(Base, :get_extension) @@ -74,8 +74,8 @@ end dists = [LightDiagNormal(randn(D), ones(D)) for i in 1:N] hmm = HMM(p, A, dists) - obs_seqs = [last(rand(hmm, T)) for _ in 1:3] nb_seqs = 3 + obs_seqs = [last(rand(hmm, T)) for _ in 1:nb_seqs] logdensityof(hmm, obs_seqs, nb_seqs) forward(hmm, obs_seqs, nb_seqs) viterbi(hmm, obs_seqs, nb_seqs) diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 87bc8ac3..83db2c8b 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -3,18 +3,18 @@ function baum_welch!( ) # Pre-allocate nearly all necessary memory logB = loglikelihoods(hmm, obs_seqs[1]) - fb = initialize_forward_backward(hmm, logB) + fb = forward_backward(hmm, logB) logBs = Vector{typeof(logB)}(undef, length(obs_seqs)) fbs = Vector{typeof(fb)}(undef, length(obs_seqs)) + logBs[1], fbs[1] = logB, fb @threads for k in eachindex(obs_seqs) - logBs[k] = loglikelihoods(hmm, obs_seqs[k]) - fbs[k] = forward_backward(hmm, logBs[k]) + if k > 1 + logBs[k] = loglikelihoods(hmm, obs_seqs[k]) + fbs[k] = forward_backward(hmm, logBs[k]) + end end - init_count, trans_count = initialize_states_stats(fbs) - state_marginals_concat = initialize_observations_stats(fbs) - obs_seqs_concat = reduce(vcat, obs_seqs) logL = loglikelihood(fbs) logL_evolution = [logL] @@ -30,9 +30,7 @@ function baum_welch!( end # M step - update_states_stats!(init_count, trans_count, fbs) - update_observations_stats!(state_marginals_concat, fbs) - fit!(hmm, init_count, trans_count, obs_seqs_concat, state_marginals_concat) + fit!(hmm, obs_seqs, fbs) # Stopping criterion if iteration > 1 diff --git a/src/inference/sufficient_stats.jl b/src/inference/sufficient_stats.jl deleted file mode 100644 index 9bf55c58..00000000 --- a/src/inference/sufficient_stats.jl +++ /dev/null @@ -1,40 +0,0 @@ - -function initialize_states_stats(fbs::Vector{ForwardBackwardStorage{R}}) where {R} - N = length(first(fbs)) - init_count = Vector{R}(undef, N) - trans_count = Matrix{R}(undef, N, N) - return init_count, trans_count -end - -function initialize_observations_stats(fbs::Vector{ForwardBackwardStorage{R}}) where {R} - N = length(first(fbs)) - T_total = sum(duration, fbs) - state_marginals_concat = Matrix{R}(undef, N, T_total) - return state_marginals_concat -end - -function update_states_stats!( - init_count, trans_count, fbs::Vector{ForwardBackwardStorage{R}} -) where {R} - init_count .= zero(R) - for k in eachindex(fbs) - @views init_count .+= fbs[k].γ[:, 1] - end - trans_count .= zero(R) - for k in eachindex(fbs) - sum!(trans_count, fbs[k].ξ; init=false) - end - return nothing -end - -function update_observations_stats!( - state_marginals_concat, fbs::Vector{ForwardBackwardStorage{R}} -) where {R} - T = 1 - for k in eachindex(fbs) - Tk = duration(fbs[k]) - @views state_marginals_concat[:, T:(T + Tk - 1)] .= fbs[k].γ - T += Tk - end - return nothing -end diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index 18c2f558..2dfec949 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -8,7 +8,7 @@ Abstract supertype for an HMM amenable to simulation, inference and learning. - `initial_distribution(hmm)` - `transition_matrix(hmm)` - `obs_distribution(hmm, i)` -- `fit!(hmm, init_count, trans_count, obs_seq, state_marginals)` (optional) +- `fit!(hmm, obs_seqs, fbs)` (optional) # Applicable methods @@ -48,6 +48,11 @@ The returned object `dist` must implement """ function obs_distribution end +""" + StatsAPI.fit!(hmm::AbstractHMM, obs_seqs, fbs) +""" +StatsAPI.fit! # TODO: docstring + function Base.rand(rng::AbstractRNG, hmm::AbstractHMM, T::Integer) mc = MarkovChain(hmm) state_seq = rand(rng, mc, T) @@ -63,32 +68,3 @@ end function MarkovChain(hmm::AbstractHMM) return MarkovChain(initial_distribution(hmm), transition_matrix(hmm)) end - -""" - PermutedHMM{H<:AbstractHMM} - -Wrapper around an `AbstractHMM` that permutes its states. - -This is computationally inefficient and mostly useful for evaluation. - -# Fields - -- `hmm:H`: the old HMM -- `perm::Vector{Int}`: a permutation such that state `i` in the new HMM corresponds to state `perm[i]` in the old. -""" -struct PermutedHMM{H<:AbstractHMM} <: AbstractHMM - hmm::H - perm::Vector{Int} -end - -Base.length(p::PermutedHMM) = length(p.hmm) - -HMMs.initial_distribution(p::PermutedHMM) = initial_distribution(p.hmm)[p.perm] - -function HMMs.transition_matrix(p::PermutedHMM) - return transition_matrix(p.hmm)[p.perm, :][:, p.perm] -end - -function HMMs.obs_distribution(p::PermutedHMM, i::Integer) - return obs_distribution(p.hmm, p.perm[i]) -end diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 17845bb9..c8cf11fe 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -45,13 +45,27 @@ obs_distribution(hmm::HMM, i::Integer) = hmm.dists[i] Update `hmm` in-place based on information generated during forward-backward. """ -function StatsAPI.fit!(hmm::HMM, init_count, trans_count, obs_seq, state_marginals) - hmm.init .= init_count +function StatsAPI.fit!(hmm::HMM, obs_seqs, fbs) + # Initial distribution + hmm.init .= zero(eltype(hmm.init)) + for k in eachindex(fbs) + @views hmm.init .+= fbs[k].γ[:, 1] + end sum_to_one!(hmm.init) - hmm.trans .= trans_count + # Transition matrix + hmm.trans .= zero(eltype(hmm.trans)) + for k in eachindex(fbs) + sum!(hmm.trans, fbs[k].ξ; init=false) + end foreach(sum_to_one!, eachrow(hmm.trans)) + # Observation distributions + obs_seqs_concat = reduce(vcat, obs_seqs) # TODO: allocation-free + state_marginals_concat = reduce(hcat, fb.γ for fb in fbs) # TODO: allocation-free + @show size(obs_seqs_concat) size(state_marginals_concat) @views for i in eachindex(hmm.dists) - fit_element_from_sequence!(hmm.dists, i, obs_seq, state_marginals[i, :]) + fit_element_from_sequence!( + hmm.dists, i, obs_seqs_concat, state_marginals_concat[i, :] + ) end check_hmm(hmm) return nothing diff --git a/src/types/permuted_hmm.jl b/src/types/permuted_hmm.jl new file mode 100644 index 00000000..580e5d42 --- /dev/null +++ b/src/types/permuted_hmm.jl @@ -0,0 +1,28 @@ +""" + PermutedHMM{H<:AbstractHMM} + +Wrapper around an `AbstractHMM` that permutes its states. + +This is computationally inefficient and mostly useful for evaluation. + +# Fields + +- `hmm:H`: the old HMM +- `perm::Vector{Int}`: a permutation such that state `i` in the new HMM corresponds to state `perm[i]` in the old. +""" +struct PermutedHMM{H<:AbstractHMM} <: AbstractHMM + hmm::H + perm::Vector{Int} +end + +Base.length(p::PermutedHMM) = length(p.hmm) + +HMMs.initial_distribution(p::PermutedHMM) = initial_distribution(p.hmm)[p.perm] + +function HMMs.transition_matrix(p::PermutedHMM) + return transition_matrix(p.hmm)[p.perm, :][:, p.perm] +end + +function HMMs.obs_distribution(p::PermutedHMM, i::Integer) + return obs_distribution(p.hmm, p.perm[i]) +end diff --git a/src/utils/lightdiagnormal.jl b/src/utils/lightdiagnormal.jl index 850f0089..fe31d105 100644 --- a/src/utils/lightdiagnormal.jl +++ b/src/utils/lightdiagnormal.jl @@ -13,6 +13,10 @@ struct LightDiagNormal{ logσ::V3 end +function Base.show(io::IO, dist::LightDiagNormal) + return print(io, "LightDiagNormal(μ=$(dist.μ), σ=$(dist.σ))") +end + function LightDiagNormal(μ, σ) check_no_nan(μ) check_positive(σ)