Skip to content

Commit

Permalink
Stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Nov 11, 2023
1 parent 967cf06 commit 27d6e13
Show file tree
Hide file tree
Showing 15 changed files with 187 additions and 220 deletions.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ DocStringExtensions = "0.9"
LinearAlgebra = "1.6"
PrecompileTools = "1.1"
Random = "1.6"
RequiredInterfaces = "0.1.3"
Requires = "1.3"
SimpleUnPack = "1.1"
StatsAPI = "1.6"
Expand Down
1 change: 0 additions & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
RequiredInterfaces = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"

Expand Down
27 changes: 24 additions & 3 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,38 @@ fit!
```@docs
rand_prob_vec
rand_trans_mat
HiddenMarkovModels.fit_element_from_sequence!
HiddenMarkovModels.LightDiagNormal
```

## Internals
## In-place algorithms (internals)

### Storage types

```@docs
HiddenMarkovModels.ForwardStorage
HiddenMarkovModels.ViterbiStorage
HiddenMarkovModels.ForwardBackwardStorage
HiddenMarkovModels.BaumWelchStorage
HiddenMarkovModels.fit_element_from_sequence!
HiddenMarkovModels.LightDiagNormal
```

### Initializing storage

```@docs
HiddenMarkovModels.initialize_forward
HiddenMarkovModels.initialize_viterbi
HiddenMarkovModels.initialize_forward_backward
HiddenMarkovModels.initialize_baum_welch
HiddenMarkovModels.initialize_logL_evolution
```

### Modifying storage

```@docs
HiddenMarkovModels.forward!
HiddenMarkovModels.viterbi!
HiddenMarkovModels.forward_backward!
HiddenMarkovModels.baum_welch!
```

## Notations
Expand Down
16 changes: 8 additions & 8 deletions ext/HiddenMarkovModelsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module HiddenMarkovModelsChainRulesCoreExt

using ChainRulesCore:
ChainRulesCore, NoTangent, ZeroTangent, RuleConfig, rrule_via_ad, @not_implemented
using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, rrule_via_ad
using DensityInterface: logdensityof
using HiddenMarkovModels
import HiddenMarkovModels as HMMs
Expand All @@ -13,20 +12,21 @@ function obs_logdensities_matrix(hmm::AbstractHMM, obs_seq::Vector)
return logB

Check warning on line 12 in ext/HiddenMarkovModelsChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/HiddenMarkovModelsChainRulesCoreExt.jl#L12

Added line #L12 was not covered by tests
end

function _params_and_loglikelihoods(hmm::AbstractHMM, obs_seq)
function _params_and_loglikelihoods(hmm::AbstractHMM, obs_seq::Vector)
p = initialization(hmm)
A = transition_matrix(hmm)
logB = obs_logdensities_matrix(hmm, obs_seq)
return p, A, logB
end

function ChainRulesCore.rrule(
rc::RuleConfig, ::typeof(logdensityof), hmm::AbstractHMM, obs_seq
rc::RuleConfig, ::typeof(logdensityof), hmm::AbstractHMM, obs_seq::Vector
)
@info "Chain rule used"
(p, A, logB), pullback = rrule_via_ad(rc, _params_and_loglikelihoods, hmm, obs_seq)
fb = HMMs.initialize_forward_backward(hmm, obs_seq)
HMMs.forward_backward!(fb, hmm, obs_seq)
@unpack α, β, γ, c, Bβ = fb
storage = HMMs.initialize_forward_backward(hmm, obs_seq)
HMMs.forward_backward!(storage, hmm, obs_seq)
@unpack logL, α, β, γ, c, Bβ = storage
T = length(obs_seq)

