Skip to content

Commit

Permalink
Give more information to fit!
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Oct 31, 2023
1 parent 5abf0d6 commit 1dfdf90
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 88 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HiddenMarkovModels"
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
authors = ["Guillaume Dalle"]
version = "0.3.1"
version = "0.4.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
6 changes: 4 additions & 2 deletions docs/src/tuto_custom.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,13 @@ As for fitting, we simply ignore the initialization count and copy the rest of t

```@example tuto
function StatsAPI.fit!(
hmm::EquilibriumHMM{R,D}, init_count, trans_count, obs_seq, state_marginals
hmm::EquilibriumHMM{R,D}, obs_seqs, fbs
) where {R,D}
hmm.trans .= trans_count ./ sum(trans_count, dims=2)
obs_seqs_concat = reduce(vcat, obs_seqs)
state_marginals_concat = reduce(hcat, fb.γ for fb in fbs)
for i in 1:N
hmm.dists[i] = fit(D, obs_seq, state_marginals[i, :])
hmm.dists[i] = fit(D, obs_seqs_concat, state_marginals_concat[i, :])
end
end
```
Expand Down
4 changes: 2 additions & 2 deletions src/HiddenMarkovModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export LightDiagNormal
include("types/abstract_mc.jl")
include("types/mc.jl")
include("types/abstract_hmm.jl")
include("types/permuted_hmm.jl")
include("types/hmm.jl")

include("utils/check.jl")
Expand All @@ -53,7 +54,6 @@ include("inference/loglikelihoods.jl")
include("inference/forward.jl")
include("inference/viterbi.jl")
include("inference/forward_backward.jl")
include("inference/sufficient_stats.jl")
include("inference/baum_welch.jl")

if !isdefined(Base, :get_extension)
Expand All @@ -74,8 +74,8 @@ end
dists = [LightDiagNormal(randn(D), ones(D)) for i in 1:N]
hmm = HMM(p, A, dists)

obs_seqs = [last(rand(hmm, T)) for _ in 1:3]
nb_seqs = 3
obs_seqs = [last(rand(hmm, T)) for _ in 1:nb_seqs]
logdensityof(hmm, obs_seqs, nb_seqs)
forward(hmm, obs_seqs, nb_seqs)
viterbi(hmm, obs_seqs, nb_seqs)
Expand Down
16 changes: 7 additions & 9 deletions src/inference/baum_welch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@ function baum_welch!(
)
# Pre-allocate nearly all necessary memory
logB = loglikelihoods(hmm, obs_seqs[1])
fb = initialize_forward_backward(hmm, logB)
fb = forward_backward(hmm, logB)

logBs = Vector{typeof(logB)}(undef, length(obs_seqs))
fbs = Vector{typeof(fb)}(undef, length(obs_seqs))
logBs[1], fbs[1] = logB, fb
@threads for k in eachindex(obs_seqs)
logBs[k] = loglikelihoods(hmm, obs_seqs[k])
fbs[k] = forward_backward(hmm, logBs[k])
if k > 1
logBs[k] = loglikelihoods(hmm, obs_seqs[k])
fbs[k] = forward_backward(hmm, logBs[k])
end
end

init_count, trans_count = initialize_states_stats(fbs)
state_marginals_concat = initialize_observations_stats(fbs)
obs_seqs_concat = reduce(vcat, obs_seqs)
logL = loglikelihood(fbs)
logL_evolution = [logL]

