Skip to content

Commit

Permalink
Periodic HMM
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Dec 1, 2023
1 parent 5861b5d commit 325417b
Show file tree
Hide file tree
Showing 13 changed files with 180 additions and 119 deletions.
1 change: 0 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ pages = [
"Basics" => joinpath("examples", "basics.md"),
"Distributions" => joinpath("examples", "distributions.md"),
"Controlled" => joinpath("examples", "controlled.md"),
"Periodic" => joinpath("examples", "periodic.md"),
],
"API reference" => "api.md",
"Advanced" => [
Expand Down
1 change: 1 addition & 0 deletions examples/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,4 @@ first(logL_evolution), last(logL_evolution)
#-

cat(hmm_est.trans, hmm.trans; dims=3)
@test hmm_est.trans hmm.trans atol = 1e-1 #src
62 changes: 62 additions & 0 deletions examples/distributions.jl
Original file line number Diff line number Diff line change
@@ -1 +1,63 @@
# # Distributions

using DensityInterface
using Distributions
using HiddenMarkovModels
using Random: Random, AbstractRNG
#md using Plots
using StatsAPI
using Test #src

#-

mutable struct PoissonProcess{R}
λ::R
end

DensityInterface.DensityKind(::PoissonProcess) = HasDensity()

function Random.rand(rng::AbstractRNG, pp::PoissonProcess)
nb_events = rand(rng, Poisson(pp.λ))
event_times = rand(rng, Uniform(0, 1), nb_events)
return event_times
end

function DensityInterface.logdensityof(pp::PoissonProcess, event_times::Vector)
return -pp.λ + length(event_times) * log(pp.λ)
end

function StatsAPI.fit!(pp::PoissonProcess, x, w)
pp.λ = sum(length(xᵢ) * wᵢ for (xᵢ, wᵢ) in zip(x, w)) / sum(w)
return nothing
end

#-

init = [0.3, 0.7]
trans = [0.8 0.2; 0.1 0.9]
dists = [PoissonProcess(1.0), PoissonProcess(5.0)]

hmm = HMM(init, trans, dists)

T = 100
state_seq, obs_seq = rand(hmm, T)

#-

#md scatter(reduce(vcat, t .+ obs_seq[t] for t in 1:T), state_seq .+ 0.02 * rand(T))

#-

forward_backward(hmm, obs_seq)

#-

init_guess = [0.5, 0.5]
trans_guess = [0.5 0.5; 0.5 0.5]
dists_guess = [PoissonProcess(2.0), PoissonProcess(3.0)]
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seq)

@test hmm_est.dists[1].λ hmm.dists[1].λ atol = 0.5 #src
@test hmm_est.dists[2].λ hmm.dists[2].λ atol = 0.5 #src
3 changes: 2 additions & 1 deletion src/HiddenMarkovModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ include("inference/logdensity.jl")
include("inference/chainrules.jl")

include("types/hmm.jl")
include("types/periodic_hmm.jl")

if !isdefined(Base, :get_extension)
function __init__()
Expand All @@ -56,6 +57,6 @@ if !isdefined(Base, :get_extension)
end
end

# include("precompile.jl")
include("precompile.jl")

end
2 changes: 1 addition & 1 deletion src/inference/logdensity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seq::Vector)
end

function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seqs::MultiSeq)
return sum(logdensity(hmm, obs_seqs[k]) for k in eachindex(obs_seqs))
return sum(logdensityof(hmm, obs_seqs[k]) for k in eachindex(obs_seqs))
end

