From 9d50e54e6983dcd4964385fc87d0a72b02f5cac8 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Nov 2023 15:59:37 +0100 Subject: [PATCH] Better periodic example --- docs/Project.toml | 2 + examples/periodic.jl | 201 ++++++++++++++++++++++++++----------------- 2 files changed, 122 insertions(+), 81 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index b7ce75df..a3d01c80 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -5,6 +5,8 @@ DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" diff --git a/examples/periodic.jl b/examples/periodic.jl index 76b754a4..b7e78789 100644 --- a/examples/periodic.jl +++ b/examples/periodic.jl @@ -1,89 +1,51 @@ -""" -$(TYPEDEF) +# # Periodic HMM -Basic implementation of a time-heterogeneous HMM with periodic transition matrices and observation distributions. +using Distributions +using HiddenMarkovModels +import HiddenMarkovModels as HMMs +using Plots +using SimpleUnPack +using StatsAPI -The period is the first type parameter `L`. +# ## Structure -# Fields +""" + PeriodicHMM{L} -$(TYPEDFIELDS) +Basic implementation of a periodic HMM with time-dependent transition matrices and observation distributions, repeating every `L` time steps. """ struct PeriodicHMM{L,V<:AbstractVector,M<:AbstractMatrix,VD<:AbstractVector} <: AbstractHMM - "initial state probabilities" init::V - "one state transition matrix per time" trans_periodic::NTuple{L,M} - "one vector of observation distributions per time (must be amenable to `logdensityof` and `rand`)" dists_periodic::NTuple{L,VD} end -function Base.copy(phmm::PeriodicHMM) - return PeriodicHMM( - copy(phmm.init), copy(phmm.trans_periodic), copy(phmm.dists_periodic) - ) -end +period(::PeriodicHMM{L}) where {L} = L Base.length(phmm::PeriodicHMM) = length(phmm.init) -initialization(phmm::PeriodicHMM) = phmm.init +HMMs.initialization(phmm::PeriodicHMM) = phmm.init -function transition_matrix(phmm::PeriodicHMM{L}, t::Integer) where {L} - return phmm.trans_periodic[(t - 1) % L + 1] +function HMMs.transition_matrix(phmm::PeriodicHMM, t::Integer) + return phmm.trans_periodic[(t - 1) % period(hmm) + 1] end -function obs_distributions(phmm::PeriodicHMM{L}, t::Integer) where {L} - return phmm.dists_periodic[(t - 1) % L + 1] +function HMMs.obs_distributions(phmm::PeriodicHMM, t::Integer) + return phmm.dists_periodic[(t - 1) % period(hmm) + 1] end ## Fitting -struct BaumWelchStoragePeriodicHMM{L,O,R} <: AbstractBaumWelchStorage - obs_seqs_concat_periodic::NTuple{L,Vector{O}} - state_marginals_concat_periodic::NTuple{L,Matrix{R}} -end - -function initialize_baum_welch( - ::PeriodicHMM{L}, - fb_storages::Vector{<:ForwardBackwardStorage}, - obs_seqs::Vector{<:Vector}, -) where {L} - obs_seqs_concat_periodic = ntuple( - l -> reduce(vcat, obs_seqs[k][l:L:end] for k in eachindex(obs_seqs)), L - ) - state_marginals_concat_periodic = ntuple( - l -> reduce(hcat, fb_storages[k].γ[:, l:L:end] for k in eachindex(fb_storages)), L - ) - return BaumWelchStoragePeriodicHMM( - obs_seqs_concat_periodic, state_marginals_concat_periodic - ) +struct BaumWelchStoragePeriodicHMM <: HMMs.AbstractBaumWelchStorage end +function HMMs.initialize_baum_welch(::PeriodicHMM, fb_storages, obs_seqs) + return BaumWelchStoragePeriodicHMM() end -function update_baum_welch!( - bw_storage::BaumWelchStoragePeriodicHMM{L}, - fb_storages::Vector{<:ForwardBackwardStorage}, - obs_seqs::Vector{<:Vector}, -) where {L} - @unpack state_marginals_concat_periodic = bw_storage - for l in 1:L - tl = 1 - for k in eachindex(obs_seqs, fb_storages) - @unpack γ = fb_storages[k] - γl = @view γ[:, l:L:end] - Tl = size(γl, 2) - state_marginals_concat_periodic[l][:, tl:(tl + Tl - 1)] .= γl - tl += Tl - end - end - return nothing -end - -function fit_states!( - hmm::PeriodicHMM{L}, fb_storages::Vector{<:ForwardBackwardStorage} -) where {L} +function fit_states!(hmm::PeriodicHMM, fb_storages::Vector{<:HMMs.ForwardBackwardStorage}) + L = period(hmm) # Reset - hmm.init .= zero(eltype(hmm.init)) + hmm.init .= 0 for l in 1:L - hmm.trans_periodic[l] .= zero(eltype(hmm.trans_periodic[l])) + hmm.trans_periodic[l] .= 0 end # Accumulate sufficient stats for k in eachindex(fb_storages) @@ -91,43 +53,120 @@ function fit_states!( hmm.init .+= view(γ, :, 1) for t in eachindex(ξ) l = (t - 1) % L + 1 - mynonzeros(hmm.trans_periodic[l]) .+= mynonzeros(ξ[t]) + hmm.trans_periodic[l] .+= ξ[t] end end # Normalize - sum_to_one!(hmm.init) + hmm.init ./= sum(hmm.init) for l in 1:L - foreach(sum_to_one!, eachrow(hmm.trans_periodic[l])) + hmm.trans_periodic[l] ./= sum(hmm.trans_periodic[l]; dims=2) end return nothing end function fit_observations!( - hmm::PeriodicHMM{L}, bw_storage::BaumWelchStoragePeriodicHMM -) where {L} - @unpack obs_seqs_concat_periodic, state_marginals_concat_periodic = bw_storage - # Fit observation distributions + hmm::PeriodicHMM, + fb_storages::Vector{<:HMMs.ForwardBackwardStorage}, + obs_seqs::Vector{<:Vector}, +) for l in 1:L + obs_seq_periodic = reduce(vcat, obs_seqs[k][l:L:end] for k in eachindex(obs_seqs)) + state_marginals_periodic = reduce( + hcat, fb_storages[k].γ[:, l:L:end] for k in eachindex(fb_storages) + ) for i in 1:length(hmm) - fit_element_from_sequence!( - hmm.dists_periodic[l], - i, - obs_seqs_concat_periodic[l], - view(state_marginals_concat_periodic[l], i, :), - ) + D = typeof(hmm.dists_periodic[l][i]) + x = obs_seq_periodic + w = view(state_marginals_periodic, i, :) + hmm.dists_periodic[l][i] = fit(D, x, w) end end return nothing end function StatsAPI.fit!( - hmm::PeriodicHMM{L}, - bw_storage::BaumWelchStoragePeriodicHMM, - fb_storages::Vector{<:ForwardBackwardStorage}, + hmm::PeriodicHMM, + ::BaumWelchStoragePeriodicHMM, + fb_storages::Vector{<:HMMs.ForwardBackwardStorage}, obs_seqs::Vector{<:Vector}, -) where {L} - update_baum_welch!(bw_storage, fb_storages, obs_seqs) +) fit_states!(hmm, fb_storages) - fit_observations!(hmm, bw_storage) + fit_observations!(hmm, fb_storages, obs_seqs) return nothing end + +# ## Example + +N = 2 # Number of hidden states +L = 10 # Period of the HMM +T = 50_000 # Number of observation + +function make_trans(l, L) + A = Matrix{Float64}(undef, 2, 2) + A[1, 1] = 0.25 + 0.1 + 0.5cos(2π / L * l + 1)^2 + A[1, 2] = 0.25 - 0.1 + 0.5sin(2π / L * l + 1)^2 + A[2, 2] = 0.25 + 0.2 + 0.5cos(2π / L * (l - L / 3))^2 + A[2, 1] = 0.25 - 0.2 + 0.5sin(2π / L * (l - L / 3))^2 + return A +end + +function make_dists(l, L, N) + dists = [Normal(2i * cos(2π * l / L), i + cos(2π / L * (l - i / 2 + 1))^2) for i in 1:N] + return dists +end + +init = ones(N) / N; +trans_periodic = ntuple(l -> make_trans(l, L), L); +dists_periodic = ntuple(l -> make_dists(l, L, N), L); + +hmm = PeriodicHMM(init, trans_periodic, dists_periodic); + +state_seq, obs_seq = rand(hmm, T); + +hmm_est, logL_evolution = baum_welch(hmm, obs_seq; max_iterations=100); +length(logL_evolution) + +## Plotting + +p = [plot(; xlabel="l", title="transitions from state $i") for i in 1:N] +for i in 1:N, j in 1:N + plot!( + p[i], + 1:L, + [transition_matrix(hmm, l)[i, j] for l in 1:L]; + label="p$((i,j)) - true", + c=j, + ) + plot!( + p[i], + 1:L, + [transition_matrix(hmm_est, l)[i, j] for l in 1:L]; + label="p$((i,j)) - est", + c=j, + s=:dash, + ) +end +plot(p...; size=(1000, 500)) + +p = [plot(; xlabel="l", title="emissions from state $i") for i in 1:N] +for i in 1:N + plot!(p[i], 1:L, [obs_distributions(hmm, l)[i].μ for l in 1:L]; label="μ - true", c=1) + plot!( + p[i], + 1:L, + [obs_distributions(hmm_est, l)[i].μ for l in 1:L]; + label="μ - est", + c=1, + s=:dash, + ) + plot!(p[i], 1:L, [obs_distributions(hmm, l)[i].σ for l in 1:L]; label="σ - true", c=2) + plot!( + p[i], + 1:L, + [obs_distributions(hmm_est, l)[i].σ for l in 1:L]; + label="σ - est", + c=2, + s=:dash, + ) +end +plot(p...; size=(1000, 500))