Expand All @@ -30,9 +30,7 @@ function baum_welch!(
end

# M step
update_states_stats!(init_count, trans_count, fbs)
update_observations_stats!(state_marginals_concat, fbs)
fit!(hmm, init_count, trans_count, obs_seqs_concat, state_marginals_concat)
fit!(hmm, obs_seqs, fbs)

# Stopping criterion
if iteration > 1
Expand Down
40 changes: 0 additions & 40 deletions src/inference/sufficient_stats.jl

This file was deleted.

36 changes: 6 additions & 30 deletions src/types/abstract_hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Abstract supertype for an HMM amenable to simulation, inference and learning.
- `initial_distribution(hmm)`
- `transition_matrix(hmm)`
- `obs_distribution(hmm, i)`
- `fit!(hmm, init_count, trans_count, obs_seq, state_marginals)` (optional)
- `fit!(hmm, obs_seqs, fbs)` (optional)
# Applicable methods
Expand Down Expand Up @@ -48,6 +48,11 @@ The returned object `dist` must implement
"""
function obs_distribution end

"""
StatsAPI.fit!(hmm::AbstractHMM, obs_seqs, fbs)
"""
StatsAPI.fit! # TODO: docstring

function Base.rand(rng::AbstractRNG, hmm::AbstractHMM, T::Integer)
mc = MarkovChain(hmm)
state_seq = rand(rng, mc, T)
Expand All @@ -63,32 +68,3 @@ end
function MarkovChain(hmm::AbstractHMM)
return MarkovChain(initial_distribution(hmm), transition_matrix(hmm))
end

"""
PermutedHMM{H<:AbstractHMM}
Wrapper around an `AbstractHMM` that permutes its states.
This is computationally inefficient and mostly useful for evaluation.
# Fields
- `hmm:H`: the old HMM
- `perm::Vector{Int}`: a permutation such that state `i` in the new HMM corresponds to state `perm[i]` in the old.
"""
struct PermutedHMM{H<:AbstractHMM} <: AbstractHMM
hmm::H
perm::Vector{Int}
end

Base.length(p::PermutedHMM) = length(p.hmm)

HMMs.initial_distribution(p::PermutedHMM) = initial_distribution(p.hmm)[p.perm]

function HMMs.transition_matrix(p::PermutedHMM)
return transition_matrix(p.hmm)[p.perm, :][:, p.perm]
end

function HMMs.obs_distribution(p::PermutedHMM, i::Integer)
return obs_distribution(p.hmm, p.perm[i])
end
22 changes: 18 additions & 4 deletions src/types/hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,27 @@ obs_distribution(hmm::HMM, i::Integer) = hmm.dists[i]
Update `hmm` in-place based on information generated during forward-backward.
"""
function StatsAPI.fit!(hmm::HMM, init_count, trans_count, obs_seq, state_marginals)
hmm.init .= init_count
function StatsAPI.fit!(hmm::HMM, obs_seqs, fbs)
# Initial distribution
hmm.init .= zero(eltype(hmm.init))
for k in eachindex(fbs)
@views hmm.init .+= fbs[k].γ[:, 1]
end
sum_to_one!(hmm.init)
hmm.trans .= trans_count
# Transition matrix
hmm.trans .= zero(eltype(hmm.trans))
for k in eachindex(fbs)
sum!(hmm.trans, fbs[k].ξ; init=false)
end
foreach(sum_to_one!, eachrow(hmm.trans))
# Observation distributions
obs_seqs_concat = reduce(vcat, obs_seqs) # TODO: allocation-free
state_marginals_concat = reduce(hcat, fb.γ for fb in fbs) # TODO: allocation-free
@show size(obs_seqs_concat) size(state_marginals_concat)
@views for i in eachindex(hmm.dists)
fit_element_from_sequence!(hmm.dists, i, obs_seq, state_marginals[i, :])
fit_element_from_sequence!(
hmm.dists, i, obs_seqs_concat, state_marginals_concat[i, :]
)
end
check_hmm(hmm)
return nothing
Expand Down
28 changes: 28 additions & 0 deletions src/types/permuted_hmm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
PermutedHMM{H<:AbstractHMM}
Wrapper around an `AbstractHMM` that permutes its states.
This is computationally inefficient and mostly useful for evaluation.
# Fields
- `hmm:H`: the old HMM
- `perm::Vector{Int}`: a permutation such that state `i` in the new HMM corresponds to state `perm[i]` in the old.
"""
struct PermutedHMM{H<:AbstractHMM} <: AbstractHMM
hmm::H
perm::Vector{Int}
end

Base.length(p::PermutedHMM) = length(p.hmm)

HMMs.initial_distribution(p::PermutedHMM) = initial_distribution(p.hmm)[p.perm]

function HMMs.transition_matrix(p::PermutedHMM)
return transition_matrix(p.hmm)[p.perm, :][:, p.perm]
end

function HMMs.obs_distribution(p::PermutedHMM, i::Integer)
return obs_distribution(p.hmm, p.perm[i])
end
4 changes: 4 additions & 0 deletions src/utils/lightdiagnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ struct LightDiagNormal{
logσ::V3
end

function Base.show(io::IO, dist::LightDiagNormal)
return print(io, "LightDiagNormal(μ=$(dist.μ), σ=$(dist.σ))")
end

function LightDiagNormal(μ, σ)
check_no_nan(μ)
check_positive(σ)
Expand Down

0 comments on commit 1dfdf90

Please sign in to comment.