Skip to content

Commit

Permalink
Better periodic example
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Nov 30, 2023
1 parent fe634f0 commit 9d50e54
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 81 deletions.
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"

Expand Down
201 changes: 120 additions & 81 deletions examples/periodic.jl
Original file line number Diff line number Diff line change
@@ -1,133 +1,172 @@
"""
$(TYPEDEF)
# # Periodic HMM

Basic implementation of a time-heterogeneous HMM with periodic transition matrices and observation distributions.
using Distributions
using HiddenMarkovModels
import HiddenMarkovModels as HMMs
using Plots
using SimpleUnPack
using StatsAPI

The period is the first type parameter `L`.
# ## Structure

# Fields
"""
PeriodicHMM{L}
$(TYPEDFIELDS)
Basic implementation of a periodic HMM with time-dependent transition matrices and observation distributions, repeating every `L` time steps.
"""
struct PeriodicHMM{L,V<:AbstractVector,M<:AbstractMatrix,VD<:AbstractVector} <: AbstractHMM
"initial state probabilities"
init::V
"one state transition matrix per time"
trans_periodic::NTuple{L,M}
"one vector of observation distributions per time (must be amenable to `logdensityof` and `rand`)"
dists_periodic::NTuple{L,VD}
end

function Base.copy(phmm::PeriodicHMM)
return PeriodicHMM(
copy(phmm.init), copy(phmm.trans_periodic), copy(phmm.dists_periodic)
)
end
period(::PeriodicHMM{L}) where {L} = L

Base.length(phmm::PeriodicHMM) = length(phmm.init)
initialization(phmm::PeriodicHMM) = phmm.init
HMMs.initialization(phmm::PeriodicHMM) = phmm.init

function transition_matrix(phmm::PeriodicHMM{L}, t::Integer) where {L}
return phmm.trans_periodic[(t - 1) % L + 1]
function HMMs.transition_matrix(phmm::PeriodicHMM, t::Integer)
return phmm.trans_periodic[(t - 1) % period(hmm) + 1]
end

function obs_distributions(phmm::PeriodicHMM{L}, t::Integer) where {L}
return phmm.dists_periodic[(t - 1) % L + 1]
function HMMs.obs_distributions(phmm::PeriodicHMM, t::Integer)
return phmm.dists_periodic[(t - 1) % period(hmm) + 1]
end

## Fitting

struct BaumWelchStoragePeriodicHMM{L,O,R} <: AbstractBaumWelchStorage
obs_seqs_concat_periodic::NTuple{L,Vector{O}}
state_marginals_concat_periodic::NTuple{L,Matrix{R}}
end

function initialize_baum_welch(
::PeriodicHMM{L},
fb_storages::Vector{<:ForwardBackwardStorage},
obs_seqs::Vector{<:Vector},
) where {L}
obs_seqs_concat_periodic = ntuple(
l -> reduce(vcat, obs_seqs[k][l:L:end] for k in eachindex(obs_seqs)), L
)
state_marginals_concat_periodic = ntuple(
l -> reduce(hcat, fb_storages[k].γ[:, l:L:end] for k in eachindex(fb_storages)), L
)
return BaumWelchStoragePeriodicHMM(
obs_seqs_concat_periodic, state_marginals_concat_periodic
)
struct BaumWelchStoragePeriodicHMM <: HMMs.AbstractBaumWelchStorage end
function HMMs.initialize_baum_welch(::PeriodicHMM, fb_storages, obs_seqs)
return BaumWelchStoragePeriodicHMM()
end

function update_baum_welch!(
bw_storage::BaumWelchStoragePeriodicHMM{L},
fb_storages::Vector{<:ForwardBackwardStorage},
obs_seqs::Vector{<:Vector},
) where {L}
@unpack state_marginals_concat_periodic = bw_storage
for l in 1:L
tl = 1
for k in eachindex(obs_seqs, fb_storages)
@unpack γ = fb_storages[k]
γl = @view γ[:, l:L:end]
Tl = size(γl, 2)
state_marginals_concat_periodic[l][:, tl:(tl + Tl - 1)] .= γl
tl += Tl
end
end
return nothing
end

function fit_states!(
hmm::PeriodicHMM{L}, fb_storages::Vector{<:ForwardBackwardStorage}
) where {L}
function fit_states!(hmm::PeriodicHMM, fb_storages::Vector{<:HMMs.ForwardBackwardStorage})
L = period(hmm)
# Reset
hmm.init .= zero(eltype(hmm.init))
hmm.init .= 0
for l in 1:L
hmm.trans_periodic[l] .= zero(eltype(hmm.trans_periodic[l]))
hmm.trans_periodic[l] .= 0
end
# Accumulate sufficient stats
for k in eachindex(fb_storages)
@unpack γ, ξ = fb_storages[k]
hmm.init .+= view(γ, :, 1)
for t in eachindex(ξ)
l = (t - 1) % L + 1
mynonzeros(hmm.trans_periodic[l]) .+= mynonzeros(ξ[t])
hmm.trans_periodic[l] .+= ξ[t]
end
end
# Normalize
sum_to_one!(hmm.init)
hmm.init ./= sum(hmm.init)
for l in 1:L
foreach(sum_to_one!, eachrow(hmm.trans_periodic[l]))
hmm.trans_periodic[l] ./= sum(hmm.trans_periodic[l]; dims=2)
end
return nothing
end

function fit_observations!(
hmm::PeriodicHMM{L}, bw_storage::BaumWelchStoragePeriodicHMM
) where {L}
@unpack obs_seqs_concat_periodic, state_marginals_concat_periodic = bw_storage
# Fit observation distributions
hmm::PeriodicHMM,
fb_storages::Vector{<:HMMs.ForwardBackwardStorage},
obs_seqs::Vector{<:Vector},
)
for l in 1:L
obs_seq_periodic = reduce(vcat, obs_seqs[k][l:L:end] for k in eachindex(obs_seqs))
state_marginals_periodic = reduce(
hcat, fb_storages[k].γ[:, l:L:end] for k in eachindex(fb_storages)
)
for i in 1:length(hmm)
fit_element_from_sequence!(
hmm.dists_periodic[l],
i,
obs_seqs_concat_periodic[l],
view(state_marginals_concat_periodic[l], i, :),
)
D = typeof(hmm.dists_periodic[l][i])
x = obs_seq_periodic
w = view(state_marginals_periodic, i, :)
hmm.dists_periodic[l][i] = fit(D, x, w)
end
end
return nothing
end

function StatsAPI.fit!(
hmm::PeriodicHMM{L},
bw_storage::BaumWelchStoragePeriodicHMM,
fb_storages::Vector{<:ForwardBackwardStorage},
hmm::PeriodicHMM,
::BaumWelchStoragePeriodicHMM,
fb_storages::Vector{<:HMMs.ForwardBackwardStorage},
obs_seqs::Vector{<:Vector},
) where {L}
update_baum_welch!(bw_storage, fb_storages, obs_seqs)
)
fit_states!(hmm, fb_storages)
fit_observations!(hmm, bw_storage)
fit_observations!(hmm, fb_storages, obs_seqs)
return nothing
end

# ## Example

N = 2 # Number of hidden states
L = 10 # Period of the HMM
T = 50_000 # Number of observation

function make_trans(l, L)
A = Matrix{Float64}(undef, 2, 2)
A[1, 1] = 0.25 + 0.1 + 0.5cos(2π / L * l + 1)^2
A[1, 2] = 0.25 - 0.1 + 0.5sin(2π / L * l + 1)^2
A[2, 2] = 0.25 + 0.2 + 0.5cos(2π / L * (l - L / 3))^2
A[2, 1] = 0.25 - 0.2 + 0.5sin(2π / L * (l - L / 3))^2
return A
end

function make_dists(l, L, N)
dists = [Normal(2i * cos(2π * l / L), i + cos(2π / L * (l - i / 2 + 1))^2) for i in 1:N]
return dists
end

init = ones(N) / N;
trans_periodic = ntuple(l -> make_trans(l, L), L);
dists_periodic = ntuple(l -> make_dists(l, L, N), L);

hmm = PeriodicHMM(init, trans_periodic, dists_periodic);

state_seq, obs_seq = rand(hmm, T);

hmm_est, logL_evolution = baum_welch(hmm, obs_seq; max_iterations=100);
length(logL_evolution)

## Plotting

p = [plot(; xlabel="l", title="transitions from state $i") for i in 1:N]
for i in 1:N, j in 1:N
plot!(
p[i],
1:L,
[transition_matrix(hmm, l)[i, j] for l in 1:L];
label="p$((i,j)) - true",
c=j,
)
plot!(
p[i],
1:L,
[transition_matrix(hmm_est, l)[i, j] for l in 1:L];
label="p$((i,j)) - est",
c=j,
s=:dash,
)
end
plot(p...; size=(1000, 500))

p = [plot(; xlabel="l", title="emissions from state $i") for i in 1:N]
for i in 1:N
plot!(p[i], 1:L, [obs_distributions(hmm, l)[i].μ for l in 1:L]; label="μ - true", c=1)
plot!(
p[i],
1:L,
[obs_distributions(hmm_est, l)[i].μ for l in 1:L];
label="μ - est",
c=1,
s=:dash,
)
plot!(p[i], 1:L, [obs_distributions(hmm, l)[i].σ for l in 1:L]; label="σ - true", c=2)
plot!(
p[i],
1:L,
[obs_distributions(hmm_est, l)[i].σ for l in 1:L];
label="σ - est",
c=2,
s=:dash,
)
end
plot(p...; size=(1000, 500))

0 comments on commit 9d50e54

Please sign in to comment.