"""
Expand Down
2 changes: 1 addition & 1 deletion src/precompile.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@compile_workload begin
N, D, T = 3, 2, 100
N, D, T = 3, 2, 10
init = rand_prob_vec(N)
trans = rand_trans_mat(N)
dists = [LightDiagNormal(randn(D), ones(D)) for i in 1:N]
Expand Down
4 changes: 2 additions & 2 deletions src/types/abstract_hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ StatsAPI.fit! # TODO: complete
Simulate `hmm` for `T` time steps.
"""
function Base.rand(rng::AbstractRNG, hmm::AbstractHMM, T::Integer)
function Random.rand(rng::AbstractRNG, hmm::AbstractHMM, T::Integer)
dummy_log_probas = fill(-Inf, length(hmm))

init = initialization(hmm)
Expand All @@ -125,4 +125,4 @@ function Base.rand(rng::AbstractRNG, hmm::AbstractHMM, T::Integer)
return (; state_seq=state_seq, obs_seq=obs_seq)
end

Base.rand(hmm::AbstractHMM, T::Integer) = rand(default_rng(), hmm, T)
Random.rand(hmm::AbstractHMM, T::Integer) = rand(default_rng(), hmm, T)
111 changes: 26 additions & 85 deletions src/types/periodic_hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,116 +8,57 @@ end

#-

period(hmm::HMM) = 1
period(hmm::PeriodicHMM) = length(hmm.trans_periodic)

Base.length(phmm::PeriodicHMM) = length(phmm.init)
HMMs.initialization(phmm::PeriodicHMM) = phmm.init
Base.length(hmm::PeriodicHMM) = length(hmm.init)
initialization(hmm::PeriodicHMM) = hmm.init

function HMMs.transition_matrix(phmm::PeriodicHMM, t::Integer)
return phmm.trans_periodic[(t - 1) % period(hmm) + 1]
function transition_matrix(hmm::PeriodicHMM, t::Integer)
return hmm.trans_periodic[(t - 1) % period(hmm) + 1]
end

function HMMs.obs_distributions(phmm::PeriodicHMM, t::Integer)
return phmm.dists_periodic[(t - 1) % period(hmm) + 1]
function obs_distributions(hmm::PeriodicHMM, t::Integer)
return hmm.dists_periodic[(t - 1) % period(hmm) + 1]
end

## Fitting

function fit_states!(hmm::PeriodicHMM, fb_storages::Vector{<:HMMs.ForwardBackwardStorage})
function StatsAPI.fit!(hmm::PeriodicHMM, bw_storage::BaumWelchStorage, obs_seqs::MultiSeq)
@unpack fb_storages, obs_seqs_concat, state_marginals_concat, seq_limits = bw_storage
L = period(hmm)
hmm.init .= 0
# States
hmm.init .= zero(eltype(hmm.init))
for l in 1:L
hmm.trans_periodic[l] .= 0
hmm.trans_periodic[l] .= zero(eltype(hmm.trans_periodic[l]))
end
for k in eachindex(fb_storages)
@unpack γ, ξ = fb_storages[k]
hmm.init .+= view(γ, :, 1)
for t in eachindex(ξ)
l = (t - 1) % L + 1
hmm.trans_periodic[l] .+= ξ[t]
mynonzeros(hmm.trans_periodic[l]) .+= mynonzeros(ξ[t])
end
end
hmm.init ./= sum(hmm.init)
sum_to_one!(hmm.init)
for l in 1:L
hmm.trans_periodic[l] ./= sum(hmm.trans_periodic[l]; dims=2)
foreach(sum_to_one!, eachrow(hmm.trans_periodic[l]))
end
return nothing
end

#-

