Skip to content

Commit

Permalink
Dirichlet prior
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Dec 8, 2023
1 parent 0979801 commit c1a5327
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 63 deletions.
52 changes: 26 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,7 @@
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
[![JET](https://img.shields.io/badge/%E2%9C%88%EF%B8%8F%20tested%20with%20-%20JET.jl%20-%20red)](https://github.com/aviatesk/JET.jl)

A Julia package for HMM modeling, simulation, inference and learning.

## Mathematical background

[Hidden Markov Models](https://en.wikipedia.org/wiki/Hidden_Markov_model) are a statistical modeling framework that is ubiquitous in signal processing, bioinformatics and plenty of other fields. They capture the distribution of an observation sequence $(Y_t)$ by assuming the existence of a latent state sequence $(X_t)$ such that:

* the state follows a (discrete time, discrete space) Markov chain $\mathbb{P}_\theta(X_t | X_{t-1})$
* the observation distribution is determined at each time by the state $\mathbb{P}_\theta(Y_t | X_t)$

HMMs are associated with several statistical problems, each of which has an efficient solution algorithm that our package implements:

| Problem | Goal | Algorithm |
| ---------- | --------------------------------------------------------------------------------------------------------- | ---------------- |
| Evaluation | Likelihood of the observation sequence $\mathbb{P}_\theta(Y_{1:T})$ | Forward |
| Filtering | Non-anticipative state marginals $\mathbb{P}_\theta(X_t \vert Y_{1:t})$ | Forward |
| Smoothing | State marginals $\mathbb{P}_\theta(X_t \vert Y_{1:T})$ | Forward-backward |
| Decoding | Most likely state sequence $\underset{X_{1:T}}{\mathrm{argmax}}~\mathbb{P}_\theta(X_{1:T} \vert Y_{1:T})$ | Viterbi |
| Learning | Maximum likelihood parameter $\underset{\theta}{\mathrm{argmax}}~\mathbb{P}_\theta(Y_{1:T})$ | Baum-Welch |
A Julia package for simulation, inference and learning of [Hidden Markov Models](https://en.wikipedia.org/wiki/Hidden_Markov_model).

## Getting started

Expand All @@ -37,7 +20,7 @@ This package can be installed using Julia's package manager:
pkg> add HiddenMarkovModels
```

Then, you can create your first HMM as follows:
Then, you can create your first model as follows:

```julia
using Distributions, HiddenMarkovModels
Expand All @@ -49,21 +32,38 @@ hmm = HMM(init, trans, dists)

Take a look at the [documentation](https://gdalle.github.io/HiddenMarkovModels.jl/stable/) to know what to do next!

## Some background

HMMs are a widely used modeling framework in signal processing, bioinformatics and plenty of other fields.
They capture the distribution of an observation sequence $(Y_t)$ by assuming the existence of a latent state sequence $(X_t)$ such that:

* the state follows a (discrete time, discrete space) Markov chain $\mathbb{P}_\theta(X_t | X_{t-1})$
* the observation distribution is determined at each time by the state $\mathbb{P}_\theta(Y_t | X_t)$

Each of the problems below has an efficient solution algorithm which our package implements:

| Problem | Goal | Algorithm |
| ---------- | --------------------------------------------------------------------------------------------------------- | ---------------- |
| Evaluation | Likelihood of the observation sequence $\mathbb{P}_\theta(Y_{1:T})$ | Forward |
| Filtering | Non-anticipative state marginals $\mathbb{P}_\theta(X_t \vert Y_{1:t})$ | Forward |
| Smoothing | State marginals $\mathbb{P}_\theta(X_t \vert Y_{1:T})$ | Forward-backward |
| Decoding | Most likely state sequence $\underset{X_{1:T}}{\mathrm{argmax}}~\mathbb{P}_\theta(X_{1:T} \vert Y_{1:T})$ | Viterbi |
| Learning | Maximum likelihood parameter $\underset{\theta}{\mathrm{argmax}}~\mathbb{P}_\theta(Y_{1:T})$ | Baum-Welch |

## Main features

This package is **generic**.
Observations can be arbitrary Julia objects, not just scalars or arrays, because their distributions only need to implement `rand(rng, dist)` and `logdensityof(dist, x)` ([DensityInterface.jl](https://github.com/JuliaMath/DensityInterface.jl)).
Number types are not restricted to floating point, and automatic differentiation is supported in forward mode ([ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl)).
Observations can be arbitrary Julia objects, not just scalars or arrays.
Number types are not restricted to floating point, which enables automatic differentiation.
Time-heterogeneous or controlled HMMs are supported out of the box.

This package is **fast**.
All the inference functions have allocation-free versions, which leverage efficient linear algebra subroutines.
We will include extensive benchmarks against Julia and Python competitors ([BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl) + [PythonCall.jl](https://github.com/cjdoris/PythonCall.jl)).
We will include extensive benchmarks against Julia and Python competitors.

This package is **reliable**.
It gives the same results as the previous reference package ([HMMBase.jl](https://github.com/maxmouchet/HMMBase.jl)) up to numerical accuracy.
The test suite incorporates quality checks ([Aqua.jl](https://github.com/JuliaTesting/Aqua.jl)), as well as type stability analysis ([JET.jl](https://github.com/aviatesk/JET.jl)).
A detailed documentation will help you find the functions you need.
It gives the same results as the previous reference package up to numerical accuracy.
The test suite incorporates quality checks as well as type stability and allocation analysis.

## Contributing

Expand All @@ -72,5 +72,5 @@ Once the issue receives positive feedback, feel free to try and fix it with a pu

## Acknowledgements

A big thank you to [Maxime Mouchet](https://www.maxmouchet.com/) and [Jacob Schreiber](https://jmschrei.github.io/), the respective lead devs of [HMMBase.jl](https://github.com/maxmouchet/HMMBase.jl) and [pomegranate](https://github.com/jmschrei/pomegranate), for their help and advice.
A big thank you to [Maxime Mouchet](https://www.maxmouchet.com/) and [Jacob Schreiber](https://jmschrei.github.io/), the respective lead devs of alternative packages [HMMBase.jl](https://github.com/maxmouchet/HMMBase.jl) and [pomegranate](https://github.com/jmschrei/pomegranate), for their help and advice.
Logo by [Clément Mantoux](https://cmantoux.github.io/) based on a portrait of [Andrey Markov](https://en.wikipedia.org/wiki/Andrey_Markov).
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ end

pages = [
"Home" => "index.md",
"API reference" => "api.md",
"Tutorials" => [
"Basics" => joinpath("examples", "basics.md"),
"Interfaces" => joinpath("examples", "interfaces.md"),
"Autodiff" => joinpath("examples", "autodiff.md"),
"Periodic" => joinpath("examples", "periodic.md"),
"Controlled" => joinpath("examples", "controlled.md"),
],
"API reference" => "api.md",
"Advanced" => [
"Alternatives" => "alternatives.md",
"Debugging" => "debugging.md",
Expand Down
35 changes: 18 additions & 17 deletions examples/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ Random.seed!(rng, 63);
# ## Model construction

#=
The package provides a versatile `HMM` type with three attributes:
The package provides a versatile [`HMM`](@ref) type with three attributes:
- a vector of state initialization probabilities
- a matrix of state transition probabilities
- a vector of observation distributions, one for each state
Here we keep it simple by using Distributions.jl, but there are other ways.
We keep it simple for now by leveraging Distributions.jl.
=#

init = [0.8, 0.2]
Expand All @@ -35,14 +35,13 @@ hmm = HMM(init, trans, dists);
# ## Simulation

#=
You can simulate a pair of state and observation sequences by specifying how long you want them to be.
You can simulate a pair of state and observation sequences with [`rand`](@ref) by specifying how long you want them to be.
=#

state_seq, obs_seq = rand(rng, hmm, 20);

#=
Note that the observation sequence is a vector, whose elements have whatever type an observation distribution returns when sampled.
As we will see, this can be more generic than just numbers or arrays.
=#

state_seq[1], obs_seq[1]
Expand All @@ -54,7 +53,7 @@ In practical applications, the state sequence is not known, which is why we need
# ## Inference

#=
The Viterbi algorithm returns the most likely state sequence $\hat{X}_{1:T} = \underset{X_{1:T}}{\mathrm{argmax}}~\mathbb{P}(X_{1:T} \vert Y_{1:T})$, along with the joint loglikelihood $\mathbb{P}(\hat{X}_{1:T}, Y_{1:T})$.
The Viterbi algorithm ([`viterbi`](@ref)) returns the most likely state sequence $\hat{X}_{1:T} = \underset{X_{1:T}}{\mathrm{argmax}}~\mathbb{P}(X_{1:T} \vert Y_{1:T})$, along with the joint loglikelihood $\mathbb{P}(\hat{X}_{1:T}, Y_{1:T})$.
=#

best_state_seq, best_joint_loglikelihood = viterbi(hmm, obs_seq);
Expand All @@ -66,7 +65,7 @@ As we can see, it is very close to the true state sequence, but not necessarily
vcat(state_seq', best_state_seq')

#=
The forward algorithm returns a matrix of filtered state marginals $\alpha[i, t] = \mathbb{P}(X_t = i | Y_{1:t})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence.
The forward algorithm ([`forward`](@ref)) returns a matrix of filtered state marginals $\alpha[i, t] = \mathbb{P}(X_t = i | Y_{1:t})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence.
=#

filtered_state_marginals, obs_seq_loglikelihood1 = forward(hmm, obs_seq);
Expand All @@ -79,7 +78,7 @@ This is particularly useful to infer the marginal distribution of the last state
filtered_state_marginals[:, end]

#=
Conversely, the forward-backward algorithm returns a matrix of smoothed state marginals $\gamma[i, t] = \mathbb{P}(X_t = i | Y_{1:T})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence.
Conversely, the forward-backward algorithm ([`forward_backward`](@ref)) returns a matrix of smoothed state marginals $\gamma[i, t] = \mathbb{P}(X_t = i | Y_{1:T})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence.
=#

smoothed_state_marginals, obs_seq_loglikelihood2 = forward_backward(hmm, obs_seq);
Expand All @@ -93,7 +92,7 @@ Note that forward and forward-backward only coincide at the last time step.
collect(zip(filtered_state_marginals, smoothed_state_marginals))

#=
Finally, we provide a thin wrapper around the forward algorithm for observation sequence loglikelihoods $\mathbb{P}(Y_{1:T})$.
Finally, we provide a thin wrapper ([`logdensityof`](@ref)) around the forward algorithm for observation sequence loglikelihoods $\mathbb{P}(Y_{1:T})$.
=#

logdensityof(hmm, obs_seq)
Expand All @@ -113,7 +112,7 @@ logdensityof(hmm, obs_seq, best_state_seq)
# ## Learning

#=
The Baum-Welch algorithm is a variant of Expectation-Maximization, designed specifically to estimate HMM parameters.
The Baum-Welch algorithm ([`baum_welch`](@ref)) is a variant of Expectation-Maximization, designed specifically to estimate HMM parameters.
Since it is a local optimization procedure, it requires a starting point that is close enough to the true model.
=#

Expand All @@ -123,11 +122,11 @@ dists_guess = [MvNormal(-0.5 * ones(3), I), MvNormal(+0.5 * ones(3), I)]
hmm_guess = HMM(init_guess, trans_guess, dists_guess);

#=
Now we can estimate parameters based on a slightly longer sequence.
Let's estimate parameters based on a slightly longer sequence.
=#

_, long_obs_seq = rand(rng, hmm, 1000)
hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, long_obs_seq);
hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq);

#=
An essential guarantee of this algorithm is that the loglikelihood of the observation sequence keeps increasing as the model improves.
Expand Down Expand Up @@ -166,12 +165,13 @@ In many applications, we have access to various observation sequences of differe
=#

_, obs_seq2 = rand(rng, hmm, 30)
obs_seqs = [obs_seq, obs_seq2];
_, obs_seq3 = rand(rng, hmm, 10)
obs_seqs = [obs_seq, obs_seq2, obs_seq3];

#=
Every algorithm in the package accepts multiple sequences in a concatenated form.
The user must also specify where each sequence ends in the concatenated vector, by passing `seq_ends` as a keyword argument.
Otherwise, the input would be treated as a single observation sequence, which is mathematically incorrect.
Otherwise, the input will be treated as a unique observation sequence, which is mathematically incorrect.
=#

obs_seq_concat = reduce(vcat, obs_seqs)
Expand All @@ -184,20 +184,21 @@ The outputs of inference algorithms are then concatenated, and the associated lo
best_state_seq_concat, _ = viterbi(hmm, obs_seq_concat; seq_ends);

#=
The function `seq_limits` returns the begin and end of a given sequence in the concatenated vector.
The function [`seq_limits`](@ref) returns the begin and end of a given sequence in the concatenated vector.
It can be used to untangle the results.
=#

start2, stop2 = seq_limits(seq_ends, 2)
best_state_seq_concat[start2:stop2] == first(viterbi(hmm, obs_seq2))

#=
While inference algorithms can also be run separately on each sequence without changing the results, considering multiple sequences together is nontrivial for Baum-Welch, and the package takes care of it automatically.
While inference algorithms can also be run separately on each sequence without changing the results, considering multiple sequences together is nontrivial for Baum-Welch.
That is why the package takes care of it automatically.
=#

baum_welch(hmm_guess, obs_seq_concat; seq_ends);
hmm_est_concat, _ = baum_welch(hmm_guess, obs_seq_concat; seq_ends);

# ## Tests #src

control_seq, seq_ends = fill(nothing, 1000), 100:10:1000 #src
control_seq, seq_ends = fill(nothing, 1000), 100:100:1000 #src
test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=1e-1) #src
51 changes: 32 additions & 19 deletions examples/interfaces.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# # Interfaces

#=
Here we discuss how to extend the observation distributions or HMM behavior to fit specific needs.
Here we discuss how to extend the observation distributions or HMM behavior to satisfy specific needs.
=#

using DensityInterface
using Distributions
using HiddenMarkovModels
using HiddenMarkovModels: test_coherent_algorithms #src
using LinearAlgebra
Expand All @@ -28,8 +29,6 @@ They only need to implement three methods:
In addition, the observation can be arbitrary Julia types.
So let's construct a distribution that generates stuff.
If you want more sophisticated examples, check out the definitions of `HiddenMarkovModels.LightDiagNormal` and `HiddenMarkovModels.LightCategorical`, which are designed to be fast and allocation-free.
=#

struct Stuff{T}
Expand All @@ -55,14 +54,11 @@ end

#=
It is important to declare to DensityInterface.jl that the custom distribution has a density, thanks to the following trait.
The logdensity itself can be computed up to an additive constant without issue.
=#

DensityInterface.DensityKind(::StuffDist) = HasDensity()

#=
The logdensity itself can be computed up to an additive constant without issue.
=#

function DensityInterface.logdensityof(dist::StuffDist, obs::Stuff)
return -abs2(obs.quantity - dist.quantity_mean)
end
Expand Down Expand Up @@ -114,10 +110,9 @@ hmm_guess = HMM(init_guess, trans_guess, dists_guess);
hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq)
obs_distributions(hmm_est)

# ## Tests #src

control_seq, seq_ends = fill(nothing, 1000), 100:10:1000 #src
test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.2, init=false) #src
#=
If you want more sophisticated examples, check out [`HiddenMarkovModels.LightDiagNormal`](@ref) and [`HiddenMarkovModels.LightCategorical`](@ref), which are designed to be fast and allocation-free.
=#

# ## Creating a new HMM type

Expand All @@ -126,13 +121,14 @@ In some scenarios, the vanilla Baum-Welch algorithm is not exactly what we want.
For instance, we might have a prior on the parameters of our model, which we want to apply during the fitting step of the iterative procedure.
Then we need to create a new type that satisfies the `AbstractHMM` interface.
Let's make a simpler version of the built-in `HMM`m with a prior saying that each transition has been observed exactly once.
Let's make a simpler version of the built-in `HMM`m with a prior saying that each transition has already been observed a certain number of times.
=#

struct PriorHMM{T,D} <: AbstractHMM
init::Vector{T}
trans::Matrix{T}
dists::Vector{D}
trans_prior_count::Int
end

#=
Expand All @@ -144,16 +140,17 @@ HiddenMarkovModels.transition_matrix(hmm::PriorHMM) = hmm.trans
HiddenMarkovModels.obs_distributions(hmm::PriorHMM) = hmm.dists

#=
In addition, we want to overload `logdensityof` to specify our prior loglikelihood.
In addition, we want to overload [`logdensityof(hmm)`](@ref) to specify our prior loglikelihood.
It corresponds to a Dirichlet distribution over each row of the transition matrix, where each Dirichlet parameter is one plus the number of times the corresponding transition has been observed.
=#

function DensityInterface.logdensityof(hmm::PriorHMM)
prior = Dirichlet(ones(length(hmm)))
prior = Dirichlet(fill(hmm.trans_prior_count + 1, length(hmm)))
return sum(logdensityof(prior, row) for row in eachrow(transition_matrix(hmm)))
end

#=
And finally, we redefine the specific method of `fit!` that is used during Baum-Welch.
And finally, we redefine the specific method of `fit!` that is used during Baum-Welch: [`fit!(hmm, obs_seq; control_seq, seq_ends, fb_storage)`](@ref).
It accepts the same inputs as `baum_welch` for multiple sequences (disregard `control_seq` for now), and an additional `fb_storage` containing the results of the forward-backward algorithm.
The goal is to modify `hmm` in-place to update its parameters with their current maximum likelihood estimates.
Expand All @@ -168,7 +165,7 @@ function StatsAPI.fit!(
fb_storage::HiddenMarkovModels.ForwardBackwardStorage,
)
hmm.init .= 0
hmm.trans .= 1 # this is where the prior comes in
hmm.trans .= hmm.trans_prior_count
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
hmm.init .+= fb_storage.γ[:, t1]
Expand All @@ -178,7 +175,7 @@ function StatsAPI.fit!(
hmm.trans ./= sum(hmm.trans; dims=2)

for i in 1:length(hmm)
weight_seq = storage.γ[:, i]
weight_seq = fb_storage.γ[i, :]
fit!(hmm.dists[i], obs_seq, weight_seq)
end
return nothing
Expand All @@ -189,6 +186,22 @@ Some distributions, such as those from Distributions.jl
- do not support in-place fitting
- might expect different formats, e.g. higher-order arrays instead of a vector of objects
The function `HiddenMarkovModels.fit_in_sequence!` is a replacement for `fit!` which you can overload at will.
It is already designed to handle Distributions.jl.
The function [`HiddenMarkovModels.fit_in_sequence!`](@ref) is a replacement for `fit!` that is designed to handle Distributions.jl.
You can overload it for your own objects too.
=#

trans_prior_count = 10
prior_hmm_guess = PriorHMM(init_guess, trans_guess, dists_guess, trans_prior_count)

prior_hmm_est, prior_logl_evolution = baum_welch(prior_hmm_guess, obs_seq)

#=
As we can see, the transition matrix for our Bayesian version is slightly more spread out, although this effect would nearly disappear with enough data.
=#

cat(transition_matrix(hmm_est), transition_matrix(prior_hmm_est); dims=3)

# ## Tests #src

control_seq, seq_ends = fill(nothing, 1000), 100:100:1000 #src
test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.2, init=false) #src

0 comments on commit c1a5327

Please sign in to comment.