function logdensityof_hmm_pullback(ΔlogL)
Expand All @@ -42,7 +42,7 @@ function ChainRulesCore.rrule(
return Δlogdensityof, Δhmm, Δobs_seq
end

return fb.logL[], logdensityof_hmm_pullback
return logL[], logdensityof_hmm_pullback
end

end
7 changes: 1 addition & 6 deletions src/HiddenMarkovModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,12 @@
HiddenMarkovModels
A Julia package for HMM modeling, simulation, inference and learning.
# Exports
$(EXPORTS)
"""
module HiddenMarkovModels

using Base: RefValue
using Base.Threads: @threads
using DensityInterface:
DensityInterface, DensityKind, HasDensity, NoDensity, densityof, logdensityof
using DensityInterface: DensityInterface, DensityKind, HasDensity, NoDensity, logdensityof
using Distributions:
Distributions,
Categorical,
Expand Down
57 changes: 40 additions & 17 deletions src/inference/baum_welch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ struct BaumWelchStorage{R,M<:AbstractMatrix{R}}
limits::Vector{Int}
end

function check_nb_seqs(obs_seqs::Vector{<:Vector}, nb_seqs::Integer)
if nb_seqs != length(obs_seqs)
throw(ArgumentError("Incoherent sizes provided: `nb_seqs != length(obs_seqs)`"))

Check warning on line 27 in src/inference/baum_welch.jl

View check run for this annotation

Codecov / codecov/patch

src/inference/baum_welch.jl#L27

Added line #L27 was not covered by tests
end
end

"""
initialize_baum_welch(hmm, obs_seqs, nb_seqs)
"""
function initialize_baum_welch(
hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer
)
Expand All @@ -36,6 +45,9 @@ function initialize_baum_welch(
return BaumWelchStorage(init_count, trans_count, state_marginals_concat, limits)
end

"""
initialize_logL_evolution(hmm, obs_seqs, nb_seqs; max_iterations)
"""
function initialize_logL_evolution(
hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer; max_iterations::Integer
)
Expand All @@ -47,16 +59,19 @@ function initialize_logL_evolution(
end

function update_sufficient_statistics!(
bw::BaumWelchStorage{R}, fbs::Vector{<:ForwardBackwardStorage}
bw_storage::BaumWelchStorage{R}, fb_storages::Vector{<:ForwardBackwardStorage}
) where {R}
@unpack init_count, trans_count, state_marginals_concat, limits = bw
@unpack init_count, trans_count, state_marginals_concat, limits = bw_storage
init_count .= zero(R)
trans_count .= zero(R)
state_marginals_concat .= zero(R)
for k in eachindex(fbs) # TODO: ThreadsX?
init_count .+= fbs[k].init_count
mynonzeros(trans_count) .+= mynonzeros(fbs[k].trans_count)
state_marginals_concat[:, (limits[k] + 1):limits[k + 1]] .= fbs[k].γ
for k in eachindex(fb_storages) # TODO: ThreadsX?
@unpack γ, ξ = fb_storages[k]
init_count .+= @view γ[:, 1]
for t in eachindex(ξ)
mynonzeros(trans_count) .+= mynonzeros(ξ[t])
end
state_marginals_concat[:, (limits[k] + 1):limits[k + 1]] .= γ
end
return nothing
end
Expand All @@ -76,15 +91,23 @@ function baum_welch_has_converged(
return false
end

function StatsAPI.fit!(hmm::AbstractHMM, bw::BaumWelchStorage, obs_seqs_concat::Vector)
return fit!(
hmm, bw.init_count, bw.trans_count, obs_seqs_concat, bw.state_marginals_concat
)
function StatsAPI.fit!(
hmm::AbstractHMM, bw_storage::BaumWelchStorage, obs_seqs_concat::Vector
)
@unpack init_count, trans_count, state_marginals_concat = bw_storage
return fit!(hmm, init_count, trans_count, obs_seqs_concat, state_marginals_concat)
end

"""
baum_welch!(
fb_storages, bw_storage, logL_evolution,
hmm, obs_seqs, obs_seqs_concat;
atol, max_iterations, loglikelihood_increasing
)
"""
function baum_welch!(
fbs::Vector{<:ForwardBackwardStorage},
bw::BaumWelchStorage,
fb_storages::Vector{<:ForwardBackwardStorage},
bw_storage::BaumWelchStorage,
logL_evolution::Vector,
hmm::AbstractHMM,
obs_seqs::Vector{<:Vector},
Expand All @@ -94,12 +117,12 @@ function baum_welch!(
loglikelihood_increasing::Bool,
)
for _ in 1:max_iterations
@threads for k in eachindex(obs_seqs, fbs)
forward_backward!(fbs[k], hmm, obs_seqs[k])
for k in eachindex(obs_seqs, fb_storages)
forward_backward!(fb_storages[k], hmm, obs_seqs[k])
end
update_sufficient_statistics!(bw, fbs)
push!(logL_evolution, sum(fb.logL[] for fb in fbs))
fit!(hmm, bw, obs_seqs_concat)
update_sufficient_statistics!(bw_storage, fb_storages)
push!(logL_evolution, sum(fb.logL[] for fb in fb_storages))
fit!(hmm, bw_storage, obs_seqs_concat)
check_hmm(hmm)
if baum_welch_has_converged(logL_evolution; atol, loglikelihood_increasing)
break
Expand Down
62 changes: 20 additions & 42 deletions src/inference/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ This storage is relative to a single sequence.
# Fields
The only fields useful outside of the algorithm are `αₜ` and `logL`.
The only fields useful outside of the algorithm are `α` and `logL`, the rest does not belong to the public API.
$(TYPEDFIELDS)
"""
Expand All @@ -22,6 +22,9 @@ struct ForwardStorage{R}
α_next::Vector{R}
end

"""
initialize_forward(hmm, obs_seq)
"""
function initialize_forward(hmm::AbstractHMM, obs_seq::Vector)
N = length(hmm)
R = eltype(hmm, obs_seq[1])
Expand All @@ -30,15 +33,18 @@ function initialize_forward(hmm::AbstractHMM, obs_seq::Vector)
logb = Vector{R}(undef, N)
α = Vector{R}(undef, N)
α_next = Vector{R}(undef, N)
f = ForwardStorage(logL, logb, α, α_next)
return f
storage = ForwardStorage(logL, logb, α, α_next)
return storage
end

function forward!(f::ForwardStorage, hmm::AbstractHMM, obs_seq::Vector)
"""
forward!(storage, hmm, obs_seq)
"""
function forward!(storage::ForwardStorage, hmm::AbstractHMM, obs_seq::Vector)
T = length(obs_seq)
p = initialization(hmm)
A = transition_matrix(hmm)
@unpack logL, logb, α, α_next = f
@unpack logL, logb, α, α_next = storage

obs_logdensities!(logb, hmm, obs_seq[1])
check_right_finite(logb)
Expand All @@ -63,58 +69,30 @@ function forward!(f::ForwardStorage, hmm::AbstractHMM, obs_seq::Vector)
return nothing
end

function forward!(
fs::Vector{<:ForwardStorage},
hmm::AbstractHMM,
obs_seqs::Vector{<:Vector},
nb_seqs::Integer,
)
check_nb_seqs(obs_seqs, nb_seqs)
@threads for k in eachindex(fs, obs_seqs)
forward!(fs[k], hmm, obs_seqs[k])
end
return nothing
end

"""
forward(hmm, obs_seq)
forward(hmm, obs_seqs, nb_seqs)
Run the forward algorithm to infer the current state of an HMM.
Run the forward algorithm to infer the current state of `hmm` after sequence `obs_seq`.
When applied on a single sequence, this function returns a tuple `(α, logL)` where
This function returns a tuple `(α, logL)` where
- `α[i]` is the posterior probability of state `i` at the end of the sequence
- `logL` is the loglikelihood of the sequence
When applied on multiple sequences, this function returns a vector of tuples.
"""
function forward(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer)
check_nb_seqs(obs_seqs, nb_seqs)
fs = [initialize_forward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)]
forward!(fs, hmm, obs_seqs, nb_seqs)
return [(f.α, f.logL[]) for f in fs]
end

function forward(hmm::AbstractHMM, obs_seq::Vector)
return only(forward(hmm, [obs_seq], 1))
storage = initialize_forward(hmm, obs_seq)
forward!(storage, hmm, obs_seq)
return storage.α, storage.logL[]
end

"""
logdensityof(hmm, obs_seq)
logdensityof(hmm, obs_seqs, nb_seqs)
Run the forward algorithm to compute the posterior loglikelihood of observations for an HMM.
Run the forward algorithm to compute the posterior loglikelihood of sequence `obs_seq` for `hmm`.
Whether it is applied on one or multiple sequences, this function returns a number.
This function returns a number.
"""
function DensityInterface.logdensityof(
hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer
)
logαs_and_logLs = forward(hmm, obs_seqs, nb_seqs)
return sum(last, logαs_and_logLs)
end

function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seq::Vector)
return logdensityof(hmm, [obs_seq], 1)
_, logL = forward(hmm, obs_seq)
return logL
end
Loading

0 comments on commit 27d6e13

Please sign in to comment.