function fit_observations!(
hmm::PeriodicHMM,
fb_storages::Vector{<:HMMs.ForwardBackwardStorage},
obs_seqs::Vector{<:Vector},
)
L = period(hmm)
# Observations
for l in 1:L
indices_l = reduce(
vcat, (seq_limits[k] + l):L:seq_limits[k + 1] for k in eachindex(obs_seqs)
) # TODO: only allocating line if I'm right
obs_seq_periodic = view(obs_seqs_concat, indices_l)
state_marginals_periodic = view(state_marginals_concat, :, indices_l)
for i in 1:length(hmm)
obs_seq_periodic = reduce(
vcat, obs_seqs[k][l:L:end] for k in eachindex(obs_seqs)
fit_element_from_sequence!(
hmm.dists_periodic[l],
i,
obs_seq_periodic,
view(state_marginals_periodic, i, :),
)
state_marginals_periodic = reduce(
vcat, fb_storages[k].γ[i, l:L:end] for k in eachindex(fb_storages)
)
D = typeof(hmm.dists_periodic[l][i])
hmm.dists_periodic[l][i] = fit(D, obs_seq_periodic, state_marginals_periodic)
end
end
return nothing
end

#-

function StatsAPI.fit!(
hmm::PeriodicHMM,
::BaumWelchStoragePeriodicHMM,
fb_storages::Vector{<:HMMs.ForwardBackwardStorage},
obs_seqs::Vector{<:Vector},
)
fit_states!(hmm, fb_storages)
fit_observations!(hmm, fb_storages, obs_seqs)
return nothing
end

# ## Example

N = 2
T = 1000

init = ones(N) / N;
trans_periodic = (
[0.9 0.1; 0.1 0.9], #
[0.8 0.2; 0.2 0.8], #
[0.7 0.3; 0.3 0.7],
);
dists_periodic = (
[Normal(0), Normal(4)], #
[Normal(2), Normal(6)], #
[Normal(4), Normal(8)],
);

hmm = PeriodicHMM(init, trans_periodic, dists_periodic);

#-

state_seq, obs_seq = rand(hmm, T);
hmm_est, logL_evolution = baum_welch(hmm, obs_seq);

#md plot(logL_evolution)

#-

cat(hmm_est.init, hmm.init; dims=3)

#-

cat(hmm_est.trans_periodic[1], hmm.trans_periodic[1]; dims=3)
cat(hmm_est.trans_periodic[2], hmm.trans_periodic[2]; dims=3)
cat(hmm_est.trans_periodic[3], hmm.trans_periodic[3]; dims=3)

#-

cat(hmm_est.dists_periodic[1], hmm.dists_periodic[1]; dims=3)
cat(hmm_est.dists_periodic[2], hmm.dists_periodic[2]; dims=3)
cat(hmm_est.dists_periodic[3], hmm.dists_periodic[3]; dims=3)
2 changes: 1 addition & 1 deletion src/utils/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Default behavior:
fit!(dists[i], x, w)
Specializatoin for Distributions.jl (in the package extension)
Specialization for Distributions.jl (in the package extension)
dists[i] = fit(eltype(dists), x, w)
Expand Down
31 changes: 31 additions & 0 deletions src/utils/probvec_transmat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,34 @@ rand_trans_mat(::Type{R}, N::Integer) where {R} = rand_trans_mat(default_rng(),

rand_prob_vec(N::Integer) = rand_prob_vec(default_rng(), N)
rand_trans_mat(N::Integer) = rand_trans_mat(default_rng(), N)

"""
project_prob_vec(v)
Compute the Euclidean projection of a vector `v` on the probability simplex.
Reference: <https://arxiv.org/abs/1602.02068>.
"""
function project_prob_vec(v::AbstractVector{R}) where {R}
d = length(v)
v_sorted = sort(v; rev=true)
v_sorted_cumsum = cumsum(v_sorted)
k = maximum(j for j in 1:d if (1 + j * v_sorted[j]) > v_sorted_cumsum[j])
τ = (v_sorted_cumsum[k] - 1) / k
p = v .- τ
p .= max.(p, zero(R))
return p
end

"""
project_trans_mat(M)
Compute the row-wise Euclidean projection of a matrix `M` on the space of transition matrices.
"""
function project_trans_mat(M::AbstractMatrix)
A = copy(M)
for v in eachrow(A)
v .= project_prob_vec(v)
end
return A
end
Loading

0 comments on commit 325417b

Please sign in to comment.