Skip to content

Commit

Permalink
Better docs and numerical stability (#87)
Browse files Browse the repository at this point in the history
* Better docs and numerical stability

* Fix leaky tests

* Better autodiff tuto

* Typos

* Fix table
  • Loading branch information
gdalle authored Feb 22, 2024
1 parent 31de15c commit ba666fb
Show file tree
Hide file tree
Showing 37 changed files with 664 additions and 480 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Guillaume Dalle"]
version = "0.4.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -13,6 +14,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[weakdeps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -21,6 +23,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
HiddenMarkovModelsDistributionsExt = "Distributions"

[compat]
ArgCheck = "2.3"
ChainRulesCore = "1.16"
DensityInterface = "0.4"
Distributions = "0.25"
Expand All @@ -31,4 +34,5 @@ PrecompileTools = "1.1"
Random = "1"
SparseArrays = "1"
StatsAPI = "1.6"
StatsFuns = "1.3"
julia = "1.9"
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ Take a look at the [documentation](https://gdalle.github.io/HiddenMarkovModels.j

[Hidden Markov Models](https://en.wikipedia.org/wiki/Hidden_Markov_model) (HMMs) are a widely used modeling framework in signal processing, bioinformatics and plenty of other fields.
They explain an observation sequence $(Y_t)$ by assuming the existence of a latent Markovian state sequence $(X_t)$ whose current value determines the distribution of observations.
In our framework, both the state and the observation sequence are also allowed to depend on a known control sequence $(U_t)$.
Each of the problems below has an efficient solution algorithm which our package implements:
In some scenarios, the state and the observation sequence are also allowed to depend on a known control sequence $(U_t)$.
Each of the problems below has an efficient solution algorithm, available here:

| Problem | Goal | Algorithm |
| ---------- | -------------------------------------- | ---------------- |
Expand All @@ -47,6 +47,10 @@ Each of the problems below has an efficient solution algorithm which our package
| Decoding | Most likely state sequence | Viterbi |
| Learning | Maximum likelihood parameter | Baum-Welch |

Take a look at this tutorial to know more about the math:

> [_A tutorial on hidden Markov models and selected applications in speech recognition_](https://ieeexplore.ieee.org/document/18626), Rabiner (1989)
## Main features

This package is **generic**.
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand Down
4 changes: 2 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ end

pages = [
"Home" => "index.md",
"API reference" => "api.md",
"Tutorials" => [
"Basics" => joinpath("examples", "basics.md"),
"Types" => joinpath("examples", "types.md"),
"Interfaces" => joinpath("examples", "interfaces.md"),
"Autodiff" => joinpath("examples", "autodiff.md"),
"Time dependency" => joinpath("examples", "temporal.md"),
"Control dependency" => joinpath("examples", "controlled.md"),
"Autodiff" => joinpath("examples", "autodiff.md"),
],
"API reference" => "api.md",
"Advanced" => [
"Alternatives" => "alternatives.md",
"Debugging" => "debugging.md",
Expand Down
59 changes: 41 additions & 18 deletions docs/src/alternatives.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,54 @@

We compare features among the following Julia packages:

* HiddenMarkovModels.jl (abbreviated to HMMs.jl)
* HiddenMarkovModels.jl
* [HMMBase.jl](https://github.com/maxmouchet/HMMBase.jl)
* [HMMGradients.jl](https://github.com/idiap/HMMGradients.jl)

We discard [MarkovModels.jl](https://github.com/FAST-ASR/MarkovModels.jl) because its focus is GPU computation.
There are also more generic packages for probabilistic programming, which are able to perform MCMC or variational inference (eg. [Turing.jl](https://github.com/TuringLang/Turing.jl)) but we leave those aside.

| | HMMs.jl | HMMBase.jl | HMMGradients.jl |
| ------------------------- | ------------------- | ------------------- | --------------- |
| Algorithms | Sim, FB, Vit, BW | Sim, FB, Vit, BW | FB |
| Observation types | anything | `Number` / `Vector` | anything |
| Observation distributions | DensityInterface.jl | Distributions.jl | manual |
| Multiple sequences | yes | no | yes |
| Priors / structures | possible | no | possible |
| Temporal dependency | yes | no | no |
| Control dependency | yes | no | no |
| Number types | anything | `Float64` | `AbstractFloat` |
| Automatic differentiation | yes | no | yes |
| Linear algebra | yes | yes | no |
| Logarithmic probabilities | halfway | halfway | yes |

Sim = Simulation, FB = Forward-Backward, Vit = Viterbi, BW = Baum-Welch
| | HiddenMarkovModels.jl | HMMBase.jl | HMMGradients.jl |
| ------------------------- | --------------------- | ---------------- | --------------- |
| Algorithms[^1] | V, FB, BW | V, FB, BW | FB |
| Number types | anything | `Float64` | `AbstractFloat` |
| Observation types | anything | number or vector | anything |
| Observation distributions | DensityInterface.jl | Distributions.jl | manual |
| Multiple sequences | yes | no | yes |
| Priors / structures | possible | no | possible |
| Control dependency | yes | no | no |
| Automatic differentiation | yes | no | yes |
| Linear algebra speedup | yes | yes | no |
| Numerical stability | scaling+ | scaling+ | log |


!!! info "Very small probabilities"
In all HMM algorithms, we work with probabilities that may become very small as time progresses.
There are two main solutions for this problem: scaling and logarithmic computations.
This package implements the Viterbi algorithm in log scale, but the other algorithms use scaling to exploit BLAS operations.
As was done in HMMBase.jl, we enhance scaling with a division by the highest observation loglikelihood: instead of working with $b_{i,t} = \mathbb{P}(Y_t | X_t = i)$, we use $b_{i,t} / \max_i b_{i,t}$.
See [Formulas](@ref) for details.

## Python

We compare features among the following Python packages:

* [hmmlearn](https://github.com/hmmlearn/hmmlearn) (based on NumPy)
* [pomegrnate](https://github.com/jmschrei/pomegranate) (based on PyTorch)
* [dynamax](https://github.com/probml/dynamax) (based on JAX)
* [pomegranate](https://github.com/jmschrei/pomegranate) (based on PyTorch)
* [dynamax](https://github.com/probml/dynamax) (based on JAX)

| | hmmlearn | pomegranate | dynamax |
| ------------------------- | -------------------- | --------------------- | -------------------- |
| Algorithms[^1] | V, FB, BW, VI | FB, BW | FB, V, BW, GD |
| Number types | NumPy formats | PyTorch formats | JAX formats |
| Observation types | number or vector | number or vector | number or vector |
| Observation distributions | hmmlearn catalogue | pomegranate catalogue | dynamax catalogue |
| Multiple sequences | yes | yes | yes |
| Priors / structures | yes | no | yes |
| Control dependency | no | no | yes |
| Automatic differentiation | no | yes | yes |
| Linear algebra speedup | yes | yes | yes |
| Numerical stability | scaling / log | log | log |


[^1]: V = Viterbi, FB = Forward-Backward, BW = Baum-Welch, VI = Variational Inference, GD = Gradient Descent
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ HiddenMarkovModels

## Sequence formatting

Most algorithms below ingest the data with three (keyword) arguments: `obs_seq`, `control_seq` and `seq_ends`.
Most algorithms below ingest the data with two positional arguments `obs_seq` (mandatory) and `control_seq` (optional), and a keyword argument `seq_ends` (optional).

- If the data consists of a single sequence, `obs_seq` and `control_seq` are the corresponding vectors of observations and controls, and you don't need to provide `seq_ends`.
- If the data consists of multiple sequences, `obs_seq` and `control_seq` are concatenations of several vectors, whose end indices are given by `seq_ends`. Starting from separate sequences `obs_seqs` and `control_seqs`, you can run the following snippet:
Expand Down
21 changes: 15 additions & 6 deletions docs/src/debugging.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,31 @@
# Debugging

## Numerical overflow
## Numerical underflow

The most frequent error you will encounter is an `OverflowError` during inference, telling you that some values are infinite or `NaN`.
The most frequent error you will encounter is an underflow during inference, caused by some values being infinite or `NaN`.
This can happen for a variety of reasons, so here are a few leads worth investigating:

* Increase the duration of the sequence / the number of sequences to get more data
* Add a prior to your transition matrix / observation distributions to avoid degenerate behavior like zero variance in a Gaussian
* Add a prior to your transition matrix / observation distributions to avoid degenerate behavior (like zero variance in a Gaussian or zero probability in a Bernoulli)
* Reduce the number of states to make every one of them useful
* Pick a better initialization to start closer to the supposed ground truth
* Use numerically stable number types (such as [LogarithmicNumbers.jl](https://github.com/cjdoris/LogarithmicNumbers.jl)) in strategic places, but beware: these numbers don't play nicely with Distributions.jl, so you may have to roll out your own observation distributions.
* Use numerically stable number types (such as [LogarithmicNumbers.jl](https://github.com/cjdoris/LogarithmicNumbers.jl)) in strategic places, but beware: these numbers don't play nicely with Distributions.jl, so you may have to roll out your own [Custom distributions](@ref).

## Method errors

This might be caused by:

* forgetting to define methods for your custom type
* omitting `control_seq` or `seq_ends` in some places.

Check the [API reference](@ref).

## Performance

If your algorithms are too slow, the general advice always applies:
If your algorithms are too slow, you can leverage the existing [Interfaces](@ref) to improve the components of your model separately (first observation distributions, then fitting).
The usual advice always applies:

* Use [BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl) to establish a baseline
* Use profiling to see where you spend most of your time
* Use [JET.jl](https://github.com/aviatesk/JET.jl) to track down type instabilities
* Use [AllocCheck.jl](https://github.com/JuliaLang/AllocCheck.jl) to reduce allocations

34 changes: 18 additions & 16 deletions docs/src/formulas.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ Suppose we are given observations $Y_1, ..., Y_T$, with hidden states $X_1, ...,
Following [Rabiner1989](@cite), we use the following notations:

* let $\pi \in \mathbb{R}^N$ be the initial state distribution $\pi_i = \mathbb{P}(X_1 = i)$
* let $A \in \mathbb{R}^{N \times N}$ be the transition matrix $a_{i,j} = \mathbb{P}(X_{t+1}=j | X_t = i)$
* let $A_t \in \mathbb{R}^{N \times N}$ be the transition matrix $a_{i,j,t} = \mathbb{P}(X_{t+1}=j | X_t = i)$
* let $B \in \mathbb{R}^{N \times T}$ be the matrix of statewise observation likelihoods $b_{i,t} = \mathbb{P}(Y_t | X_t = i)$

The conditioning on the known controls $U_{1:T}$ is implicit throughout.

## Vanilla forward-backward

### Recursion
Expand All @@ -33,8 +35,8 @@ and satisfy the dynamic programming equations

```math
\begin{align*}
\alpha_{j,t+1} & = \left(\sum_{i=1}^N \alpha_{i,t} a_{i,j}\right) b_{j,t+1} \\
\beta_{i,t} & = \sum_{j=1}^N a_{i,j} b_{j,t+1} \beta_{j,t+1}
\alpha_{j,t+1} & = \left(\sum_{i=1}^N \alpha_{i,t} a_{i,j,t}\right) b_{j,t+1} \\
\beta_{i,t} & = \sum_{j=1}^N a_{i,j,t} b_{j,t+1} \beta_{j,t+1}
\end{align*}
```

Expand All @@ -53,7 +55,7 @@ We notice that
```math
\begin{align*}
\alpha_{i,t} \beta_{i,t} & = \mathbb{P}(Y_{1:T}, X_t=i) \\
\alpha_{i,t} a_{i,j} b_{j,t+1} \beta_{j,t+1} & = \mathbb{P}(Y_{1:T}, X_t=i, X_{t+1}=j)
\alpha_{i,t} a_{i,j,t} b_{j,t+1} \beta_{j,t+1} & = \mathbb{P}(Y_{1:T}, X_t=i, X_{t+1}=j)
\end{align*}
```

Expand All @@ -62,7 +64,7 @@ Thus we deduce the one-state and two-state marginals
```math
\begin{align*}
\gamma_{i,t} & = \mathbb{P}(X_t=i | Y_{1:T}) = \frac{1}{\mathcal{L}} \alpha_{i,t} \beta_{i,t} \\
\xi_{i,j,t} & = \mathbb{P}(X_t=i, X_{t+1}=j | Y_{1:T}) = \frac{1}{\mathcal{L}} \alpha_{i,t} a_{i,j} b_{j,t+1} \beta_{j,t+1}
\xi_{i,j,t} & = \mathbb{P}(X_t=i, X_{t+1}=j | Y_{1:T}) = \frac{1}{\mathcal{L}} \alpha_{i,t} a_{i,j,t} b_{j,t+1} \beta_{j,t+1}
\end{align*}
```

Expand All @@ -75,7 +77,7 @@ According to [Qin2000](@cite), derivatives of the likelihood can be obtained as
\frac{\partial \mathcal{L}}{\partial \pi_i} &= \beta_{i,1} b_{i,1} \\
\frac{\partial \mathcal{L}}{\partial a_{i,j}} &= \sum_{t=1}^{T-1} \alpha_{i,t} b_{j,t+1} \beta_{j,t+1} \\
\frac{\partial \mathcal{L}}{\partial b_{j,1}} &= \pi_j \beta_{j,1} \\
\frac{\partial \mathcal{L}}{\partial b_{j,t}} &= \left(\sum_{i=1}^N \alpha_{i,t-1} a_{i,j}\right) \beta_{j,t}
\frac{\partial \mathcal{L}}{\partial b_{j,t}} &= \left(\sum_{i=1}^N \alpha_{i,t-1} a_{i,j,t-1}\right) \beta_{j,t}
\end{align*}
```

Expand All @@ -98,8 +100,8 @@ and satisfy the dynamic programming equations

```math
\begin{align*}
\hat{\alpha}_{j,t+1} & = \left(\sum_{i=1}^N \bar{\alpha}_{i,t} a_{i,j}\right) \frac{b_{j,t+1}}{m_{t+1}} & c_{t+1} & = \frac{1}{\sum_j \hat{\alpha}_{j,t+1}} & \bar{\alpha}_{j,t+1} = c_{t+1} \hat{\alpha}_{j,t+1} \\
\hat{\beta}_{i,t} & = \sum_{j=1}^N a_{i,j} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} & && \bar{\beta}_{j,t} = c_t \hat{\beta}_{j,t}
\hat{\alpha}_{j,t+1} & = \left(\sum_{i=1}^N \bar{\alpha}_{i,t} a_{i,j,t}\right) \frac{b_{j,t+1}}{m_{t+1}} & c_{t+1} & = \frac{1}{\sum_j \hat{\alpha}_{j,t+1}} & \bar{\alpha}_{j,t+1} = c_{t+1} \hat{\alpha}_{j,t+1} \\
\hat{\beta}_{i,t} & = \sum_{j=1}^N a_{i,j,t} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} & && \bar{\beta}_{j,t} = c_t \hat{\beta}_{j,t}
\end{align*}
```

Expand Down Expand Up @@ -140,9 +142,9 @@ We can now express the marginals using scaled variables:
```math
\begin{align*}
\xi_{i,j,t} & = \frac{1}{\mathcal{L}} \alpha_{i,t} a_{i,j} b_{j,t+1} \beta_{j,t+1} \\
&= \frac{1}{\mathcal{L}} \left(\bar{\alpha}_{i,t} \prod_{s=1}^t \frac{m_s}{c_s}\right) a_{i,j} b_{j,t+1} \left(\bar{\beta}_{j,t+1} \frac{1}{c_{t+1}} \prod_{s=t+2}^T \frac{m_s}{c_s}\right) \\
&= \frac{1}{\mathcal{L}} \bar{\alpha}_{i,t} a_{i,j} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} \left(\prod_{s=1}^T \frac{m_s}{c_s}\right) \\
&= \bar{\alpha}_{i,t} a_{i,j} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1}
&= \frac{1}{\mathcal{L}} \left(\bar{\alpha}_{i,t} \prod_{s=1}^t \frac{m_s}{c_s}\right) a_{i,j,t} b_{j,t+1} \left(\bar{\beta}_{j,t+1} \frac{1}{c_{t+1}} \prod_{s=t+2}^T \frac{m_s}{c_s}\right) \\
&= \frac{1}{\mathcal{L}} \bar{\alpha}_{i,t} a_{i,j,t} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} \left(\prod_{s=1}^T \frac{m_s}{c_s}\right) \\
&= \bar{\alpha}_{i,t} a_{i,j,t} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1}
\end{align*}
```

Expand Down Expand Up @@ -179,10 +181,10 @@ And for the statewise observation likelihoods,

```math
\begin{align*}
\frac{\partial \mathcal{L}}{\partial b_{j,t}} &= \left(\sum_{i=1}^N \alpha_{i,t-1} a_{i,j}\right) \beta_{j,t} \\
&= \sum_{i=1}^N \left(\bar{\alpha}_{i,t-1} \prod_{s=1}^{t-1} \frac{m_s}{c_s}\right) a_{i,j} \left(\bar{\beta}_{j,t} \frac{1}{c_t} \prod_{s=t+1}^T \frac{m_s}{c_s} \right) \\
&= \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j} \bar{\beta}_{j,t} \frac{1}{m_t} \left(\prod_{s=1}^T \frac{m_s}{c_s}\right) \\
&= \mathcal{L} \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j} \bar{\beta}_{j,t} \frac{1}{m_t} \\
\frac{\partial \mathcal{L}}{\partial b_{j,t}} &= \left(\sum_{i=1}^N \alpha_{i,t-1} a_{i,j,t-1}\right) \beta_{j,t} \\
&= \sum_{i=1}^N \left(\bar{\alpha}_{i,t-1} \prod_{s=1}^{t-1} \frac{m_s}{c_s}\right) a_{i,j,t-1} \left(\bar{\beta}_{j,t} \frac{1}{c_t} \prod_{s=t+1}^T \frac{m_s}{c_s} \right) \\
&= \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j,t-1} \bar{\beta}_{j,t} \frac{1}{m_t} \left(\prod_{s=1}^T \frac{m_s}{c_s}\right) \\
&= \mathcal{L} \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j,t-1} \bar{\beta}_{j,t} \frac{1}{m_t} \\
\end{align*}
```

Expand All @@ -199,7 +201,7 @@ To sum up,
\frac{\partial \log \mathcal{L}}{\partial \pi_i} &= \frac{b_{i,1}}{m_1} \bar{\beta}_{i,1} \\
\frac{\partial \log \mathcal{L}}{\partial a_{i,j}} &= \sum_{t=1}^{T-1} \bar{\alpha}_{i,t} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} \\
\frac{\partial \log \mathcal{L}}{\partial \log b_{j,1}} &= \pi_j \frac{b_{j,1}}{m_1} \bar{\beta}_{j,1} = \frac{\bar{\alpha}_{j,1} \bar{\beta}_{j,1}}{c_1} = \gamma_{j,1} \\
\frac{\partial \log \mathcal{L}}{\partial \log b_{j,t}} &= \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j} \frac{b_{j,t}}{m_t} \bar{\beta}_{j,t} = \frac{\bar{\alpha}_{j,t} \bar{\beta}_{j,t}}{c_t} = \gamma_{j,t}
\frac{\partial \log \mathcal{L}}{\partial \log b_{j,t}} &= \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j,t-1} \frac{b_{j,t}}{m_t} \bar{\beta}_{j,t} = \frac{\bar{\alpha}_{j,t} \bar{\beta}_{j,t}}{c_t} = \gamma_{j,t}
\end{align*}
```

Expand Down
Loading

0 comments on commit ba666fb

Please sign in to comment.