Skip to content

Commit

Permalink
Better docs and numerical stability
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Feb 22, 2024
1 parent 31de15c commit 9657939
Show file tree
Hide file tree
Showing 36 changed files with 452 additions and 370 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Guillaume Dalle"]
version = "0.4.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -13,6 +14,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[weakdeps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -21,6 +23,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
HiddenMarkovModelsDistributionsExt = "Distributions"

[compat]
ArgCheck = "2.3"
ChainRulesCore = "1.16"
DensityInterface = "0.4"
Distributions = "0.25"
Expand All @@ -31,4 +34,5 @@ PrecompileTools = "1.1"
Random = "1"
SparseArrays = "1"
StatsAPI = "1.6"
StatsFuns = "1.3"
julia = "1.9"
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ Take a look at the [documentation](https://gdalle.github.io/HiddenMarkovModels.j

[Hidden Markov Models](https://en.wikipedia.org/wiki/Hidden_Markov_model) (HMMs) are a widely used modeling framework in signal processing, bioinformatics and plenty of other fields.
They explain an observation sequence $(Y_t)$ by assuming the existence of a latent Markovian state sequence $(X_t)$ whose current value determines the distribution of observations.
In our framework, both the state and the observation sequence are also allowed to depend on a known control sequence $(U_t)$.
Each of the problems below has an efficient solution algorithm which our package implements:
In some scenarios, the state and the observation sequence are also allowed to depend on a known control sequence $(U_t)$.
Each of the problems below has an efficient solution algorithm, available here:

| Problem | Goal | Algorithm |
| ---------- | -------------------------------------- | ---------------- |
Expand All @@ -47,6 +47,10 @@ Each of the problems below has an efficient solution algorithm which our package
| Decoding | Most likely state sequence | Viterbi |
| Learning | Maximum likelihood parameter | Baum-Welch |

Take a look at this tutorial to know more about the math:

> [_A tutorial on hidden Markov models and selected applications in speech recognition_](https://ieeexplore.ieee.org/document/18626), Rabiner (1989)
## Main features

This package is **generic**.
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand Down
15 changes: 7 additions & 8 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,20 @@ function literate_title(path)
end

pages = [
"Home" => "index.md",
"API reference" => "api.md",
"First steps" => [
"Home" => "index.md",
"Alternatives" => "alternatives.md",
"API reference" => "api.md",
],
"Tutorials" => [
"Basics" => joinpath("examples", "basics.md"),
"Types" => joinpath("examples", "types.md"),
"Interfaces" => joinpath("examples", "interfaces.md"),
"Autodiff" => joinpath("examples", "autodiff.md"),
"Time dependency" => joinpath("examples", "temporal.md"),
"Control dependency" => joinpath("examples", "controlled.md"),
"Autodiff" => joinpath("examples", "autodiff.md"),
],
"Advanced" => [
"Alternatives" => "alternatives.md",
"Debugging" => "debugging.md",
"Formulas" => "formulas.md",
],
"Advanced" => ["Debugging" => "debugging.md", "Formulas" => "formulas.md"],
]

fmt = Documenter.HTML(;
Expand Down
59 changes: 42 additions & 17 deletions docs/src/alternatives.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,49 @@ We compare features among the following Julia packages:
We discard [MarkovModels.jl](https://github.com/FAST-ASR/MarkovModels.jl) because its focus is GPU computation.
There are also more generic packages for probabilistic programming, which are able to perform MCMC or variational inference (eg. [Turing.jl](https://github.com/TuringLang/Turing.jl)) but we leave those aside.

| | HMMs.jl | HMMBase.jl | HMMGradients.jl |
| ------------------------- | ------------------- | ------------------- | --------------- |
| Algorithms | Sim, FB, Vit, BW | Sim, FB, Vit, BW | FB |
| Observation types | anything | `Number` / `Vector` | anything |
| Observation distributions | DensityInterface.jl | Distributions.jl | manual |
| Multiple sequences | yes | no | yes |
| Priors / structures | possible | no | possible |
| Temporal dependency | yes | no | no |
| Control dependency | yes | no | no |
| Number types | anything | `Float64` | `AbstractFloat` |
| Automatic differentiation | yes | no | yes |
| Linear algebra | yes | yes | no |
| Logarithmic probabilities | halfway | halfway | yes |

Sim = Simulation, FB = Forward-Backward, Vit = Viterbi, BW = Baum-Welch
| | HMMs.jl | HMMBase.jl | HMMGradients.jl |
| ------------------------- | ------------------- | ---------------- | --------------- |
| Algorithms[^1] | V, FB, BW | V, FB, BW | FB |
| Number types | anything | `Float64` | `AbstractFloat` |
| Observation types | anything | number or vector | anything |
| Observation distributions | DensityInterface.jl | Distributions.jl | manual |
| Multiple sequences | yes | no | yes |
| Priors / structures | possible | no | possible |
| Temporal dependency | yes | no | no |
| Control dependency | yes | no | no |
| Automatic differentiation | yes | no | yes |
| Linear algebra speedup | yes | yes | no |
| Numerical stability | scaling+ | scaling+ | log |


!!! info "Very small probabilities"
In all HMM algorithms, we work with probabilities that may become very small as time progresses.
There are two main solutions for this problem: scaling and logarithmic computations.
This package implements the Viterbi algorithm in log scale, but the other algorithms use scaling to exploit BLAS operations.
As was done in HMMBase.jl, we enhance scaling with a division by the highest observation loglikelihood: instead of working with $b_{i,t} = \mathbb{P}(Y_t | X_t = i)$, we use $b_{i,t} / \max_i b_{i,t}$.
See [Formulas](@ref) for details.

## Python

We compare features among the following Python packages:

* [hmmlearn](https://github.com/hmmlearn/hmmlearn) (based on NumPy)
* [pomegrnate](https://github.com/jmschrei/pomegranate) (based on PyTorch)
* [dynamax](https://github.com/probml/dynamax) (based on JAX)
* [pomegranate](https://github.com/jmschrei/pomegranate) (based on PyTorch)
* [dynamax](https://github.com/probml/dynamax) (based on JAX)

| | hmmlearn | pomegranate | dynamax |
| ------------------------- | -------------------- | --------------------- | -------------------- |
| Algorithms[^1] | V, FB, BW, VI | V, FB, BW | FB, V, BW, GD |
| Number types | NumPy format | PyTorch format | JAX format |
| Observation types | number or vector | number or vector | number or vector |
| Observation distributions | discrete or Gaussian | pomegranate catalogue | discrete or Gaussian |
| Multiple sequences | yes | yes | yes |
| Priors / structures | yes | no | ? |
| Temporal dependency | no | no | no |
| Control dependency | no | no | no |
| Automatic differentiation | no | yes | yes |
| Linear algebra speedup | yes | yes | yes |
| Logarithmic probabilities | scaling / log | log | log |


[^1]: V = Viterbi, FB = Forward-Backward, BW = Baum-Welch, VI = Variational Inference, GD = Gradient Descent
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ HiddenMarkovModels

## Sequence formatting

Most algorithms below ingest the data with three (keyword) arguments: `obs_seq`, `control_seq` and `seq_ends`.
Most algorithms below ingest the data with two positional arguments `obs_seq` (mandatory) and `control_seq` (optional), and a keyword argument `seq_ends` (optional).

- If the data consists of a single sequence, `obs_seq` and `control_seq` are the corresponding vectors of observations and controls, and you don't need to provide `seq_ends`.
- If the data consists of multiple sequences, `obs_seq` and `control_seq` are concatenations of several vectors, whose end indices are given by `seq_ends`. Starting from separate sequences `obs_seqs` and `control_seqs`, you can run the following snippet:
Expand Down
21 changes: 15 additions & 6 deletions docs/src/debugging.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,31 @@
# Debugging

## Numerical overflow
## Numerical underflow

The most frequent error you will encounter is an `OverflowError` during inference, telling you that some values are infinite or `NaN`.
The most frequent error you will encounter is an underflow during inference, caused by some values being infinite or `NaN`.
This can happen for a variety of reasons, so here are a few leads worth investigating:

* Increase the duration of the sequence / the number of sequences to get more data
* Add a prior to your transition matrix / observation distributions to avoid degenerate behavior like zero variance in a Gaussian
* Add a prior to your transition matrix / observation distributions to avoid degenerate behavior (like zero variance in a Gaussian or zero probability in a Bernoulli)
* Reduce the number of states to make every one of them useful
* Pick a better initialization to start closer to the supposed ground truth
* Use numerically stable number types (such as [LogarithmicNumbers.jl](https://github.com/cjdoris/LogarithmicNumbers.jl)) in strategic places, but beware: these numbers don't play nicely with Distributions.jl, so you may have to roll out your own observation distributions.
* Use numerically stable number types (such as [LogarithmicNumbers.jl](https://github.com/cjdoris/LogarithmicNumbers.jl)) in strategic places, but beware: these numbers don't play nicely with Distributions.jl, so you may have to roll out your own [Custom distributions](@ref).

## Method errors

This might be caused by:

* forgetting to define methods for your custom type
* omitting `control_seq` or `seq_ends` in some places.

Check the [API reference](@ref).

## Performance

If your algorithms are too slow, the general advice always applies:
If your algorithms are too slow, you can leverage the existing [Interfaces](@ref) to improve the components of your model separately (first observation distributions, then fitting).
The usual advice always applies:

* Use [BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl) to establish a baseline
* Use profiling to see where you spend most of your time
* Use [JET.jl](https://github.com/aviatesk/JET.jl) to track down type instabilities
* Use [AllocCheck.jl](https://github.com/JuliaLang/AllocCheck.jl) to reduce allocations

14 changes: 7 additions & 7 deletions examples/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ Enzyme.API.runtimeActivity!(true)

#-

rng = StableRNG(63)
rng = StableRNG(63);

# ## Data generation

init = [0.8, 0.2]
init = [0.6, 0.4]
trans = [0.7 0.3; 0.3 0.7]
means = [-1.0, 1.0]
dists = Normal.(means)
Expand All @@ -41,9 +41,9 @@ seq_ends = cumsum(length.(obs_seqs));
# ## Forward mode

#=
Since all of our code is type-generic, it is amenable to forward-mode automatic differentiation with ForwardDiff.jl.
Since all of our code is type-generic, it is amenable to forward-mode automatic differentiation with [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl).
Because this backend only accepts a single vector input, we wrap all parameters with ComponentArrays.jl, and define a new function to differentiate.
Because this backend only accepts a single vector input, we wrap all parameters with [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl), and define a new function to differentiate.
=#

params = ComponentVector(; init, trans, means)
Expand All @@ -64,7 +64,7 @@ grad_f = ForwardDiff.gradient(f, params)

#=
In the presence of many parameters, reverse mode automatic differentiation of the loglikelihood will be much more efficient.
The package includes a chain rule for `logdensityof`, which means backends like Zygote.jl can be used out of the box.
The package includes a chain rule for `logdensityof`, which means backends like [Zygote.jl](https://github.com/FluxML/Zygote.jl) can be used out of the box.
=#

grad_z = Zygote.gradient(f, params)[1]
Expand All @@ -74,7 +74,7 @@ grad_z = Zygote.gradient(f, params)[1]
grad_f grad_z

#=
Enzyme.jl also works natively but we have to avoid the type instability of global variables by providing more information.
The more efficient [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) also works natively but we have to avoid the type instability of global variables.
=#

function f_extended(params::ComponentVector, obs_seq, seq_ends)
Expand All @@ -90,7 +90,7 @@ Enzyme.autodiff(
Enzyme.Active,
Enzyme.Duplicated(params, shadow_params),
Enzyme.Const(obs_seq),
Enzyme.Const(seq_ends),
Enzyme.Duplicated(seq_ends, Enzyme.make_zero(seq_ends)),
)

grad_e = shadow_params
Expand Down
Loading

0 comments on commit 9657939

Please sign in to comment.