From c1a53278faadee374f068a4bc214a3ebef708b83 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 8 Dec 2023 14:55:56 +0100 Subject: [PATCH] Dirichlet prior --- README.md | 52 +++++++++++++++++++++--------------------- docs/make.jl | 2 +- examples/basics.jl | 35 ++++++++++++++-------------- examples/interfaces.jl | 51 ++++++++++++++++++++++++++--------------- 4 files changed, 77 insertions(+), 63 deletions(-) diff --git a/README.md b/README.md index 53b69424..6370e7e4 100644 --- a/README.md +++ b/README.md @@ -10,24 +10,7 @@ [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![JET](https://img.shields.io/badge/%E2%9C%88%EF%B8%8F%20tested%20with%20-%20JET.jl%20-%20red)](https://github.com/aviatesk/JET.jl) -A Julia package for HMM modeling, simulation, inference and learning. - -## Mathematical background - -[Hidden Markov Models](https://en.wikipedia.org/wiki/Hidden_Markov_model) are a statistical modeling framework that is ubiquitous in signal processing, bioinformatics and plenty of other fields. They capture the distribution of an observation sequence $(Y_t)$ by assuming the existence of a latent state sequence $(X_t)$ such that: - -* the state follows a (discrete time, discrete space) Markov chain $\mathbb{P}_\theta(X_t | X_{t-1})$ -* the observation distribution is determined at each time by the state $\mathbb{P}_\theta(Y_t | X_t)$ - -HMMs are associated with several statistical problems, each of which has an efficient solution algorithm that our package implements: - -| Problem | Goal | Algorithm | -| ---------- | --------------------------------------------------------------------------------------------------------- | ---------------- | -| Evaluation | Likelihood of the observation sequence $\mathbb{P}_\theta(Y_{1:T})$ | Forward | -| Filtering | Non-anticipative state marginals $\mathbb{P}_\theta(X_t \vert Y_{1:t})$ | Forward | -| Smoothing | State marginals $\mathbb{P}_\theta(X_t \vert Y_{1:T})$ | Forward-backward | -| Decoding | Most likely state sequence $\underset{X_{1:T}}{\mathrm{argmax}}~\mathbb{P}_\theta(X_{1:T} \vert Y_{1:T})$ | Viterbi | -| Learning | Maximum likelihood parameter $\underset{\theta}{\mathrm{argmax}}~\mathbb{P}_\theta(Y_{1:T})$ | Baum-Welch | +A Julia package for simulation, inference and learning of [Hidden Markov Models](https://en.wikipedia.org/wiki/Hidden_Markov_model). ## Getting started @@ -37,7 +20,7 @@ This package can be installed using Julia's package manager: pkg> add HiddenMarkovModels ``` -Then, you can create your first HMM as follows: +Then, you can create your first model as follows: ```julia using Distributions, HiddenMarkovModels @@ -49,21 +32,38 @@ hmm = HMM(init, trans, dists) Take a look at the [documentation](https://gdalle.github.io/HiddenMarkovModels.jl/stable/) to know what to do next! +## Some background + +HMMs are a widely used modeling framework in signal processing, bioinformatics and plenty of other fields. +They capture the distribution of an observation sequence $(Y_t)$ by assuming the existence of a latent state sequence $(X_t)$ such that: + +* the state follows a (discrete time, discrete space) Markov chain $\mathbb{P}_\theta(X_t | X_{t-1})$ +* the observation distribution is determined at each time by the state $\mathbb{P}_\theta(Y_t | X_t)$ + +Each of the problems below has an efficient solution algorithm which our package implements: + +| Problem | Goal | Algorithm | +| ---------- | --------------------------------------------------------------------------------------------------------- | ---------------- | +| Evaluation | Likelihood of the observation sequence $\mathbb{P}_\theta(Y_{1:T})$ | Forward | +| Filtering | Non-anticipative state marginals $\mathbb{P}_\theta(X_t \vert Y_{1:t})$ | Forward | +| Smoothing | State marginals $\mathbb{P}_\theta(X_t \vert Y_{1:T})$ | Forward-backward | +| Decoding | Most likely state sequence $\underset{X_{1:T}}{\mathrm{argmax}}~\mathbb{P}_\theta(X_{1:T} \vert Y_{1:T})$ | Viterbi | +| Learning | Maximum likelihood parameter $\underset{\theta}{\mathrm{argmax}}~\mathbb{P}_\theta(Y_{1:T})$ | Baum-Welch | + ## Main features This package is **generic**. -Observations can be arbitrary Julia objects, not just scalars or arrays, because their distributions only need to implement `rand(rng, dist)` and `logdensityof(dist, x)` ([DensityInterface.jl](https://github.com/JuliaMath/DensityInterface.jl)). -Number types are not restricted to floating point, and automatic differentiation is supported in forward mode ([ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl)). +Observations can be arbitrary Julia objects, not just scalars or arrays. +Number types are not restricted to floating point, which enables automatic differentiation. Time-heterogeneous or controlled HMMs are supported out of the box. This package is **fast**. All the inference functions have allocation-free versions, which leverage efficient linear algebra subroutines. -We will include extensive benchmarks against Julia and Python competitors ([BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl) + [PythonCall.jl](https://github.com/cjdoris/PythonCall.jl)). +We will include extensive benchmarks against Julia and Python competitors. This package is **reliable**. -It gives the same results as the previous reference package ([HMMBase.jl](https://github.com/maxmouchet/HMMBase.jl)) up to numerical accuracy. -The test suite incorporates quality checks ([Aqua.jl](https://github.com/JuliaTesting/Aqua.jl)), as well as type stability analysis ([JET.jl](https://github.com/aviatesk/JET.jl)). -A detailed documentation will help you find the functions you need. +It gives the same results as the previous reference package up to numerical accuracy. +The test suite incorporates quality checks as well as type stability and allocation analysis. ## Contributing @@ -72,5 +72,5 @@ Once the issue receives positive feedback, feel free to try and fix it with a pu ## Acknowledgements -A big thank you to [Maxime Mouchet](https://www.maxmouchet.com/) and [Jacob Schreiber](https://jmschrei.github.io/), the respective lead devs of [HMMBase.jl](https://github.com/maxmouchet/HMMBase.jl) and [pomegranate](https://github.com/jmschrei/pomegranate), for their help and advice. +A big thank you to [Maxime Mouchet](https://www.maxmouchet.com/) and [Jacob Schreiber](https://jmschrei.github.io/), the respective lead devs of alternative packages [HMMBase.jl](https://github.com/maxmouchet/HMMBase.jl) and [pomegranate](https://github.com/jmschrei/pomegranate), for their help and advice. Logo by [Clément Mantoux](https://cmantoux.github.io/) based on a portrait of [Andrey Markov](https://en.wikipedia.org/wiki/Andrey_Markov). diff --git a/docs/make.jl b/docs/make.jl index ba095091..46efedd2 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -43,6 +43,7 @@ end pages = [ "Home" => "index.md", + "API reference" => "api.md", "Tutorials" => [ "Basics" => joinpath("examples", "basics.md"), "Interfaces" => joinpath("examples", "interfaces.md"), @@ -50,7 +51,6 @@ pages = [ "Periodic" => joinpath("examples", "periodic.md"), "Controlled" => joinpath("examples", "controlled.md"), ], - "API reference" => "api.md", "Advanced" => [ "Alternatives" => "alternatives.md", "Debugging" => "debugging.md", diff --git a/examples/basics.jl b/examples/basics.jl index 683e7f3d..23a0d05f 100644 --- a/examples/basics.jl +++ b/examples/basics.jl @@ -19,12 +19,12 @@ Random.seed!(rng, 63); # ## Model construction #= -The package provides a versatile `HMM` type with three attributes: +The package provides a versatile [`HMM`](@ref) type with three attributes: - a vector of state initialization probabilities - a matrix of state transition probabilities - a vector of observation distributions, one for each state -Here we keep it simple by using Distributions.jl, but there are other ways. +We keep it simple for now by leveraging Distributions.jl. =# init = [0.8, 0.2] @@ -35,14 +35,13 @@ hmm = HMM(init, trans, dists); # ## Simulation #= -You can simulate a pair of state and observation sequences by specifying how long you want them to be. +You can simulate a pair of state and observation sequences with [`rand`](@ref) by specifying how long you want them to be. =# state_seq, obs_seq = rand(rng, hmm, 20); #= Note that the observation sequence is a vector, whose elements have whatever type an observation distribution returns when sampled. -As we will see, this can be more generic than just numbers or arrays. =# state_seq[1], obs_seq[1] @@ -54,7 +53,7 @@ In practical applications, the state sequence is not known, which is why we need # ## Inference #= -The Viterbi algorithm returns the most likely state sequence $\hat{X}_{1:T} = \underset{X_{1:T}}{\mathrm{argmax}}~\mathbb{P}(X_{1:T} \vert Y_{1:T})$, along with the joint loglikelihood $\mathbb{P}(\hat{X}_{1:T}, Y_{1:T})$. +The Viterbi algorithm ([`viterbi`](@ref)) returns the most likely state sequence $\hat{X}_{1:T} = \underset{X_{1:T}}{\mathrm{argmax}}~\mathbb{P}(X_{1:T} \vert Y_{1:T})$, along with the joint loglikelihood $\mathbb{P}(\hat{X}_{1:T}, Y_{1:T})$. =# best_state_seq, best_joint_loglikelihood = viterbi(hmm, obs_seq); @@ -66,7 +65,7 @@ As we can see, it is very close to the true state sequence, but not necessarily vcat(state_seq', best_state_seq') #= -The forward algorithm returns a matrix of filtered state marginals $\alpha[i, t] = \mathbb{P}(X_t = i | Y_{1:t})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence. +The forward algorithm ([`forward`](@ref)) returns a matrix of filtered state marginals $\alpha[i, t] = \mathbb{P}(X_t = i | Y_{1:t})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence. =# filtered_state_marginals, obs_seq_loglikelihood1 = forward(hmm, obs_seq); @@ -79,7 +78,7 @@ This is particularly useful to infer the marginal distribution of the last state filtered_state_marginals[:, end] #= -Conversely, the forward-backward algorithm returns a matrix of smoothed state marginals $\gamma[i, t] = \mathbb{P}(X_t = i | Y_{1:T})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence. +Conversely, the forward-backward algorithm ([`forward_backward`](@ref)) returns a matrix of smoothed state marginals $\gamma[i, t] = \mathbb{P}(X_t = i | Y_{1:T})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence. =# smoothed_state_marginals, obs_seq_loglikelihood2 = forward_backward(hmm, obs_seq); @@ -93,7 +92,7 @@ Note that forward and forward-backward only coincide at the last time step. collect(zip(filtered_state_marginals, smoothed_state_marginals)) #= -Finally, we provide a thin wrapper around the forward algorithm for observation sequence loglikelihoods $\mathbb{P}(Y_{1:T})$. +Finally, we provide a thin wrapper ([`logdensityof`](@ref)) around the forward algorithm for observation sequence loglikelihoods $\mathbb{P}(Y_{1:T})$. =# logdensityof(hmm, obs_seq) @@ -113,7 +112,7 @@ logdensityof(hmm, obs_seq, best_state_seq) # ## Learning #= -The Baum-Welch algorithm is a variant of Expectation-Maximization, designed specifically to estimate HMM parameters. +The Baum-Welch algorithm ([`baum_welch`](@ref)) is a variant of Expectation-Maximization, designed specifically to estimate HMM parameters. Since it is a local optimization procedure, it requires a starting point that is close enough to the true model. =# @@ -123,11 +122,11 @@ dists_guess = [MvNormal(-0.5 * ones(3), I), MvNormal(+0.5 * ones(3), I)] hmm_guess = HMM(init_guess, trans_guess, dists_guess); #= -Now we can estimate parameters based on a slightly longer sequence. +Let's estimate parameters based on a slightly longer sequence. =# _, long_obs_seq = rand(rng, hmm, 1000) -hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, long_obs_seq); +hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq); #= An essential guarantee of this algorithm is that the loglikelihood of the observation sequence keeps increasing as the model improves. @@ -166,12 +165,13 @@ In many applications, we have access to various observation sequences of differe =# _, obs_seq2 = rand(rng, hmm, 30) -obs_seqs = [obs_seq, obs_seq2]; +_, obs_seq3 = rand(rng, hmm, 10) +obs_seqs = [obs_seq, obs_seq2, obs_seq3]; #= Every algorithm in the package accepts multiple sequences in a concatenated form. The user must also specify where each sequence ends in the concatenated vector, by passing `seq_ends` as a keyword argument. -Otherwise, the input would be treated as a single observation sequence, which is mathematically incorrect. +Otherwise, the input will be treated as a unique observation sequence, which is mathematically incorrect. =# obs_seq_concat = reduce(vcat, obs_seqs) @@ -184,7 +184,7 @@ The outputs of inference algorithms are then concatenated, and the associated lo best_state_seq_concat, _ = viterbi(hmm, obs_seq_concat; seq_ends); #= -The function `seq_limits` returns the begin and end of a given sequence in the concatenated vector. +The function [`seq_limits`](@ref) returns the begin and end of a given sequence in the concatenated vector. It can be used to untangle the results. =# @@ -192,12 +192,13 @@ start2, stop2 = seq_limits(seq_ends, 2) best_state_seq_concat[start2:stop2] == first(viterbi(hmm, obs_seq2)) #= -While inference algorithms can also be run separately on each sequence without changing the results, considering multiple sequences together is nontrivial for Baum-Welch, and the package takes care of it automatically. +While inference algorithms can also be run separately on each sequence without changing the results, considering multiple sequences together is nontrivial for Baum-Welch. +That is why the package takes care of it automatically. =# -baum_welch(hmm_guess, obs_seq_concat; seq_ends); +hmm_est_concat, _ = baum_welch(hmm_guess, obs_seq_concat; seq_ends); # ## Tests #src -control_seq, seq_ends = fill(nothing, 1000), 100:10:1000 #src +control_seq, seq_ends = fill(nothing, 1000), 100:100:1000 #src test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=1e-1) #src diff --git a/examples/interfaces.jl b/examples/interfaces.jl index 4c71c65f..fcd96663 100644 --- a/examples/interfaces.jl +++ b/examples/interfaces.jl @@ -1,10 +1,11 @@ # # Interfaces #= -Here we discuss how to extend the observation distributions or HMM behavior to fit specific needs. +Here we discuss how to extend the observation distributions or HMM behavior to satisfy specific needs. =# using DensityInterface +using Distributions using HiddenMarkovModels using HiddenMarkovModels: test_coherent_algorithms #src using LinearAlgebra @@ -28,8 +29,6 @@ They only need to implement three methods: In addition, the observation can be arbitrary Julia types. So let's construct a distribution that generates stuff. - -If you want more sophisticated examples, check out the definitions of `HiddenMarkovModels.LightDiagNormal` and `HiddenMarkovModels.LightCategorical`, which are designed to be fast and allocation-free. =# struct Stuff{T} @@ -55,14 +54,11 @@ end #= It is important to declare to DensityInterface.jl that the custom distribution has a density, thanks to the following trait. +The logdensity itself can be computed up to an additive constant without issue. =# DensityInterface.DensityKind(::StuffDist) = HasDensity() -#= -The logdensity itself can be computed up to an additive constant without issue. -=# - function DensityInterface.logdensityof(dist::StuffDist, obs::Stuff) return -abs2(obs.quantity - dist.quantity_mean) end @@ -114,10 +110,9 @@ hmm_guess = HMM(init_guess, trans_guess, dists_guess); hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq) obs_distributions(hmm_est) -# ## Tests #src - -control_seq, seq_ends = fill(nothing, 1000), 100:10:1000 #src -test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.2, init=false) #src +#= +If you want more sophisticated examples, check out [`HiddenMarkovModels.LightDiagNormal`](@ref) and [`HiddenMarkovModels.LightCategorical`](@ref), which are designed to be fast and allocation-free. +=# # ## Creating a new HMM type @@ -126,13 +121,14 @@ In some scenarios, the vanilla Baum-Welch algorithm is not exactly what we want. For instance, we might have a prior on the parameters of our model, which we want to apply during the fitting step of the iterative procedure. Then we need to create a new type that satisfies the `AbstractHMM` interface. -Let's make a simpler version of the built-in `HMM`m with a prior saying that each transition has been observed exactly once. +Let's make a simpler version of the built-in `HMM`m with a prior saying that each transition has already been observed a certain number of times. =# struct PriorHMM{T,D} <: AbstractHMM init::Vector{T} trans::Matrix{T} dists::Vector{D} + trans_prior_count::Int end #= @@ -144,16 +140,17 @@ HiddenMarkovModels.transition_matrix(hmm::PriorHMM) = hmm.trans HiddenMarkovModels.obs_distributions(hmm::PriorHMM) = hmm.dists #= -In addition, we want to overload `logdensityof` to specify our prior loglikelihood. +In addition, we want to overload [`logdensityof(hmm)`](@ref) to specify our prior loglikelihood. +It corresponds to a Dirichlet distribution over each row of the transition matrix, where each Dirichlet parameter is one plus the number of times the corresponding transition has been observed. =# function DensityInterface.logdensityof(hmm::PriorHMM) - prior = Dirichlet(ones(length(hmm))) + prior = Dirichlet(fill(hmm.trans_prior_count + 1, length(hmm))) return sum(logdensityof(prior, row) for row in eachrow(transition_matrix(hmm))) end #= -And finally, we redefine the specific method of `fit!` that is used during Baum-Welch. +And finally, we redefine the specific method of `fit!` that is used during Baum-Welch: [`fit!(hmm, obs_seq; control_seq, seq_ends, fb_storage)`](@ref). It accepts the same inputs as `baum_welch` for multiple sequences (disregard `control_seq` for now), and an additional `fb_storage` containing the results of the forward-backward algorithm. The goal is to modify `hmm` in-place to update its parameters with their current maximum likelihood estimates. @@ -168,7 +165,7 @@ function StatsAPI.fit!( fb_storage::HiddenMarkovModels.ForwardBackwardStorage, ) hmm.init .= 0 - hmm.trans .= 1 # this is where the prior comes in + hmm.trans .= hmm.trans_prior_count for k in eachindex(seq_ends) t1, t2 = seq_limits(seq_ends, k) hmm.init .+= fb_storage.γ[:, t1] @@ -178,7 +175,7 @@ function StatsAPI.fit!( hmm.trans ./= sum(hmm.trans; dims=2) for i in 1:length(hmm) - weight_seq = storage.γ[:, i] + weight_seq = fb_storage.γ[i, :] fit!(hmm.dists[i], obs_seq, weight_seq) end return nothing @@ -189,6 +186,22 @@ Some distributions, such as those from Distributions.jl - do not support in-place fitting - might expect different formats, e.g. higher-order arrays instead of a vector of objects -The function `HiddenMarkovModels.fit_in_sequence!` is a replacement for `fit!` which you can overload at will. -It is already designed to handle Distributions.jl. +The function [`HiddenMarkovModels.fit_in_sequence!`](@ref) is a replacement for `fit!` that is designed to handle Distributions.jl. +You can overload it for your own objects too. =# + +trans_prior_count = 10 +prior_hmm_guess = PriorHMM(init_guess, trans_guess, dists_guess, trans_prior_count) + +prior_hmm_est, prior_logl_evolution = baum_welch(prior_hmm_guess, obs_seq) + +#= +As we can see, the transition matrix for our Bayesian version is slightly more spread out, although this effect would nearly disappear with enough data. +=# + +cat(transition_matrix(hmm_est), transition_matrix(prior_hmm_est); dims=3) + +# ## Tests #src + +control_seq, seq_ends = fill(nothing, 1000), 100:100:1000 #src +test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.2, init=false) #src