From 965793977de25203b8caf516b8c42d5b555041de Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 22 Feb 2024 12:41:38 +0100 Subject: [PATCH 1/5] Better docs and numerical stability --- Project.toml | 4 + README.md | 8 +- docs/Project.toml | 1 + docs/make.jl | 15 ++-- docs/src/alternatives.md | 59 +++++++++---- docs/src/api.md | 2 +- docs/src/debugging.md | 21 +++-- examples/autodiff.jl | 14 +-- examples/basics.jl | 113 ++++++++++++++++-------- examples/controlled.jl | 10 ++- examples/interfaces.jl | 59 +++++++------ examples/temporal.jl | 29 +++--- examples/types.jl | 111 +++++++++++++++++------ libs/HMMComparison/src/HMMComparison.jl | 1 + libs/HMMComparison/src/dynamax.jl | 8 +- libs/HMMComparison/src/hmmbase.jl | 21 +++-- libs/HMMTest/src/allocations.jl | 8 +- libs/HMMTest/src/coherence.jl | 35 +++----- libs/HMMTest/src/hmmbase.jl | 10 ++- libs/HMMTest/src/jet.jl | 4 +- src/HiddenMarkovModels.jl | 4 +- src/inference/baum_welch.jl | 14 ++- src/inference/forward.jl | 4 +- src/inference/forward_backward.jl | 7 +- src/inference/viterbi.jl | 25 +++--- src/types/abstract_hmm.jl | 16 ++-- src/types/hmm.jl | 7 +- src/utils/check.jl | 77 ---------------- src/utils/lightcategorical.jl | 6 +- src/utils/lightdiagnormal.jl | 28 +++--- src/utils/linalg.jl | 6 +- src/utils/probvec_transmat.jl | 21 ----- src/utils/valid.jl | 28 ++++++ test/Project.toml | 1 + test/correctness.jl | 42 ++++----- test/distributions.jl | 3 +- 36 files changed, 452 insertions(+), 370 deletions(-) delete mode 100644 src/utils/check.jl create mode 100644 src/utils/valid.jl diff --git a/Project.toml b/Project.toml index 7689d2bf..b9f25529 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Guillaume Dalle"] version = "0.4.0" [deps] +ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -13,6 +14,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [weakdeps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -21,6 +23,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" HiddenMarkovModelsDistributionsExt = "Distributions" [compat] +ArgCheck = "2.3" ChainRulesCore = "1.16" DensityInterface = "0.4" Distributions = "0.25" @@ -31,4 +34,5 @@ PrecompileTools = "1.1" Random = "1" SparseArrays = "1" StatsAPI = "1.6" +StatsFuns = "1.3" julia = "1.9" diff --git a/README.md b/README.md index 137ba1f5..9f78227e 100644 --- a/README.md +++ b/README.md @@ -36,8 +36,8 @@ Take a look at the [documentation](https://gdalle.github.io/HiddenMarkovModels.j [Hidden Markov Models](https://en.wikipedia.org/wiki/Hidden_Markov_model) (HMMs) are a widely used modeling framework in signal processing, bioinformatics and plenty of other fields. They explain an observation sequence $(Y_t)$ by assuming the existence of a latent Markovian state sequence $(X_t)$ whose current value determines the distribution of observations. -In our framework, both the state and the observation sequence are also allowed to depend on a known control sequence $(U_t)$. -Each of the problems below has an efficient solution algorithm which our package implements: +In some scenarios, the state and the observation sequence are also allowed to depend on a known control sequence $(U_t)$. +Each of the problems below has an efficient solution algorithm, available here: | Problem | Goal | Algorithm | | ---------- | -------------------------------------- | ---------------- | @@ -47,6 +47,10 @@ Each of the problems below has an efficient solution algorithm which our package | Decoding | Most likely state sequence | Viterbi | | Learning | Maximum likelihood parameter | Baum-Welch | +Take a look at this tutorial to know more about the math: + +> [_A tutorial on hidden Markov models and selected applications in speech recognition_](https://ieeexplore.ieee.org/document/18626), Rabiner (1989) + ## Main features This package is **generic**. diff --git a/docs/Project.toml b/docs/Project.toml index 46f28707..01d37223 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -10,6 +10,7 @@ HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899" +Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/docs/make.jl b/docs/make.jl index d30cfa25..8093e1d7 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -42,21 +42,20 @@ function literate_title(path) end pages = [ - "Home" => "index.md", - "API reference" => "api.md", + "First steps" => [ + "Home" => "index.md", + "Alternatives" => "alternatives.md", + "API reference" => "api.md", + ], "Tutorials" => [ "Basics" => joinpath("examples", "basics.md"), "Types" => joinpath("examples", "types.md"), "Interfaces" => joinpath("examples", "interfaces.md"), - "Autodiff" => joinpath("examples", "autodiff.md"), "Time dependency" => joinpath("examples", "temporal.md"), "Control dependency" => joinpath("examples", "controlled.md"), + "Autodiff" => joinpath("examples", "autodiff.md"), ], - "Advanced" => [ - "Alternatives" => "alternatives.md", - "Debugging" => "debugging.md", - "Formulas" => "formulas.md", - ], + "Advanced" => ["Debugging" => "debugging.md", "Formulas" => "formulas.md"], ] fmt = Documenter.HTML(; diff --git a/docs/src/alternatives.md b/docs/src/alternatives.md index d1d48c8e..21aa3792 100644 --- a/docs/src/alternatives.md +++ b/docs/src/alternatives.md @@ -11,24 +11,49 @@ We compare features among the following Julia packages: We discard [MarkovModels.jl](https://github.com/FAST-ASR/MarkovModels.jl) because its focus is GPU computation. There are also more generic packages for probabilistic programming, which are able to perform MCMC or variational inference (eg. [Turing.jl](https://github.com/TuringLang/Turing.jl)) but we leave those aside. -| | HMMs.jl | HMMBase.jl | HMMGradients.jl | -| ------------------------- | ------------------- | ------------------- | --------------- | -| Algorithms | Sim, FB, Vit, BW | Sim, FB, Vit, BW | FB | -| Observation types | anything | `Number` / `Vector` | anything | -| Observation distributions | DensityInterface.jl | Distributions.jl | manual | -| Multiple sequences | yes | no | yes | -| Priors / structures | possible | no | possible | -| Temporal dependency | yes | no | no | -| Control dependency | yes | no | no | -| Number types | anything | `Float64` | `AbstractFloat` | -| Automatic differentiation | yes | no | yes | -| Linear algebra | yes | yes | no | -| Logarithmic probabilities | halfway | halfway | yes | - -Sim = Simulation, FB = Forward-Backward, Vit = Viterbi, BW = Baum-Welch +| | HMMs.jl | HMMBase.jl | HMMGradients.jl | +| ------------------------- | ------------------- | ---------------- | --------------- | +| Algorithms[^1] | V, FB, BW | V, FB, BW | FB | +| Number types | anything | `Float64` | `AbstractFloat` | +| Observation types | anything | number or vector | anything | +| Observation distributions | DensityInterface.jl | Distributions.jl | manual | +| Multiple sequences | yes | no | yes | +| Priors / structures | possible | no | possible | +| Temporal dependency | yes | no | no | +| Control dependency | yes | no | no | +| Automatic differentiation | yes | no | yes | +| Linear algebra speedup | yes | yes | no | +| Numerical stability | scaling+ | scaling+ | log | + + +!!! info "Very small probabilities" + In all HMM algorithms, we work with probabilities that may become very small as time progresses. + There are two main solutions for this problem: scaling and logarithmic computations. + This package implements the Viterbi algorithm in log scale, but the other algorithms use scaling to exploit BLAS operations. + As was done in HMMBase.jl, we enhance scaling with a division by the highest observation loglikelihood: instead of working with $b_{i,t} = \mathbb{P}(Y_t | X_t = i)$, we use $b_{i,t} / \max_i b_{i,t}$. + See [Formulas](@ref) for details. ## Python +We compare features among the following Python packages: + * [hmmlearn](https://github.com/hmmlearn/hmmlearn) (based on NumPy) -* [pomegrnate](https://github.com/jmschrei/pomegranate) (based on PyTorch) -* [dynamax](https://github.com/probml/dynamax) (based on JAX) \ No newline at end of file +* [pomegranate](https://github.com/jmschrei/pomegranate) (based on PyTorch) +* [dynamax](https://github.com/probml/dynamax) (based on JAX) + +| | hmmlearn | pomegranate | dynamax | +| ------------------------- | -------------------- | --------------------- | -------------------- | +| Algorithms[^1] | V, FB, BW, VI | V, FB, BW | FB, V, BW, GD | +| Number types | NumPy format | PyTorch format | JAX format | +| Observation types | number or vector | number or vector | number or vector | +| Observation distributions | discrete or Gaussian | pomegranate catalogue | discrete or Gaussian | +| Multiple sequences | yes | yes | yes | +| Priors / structures | yes | no | ? | +| Temporal dependency | no | no | no | +| Control dependency | no | no | no | +| Automatic differentiation | no | yes | yes | +| Linear algebra speedup | yes | yes | yes | +| Logarithmic probabilities | scaling / log | log | log | + + +[^1]: V = Viterbi, FB = Forward-Backward, BW = Baum-Welch, VI = Variational Inference, GD = Gradient Descent diff --git a/docs/src/api.md b/docs/src/api.md index 4e93cef2..d627d1fa 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -6,7 +6,7 @@ HiddenMarkovModels ## Sequence formatting -Most algorithms below ingest the data with three (keyword) arguments: `obs_seq`, `control_seq` and `seq_ends`. +Most algorithms below ingest the data with two positional arguments `obs_seq` (mandatory) and `control_seq` (optional), and a keyword argument `seq_ends` (optional). - If the data consists of a single sequence, `obs_seq` and `control_seq` are the corresponding vectors of observations and controls, and you don't need to provide `seq_ends`. - If the data consists of multiple sequences, `obs_seq` and `control_seq` are concatenations of several vectors, whose end indices are given by `seq_ends`. Starting from separate sequences `obs_seqs` and `control_seqs`, you can run the following snippet: diff --git a/docs/src/debugging.md b/docs/src/debugging.md index 5e274991..a665fa25 100644 --- a/docs/src/debugging.md +++ b/docs/src/debugging.md @@ -1,22 +1,31 @@ # Debugging -## Numerical overflow +## Numerical underflow -The most frequent error you will encounter is an `OverflowError` during inference, telling you that some values are infinite or `NaN`. +The most frequent error you will encounter is an underflow during inference, caused by some values being infinite or `NaN`. This can happen for a variety of reasons, so here are a few leads worth investigating: * Increase the duration of the sequence / the number of sequences to get more data -* Add a prior to your transition matrix / observation distributions to avoid degenerate behavior like zero variance in a Gaussian +* Add a prior to your transition matrix / observation distributions to avoid degenerate behavior (like zero variance in a Gaussian or zero probability in a Bernoulli) * Reduce the number of states to make every one of them useful * Pick a better initialization to start closer to the supposed ground truth -* Use numerically stable number types (such as [LogarithmicNumbers.jl](https://github.com/cjdoris/LogarithmicNumbers.jl)) in strategic places, but beware: these numbers don't play nicely with Distributions.jl, so you may have to roll out your own observation distributions. +* Use numerically stable number types (such as [LogarithmicNumbers.jl](https://github.com/cjdoris/LogarithmicNumbers.jl)) in strategic places, but beware: these numbers don't play nicely with Distributions.jl, so you may have to roll out your own [Custom distributions](@ref). + +## Method errors + +This might be caused by: + +* forgetting to define methods for your custom type +* omitting `control_seq` or `seq_ends` in some places. + +Check the [API reference](@ref). ## Performance -If your algorithms are too slow, the general advice always applies: +If your algorithms are too slow, you can leverage the existing [Interfaces](@ref) to improve the components of your model separately (first observation distributions, then fitting). +The usual advice always applies: * Use [BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl) to establish a baseline * Use profiling to see where you spend most of your time * Use [JET.jl](https://github.com/aviatesk/JET.jl) to track down type instabilities * Use [AllocCheck.jl](https://github.com/JuliaLang/AllocCheck.jl) to reduce allocations - diff --git a/examples/autodiff.jl b/examples/autodiff.jl index 1928facf..5928f4d9 100644 --- a/examples/autodiff.jl +++ b/examples/autodiff.jl @@ -22,11 +22,11 @@ Enzyme.API.runtimeActivity!(true) #- -rng = StableRNG(63) +rng = StableRNG(63); # ## Data generation -init = [0.8, 0.2] +init = [0.6, 0.4] trans = [0.7 0.3; 0.3 0.7] means = [-1.0, 1.0] dists = Normal.(means) @@ -41,9 +41,9 @@ seq_ends = cumsum(length.(obs_seqs)); # ## Forward mode #= -Since all of our code is type-generic, it is amenable to forward-mode automatic differentiation with ForwardDiff.jl. +Since all of our code is type-generic, it is amenable to forward-mode automatic differentiation with [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). -Because this backend only accepts a single vector input, we wrap all parameters with ComponentArrays.jl, and define a new function to differentiate. +Because this backend only accepts a single vector input, we wrap all parameters with [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl), and define a new function to differentiate. =# params = ComponentVector(; init, trans, means) @@ -64,7 +64,7 @@ grad_f = ForwardDiff.gradient(f, params) #= In the presence of many parameters, reverse mode automatic differentiation of the loglikelihood will be much more efficient. -The package includes a chain rule for `logdensityof`, which means backends like Zygote.jl can be used out of the box. +The package includes a chain rule for `logdensityof`, which means backends like [Zygote.jl](https://github.com/FluxML/Zygote.jl) can be used out of the box. =# grad_z = Zygote.gradient(f, params)[1] @@ -74,7 +74,7 @@ grad_z = Zygote.gradient(f, params)[1] grad_f ≈ grad_z #= -Enzyme.jl also works natively but we have to avoid the type instability of global variables by providing more information. +The more efficient [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) also works natively but we have to avoid the type instability of global variables. =# function f_extended(params::ComponentVector, obs_seq, seq_ends) @@ -90,7 +90,7 @@ Enzyme.autodiff( Enzyme.Active, Enzyme.Duplicated(params, shadow_params), Enzyme.Const(obs_seq), - Enzyme.Const(seq_ends), + Enzyme.Duplicated(seq_ends, Enzyme.make_zero(seq_ends)), ) grad_e = shadow_params diff --git a/examples/basics.jl b/examples/basics.jl index 70f80c7e..79168429 100644 --- a/examples/basics.jl +++ b/examples/basics.jl @@ -14,7 +14,7 @@ using Test #src #- -rng = StableRNG(63) +rng = StableRNG(63); # ## Model @@ -24,13 +24,12 @@ The package provides a versatile [`HMM`](@ref) type with three attributes: - a matrix of state transition probabilities - a vector of observation distributions, one for each state -We keep it simple for now by leveraging Distributions.jl. +Any scalar- or vector-valued distribution from [Distributions.jl](https://github.com/JuliaStats/Distributions.jl) can be used for the last part, as well as [Custom distributions](@ref). =# -d = 3 -init = [0.8, 0.2] +init = [0.6, 0.4] trans = [0.7 0.3; 0.3 0.7] -dists = [MvNormal(-1.0 * ones(d), I), MvNormal(+1.0 * ones(d), I)] +dists = [MvNormal([-0.5, -0.8], I), MvNormal([0.5, 0.8], I)] hmm = HMM(init, trans, dists); # ## Simulation @@ -39,13 +38,25 @@ hmm = HMM(init, trans, dists); You can simulate a pair of state and observation sequences with [`rand`](@ref) by specifying how long you want them to be. =# -state_seq, obs_seq = rand(rng, hmm, 20); +T = 20 +state_seq, obs_seq = rand(rng, hmm, T); #= -Note that the observation sequence is a vector, whose elements have whatever type an observation distribution returns when sampled. +The state sequence is a vector of integers. =# -state_seq[1], obs_seq[1] +state_seq[1:3] + +#= +The observation sequence is a vector whose elements have whatever type an observation distribution returns when sampled. +Here we chose a multivariate normal distribution, so we get vectors at each time step. + +!!! warning "Difference from HMMBase.jl" + In the case of multivariate observations, HMMBase.jl works with matrices, whereas HiddenMarkovModels.jl works with vectors of vectors. + This allows us to accept more generic observations than just numbers or vectors inside the sequence. +=# + +obs_seq[1:3] #= In practical applications, the state sequence is not known, which is why we need inference algorithms to gather information about it. @@ -54,35 +65,44 @@ In practical applications, the state sequence is not known, which is why we need # ## Inference #= -The Viterbi algorithm ([`viterbi`](@ref)) returns the most likely state sequence $\hat{X}_{1:T} = \underset{X_{1:T}}{\mathrm{argmax}}~\mathbb{P}(X_{1:T} \vert Y_{1:T})$, along with the joint loglikelihood $\mathbb{P}(\hat{X}_{1:T}, Y_{1:T})$ (in a vector of size 1). +The **Viterbi algorithm** ([`viterbi`](@ref)) returns: +- the most likely state sequence $\hat{X}_{1:T} = \underset{X_{1:T}}{\mathrm{argmax}}~\mathbb{P}(X_{1:T} \vert Y_{1:T})$, +- the joint loglikelihood $\mathbb{P}(\hat{X}_{1:T}, Y_{1:T})$ (in a vector of size 1). =# best_state_seq, best_joint_loglikelihood = viterbi(hmm, obs_seq); +only(best_joint_loglikelihood) #= -As we can see, it is very close to the true state sequence, but not necessarily equal. +As we can see, the most likely state sequence is very close to the true state sequence, but not necessarily equal. =# -vcat(state_seq', best_state_seq') +(state_seq .== best_state_seq)' #= -The forward algorithm ([`forward`](@ref)) returns a matrix of filtered state marginals $\alpha[i, t] = \mathbb{P}(X_t = i | Y_{1:t})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence (in a vector of size 1). +The **forward algorithm** ([`forward`](@ref)) returns: +- a matrix of filtered state marginals $\alpha[i, t] = \mathbb{P}(X_t = i | Y_{1:t})$, +- the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence (in a vector of size 1). =# -filtered_state_marginals, obs_seq_loglikelihood1 = forward(hmm, obs_seq); +filtered_state_marginals, obs_seq_loglikelihood_f = forward(hmm, obs_seq); +only(obs_seq_loglikelihood_f) #= -At each time $t$, it takes only the observations up to time $t$ into account. +At each time $t$, these filtered marginals take only the observations up to time $t$ into account. This is particularly useful to infer the marginal distribution of the last state. =# -filtered_state_marginals[:, end] +filtered_state_marginals[:, T] #= -Conversely, the forward-backward algorithm ([`forward_backward`](@ref)) returns a matrix of smoothed state marginals $\gamma[i, t] = \mathbb{P}(X_t = i | Y_{1:T})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence (in a vector of size 1). +The forward-backward algorithm ([`forward_backward`](@ref)) returns: +- a matrix of smoothed state marginals $\gamma[i, t] = \mathbb{P}(X_t = i | Y_{1:T})$, +- the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence (in a vector of size 1). =# -smoothed_state_marginals, obs_seq_loglikelihood2 = forward_backward(hmm, obs_seq); +smoothed_state_marginals, obs_seq_loglikelihood_fb = forward_backward(hmm, obs_seq); +only(obs_seq_loglikelihood_fb) #= At each time $t$, it takes all observations up to time $T$ into account. @@ -90,7 +110,11 @@ This is particularly useful during learning. Note that forward and forward-backward only coincide at the last time step. =# -collect(zip(filtered_state_marginals, smoothed_state_marginals)) +filtered_state_marginals[:, T - 1] ≈ smoothed_state_marginals[:, T - 1] + +#- + +filtered_state_marginals[:, T] ≈ smoothed_state_marginals[:, T] #= Finally, we provide a thin wrapper ([`logdensityof`](@ref)) around the forward algorithm for observation sequence loglikelihoods $\mathbb{P}(Y_{1:T})$. @@ -99,7 +123,7 @@ Finally, we provide a thin wrapper ([`logdensityof`](@ref)) around the forward a logdensityof(hmm, obs_seq) #= -Another function can compute joint loglikelihoods $\mathbb{P}(X_{1:T}, Y_{1:T})$ which take the states into account. +Another function ([`joint_logdensityof`](@ref)) can compute joint loglikelihoods $\mathbb{P}(X_{1:T}, Y_{1:T})$ which take the states into account. =# joint_logdensityof(hmm, obs_seq, state_seq) @@ -117,16 +141,16 @@ The Baum-Welch algorithm ([`baum_welch`](@ref)) is a variant of Expectation-Maxi Since it is a local optimization procedure, it requires a starting point that is close enough to the true model. =# -init_guess = [0.7, 0.3] +init_guess = [0.5, 0.5] trans_guess = [0.6 0.4; 0.4 0.6] -dists_guess = [MvNormal(-0.7 * ones(d), I), MvNormal(+0.7 * ones(d), I)] +dists_guess = [MvNormal([-0.6, -0.7], I), MvNormal([0.6, 0.7], I)] hmm_guess = HMM(init_guess, trans_guess, dists_guess); #= Let's estimate parameters based on a slightly longer sequence. =# -_, long_obs_seq = rand(rng, hmm, 100) +_, long_obs_seq = rand(rng, hmm, 200) hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, long_obs_seq); #= @@ -145,7 +169,7 @@ cat(transition_matrix(hmm_est), transition_matrix(hmm); dims=3) And so have the estimates for the observation distributions. =# -map(dist -> dist.μ, hcat(obs_distributions(hmm_est), obs_distributions(hmm))) +map(mean, hcat(obs_distributions(hmm_est), obs_distributions(hmm))) #= On the other hand, the initialization is concentrated on one state. @@ -165,9 +189,9 @@ This is important to keep in mind when testing new models. In many applications, we have access to various observation sequences of different lengths. =# -_, long_obs_seq2 = rand(rng, hmm, 300) -_, long_obs_seq3 = rand(rng, hmm, 200) -long_obs_seqs = [long_obs_seq, long_obs_seq2, long_obs_seq3]; +nb_seqs = 100 +long_obs_seqs = [last(rand(rng, hmm, rand(rng, 100:200))) for k in 1:nb_seqs]; +typeof(long_obs_seqs) #= Every algorithm in the package accepts multiple sequences in a concatenated form. @@ -176,14 +200,28 @@ Otherwise, the input will be treated as a unique observation sequence, which is =# long_obs_seq_concat = reduce(vcat, long_obs_seqs) +typeof(long_obs_seq_concat) + +#- + seq_ends = cumsum(length.(long_obs_seqs)) +seq_ends' #= The outputs of inference algorithms are then concatenated, and the associated loglikelihoods are split by sequence (in a vector of size `length(seq_ends)`). =# -best_state_seq_concat, _ = viterbi(hmm, long_obs_seq_concat; seq_ends); -length(best_state_seq_concat) +best_state_seq_concat, best_joint_loglikelihood_concat = viterbi( + hmm, long_obs_seq_concat; seq_ends +); + +#- + +length(best_joint_loglikelihood_concat) == length(seq_ends) + +#- + +length(best_state_seq_concat) == last(seq_ends) #= The function [`seq_limits`](@ref) returns the begin and end of a given sequence in the concatenated vector. @@ -194,7 +232,7 @@ start2, stop2 = seq_limits(seq_ends, 2) #- -best_state_seq_concat[start2:stop2] == first(viterbi(hmm, long_obs_seq2)) +best_state_seq_concat[start2:stop2] == first(viterbi(hmm, long_obs_seqs[2])) #= While inference algorithms can also be run separately on each sequence without changing the results, considering multiple sequences together is nontrivial for Baum-Welch. @@ -211,14 +249,15 @@ cat(transition_matrix(hmm_est_concat), transition_matrix(hmm); dims=3) #- -map(dist -> dist.μ, hcat(obs_distributions(hmm_est_concat), obs_distributions(hmm))) +map(mean, hcat(obs_distributions(hmm_est_concat), obs_distributions(hmm))) -# ## Tests #src +#- -control_seqs = [fill(nothing, rand(rng, 100:200)) for k in 1:500]; #src -control_seq = reduce(vcat, control_seqs); #src -seq_ends = cumsum(length.(control_seqs)); #src +hcat(initialization(hmm_est_concat), initialization(hmm)) + +# ## Tests #src -test_identical_hmmbase(rng, hmm, hmm_guess; T=100) #src -test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05) #src -test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) #src +control_seq = fill(nothing, last(seq_ends)); #src +test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess) #src +test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/examples/controlled.jl b/examples/controlled.jl index cdffafdd..a3c9ca72 100644 --- a/examples/controlled.jl +++ b/examples/controlled.jl @@ -16,13 +16,13 @@ using Test #src #- -rng = StableRNG(63) +rng = StableRNG(63); # ## Model #= A Markov switching regression is like a classical regression, except that the weights depend on the unobserved state of an HMM. -We can represent it with the following subtype of `AbstractHMM`, which has one vector of coefficients $\beta_i$ per state. +We can represent it with the following subtype of `AbstractHMM` (see [Custom HMM structures](@ref)), which has one vector of coefficients $\beta_i$ per state. =# struct ControlledGaussianHMM{T} <: AbstractHMM @@ -148,5 +148,7 @@ hcat(hmm_est.dist_coeffs[2], hmm.dist_coeffs[2]) # ## Tests #src -test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.08, init=false) #src -test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) #src +@test hmm_est.dist_coeffs[1] ≈ hmm.dist_coeffs[1] atol = 0.05 #src +@test hmm_est.dist_coeffs[2] ≈ hmm.dist_coeffs[2] atol = 0.05 #src +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.08, init=false) #src +test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/examples/interfaces.jl b/examples/interfaces.jl index 3df60d58..703f8c36 100644 --- a/examples/interfaces.jl +++ b/examples/interfaces.jl @@ -17,7 +17,7 @@ using Test #src #- -rng = StableRNG(63) +rng = StableRNG(63); # ## Custom distributions @@ -28,7 +28,7 @@ They only need to implement three methods: - `DensityInterface.logdensityof(dist, obs)` for inference - `StatsAPI.fit!(dist, obs_seq, weight_seq)` for learning -In addition, the observation can be arbitrary Julia types. +In addition, the observations can be arbitrary Julia types. So let's construct a distribution that generates stuff. =# @@ -37,7 +37,7 @@ struct Stuff{T} end #= -The distribution will only be a wrapper for a normal distribution on the quantity. +The associated distribution will only be a wrapper for a normal distribution on the quantity. =# mutable struct StuffDist{T} @@ -54,7 +54,7 @@ function Random.rand(rng::AbstractRNG, dist::StuffDist) end #= -It is important to declare to DensityInterface.jl that the custom distribution has a density, thanks to the following trait. +It is important to declare to [DensityInterface.jl](https://github.com/JuliaMath/DensityInterface.jl) that the custom distribution has a density, thanks to the following trait. The logdensity itself can be computed up to an additive constant without issue. =# @@ -81,7 +81,7 @@ end Let's put it to the test. =# -init = [0.8, 0.2] +init = [0.6, 0.4] trans = [0.7 0.3; 0.3 0.7] dists = [StuffDist(-1.0), StuffDist(+1.0)] hmm = HMM(init, trans, dists); @@ -103,7 +103,7 @@ viterbi(hmm, obs_seq) If we implement `fit!`, Baum-Welch also works seamlessly. =# -init_guess = [0.7, 0.3] +init_guess = [0.5, 0.5] trans_guess = [0.6 0.4; 0.4 0.6] dists_guess = [StuffDist(-0.5), StuffDist(+0.5)] hmm_guess = HMM(init_guess, trans_guess, dists_guess); @@ -125,6 +125,14 @@ transition_matrix(hmm_est) If you want more sophisticated examples, check out [`HiddenMarkovModels.LightDiagNormal`](@ref) and [`HiddenMarkovModels.LightCategorical`](@ref), which are designed to be fast and allocation-free. =# +# ## Tests #src + +seq_ends = cumsum(rand(rng, 100:200, 100)); #src +control_seq = fill(nothing, last(seq_ends)); #src +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src +test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src +test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) #src + # ## Custom HMM structures #= @@ -145,9 +153,7 @@ struct PriorHMM{T,D} <: AbstractHMM end #= -The basic requirements for `AbstractHMM` are the following three functions. - -While [`initialization`](@ref) will always have the same signature, [`transition_matrix`](@ref) and [`obs_distributions`](@ref) can accept an additional `control` argument, as we will see later on. +The basic requirements for `AbstractHMM` are the following three functions: [`initialization`](@ref), [`transition_matrix`](@ref) and [`obs_distributions`](@ref). =# HiddenMarkovModels.initialization(hmm::PriorHMM) = hmm.init @@ -170,7 +176,7 @@ This function takes as inputs: - the `hmm` itself - a `fb_storage` of type [`HiddenMarkovModels.ForwardBackwardStorage`](@ref) containing the results of the forward-backward algorithm. -- the same inputs as `baum_welch` for multiple sequences (we haven't encountered `control_seq` yet but its role will become clear in other tutorials) +- the same inputs as `baum_welch` for multiple sequences The goal is to modify `hmm` in-place, updating parameters with their maximum likelihood estimates given current inference results. We will make use of the fields `fb_storage.γ` and `fb_storage.ξ`, which contain the state and transition marginals `γ[i, t]` and `ξ[t][i, j]` at each time step. @@ -179,8 +185,7 @@ We will make use of the fields `fb_storage.γ` and `fb_storage.ξ`, which contai function StatsAPI.fit!( hmm::PriorHMM, fb_storage::HiddenMarkovModels.ForwardBackwardStorage, - obs_seq::AbstractVector, - control_seq::AbstractVector; + obs_seq::AbstractVector; seq_ends::AbstractVector{Int}, ) ## initialize to defaults without observations @@ -207,19 +212,21 @@ function StatsAPI.fit!( end ## perform a few checks on the model - HMMs.check_hmm(hmm) + @assert HMMs.valid_hmm(hmm) return nothing end #= -Note that some distributions, such as those from Distributions.jl: -- do not support in-place fitting -- might expect different input formats, e.g. higher-order arrays instead of a vector of objects - -The function [`HiddenMarkovModels.fit_in_sequence!`](@ref) is a replacement for `fit!`, designed to handle Distributions.jl. -You can overload it for your own objects too if needed. +!!! warning "When distributions don't comply" + Note that some distributions, such as those from Distributions.jl: + - do not support in-place fitting + - expect different input formats, e.g. matrices instead of a vector of vectors + The function [`HiddenMarkovModels.fit_in_sequence!`](@ref) is a replacement for `fit!`, designed to handle Distributions.jl without committing type piracy. + Check out its source code, and overload it for your other distributions too if they do not support in-place fitting. +=# -Now let's see that everything works. +#= +Now let's see that everything works, even with our custom distribution from before. =# trans_prior_count = 10 @@ -236,12 +243,10 @@ As we can see, the transition matrix for our Bayesian version is slightly more s cat(transition_matrix(hmm_est), transition_matrix(prior_hmm_est); dims=3) -# ## Tests #src +#- -control_seqs = [fill(nothing, rand(rng, 100:200)) for k in 1:100]; #src -control_seq = reduce(vcat, control_seqs); #src -seq_ends = cumsum(length.(control_seqs)); #src +std(vec(transition_matrix(hmm_est))) < std(vec(transition_matrix(hmm))) + +# ## Tests #src -test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05, init=false) #src -test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) #src -test_allocations(rng, hmm, hmm_guess; control_seq, seq_ends) #src +@test std(vec(transition_matrix(hmm_est))) < std(vec(transition_matrix(hmm))) #src diff --git a/examples/temporal.jl b/examples/temporal.jl index 29252730..dee15c28 100644 --- a/examples/temporal.jl +++ b/examples/temporal.jl @@ -16,14 +16,14 @@ using Test #src #- -rng = StableRNG(63) +rng = StableRNG(63); # ## Model #= We focus on the particular case of a periodic HMM with period `L`. It has only one initialization vector, but `L` transition matrices and `L` vectors of observation distributions. -Once again we need to subtype `AbstractHMM`. +As in [Custom HMM structures](@ref), we need to subtype `AbstractHMM`. =# struct PeriodicHMM{T<:Number,D,L} <: AbstractHMM @@ -54,7 +54,7 @@ end # ## Simulation -init = [0.8, 0.2] +init = [0.6, 0.4] trans_per = ([0.7 0.3; 0.3 0.7], [0.3 0.7; 0.7 0.3]) dists_per = ([Normal(-1.0), Normal(-2.0)], [Normal(+1.0), Normal(+2.0)]) hmm = PeriodicHMM(init, trans_per, dists_per); @@ -65,16 +65,18 @@ Since the behavior of the model depends on control variables, we need to pass th control_seq = 1:10 state_seq, obs_seq = rand(rng, hmm, control_seq); -obs_seq' -@test mean(obs_seq[1:2:end]) < 0 < mean(obs_seq[2:2:end]) #src #= The observations mostly alternate between positive and negative values, which is coherent with negative observation means at odd times and positive observation means at even times. +=# +obs_seq' + +#= We now generate several sequences of variable lengths, for inference and learning tasks. =# -control_seqs = [1:rand(100:200) for k in 1:1000] +control_seqs = [1:rand(rng, 100:200) for k in 1:1000] obs_seqs = [rand(rng, hmm, control_seqs[k]).obs_seq for k in eachindex(control_seqs)]; obs_seq = reduce(vcat, obs_seqs) @@ -84,7 +86,7 @@ seq_ends = cumsum(length.(obs_seqs)); # ## Inference #= -All three inference algorithms work in the same way, except that we need to provide the control sequence as a keyword argument. +All three inference algorithms work in the same way, except that we need to provide the control sequence as the last positional argument. =# best_state_seq, _ = viterbi(hmm, obs_seq, control_seq; seq_ends) @@ -98,7 +100,7 @@ vcat(obs_seq', best_state_seq') # ## Learning #= -When estimating parameters for a custom subtype of `AbstractHMM`, we have to override the fitting procedure after forward-backward (with an additional `control_seq` keyword argument). +When estimating parameters for a custom subtype of `AbstractHMM`, we have to override the fitting procedure after forward-backward, with an additional `control_seq` positional argument. The key is to split the observations according to which periodic parameter they belong to. =# @@ -140,7 +142,7 @@ function StatsAPI.fit!( end for l in 1:L - HMMs.check_hmm(hmm; control=l) + @assert HMMs.valid_hmm(hmm, l) end return nothing end @@ -173,13 +175,14 @@ cat(transition_matrix(hmm_est, 2), transition_matrix(hmm, 2); dims=3) #- -hcat(obs_distributions(hmm_est, 1), obs_distributions(hmm, 1)) +map(mean, hcat(obs_distributions(hmm_est, 1), obs_distributions(hmm, 1))) #- -hcat(obs_distributions(hmm_est, 2), obs_distributions(hmm, 2)) +map(mean, hcat(obs_distributions(hmm_est, 2), obs_distributions(hmm, 2))) # ## Tests #src -test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.1, init=false) #src -test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) #src +@test mean(obs_seq[1:2:end]) < 0 < mean(obs_seq[2:2:end]) #src +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.09, init=false) #src +test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/examples/types.jl b/examples/types.jl index 6634a150..dc97c6eb 100644 --- a/examples/types.jl +++ b/examples/types.jl @@ -9,6 +9,7 @@ using HiddenMarkovModels using HMMTest #src using LinearAlgebra using LogarithmicNumbers +using Measurements using Random using SparseArrays using StableRNGs @@ -16,19 +17,70 @@ using Test #src #- -rng = StableRNG(63) +rng = StableRNG(63); -# ## Logarithmic numbers +# ## General principle #= -!!! warning - Work in progress +The whole package is agnostic with respect to types, it performs the right promotions automatically. +Therefore, the types we get in the output only depend only on the types present in the input HMM and the observation sequences. =# -# ## Sparse arrays +# ## Weird number types #= -Using sparse matrices is very useful for large models, because it means the memory and computational requirements will scale as the number of possible transitions. +A wide variety of number types can be plugged into HMM parameters to enhance precision or change inference behavior. +Some examples are: +- `BigFloat` for arbitrary precision +- [LogarithmicNumbers.jl](https://github.com/cjdoris/LogarithmicNumbers.jl) to increase numerical stability +- [Measurements.jl](https://github.com/JuliaPhysics/Measurements.jl) to propagate uncertainties +- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) for dual numbers in automatic differentiation + +To give an example, let us first generate some data from a vanilla HMM. +=# + +init = [0.6, 0.4] +trans = [0.7 0.3; 0.3 0.7] +dists = [Normal(-1.0), Normal(1.0)] +hmm = HMM(init, trans, dists) +state_seq, obs_seq = rand(rng, hmm, 100); + +#= +Now we construct a new HMM with some uncertainty on the observation means, using Measurements.jl. +Note that uncertainty on the transition parameters would throw an error because the matrix has to be stochastic. +=# + +dists_guess = [Normal(-1.0 ± 0.1), Normal(1.0 ± 0.2)] +hmm_uncertain = HMM(init, trans, dists_guess); + +#= +Every quantity we compute with this new HMM will have propagated uncertainties around it. +=# + +logdensityof(hmm_uncertain, obs_seq) + +#= +We can check that the interval is centered around the true value. +=# + +Measurements.value(logdensityof(hmm_uncertain, obs_seq)) ≈ logdensityof(hmm, obs_seq) + +#= +!!! warning "Number types in Baum-Welch" + For now, the Baum-Welch algorithm will generally fail with custom number types due to promotion. + The reason is that if some parameters have type `T1` and some `T2`, the forward-backward algorithm will compute quantities of type `T = promote_type(T1, T2)`. + These quantities may not be suited to the existing containers inside an HMM, and since updates happen in-place for performance, we cannot create a new one. + Suggestions are welcome to fix this issue. +=# + +# ## Tests #src + +@test Measurements.value(logdensityof(hmm_uncertain, obs_seq)) ≈ logdensityof(hmm, obs_seq) #src + +# ## Sparse matrices + +#= +[Sparse matrices](https://docs.julialang.org/en/v1/stdlib/SparseArrays/) are very useful for large models, because it means the memory and computational requirements will scale as the number of possible transitions. In general, this number is much smaller than the square of the number of states. =# @@ -36,12 +88,15 @@ In general, this number is much smaller than the square of the number of states. We can easily construct an HMM with a sparse transition matrix, where some transitions are structurally forbidden. =# -init = [0.2, 0.6, 0.2] trans = sparse([ - 0.8 0.2 0.0 - 0.0 0.8 0.2 - 0.2 0.0 0.8 + 0.7 0.3 0 + 0 0.7 0.3 + 0.3 0 0.7 ]) + +#- + +init = [0.2, 0.6, 0.2] dists = [Normal(-2.0), Normal(0.0), Normal(+2.0)] hmm = HMM(init, trans, dists); @@ -52,11 +107,15 @@ When we simulate it, the transitions outside of the nonzero coefficients simply state_seq, obs_seq = rand(rng, hmm, 1000) state_transitions = collect(zip(state_seq[1:(end - 1)], state_seq[2:end])); -#- +#= +For a possible transition: +=# count(isequal((2, 2)), state_transitions) -#- +#= +For an impossible transition: +=# count(isequal((2, 1)), state_transitions) @@ -66,9 +125,9 @@ Now we apply Baum-Welch from a guess with the right sparsity pattern. init_guess = [0.3, 0.4, 0.3] trans_guess = sparse([ - 0.7 0.3 0.0 - 0.0 0.7 0.3 - 0.3 0.0 0.7 + 0.6 0.4 0 + 0 0.6 0.4 + 0.4 0 0.6 ]) dists_guess = [Normal(-1.5), Normal(0.0), Normal(+1.5)] hmm_guess = HMM(init_guess, trans_guess, dists_guess); @@ -79,23 +138,21 @@ hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq); first(loglikelihood_evolution), last(loglikelihood_evolution) #= -The estimated model has kept the same sparsity pattern. +The estimated model has kept the same sparsity pattern as the guess. =# transition_matrix(hmm_est) -#- - -transition_matrix(hmm) +#= +Another useful array type is [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl), which reduces allocations for small state spaces. +=# # ## Tests #src -control_seqs = [fill(nothing, rand(rng, 100:200)) for k in 1:100]; #src -control_seq = reduce(vcat, control_seqs); #src -seq_ends = cumsum(length.(control_seqs)); #src - -test_identical_hmmbase(rng, hmm, hmm_guess; T=100) #src -test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05, init=false) #src -test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) #src +seq_ends = cumsum(rand(rng, 100:200, 100)); #src +control_seqs = fill(nothing, length(seq_ends)); #src +test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src +test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src # https://github.com/JuliaSparse/SparseArrays.jl/issues/469 #src -@test_skip test_allocations(rng, hmm, hmm_guess; control_seq, seq_ends) #src +@test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/libs/HMMComparison/src/HMMComparison.jl b/libs/HMMComparison/src/HMMComparison.jl index 4b956de8..7e25f836 100644 --- a/libs/HMMComparison/src/HMMComparison.jl +++ b/libs/HMMComparison/src/HMMComparison.jl @@ -1,5 +1,6 @@ module HMMComparison +using Base.Threads: @threads using BenchmarkTools: BenchmarkGroup, @benchmarkable using CondaPkg: CondaPkg using Distributions: Normal, MvNormal diff --git a/libs/HMMComparison/src/dynamax.jl b/libs/HMMComparison/src/dynamax.jl index 4a845911..faecdca3 100644 --- a/libs/HMMComparison/src/dynamax.jl +++ b/libs/HMMComparison/src/dynamax.jl @@ -47,7 +47,7 @@ function HMMBenchmark.build_benchmarkables( if "forward" in algos filter_vmap = jax.jit(jax.vmap(hmm.filter; in_axes=pylist((pybuiltins.None, 0)))) benchs["forward"] = @benchmarkable begin - $(filter_vmap)($dyn_params, $obs_tens_jax_py) + $(filter_vmap)($dyn_params, $obs_tens_jax_py).block_until_ready() end evals = 1 samples = 100 end @@ -56,13 +56,13 @@ function HMMBenchmark.build_benchmarkables( jax.vmap(hmm.most_likely_states; in_axes=pylist((pybuiltins.None, 0))) ) benchs["viterbi"] = @benchmarkable begin - $(most_likely_states_vmap)($dyn_params, $obs_tens_jax_py) + $(most_likely_states_vmap)($dyn_params, $obs_tens_jax_py).block_until_ready() end evals = 1 samples = 100 end if "forward_backward" in algos smoother_vmap = jax.jit( - jax.vmap(hmm.smoother; in_axes=pylist((pybuiltins.None, 0))) + jax.vmap(hmm.smoother; in_axes=pylist((pybuiltins.None, 0))).block_until_ready() ) benchs["forward_backward"] = @benchmarkable begin $(smoother_vmap)($dyn_params, $obs_tens_jax_py) @@ -77,7 +77,7 @@ function HMMBenchmark.build_benchmarkables( $obs_tens_jax_py; num_iters=$bw_iter, verbose=false, - ) + ).block_until_ready() end evals = 1 samples = 100 setup = ( tup = build_model($implem, $instance, $params); hmm_guess = tup[1]; diff --git a/libs/HMMComparison/src/hmmbase.jl b/libs/HMMComparison/src/hmmbase.jl index 12b629a7..05881746 100644 --- a/libs/HMMComparison/src/hmmbase.jl +++ b/libs/HMMComparison/src/hmmbase.jl @@ -28,34 +28,43 @@ function HMMBenchmark.build_benchmarkables( hmm = build_model(implem, instance, params) if obs_dim == 1 - obs_mat = reduce(vcat, data[k, :, 1] for k in 1:nb_seqs) + obs_mats = [data[k, :, 1] for k in 1:nb_seqs] else - obs_mat = reduce(vcat, data[k, :, :] for k in 1:nb_seqs) + obs_mats = [data[k, :, :] for k in 1:nb_seqs] end + obs_mat_concat = reduce(vcat, obs_mats) benchs = BenchmarkGroup() if "forward" in algos benchs["forward"] = @benchmarkable begin - HMMBase.forward($hmm, $obs_mat) + @threads for k in eachindex(obs_mats) + HMMBase.forward($hmm, $(obs_mats[k])) + end end evals = 1 samples = 100 end if "viterbi" in algos benchs["viterbi"] = @benchmarkable begin - HMMBase.viterbi($hmm, $obs_mat) + @threads for k in eachindex(obs_mats) + HMMBase.viterbi($hmm, $(obs_mats[k])) + end end evals = 1 samples = 100 end if "forward_backward" in algos benchs["forward_backward"] = @benchmarkable begin - HMMBase.posteriors($hmm, $obs_mat) + @threads for k in eachindex(obs_mats) + HMMBase.posteriors($hmm, $(obs_mats[k])) + end end evals = 1 samples = 100 end if "baum_welch" in algos benchs["baum_welch"] = @benchmarkable begin - HMMBase.fit_mle($hmm, $obs_mat; maxiter=$bw_iter, tol=-Inf) + @threads for k in eachindex(obs_mats) + HMMBase.fit_mle($hmm, $(obs_mats[k]); maxiter=$bw_iter, tol=-Inf) + end end evals = 1 samples = 100 end diff --git a/libs/HMMTest/src/allocations.jl b/libs/HMMTest/src/allocations.jl index 0dac7ce7..d3bb21b9 100644 --- a/libs/HMMTest/src/allocations.jl +++ b/libs/HMMTest/src/allocations.jl @@ -2,9 +2,9 @@ function test_allocations( rng::AbstractRNG, hmm::AbstractHMM, - hmm_guess::Union{Nothing,AbstractHMM}=nothing; - control_seq::AbstractVector, + control_seq::AbstractVector; seq_ends::AbstractVector{Int}, + hmm_guess::Union{Nothing,AbstractHMM}=nothing, ) @testset "Allocations" begin obs_seq = mapreduce(vcat, eachindex(seq_ends)) do k @@ -15,6 +15,7 @@ function test_allocations( t1, t2 = 1, seq_ends[1] ## Forward + f_storage = HMMs.initialize_forward(hmm, obs_seq, control_seq; seq_ends) allocs_f = @ballocated HMMs.forward!( $f_storage, $hmm, $obs_seq, $control_seq, $t1, $t2 @@ -22,6 +23,7 @@ function test_allocations( @test allocs_f == 0 ## Viterbi + v_storage = HMMs.initialize_viterbi(hmm, obs_seq, control_seq; seq_ends) allocs_v = @ballocated HMMs.viterbi!( $v_storage, $hmm, $obs_seq, $control_seq, $t1, $t2 @@ -29,6 +31,7 @@ function test_allocations( @test allocs_v == 0 ## Forward-backward + fb_storage = HMMs.initialize_forward_backward(hmm, obs_seq, control_seq; seq_ends) allocs_fb = @ballocated HMMs.forward_backward!( $fb_storage, $hmm, $obs_seq, $control_seq, $t1, $t2 @@ -36,6 +39,7 @@ function test_allocations( @test allocs_fb == 0 ## Baum-Welch + if !isnothing(hmm_guess) fb_storage = HMMs.initialize_forward_backward( hmm_guess, obs_seq, control_seq; seq_ends diff --git a/libs/HMMTest/src/coherence.jl b/libs/HMMTest/src/coherence.jl index f95004dd..f5ca9147 100644 --- a/libs/HMMTest/src/coherence.jl +++ b/libs/HMMTest/src/coherence.jl @@ -1,12 +1,12 @@ infnorm(x) = maximum(abs, x) -function check_equal_hmms( +function are_equal_hmms( hmm1::AbstractHMM, - hmm2::AbstractHMM; - control_seq=[nothing], - atol::Real=0.1, - init::Bool=true, - test::Bool=true, + hmm2::AbstractHMM, + control_seq::AbstractVector; + atol::Real, + init::Bool, + test::Bool, ) equal_check = true @@ -43,24 +43,13 @@ function check_equal_hmms( return equal_check end -function test_equal_hmms( - hmm1::AbstractHMM, - hmm2::AbstractHMM; - control_seq=[nothing], - atol::Real=0.1, - init::Bool=true, -) - check_equal_hmms(hmm1, hmm2; control_seq, atol, init, test=true) - return nothing -end - function test_coherent_algorithms( rng::AbstractRNG, hmm::AbstractHMM, - hmm_guess::Union{Nothing,AbstractHMM}=nothing; - control_seq::AbstractVector, + control_seq::AbstractVector; seq_ends::AbstractVector{Int}, - atol::Real=0.1, + hmm_guess::Union{Nothing,AbstractHMM}=nothing, + atol::Real=0.05, init::Bool=true, ) @testset "Coherence" begin @@ -92,10 +81,8 @@ function test_coherent_algorithms( if !isnothing(hmm_guess) hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends) @test all(>=(0), diff(logL_evolution)) - @test !check_equal_hmms( - hmm, hmm_guess; control_seq=control_seq[1:2], atol, test=false - ) - test_equal_hmms(hmm, hmm_est; control_seq=control_seq[1:2], atol, init) + @test !are_equal_hmms(hmm, hmm_guess, control_seq[1:2]; atol, init, test=false) + are_equal_hmms(hmm, hmm_est, control_seq[1:2]; atol, init, test=true) end end end diff --git a/libs/HMMTest/src/hmmbase.jl b/libs/HMMTest/src/hmmbase.jl index 687b72fb..1ecf76fd 100644 --- a/libs/HMMTest/src/hmmbase.jl +++ b/libs/HMMTest/src/hmmbase.jl @@ -2,9 +2,9 @@ function test_identical_hmmbase( rng::AbstractRNG, hmm::AbstractHMM, - hmm_guess::Union{Nothing,AbstractHMM}=nothing; - T::Integer, + T::Integer; atol::Real=1e-5, + hmm_guess::Union{Nothing,AbstractHMM}=nothing, ) @testset "HMMBase" begin sim = rand(rng, hmm, T) @@ -50,11 +50,13 @@ function test_identical_hmmbase( @test isapprox( logL_evolution[(begin + 1):end], 2 * logL_evolution_base[begin:(end - 1)] ) - test_equal_hmms( + are_equal_hmms( hmm_est, - HMM(hmm_est_base.a, hmm_est_base.A, hmm_est_base.B); + HMM(hmm_est_base.a, hmm_est_base.A, hmm_est_base.B), + [nothing]; atol, init=true, + test=true, ) end end diff --git a/libs/HMMTest/src/jet.jl b/libs/HMMTest/src/jet.jl index 95d82789..75820193 100644 --- a/libs/HMMTest/src/jet.jl +++ b/libs/HMMTest/src/jet.jl @@ -2,9 +2,9 @@ function test_type_stability( rng::AbstractRNG, hmm::AbstractHMM, - hmm_guess::Union{Nothing,AbstractHMM}=nothing; - control_seq::AbstractVector, + control_seq::AbstractVector; seq_ends::AbstractVector{Int}, + hmm_guess::Union{Nothing,AbstractHMM}=nothing, ) @testset "Type stability" begin state_seq, obs_seq = rand(rng, hmm, control_seq) diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index e12bcd6e..8c4fedce 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -9,6 +9,7 @@ $(EXPORTS) """ module HiddenMarkovModels +using ArgCheck: @argcheck using Base: RefValue using Base.Threads: @threads using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, rrule_via_ad @@ -20,6 +21,7 @@ using PrecompileTools: @compile_workload using Random: Random, AbstractRNG, default_rng using SparseArrays: AbstractSparseArray, SparseMatrixCSC, nonzeros, nnz, nzrange using StatsAPI: StatsAPI, fit, fit! +using StatsFuns: log2π export AbstractHMM, HMM export initialization, transition_matrix, obs_distributions @@ -30,7 +32,7 @@ export seq_limits include("types/abstract_hmm.jl") include("utils/linalg.jl") -include("utils/check.jl") +include("utils/valid.jl") include("utils/probvec_transmat.jl") include("utils/fit.jl") include("utils/lightdiagnormal.jl") diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 924e2bf1..3bec0ac2 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -27,7 +27,7 @@ function baum_welch!( max_iterations::Integer, loglikelihood_increasing::Bool, ) - for iteration in 1:max_iterations + for _ in 1:max_iterations forward_backward!(fb_storage, hmm, obs_seq, control_seq; seq_ends) push!(logL_evolution, logdensityof(hmm) + sum(fb_storage.logL)) fit!(hmm, fb_storage, obs_seq, control_seq; seq_ends) @@ -77,3 +77,15 @@ function baum_welch( ) return hmm, logL_evolution end + +## Fallback + +function StatsAPI.fit!( + hmm::AbstractHMM, + fb_storage::ForwardBackwardStorage, + obs_seq::AbstractVector, + control_seq::AbstractVector; + seq_ends::AbstractVector{Int}, +) + return fit!(hmm, fb_storage, obs_seq; seq_ends) +end diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 0a49d745..2849f795 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -80,6 +80,7 @@ function forward!( logL += -log(c[t + 1]) + logm end + @argcheck isfinite(logL) return logL end @@ -93,12 +94,11 @@ function forward!( control_seq::AbstractVector; seq_ends::AbstractVector{Int}, ) - (; α, logL) = storage + (; logL) = storage @threads for k in eachindex(seq_ends) t1, t2 = seq_limits(seq_ends, k) logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;) end - check_finite(α) return nothing end diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index eb6ec9d3..63ad978b 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -36,13 +36,13 @@ function initialize_forward_backward( N, T, K = length(hmm), length(obs_seq), length(seq_ends) R = eltype(hmm, obs_seq[1], control_seq[1]) trans = transition_matrix(hmm, control_seq[1]) - M = typeof(mysimilar_mutable(trans, R)) + M = typeof(similar(trans, R)) γ = Matrix{R}(undef, N, T) ξ = Vector{M}(undef, T) if transition_marginals for t in 1:T - ξ[t] = mysimilar_mutable(transition_matrix(hmm, control_seq[t]), R) + ξ[t] = similar(transition_matrix(hmm, control_seq[t]), R) end end logL = Vector{R}(undef, K) @@ -107,14 +107,13 @@ function forward_backward!( seq_ends::AbstractVector{Int}, transition_marginals::Bool=true, ) where {R} - (; logL, γ) = storage + (; logL) = storage @threads for k in eachindex(seq_ends) t1, t2 = seq_limits(seq_ends, k) logL[k] = forward_backward!( storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals ) end - check_finite(γ) return nothing end diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index 4948261b..1795d772 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -13,7 +13,7 @@ struct ViterbiStorage{R} "one joint loglikelihood per pair of observation sequence and most likely state sequence" logL::Vector{R} logB::Matrix{R} - ϕ::Matrix{R} + logϕ::Matrix{R} ψ::Matrix{Int} end @@ -33,9 +33,9 @@ function initialize_viterbi( q = Vector{Int}(undef, T) logL = Vector{R}(undef, K) logB = Matrix{R}(undef, N, T) - ϕ = Matrix{R}(undef, N, T) + logϕ = Matrix{R}(undef, N, T) ψ = Matrix{Int}(undef, N, T) - return ViterbiStorage(q, logL, logB, ϕ, ψ) + return ViterbiStorage(q, logL, logB, logϕ, ψ) end """ @@ -49,35 +49,37 @@ function viterbi!( t1::Integer, t2::Integer; ) where {R} - (; q, logB, ϕ, ψ) = storage + (; q, logB, logϕ, ψ) = storage obs_logdensities!(view(logB, :, t1), hmm, obs_seq[t1], control_seq[t1]) init = initialization(hmm) - ϕ[:, t1] .= log.(init) .+ view(logB, :, t1) + logϕ[:, t1] .= log.(init) .+ view(logB, :, t1) for t in (t1 + 1):t2 obs_logdensities!(view(logB, :, t), hmm, obs_seq[t], control_seq[t]) trans = transition_matrix(hmm, control_seq[t - 1]) for j in 1:length(hmm) i_max = 1 - score_max = ϕ[i_max, t - 1] + log(trans[i_max, j]) + score_max = logϕ[i_max, t - 1] + log(trans[i_max, j]) for i in 2:length(hmm) - score = ϕ[i, t - 1] + log(trans[i, j]) + score = logϕ[i, t - 1] + log(trans[i, j]) if score > score_max score_max, i_max = score, i end end ψ[j, t] = i_max - ϕ[j, t] = score_max + logB[j, t] + logϕ[j, t] = score_max + logB[j, t] end end - q[t2] = argmax(view(ϕ, :, t2)) + q[t2] = argmax(view(logϕ, :, t2)) + logL = logϕ[q[t2], t2] for t in (t2 - 1):-1:t1 q[t] = ψ[q[t + 1], t + 1] end - return ϕ[q[t2], t2] + @argcheck isfinite(logL) + return logL end """ @@ -90,12 +92,11 @@ function viterbi!( control_seq::AbstractVector; seq_ends::AbstractVector{Int}, ) where {R} - (; logL, ϕ) = storage + (; logL) = storage @threads for k in eachindex(seq_ends) t1, t2 = seq_limits(seq_ends, k) logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;) end - check_right_finite(ϕ) return nothing end diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index f02d4ca5..d88b9a8b 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -82,27 +82,21 @@ obs_distributions(hmm::AbstractHMM, control) = obs_distributions(hmm) function obs_logdensities!(logb::AbstractVector, hmm::AbstractHMM, obs, control) dists = obs_distributions(hmm, control) - @inbounds for i in eachindex(logb, dists) + @inbounds @simd for i in eachindex(logb, dists) logb[i] = logdensityof(dists[i], obs) end - check_right_finite(logb) + @argcheck all(<(typemax(eltype(logb))), logb) return nothing end """ - fit!( - hmm::AbstractHMM, - fb_storage::ForwardBackwardStorage, - obs_seq::AbstractVector; - control_seq::AbstractVector, - seq_ends::AbstractVector{Int}, - ) +$(SIGNATURES) Update `hmm` in-place based on information generated during forward-backward. This function is allowed to reuse `fb_storage` as a scratch space, so its contents should not be trusted afterwards. """ -StatsAPI.fit! # TODO: complete +StatsAPI.fit! ## Sampling @@ -161,4 +155,4 @@ end Return the prior loglikelihood associated with the parameters of `hmm`. """ -DensityInterface.logdensityof(hmm::AbstractHMM) = 0 +DensityInterface.logdensityof(hmm::AbstractHMM) = false diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 5dcd6ec4..19c51340 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -17,7 +17,7 @@ struct HMM{V<:AbstractVector,M<:AbstractMatrix,VD<:AbstractVector} <: AbstractHM function HMM(init::AbstractVector, trans::AbstractMatrix, dists::AbstractVector) hmm = new{typeof(init),typeof(trans),typeof(dists)}(init, trans, dists) - check_hmm(hmm) + @argcheck valid_hmm(hmm) return hmm end end @@ -35,8 +35,7 @@ obs_distributions(hmm::HMM) = hmm.dists function StatsAPI.fit!( hmm::HMM, fb_storage::ForwardBackwardStorage, - obs_seq::AbstractVector, - control_seq::AbstractVector; + obs_seq::AbstractVector; seq_ends::AbstractVector{Int}, ) (; γ, ξ) = fb_storage @@ -64,6 +63,6 @@ function StatsAPI.fit!( fit_in_sequence!(hmm.dists, i, obs_seq, view(γ, i, :)) end # Safety check - check_hmm(hmm) + @argcheck valid_hmm(hmm) return nothing end diff --git a/src/utils/check.jl b/src/utils/check.jl deleted file mode 100644 index 502195e4..00000000 --- a/src/utils/check.jl +++ /dev/null @@ -1,77 +0,0 @@ -function check_finite(a) - if !all(isfinite, mynonzeros(a)) - throw(OverflowError("Some values are infinite or NaN")) - end -end - -function check_right_finite(a) - if !all(<(typemax(eltype(a))), mynonzeros(a)) - throw(OverflowError("Some values are positive infinite or NaN")) - end -end - -function check_no_nan(a) - if any(isnan, mynonzeros(a)) - throw(OverflowError("Some values are NaN")) - end -end - -function check_positive(a) - if !all(>(zero(eltype(a))), mynonzeros(a)) - throw(OverflowError("Some values are not positive")) - end -end - -function check_nonnegative(a) - if any(<(zero(eltype(a))), mynonzeros(a)) - throw(OverflowError("Some values are negative")) - end -end - -function check_prob_vec(p::AbstractVector) - check_finite(p) - if !valid_prob_vec(p) - throw(ArgumentError("Invalid probability distribution.")) - end -end - -function check_trans_mat(A::AbstractMatrix) - check_finite(A) - if !valid_trans_mat(A) - throw(ArgumentError("Invalid transition matrix.")) - end -end - -function check_dists(d::AbstractVector) - for i in eachindex(d) - if DensityKind(d[i]) == NoDensity() - throw( - ArgumentError( - "Invalid observation distributions (do not satisfy DensityInterface.jl)" - ), - ) - end - end - return true -end - -function check_hmm_sizes(init::AbstractVector, trans::AbstractMatrix, dists::AbstractVector) - if !(size(trans) == (length(init), length(init)) == (length(dists), length(dists))) - throw( - DimensionMismatch( - "Initialization, transition matrix and observation distributions have incompatible sizes.", - ), - ) - end -end - -function check_hmm(hmm::AbstractHMM; control=nothing) - init = initialization(hmm) - trans = transition_matrix(hmm, control) - dists = obs_distributions(hmm, control) - check_hmm_sizes(init, trans, dists) - check_prob_vec(init) - check_trans_mat(trans) - check_dists(dists) - return nothing -end diff --git a/src/utils/lightcategorical.jl b/src/utils/lightcategorical.jl index a7dfbe81..f0786717 100644 --- a/src/utils/lightcategorical.jl +++ b/src/utils/lightcategorical.jl @@ -16,8 +16,8 @@ struct LightCategorical{T1,T2,V1<:AbstractVector{T1},V2<:AbstractVector{T2}} logp::V2 end -function LightCategorical(p::AbstractVector) - check_prob_vec(p) +function LightCategorical(p::AbstractVector{T}) where {T} + @argcheck valid_prob_vec(p) return LightCategorical(p, log.(p)) end @@ -55,6 +55,6 @@ function StatsAPI.fit!(dist::LightCategorical{T1}, x, w) where {T1} end dist.p ./= w_tot dist.logp .= log.(dist.p) - check_prob_vec(dist.p) + @argcheck valid_prob_vec(dist.p) return nothing end diff --git a/src/utils/lightdiagnormal.jl b/src/utils/lightdiagnormal.jl index 67ec07cc..1d0b05a3 100644 --- a/src/utils/lightdiagnormal.jl +++ b/src/utils/lightdiagnormal.jl @@ -20,9 +20,12 @@ struct LightDiagNormal{ logσ::V3 end -function LightDiagNormal(μ::AbstractVector, σ::AbstractVector) - check_positive(σ) - return LightDiagNormal(μ, σ, log.(σ)) +function LightDiagNormal(μ::AbstractVector{T1}, σ::AbstractVector{T2}) where {T1,T2} + logσ = log.(σ) + @argcheck all(isfinite, μ) + @argcheck all(isfinite, σ) + @argcheck all(isfinite, logσ) + return LightDiagNormal(μ, σ, logσ) end function Base.show(io::IO, dist::LightDiagNormal) @@ -38,14 +41,13 @@ function Base.rand(rng::AbstractRNG, dist::LightDiagNormal{T1,T2}) where {T1,T2} return dist.σ .* randn(rng, T, length(dist)) .+ dist.μ end -function DensityInterface.logdensityof(dist::LightDiagNormal, x) - b = -sum(dist.logσ) - c = - -sum( - abs2(x[i] - dist.μ[i]) / (2 * abs2(dist.σ[i])) for - i in eachindex(x, dist.μ, dist.σ) - ) - return b + c +function DensityInterface.logdensityof(dist::LightDiagNormal{T1,T2,T3}, x) where {T1,T2,T3} + l = zero(promote_type(T1, T2, T3, eltype(x))) + l -= sum(dist.logσ) + log2π * length(x) / 2 + @inbounds @simd for i in eachindex(x, dist.μ, dist.σ) + l -= abs2(x[i] - dist.μ[i]) / (2 * abs2(dist.σ[i])) + end + return l end function StatsAPI.fit!(dist::LightDiagNormal{T1,T2}, x, w) where {T1,T2} @@ -61,6 +63,8 @@ function StatsAPI.fit!(dist::LightDiagNormal{T1,T2}, x, w) where {T1,T2} dist.σ .-= min.(abs2.(dist.μ), dist.σ) dist.σ .= sqrt.(dist.σ) dist.logσ .= log.(dist.σ) - check_positive(dist.σ) + @argcheck all(isfinite, dist.μ) + @argcheck all(isfinite, dist.σ) + @argcheck all(isfinite, dist.logσ) return nothing end diff --git a/src/utils/linalg.jl b/src/utils/linalg.jl index 72a7abeb..550c160b 100644 --- a/src/utils/linalg.jl +++ b/src/utils/linalg.jl @@ -1,7 +1,5 @@ sum_to_one!(x) = ldiv!(sum(x), x) -mysimilar_mutable(x::AbstractArray, ::Type{R}) where {R} = similar(x, R) - mynonzeros(x::AbstractArray) = x mynonzeros(x::AbstractSparseArray) = nonzeros(x) @@ -17,8 +15,8 @@ end function mul_rows_cols!( B::SparseMatrixCSC, l::AbstractVector, A::SparseMatrixCSC, r::AbstractVector ) - @assert size(B) == size(A) == (length(l), length(r)) - @assert nnz(B) == nnz(A) + @argcheck size(B) == size(A) == (length(l), length(r)) + @argcheck nnz(B) == nnz(A) for j in axes(B, 2) for k in nzrange(B, j) i = B.rowval[k] diff --git a/src/utils/probvec_transmat.jl b/src/utils/probvec_transmat.jl index b070ae9d..b8df74fc 100644 --- a/src/utils/probvec_transmat.jl +++ b/src/utils/probvec_transmat.jl @@ -1,24 +1,3 @@ -function valid_prob_vec(p::AbstractVector; atol=1e-2) - return (minimum(p) >= 0) && isapprox(sum(p), 1; atol=atol) -end - -function is_square(A::AbstractMatrix) - return size(A, 1) == size(A, 2) -end - -function valid_trans_mat(A::AbstractMatrix; atol=1e-2) - if !is_square(A) - return false - else - for row in eachrow(A) - if !valid_prob_vec(row; atol=atol) - return false - end - end - return true - end -end - """ rand_prob_vec([rng, ::Type{R},] N) diff --git a/src/utils/valid.jl b/src/utils/valid.jl new file mode 100644 index 00000000..ed147691 --- /dev/null +++ b/src/utils/valid.jl @@ -0,0 +1,28 @@ +function valid_prob_vec(p::AbstractVector{T}) where {T} + return minimum(p) >= zero(T) && sum(p) ≈ one(T) +end + +function valid_trans_mat(A::AbstractMatrix) + return size(A, 1) == size(A, 2) && all(valid_prob_vec, eachrow(A)) +end + +function valid_dists(d::AbstractVector) + for i in eachindex(d) + if DensityKind(d[i]) == NoDensity() + return false + end + end + return true +end + +function valid_hmm(hmm::AbstractHMM, control=nothing) + init = initialization(hmm) + trans = transition_matrix(hmm, control) + dists = obs_distributions(hmm, control) + return ( + length(init) == length(dists) == size(trans, 1) == size(trans, 2) && + valid_prob_vec(init) && + valid_trans_mat(trans) && + valid_dists(dists) + ) +end diff --git a/test/Project.toml b/test/Project.toml index 3840b29f..9cb468f6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,6 +10,7 @@ JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899" +Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/test/correctness.jl b/test/correctness.jl index c5f72d5a..639c3b7b 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -42,12 +42,10 @@ seq_ends = cumsum(length.(control_seqs)); hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_identical_hmmbase(rng, hmm, hmm_guess; T) - test_coherent_algorithms( - rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05, init=false - ) - test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) - test_allocations(rng, hmm, hmm_guess; control_seq, seq_ends) + test_identical_hmmbase(rng, hmm, T; hmm_guess) + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) end @testset "DiagNormal" begin @@ -59,11 +57,11 @@ end hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_identical_hmmbase(rng, hmm, hmm_guess; T) - test_coherent_algorithms( - rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05, init=false + test_identical_hmmbase(rng, hmm, T; hmm_guess) + @test_skip test_coherent_algorithms( + rng, hmm, control_seq; seq_ends, hmm_guess, init=false ) - test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) end @testset "LightCategorical" begin @@ -73,25 +71,21 @@ end hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_coherent_algorithms( - rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05, init=false - ) - test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) - test_allocations(rng, hmm, hmm_guess; control_seq, seq_ends) + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) end -@test_skip @testset "LightDiagNormal" begin +@testset "LightDiagNormal" begin dists = [LightDiagNormal(μ[1], σ), LightDiagNormal(μ[2], σ)] dists_guess = [LightDiagNormal(μ_guess[1], σ), LightDiagNormal(μ_guess[2], σ)] hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_coherent_algorithms( - rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05, init=false - ) - test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) - test_allocations(rng, hmm, hmm_guess; control_seq, seq_ends) + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) end # Controlled @@ -105,13 +99,11 @@ end HMMs.initialization(hmm::DiffusionHMM) = hmm.init function HMMs.transition_matrix(hmm::DiffusionHMM, λ::Number) - @assert 0 <= λ <= 1 N = length(hmm) return (1 - λ) * hmm.trans + λ * ones(N, N) / N end function HMMs.obs_distributions(hmm::DiffusionHMM, λ::Number) - @assert 0 <= λ <= 1 return [Normal((1 - λ) * hmm.means[i]) for i in 1:length(hmm)] end @@ -123,6 +115,6 @@ end control_seq = reduce(vcat, control_seqs) seq_ends = cumsum(length.(control_seqs)) - test_coherent_algorithms(rng, hmm; control_seq, seq_ends, atol=0.05, init=false) - test_type_stability(rng, hmm; control_seq, seq_ends) + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends) end diff --git a/test/distributions.jl b/test/distributions.jl index c96bc657..544ba063 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -50,6 +50,5 @@ end @test dist_est.σ ≈ σ atol = 2e-2 test_fit_allocs(dist, x, w) # Logdensity - @test logdensityof(dist, x[1]) ≈ - logdensityof(MvNormal(μ, Diagonal(abs2.(σ))), x[1]) + length(x[1]) * log(sqrt(2π)) + @test logdensityof(dist, x[1]) ≈ logdensityof(MvNormal(μ, Diagonal(abs2.(σ))), x[1]) end From b56d1b253c7a04a5b164cb12fd2eac1b5366295d Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 22 Feb 2024 13:53:41 +0100 Subject: [PATCH 2/5] Fix leaky tests --- docs/src/formulas.md | 34 ++++++++++++++++-------------- examples/basics.jl | 4 ++-- examples/interfaces.jl | 2 +- examples/types.jl | 4 ++-- libs/HMMTest/src/allocations.jl | 4 ++-- libs/HMMTest/src/coherence.jl | 37 ++++++++++++++++++++------------- libs/HMMTest/src/hmmbase.jl | 10 ++------- src/types/hmm.jl | 7 +++++++ src/utils/linalg.jl | 1 + test/correctness.jl | 8 +++---- 10 files changed, 61 insertions(+), 50 deletions(-) diff --git a/docs/src/formulas.md b/docs/src/formulas.md index 7c20cb04..082fdfc6 100644 --- a/docs/src/formulas.md +++ b/docs/src/formulas.md @@ -4,9 +4,11 @@ Suppose we are given observations $Y_1, ..., Y_T$, with hidden states $X_1, ..., Following [Rabiner1989](@cite), we use the following notations: * let $\pi \in \mathbb{R}^N$ be the initial state distribution $\pi_i = \mathbb{P}(X_1 = i)$ -* let $A \in \mathbb{R}^{N \times N}$ be the transition matrix $a_{i,j} = \mathbb{P}(X_{t+1}=j | X_t = i)$ +* let $A_t \in \mathbb{R}^{N \times N}$ be the transition matrix $a_{i,j,t} = \mathbb{P}(X_{t+1}=j | X_t = i)$ * let $B \in \mathbb{R}^{N \times T}$ be the matrix of statewise observation likelihoods $b_{i,t} = \mathbb{P}(Y_t | X_t = i)$ +The conditioning on the known controls $U_{1:T}$ is implicit throughout. + ## Vanilla forward-backward ### Recursion @@ -33,8 +35,8 @@ and satisfy the dynamic programming equations ```math \begin{align*} -\alpha_{j,t+1} & = \left(\sum_{i=1}^N \alpha_{i,t} a_{i,j}\right) b_{j,t+1} \\ -\beta_{i,t} & = \sum_{j=1}^N a_{i,j} b_{j,t+1} \beta_{j,t+1} +\alpha_{j,t+1} & = \left(\sum_{i=1}^N \alpha_{i,t} a_{i,j,t}\right) b_{j,t+1} \\ +\beta_{i,t} & = \sum_{j=1}^N a_{i,j,t} b_{j,t+1} \beta_{j,t+1} \end{align*} ``` @@ -53,7 +55,7 @@ We notice that ```math \begin{align*} \alpha_{i,t} \beta_{i,t} & = \mathbb{P}(Y_{1:T}, X_t=i) \\ -\alpha_{i,t} a_{i,j} b_{j,t+1} \beta_{j,t+1} & = \mathbb{P}(Y_{1:T}, X_t=i, X_{t+1}=j) +\alpha_{i,t} a_{i,j,t} b_{j,t+1} \beta_{j,t+1} & = \mathbb{P}(Y_{1:T}, X_t=i, X_{t+1}=j) \end{align*} ``` @@ -62,7 +64,7 @@ Thus we deduce the one-state and two-state marginals ```math \begin{align*} \gamma_{i,t} & = \mathbb{P}(X_t=i | Y_{1:T}) = \frac{1}{\mathcal{L}} \alpha_{i,t} \beta_{i,t} \\ -\xi_{i,j,t} & = \mathbb{P}(X_t=i, X_{t+1}=j | Y_{1:T}) = \frac{1}{\mathcal{L}} \alpha_{i,t} a_{i,j} b_{j,t+1} \beta_{j,t+1} +\xi_{i,j,t} & = \mathbb{P}(X_t=i, X_{t+1}=j | Y_{1:T}) = \frac{1}{\mathcal{L}} \alpha_{i,t} a_{i,j,t} b_{j,t+1} \beta_{j,t+1} \end{align*} ``` @@ -75,7 +77,7 @@ According to [Qin2000](@cite), derivatives of the likelihood can be obtained as \frac{\partial \mathcal{L}}{\partial \pi_i} &= \beta_{i,1} b_{i,1} \\ \frac{\partial \mathcal{L}}{\partial a_{i,j}} &= \sum_{t=1}^{T-1} \alpha_{i,t} b_{j,t+1} \beta_{j,t+1} \\ \frac{\partial \mathcal{L}}{\partial b_{j,1}} &= \pi_j \beta_{j,1} \\ -\frac{\partial \mathcal{L}}{\partial b_{j,t}} &= \left(\sum_{i=1}^N \alpha_{i,t-1} a_{i,j}\right) \beta_{j,t} +\frac{\partial \mathcal{L}}{\partial b_{j,t}} &= \left(\sum_{i=1}^N \alpha_{i,t-1} a_{i,j,t-1}\right) \beta_{j,t} \end{align*} ``` @@ -98,8 +100,8 @@ and satisfy the dynamic programming equations ```math \begin{align*} -\hat{\alpha}_{j,t+1} & = \left(\sum_{i=1}^N \bar{\alpha}_{i,t} a_{i,j}\right) \frac{b_{j,t+1}}{m_{t+1}} & c_{t+1} & = \frac{1}{\sum_j \hat{\alpha}_{j,t+1}} & \bar{\alpha}_{j,t+1} = c_{t+1} \hat{\alpha}_{j,t+1} \\ -\hat{\beta}_{i,t} & = \sum_{j=1}^N a_{i,j} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} & && \bar{\beta}_{j,t} = c_t \hat{\beta}_{j,t} +\hat{\alpha}_{j,t+1} & = \left(\sum_{i=1}^N \bar{\alpha}_{i,t} a_{i,j,t}\right) \frac{b_{j,t+1}}{m_{t+1}} & c_{t+1} & = \frac{1}{\sum_j \hat{\alpha}_{j,t+1}} & \bar{\alpha}_{j,t+1} = c_{t+1} \hat{\alpha}_{j,t+1} \\ +\hat{\beta}_{i,t} & = \sum_{j=1}^N a_{i,j,t} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} & && \bar{\beta}_{j,t} = c_t \hat{\beta}_{j,t} \end{align*} ``` @@ -140,9 +142,9 @@ We can now express the marginals using scaled variables: ```math \begin{align*} \xi_{i,j,t} & = \frac{1}{\mathcal{L}} \alpha_{i,t} a_{i,j} b_{j,t+1} \beta_{j,t+1} \\ -&= \frac{1}{\mathcal{L}} \left(\bar{\alpha}_{i,t} \prod_{s=1}^t \frac{m_s}{c_s}\right) a_{i,j} b_{j,t+1} \left(\bar{\beta}_{j,t+1} \frac{1}{c_{t+1}} \prod_{s=t+2}^T \frac{m_s}{c_s}\right) \\ -&= \frac{1}{\mathcal{L}} \bar{\alpha}_{i,t} a_{i,j} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} \left(\prod_{s=1}^T \frac{m_s}{c_s}\right) \\ -&= \bar{\alpha}_{i,t} a_{i,j} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} +&= \frac{1}{\mathcal{L}} \left(\bar{\alpha}_{i,t} \prod_{s=1}^t \frac{m_s}{c_s}\right) a_{i,j,t} b_{j,t+1} \left(\bar{\beta}_{j,t+1} \frac{1}{c_{t+1}} \prod_{s=t+2}^T \frac{m_s}{c_s}\right) \\ +&= \frac{1}{\mathcal{L}} \bar{\alpha}_{i,t} a_{i,j,t} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} \left(\prod_{s=1}^T \frac{m_s}{c_s}\right) \\ +&= \bar{\alpha}_{i,t} a_{i,j,t} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} \end{align*} ``` @@ -179,10 +181,10 @@ And for the statewise observation likelihoods, ```math \begin{align*} -\frac{\partial \mathcal{L}}{\partial b_{j,t}} &= \left(\sum_{i=1}^N \alpha_{i,t-1} a_{i,j}\right) \beta_{j,t} \\ -&= \sum_{i=1}^N \left(\bar{\alpha}_{i,t-1} \prod_{s=1}^{t-1} \frac{m_s}{c_s}\right) a_{i,j} \left(\bar{\beta}_{j,t} \frac{1}{c_t} \prod_{s=t+1}^T \frac{m_s}{c_s} \right) \\ -&= \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j} \bar{\beta}_{j,t} \frac{1}{m_t} \left(\prod_{s=1}^T \frac{m_s}{c_s}\right) \\ -&= \mathcal{L} \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j} \bar{\beta}_{j,t} \frac{1}{m_t} \\ +\frac{\partial \mathcal{L}}{\partial b_{j,t}} &= \left(\sum_{i=1}^N \alpha_{i,t-1} a_{i,j,t-1}\right) \beta_{j,t} \\ +&= \sum_{i=1}^N \left(\bar{\alpha}_{i,t-1} \prod_{s=1}^{t-1} \frac{m_s}{c_s}\right) a_{i,j,t-1} \left(\bar{\beta}_{j,t} \frac{1}{c_t} \prod_{s=t+1}^T \frac{m_s}{c_s} \right) \\ +&= \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j,t-1} \bar{\beta}_{j,t} \frac{1}{m_t} \left(\prod_{s=1}^T \frac{m_s}{c_s}\right) \\ +&= \mathcal{L} \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j,t-1} \bar{\beta}_{j,t} \frac{1}{m_t} \\ \end{align*} ``` @@ -199,7 +201,7 @@ To sum up, \frac{\partial \log \mathcal{L}}{\partial \pi_i} &= \frac{b_{i,1}}{m_1} \bar{\beta}_{i,1} \\ \frac{\partial \log \mathcal{L}}{\partial a_{i,j}} &= \sum_{t=1}^{T-1} \bar{\alpha}_{i,t} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} \\ \frac{\partial \log \mathcal{L}}{\partial \log b_{j,1}} &= \pi_j \frac{b_{j,1}}{m_1} \bar{\beta}_{j,1} = \frac{\bar{\alpha}_{j,1} \bar{\beta}_{j,1}}{c_1} = \gamma_{j,1} \\ -\frac{\partial \log \mathcal{L}}{\partial \log b_{j,t}} &= \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j} \frac{b_{j,t}}{m_t} \bar{\beta}_{j,t} = \frac{\bar{\alpha}_{j,t} \bar{\beta}_{j,t}}{c_t} = \gamma_{j,t} +\frac{\partial \log \mathcal{L}}{\partial \log b_{j,t}} &= \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j,t-1} \frac{b_{j,t}}{m_t} \bar{\beta}_{j,t} = \frac{\bar{\alpha}_{j,t} \bar{\beta}_{j,t}}{c_t} = \gamma_{j,t} \end{align*} ``` diff --git a/examples/basics.jl b/examples/basics.jl index 79168429..034583c7 100644 --- a/examples/basics.jl +++ b/examples/basics.jl @@ -30,7 +30,7 @@ Any scalar- or vector-valued distribution from [Distributions.jl](https://github init = [0.6, 0.4] trans = [0.7 0.3; 0.3 0.7] dists = [MvNormal([-0.5, -0.8], I), MvNormal([0.5, 0.8], I)] -hmm = HMM(init, trans, dists); +hmm = HMM(init, trans, dists) # ## Simulation @@ -143,7 +143,7 @@ Since it is a local optimization procedure, it requires a starting point that is init_guess = [0.5, 0.5] trans_guess = [0.6 0.4; 0.4 0.6] -dists_guess = [MvNormal([-0.6, -0.7], I), MvNormal([0.6, 0.7], I)] +dists_guess = [MvNormal([-0.4, -0.7], I), MvNormal([0.4, 0.7], I)] hmm_guess = HMM(init_guess, trans_guess, dists_guess); #= diff --git a/examples/interfaces.jl b/examples/interfaces.jl index 703f8c36..18e0d7b5 100644 --- a/examples/interfaces.jl +++ b/examples/interfaces.jl @@ -105,7 +105,7 @@ If we implement `fit!`, Baum-Welch also works seamlessly. init_guess = [0.5, 0.5] trans_guess = [0.6 0.4; 0.4 0.6] -dists_guess = [StuffDist(-0.5), StuffDist(+0.5)] +dists_guess = [StuffDist(-0.7), StuffDist(+0.7)] hmm_guess = HMM(init_guess, trans_guess, dists_guess); #- diff --git a/examples/types.jl b/examples/types.jl index dc97c6eb..dc4c074a 100644 --- a/examples/types.jl +++ b/examples/types.jl @@ -51,7 +51,7 @@ Note that uncertainty on the transition parameters would throw an error because =# dists_guess = [Normal(-1.0 ± 0.1), Normal(1.0 ± 0.2)] -hmm_uncertain = HMM(init, trans, dists_guess); +hmm_uncertain = HMM(init, trans, dists_guess) #= Every quantity we compute with this new HMM will have propagated uncertainties around it. @@ -98,7 +98,7 @@ trans = sparse([ init = [0.2, 0.6, 0.2] dists = [Normal(-2.0), Normal(0.0), Normal(+2.0)] -hmm = HMM(init, trans, dists); +hmm = HMM(init, trans, dists) #= When we simulate it, the transitions outside of the nonzero coefficients simply cannot happen. diff --git a/libs/HMMTest/src/allocations.jl b/libs/HMMTest/src/allocations.jl index d3bb21b9..ea3aeef9 100644 --- a/libs/HMMTest/src/allocations.jl +++ b/libs/HMMTest/src/allocations.jl @@ -46,8 +46,8 @@ function test_allocations( ) HMMs.forward_backward!(fb_storage, hmm, obs_seq, control_seq; seq_ends) allocs_bw = @ballocated fit!( - $hmm_guess, $fb_storage, $obs_seq, $control_seq; seq_ends=$seq_ends - ) evals = 1 samples = 1 setup = (hmm_guess = deepcopy($hmm)) + hmm_guess_copy, $fb_storage, $obs_seq, $control_seq; seq_ends=$seq_ends + ) evals = 1 samples = 1 setup = (hmm_guess_copy = deepcopy($hmm_guess)) @test_broken allocs_bw == 0 end end diff --git a/libs/HMMTest/src/coherence.jl b/libs/HMMTest/src/coherence.jl index f5ca9147..c8983eea 100644 --- a/libs/HMMTest/src/coherence.jl +++ b/libs/HMMTest/src/coherence.jl @@ -1,27 +1,31 @@ infnorm(x) = maximum(abs, x) -function are_equal_hmms( +function test_equal_hmms( hmm1::AbstractHMM, hmm2::AbstractHMM, control_seq::AbstractVector; atol::Real, init::Bool, - test::Bool, + flip::Bool=false, ) - equal_check = true - if init init1 = initialization(hmm1) init2 = initialization(hmm2) - test && @test isapprox(init1, init2; atol, norm=infnorm) - equal_check = equal_check && isapprox(init1, init2; atol, norm=infnorm) + if flip + @test !isapprox(init1, init2; atol, norm=infnorm) + else + @test isapprox(init1, init2; atol, norm=infnorm) + end end for control in control_seq trans1 = transition_matrix(hmm1, control) trans2 = transition_matrix(hmm2, control) - test && @test isapprox(trans1, trans2; atol, norm=infnorm) - equal_check = equal_check && isapprox(trans1, trans2; atol, norm=infnorm) + if flip + @test !isapprox(trans1, trans2; atol, norm=infnorm) + else + @test isapprox(trans1, trans2; atol, norm=infnorm) + end end for control in control_seq @@ -29,18 +33,23 @@ function are_equal_hmms( dists2 = obs_distributions(hmm2, control) for (dist1, dist2) in zip(dists1, dists2) for field in fieldnames(typeof(dist1)) - if startswith(string(field), "log") + if startswith(string(field), "log") || + contains("σ", string(field)) || + contains("Σ", string(field)) continue end x1 = getfield(dist1, field) x2 = getfield(dist2, field) - test && @test isapprox(x1, x2; atol, norm=infnorm) - equal_check = equal_check && isapprox(x1, x2; atol, norm=infnorm) + if flip + @test !isapprox(x1, x2; atol, norm=infnorm) + else + @test isapprox(x1, x2; atol, norm=infnorm) + end end end end - return equal_check + return nothing end function test_coherent_algorithms( @@ -81,8 +90,8 @@ function test_coherent_algorithms( if !isnothing(hmm_guess) hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends) @test all(>=(0), diff(logL_evolution)) - @test !are_equal_hmms(hmm, hmm_guess, control_seq[1:2]; atol, init, test=false) - are_equal_hmms(hmm, hmm_est, control_seq[1:2]; atol, init, test=true) + test_equal_hmms(hmm, hmm_guess, control_seq[1:2]; atol, init, flip=true) + test_equal_hmms(hmm, hmm_est, control_seq[1:2]; atol, init) end end end diff --git a/libs/HMMTest/src/hmmbase.jl b/libs/HMMTest/src/hmmbase.jl index 1ecf76fd..fe1e48e3 100644 --- a/libs/HMMTest/src/hmmbase.jl +++ b/libs/HMMTest/src/hmmbase.jl @@ -50,14 +50,8 @@ function test_identical_hmmbase( @test isapprox( logL_evolution[(begin + 1):end], 2 * logL_evolution_base[begin:(end - 1)] ) - are_equal_hmms( - hmm_est, - HMM(hmm_est_base.a, hmm_est_base.A, hmm_est_base.B), - [nothing]; - atol, - init=true, - test=true, - ) + hmm_est_base_converted = HMM(hmm_est_base.a, hmm_est_base.A, hmm_est_base.B) + test_equal_hmms(hmm_est, hmm_est_base_converted, [nothing]; atol, init=true) end end end diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 19c51340..4ebf8e65 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -26,6 +26,13 @@ function Base.copy(hmm::HMM) return HMM(copy(hmm.init), copy(hmm.trans), copy(hmm.dists)) end +function Base.show(io::IO, hmm::HMM) + return print( + io, + "Hidden Markov Model with:\n - initialization: $(hmm.init)\n - transition matrix: $(hmm.trans)\n - observation distributions: $(hmm.dists)", + ) +end + initialization(hmm::HMM) = hmm.init transition_matrix(hmm::HMM) = hmm.trans obs_distributions(hmm::HMM) = hmm.dists diff --git a/src/utils/linalg.jl b/src/utils/linalg.jl index 550c160b..e3fe2b6f 100644 --- a/src/utils/linalg.jl +++ b/src/utils/linalg.jl @@ -18,6 +18,7 @@ function mul_rows_cols!( @argcheck size(B) == size(A) == (length(l), length(r)) @argcheck nnz(B) == nnz(A) for j in axes(B, 2) + @argcheck nzrange(B, j) == nzrange(A, j) for k in nzrange(B, j) i = B.rowval[k] B.nzval[k] = l[i] * A.nzval[k] * r[j] diff --git a/test/correctness.jl b/test/correctness.jl index 639c3b7b..c7dc84a8 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -18,8 +18,8 @@ T, K = 100, 200 init = [0.4, 0.6] init_guess = [0.5, 0.5] -trans = [0.8 0.2; 0.2 0.8] -trans_guess = [0.7 0.3; 0.3 0.7] +trans = [0.7 0.3; 0.3 0.7] +trans_guess = [0.6 0.4; 0.4 0.6] p = [[0.8, 0.2], [0.2, 0.8]] p_guess = [[0.7, 0.3], [0.3, 0.7]] @@ -58,9 +58,7 @@ end hmm_guess = HMM(init_guess, trans_guess, dists_guess) test_identical_hmmbase(rng, hmm, T; hmm_guess) - @test_skip test_coherent_algorithms( - rng, hmm, control_seq; seq_ends, hmm_guess, init=false - ) + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) end From fa52fbea3d3484307121928a902f9fb275f57775 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 22 Feb 2024 15:25:49 +0100 Subject: [PATCH 3/5] Better autodiff tuto --- examples/autodiff.jl | 200 +++++++++++++++++++++++++++------- examples/types.jl | 4 +- libs/HMMTest/src/coherence.jl | 6 +- src/inference/baum_welch.jl | 2 +- src/types/abstract_hmm.jl | 8 +- src/types/hmm.jl | 2 +- src/utils/lightcategorical.jl | 5 +- src/utils/lightdiagnormal.jl | 12 +- test/correctness.jl | 35 +----- 9 files changed, 183 insertions(+), 91 deletions(-) diff --git a/examples/autodiff.jl b/examples/autodiff.jl index 5928f4d9..afff758c 100644 --- a/examples/autodiff.jl +++ b/examples/autodiff.jl @@ -1,7 +1,7 @@ # # Autodiff #= -Here we show how to compute gradients of the observation sequence loglikelihood with respect to various parameters. +Here we show how to compute gradients of the observation sequence loglikelihood with respect to various inputs. =# using ComponentArrays @@ -18,90 +18,189 @@ using StatsAPI using Test #src using Zygote: Zygote -Enzyme.API.runtimeActivity!(true) - #- rng = StableRNG(63); -# ## Data generation +# ## Diffusion HMM + +#= +To play around with automatic differentiation, we define a simple controlled HMM. +=# + +struct DiffusionHMM{V1<:AbstractVector,M2<:AbstractMatrix,V3<:AbstractVector} <: AbstractHMM + init::V1 + trans::M2 + means::V3 +end + +#= +Both its transition matrix and its vector of observation means result from a convex combination between the corresponding field and a base value (aka diffusion). +The coefficient $\lambda$ of this convex combination is given as a control. +=# + +HMMs.initialization(hmm::DiffusionHMM) = hmm.init + +function HMMs.transition_matrix(hmm::DiffusionHMM, λ::Number) + N = length(hmm) + return (1 - λ) * hmm.trans + λ * ones(N, N) / N +end + +function HMMs.obs_distributions(hmm::DiffusionHMM, λ::Number) + return [Normal((1 - λ) * hmm.means[i] + λ * 0) for i in 1:length(hmm)] +end + +#= +We now construct an instance of this object and draw samples from it. +=# init = [0.6, 0.4] trans = [0.7 0.3; 0.3 0.7] means = [-1.0, 1.0] -dists = Normal.(means) -hmm = HMM(init, trans, dists); +hmm = DiffusionHMM(init, trans, means); -#- +#= +It is essential that the controls are taken between $0$ and $1$. +=# + +control_seqs = [rand(rng, 3), rand(rng, 5)]; +obs_seqs = [rand(rng, hmm, control_seqs[k]).obs_seq for k in 1:2]; -obs_seqs = [rand(rng, hmm, 10).obs_seq, rand(rng, hmm, 20).obs_seq]; +control_seq = reduce(vcat, control_seqs) obs_seq = reduce(vcat, obs_seqs) seq_ends = cumsum(length.(obs_seqs)); -# ## Forward mode +# ## What to differentiate? #= -Since all of our code is type-generic, it is amenable to forward-mode automatic differentiation with [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). +The key function we are interested in is the loglikelihood of the observation sequence. +We can differentiate it with respect to +- the model itself (`hmm`), or more precisely its parameters +- the observation sequence (`obs_seq`) +- the control sequence (`control_seq`). +- but not with respect to the sequence limits (`seq_ends`), which are discrete. +=# + +logdensityof(hmm, obs_seq, control_seq; seq_ends) -Because this backend only accepts a single vector input, we wrap all parameters with [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl), and define a new function to differentiate. +#= +To ensure compatibility with backends that only accept a single input, we wrap all parameters inside a `ComponentVector` from [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl), and define a new function to differentiate. =# -params = ComponentVector(; init, trans, means) +parameters = ComponentVector(; init, trans, means) -function f(params::ComponentVector) - new_hmm = HMM(params.init, params.trans, Normal.(params.means)) - return logdensityof(new_hmm, obs_seq; seq_ends) +function f(parameters::ComponentVector, obs_seq, control_seq; seq_ends) + new_hmm = DiffusionHMM(parameters.init, parameters.trans, parameters.means) + return logdensityof(new_hmm, obs_seq, control_seq; seq_ends) end; +f(parameters, obs_seq, control_seq; seq_ends) + +# ## Forward mode + #= -The gradient computation is now straightforward. -We will use this value as a source of truth to compare with reverse mode. +Since all of our code is type-generic, it is amenable to forward-mode automatic differentiation with [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). + +Because ForwardDiff.jl only accepts a single input, we must compute derivatives one at a time. =# -grad_f = ForwardDiff.gradient(f, params) +∇parameters_forwarddiff = ForwardDiff.gradient( + _parameters -> f(_parameters, obs_seq, control_seq; seq_ends), parameters +) + +#- + +∇obs_forwarddiff = ForwardDiff.gradient( + _obs_seq -> f(parameters, _obs_seq, control_seq; seq_ends), obs_seq +) + +#- + +∇control_forwarddiff = ForwardDiff.gradient( + _control_seq -> f(parameters, obs_seq, _control_seq; seq_ends), control_seq +) + +#= +These values will serve as ground truth when we compare with reverse mode. +=# -# ## Reverse mode +# ## Reverse mode with Zygote.jl #= In the presence of many parameters, reverse mode automatic differentiation of the loglikelihood will be much more efficient. -The package includes a chain rule for `logdensityof`, which means backends like [Zygote.jl](https://github.com/FluxML/Zygote.jl) can be used out of the box. +The package includes a handwritten chain rule for `logdensityof`, which means backends like [Zygote.jl](https://github.com/FluxML/Zygote.jl) can be used out of the box. +Using it, we can compute all derivatives at once. +=# + +∇all_zygote = Zygote.gradient( + (_a, _b, _c) -> f(_a, _b, _c; seq_ends), parameters, obs_seq, control_seq +); + +∇parameters_zygote, ∇obs_zygote, ∇control_zygote = ∇all_zygote; + +#= +We can check the results to validate our chain rule. =# -grad_z = Zygote.gradient(f, params)[1] +∇parameters_zygote ≈ ∇parameters_forwarddiff #- -grad_f ≈ grad_z +∇obs_zygote ≈ ∇obs_forwarddiff + +#- + +∇control_zygote ≈ ∇control_forwarddiff + +# ## Reverse mode with Enzyme.jl #= -The more efficient [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) also works natively but we have to avoid the type instability of global variables. +The more efficient [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) also works natively as long as there are no type instabilities, which is why we avoid the closure and the keyword arguments with `f_aux`: =# -function f_extended(params::ComponentVector, obs_seq, seq_ends) - new_hmm = HMM(params.init, params.trans, Normal.(params.means)) - return logdensityof(new_hmm, obs_seq; seq_ends) -end; +function f_aux(parameters, obs_seq, control_seq, seq_ends) + return f(parameters, obs_seq, control_seq; seq_ends) +end + +#= +Enzyme.jl requires preallocated storage for the gradients, which we happily provide. +=# -shadow_params = Enzyme.make_zero(params) +∇parameters_enzyme = Enzyme.make_zero(parameters) +∇obs_enzyme = Enzyme.make_zero(obs_seq) +∇control_enzyme = Enzyme.make_zero(control_seq) + +#= +The syntax is a bit more complex, see the Enzyme.jl docs for details. +=# Enzyme.autodiff( Enzyme.Reverse, - f_extended, + f_aux, Enzyme.Active, - Enzyme.Duplicated(params, shadow_params), - Enzyme.Const(obs_seq), - Enzyme.Duplicated(seq_ends, Enzyme.make_zero(seq_ends)), + Enzyme.Duplicated(parameters, ∇parameters_enzyme), + Enzyme.Duplicated(obs_seq, ∇obs_enzyme), + Enzyme.Duplicated(control_seq, ∇control_enzyme), + Enzyme.Const(seq_ends), ) -grad_e = shadow_params +#= +Once again we can check the results. +=# + +∇parameters_enzyme ≈ ∇parameters_forwarddiff + +#- + +∇obs_enzyme ≈ ∇obs_forwarddiff #- -grad_e ≈ grad_f +∇control_enzyme ≈ ∇control_forwarddiff #= -For increased efficiency, one can provide temporary storage to Enzyme.jl in order to avoid allocations. -This requires going one level deeper, by leveraging the in-place [`HiddenMarkovModels.forward!`](@ref) function. +For increased efficiency, we could provide temporary storage to Enzyme.jl in order to avoid allocations. +This requires going one level deeper and leveraging the in-place [`HiddenMarkovModels.forward!`](@ref) function. =# # ## Gradient methods @@ -122,5 +221,30 @@ Still, first order optimization can be relevant when we lack explicit formulas f # ## Tests #src -@test grad_f ≈ grad_z #src -@test grad_e ≈ grad_f #src +@testset "Gradient correctness" begin #src + @testset "ForwardDiff" begin #src + @test all(!iszero, ∇parameters_forwarddiff) #src + @test all(!iszero, ∇obs_forwarddiff) #src + @test all(!iszero, ∇control_forwarddiff) #src + @test all(isfinite, ∇parameters_forwarddiff) #src + @test all(isfinite, ∇obs_forwarddiff) #src + @test all(isfinite, ∇control_forwarddiff) #src + end #src + @testset "Zygote" begin #src + @test ∇parameters_zygote ≈ ∇parameters_forwarddiff #src + @test ∇obs_zygote ≈ ∇obs_forwarddiff #src + @test ∇control_zygote ≈ ∇control_forwarddiff #src + end #src + @testset "Enzyme" begin #src + @test ∇parameters_enzyme ≈ ∇parameters_forwarddiff #src + @test ∇obs_enzyme ≈ ∇obs_forwarddiff #src + @test ∇control_enzyme ≈ ∇control_forwarddiff #src + end #src +end #src + +control_seqs = [rand(rng, rand(rng, 100:200)) for k in 1:100] #src +control_seq = reduce(vcat, control_seqs) #src +seq_ends = cumsum(length.(control_seqs)) #src + +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, init=false) #src +test_type_stability(rng, hmm, control_seq; seq_ends) #src diff --git a/examples/types.jl b/examples/types.jl index dc4c074a..7160dd27 100644 --- a/examples/types.jl +++ b/examples/types.jl @@ -97,7 +97,7 @@ trans = sparse([ #- init = [0.2, 0.6, 0.2] -dists = [Normal(-2.0), Normal(0.0), Normal(+2.0)] +dists = [Normal(1.0), Normal(2.0), Normal(3.0)] hmm = HMM(init, trans, dists) #= @@ -129,7 +129,7 @@ trans_guess = sparse([ 0 0.6 0.4 0.4 0 0.6 ]) -dists_guess = [Normal(-1.5), Normal(0.0), Normal(+1.5)] +dists_guess = [Normal(1.2), Normal(2.2), Normal(3.2)] hmm_guess = HMM(init_guess, trans_guess, dists_guess); #- diff --git a/libs/HMMTest/src/coherence.jl b/libs/HMMTest/src/coherence.jl index c8983eea..2b4ce8d6 100644 --- a/libs/HMMTest/src/coherence.jl +++ b/libs/HMMTest/src/coherence.jl @@ -33,11 +33,7 @@ function test_equal_hmms( dists2 = obs_distributions(hmm2, control) for (dist1, dist2) in zip(dists1, dists2) for field in fieldnames(typeof(dist1)) - if startswith(string(field), "log") || - contains("σ", string(field)) || - contains("Σ", string(field)) - continue - end + string(field) in ("μ", "p") || continue x1 = getfield(dist1, field) x2 = getfield(dist2, field) if flip diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 3bec0ac2..1bc25665 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -73,7 +73,7 @@ function baum_welch( seq_ends, atol, max_iterations, - loglikelihood_increasing, + loglikelihood_increasing=false, ) return hmm, logL_evolution end diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index d88b9a8b..c919e145 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -21,7 +21,7 @@ Any `AbstractHMM` which satisfies the interface can be given to the following fu - [`forward`](@ref) - [`viterbi`](@ref) - [`forward_backward`](@ref) -- [`baum_welch`](@ref) (if `fit!` is implemented) +- [`baum_welch`](@ref) (if `[fit!](@ref)` is implemented) """ abstract type AbstractHMM end @@ -80,12 +80,14 @@ These distribution objects should implement """ obs_distributions(hmm::AbstractHMM, control) = obs_distributions(hmm) -function obs_logdensities!(logb::AbstractVector, hmm::AbstractHMM, obs, control) +function obs_logdensities!( + logb::AbstractVector{T}, hmm::AbstractHMM, obs, control +) where {T} dists = obs_distributions(hmm, control) @inbounds @simd for i in eachindex(logb, dists) logb[i] = logdensityof(dists[i], obs) end - @argcheck all(<(typemax(eltype(logb))), logb) + @argcheck maximum(logb) < typemax(T) return nothing end diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 4ebf8e65..955c8ace 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -29,7 +29,7 @@ end function Base.show(io::IO, hmm::HMM) return print( io, - "Hidden Markov Model with:\n - initialization: $(hmm.init)\n - transition matrix: $(hmm.trans)\n - observation distributions: $(hmm.dists)", + "Hidden Markov Model with:\n - initialization: $(hmm.init)\n - transition matrix: $(hmm.trans)\n - observation distributions: [$(join(hmm.dists, ", "))]", ) end diff --git a/src/utils/lightcategorical.jl b/src/utils/lightcategorical.jl index f0786717..17627bea 100644 --- a/src/utils/lightcategorical.jl +++ b/src/utils/lightcategorical.jl @@ -48,10 +48,11 @@ function DensityInterface.logdensityof(dist::LightCategorical, k::Integer) end function StatsAPI.fit!(dist::LightCategorical{T1}, x, w) where {T1} + @argcheck 1 <= minimum(x) <= maximum(x) <= length(dist.p) w_tot = sum(w) dist.p .= zero(T1) - for (xᵢ, wᵢ) in zip(x, w) - dist.p[xᵢ] += wᵢ + @inbounds @simd for i in eachindex(x, w) + dist.p[x[i]] += w[i] end dist.p ./= w_tot dist.logp .= log.(dist.p) diff --git a/src/utils/lightdiagnormal.jl b/src/utils/lightdiagnormal.jl index 1d0b05a3..17d114cb 100644 --- a/src/utils/lightdiagnormal.jl +++ b/src/utils/lightdiagnormal.jl @@ -54,14 +54,14 @@ function StatsAPI.fit!(dist::LightDiagNormal{T1,T2}, x, w) where {T1,T2} w_tot = sum(w) dist.μ .= zero(T1) dist.σ .= zero(T2) - for (xᵢ, wᵢ) in zip(x, w) - dist.μ .+= xᵢ .* wᵢ - dist.σ .+= abs2.(xᵢ) .* wᵢ + @inbounds @simd for i in eachindex(x, w) + dist.μ .+= x[i] .* w[i] end dist.μ ./= w_tot - dist.σ ./= w_tot - dist.σ .-= min.(abs2.(dist.μ), dist.σ) - dist.σ .= sqrt.(dist.σ) + @inbounds @simd for i in eachindex(x, w) + dist.σ .+= abs2.(x[i] .- dist.μ) .* w[i] + end + dist.σ .= sqrt.(dist.σ ./ w_tot) dist.logσ .= log.(dist.σ) @argcheck all(isfinite, dist.μ) @argcheck all(isfinite, dist.σ) diff --git a/test/correctness.jl b/test/correctness.jl index c7dc84a8..291f6e23 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -13,7 +13,7 @@ rng = StableRNG(63) ## Settings -T, K = 100, 200 +T, K = 50, 200 init = [0.4, 0.6] init_guess = [0.5, 0.5] @@ -25,7 +25,7 @@ p = [[0.8, 0.2], [0.2, 0.8]] p_guess = [[0.7, 0.3], [0.3, 0.7]] μ = [-ones(2), +ones(2)] -μ_guess = [-0.7 * ones(2), +0.7 * ones(2)] +μ_guess = [-0.8 * ones(2), +0.8 * ones(2)] σ = ones(2) @@ -85,34 +85,3 @@ end test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) end - -# Controlled - -struct DiffusionHMM{R1,R2,R3} <: AbstractHMM - init::Vector{R1} - trans::Matrix{R2} - means::Vector{R3} -end - -HMMs.initialization(hmm::DiffusionHMM) = hmm.init - -function HMMs.transition_matrix(hmm::DiffusionHMM, λ::Number) - N = length(hmm) - return (1 - λ) * hmm.trans + λ * ones(N, N) / N -end - -function HMMs.obs_distributions(hmm::DiffusionHMM, λ::Number) - return [Normal((1 - λ) * hmm.means[i]) for i in 1:length(hmm)] -end - -@testset "Controlled" begin - means = randn(rng, 2) - hmm = DiffusionHMM(init, trans, means) - - control_seqs = [[rand(rng) for t in 1:rand(T:(2T))] for k in 1:K] - control_seq = reduce(vcat, control_seqs) - seq_ends = cumsum(length.(control_seqs)) - - test_coherent_algorithms(rng, hmm, control_seq; seq_ends, init=false) - test_type_stability(rng, hmm, control_seq; seq_ends) -end From d6016150936f5c003148ab3486f56649ff70d684 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 22 Feb 2024 15:44:53 +0100 Subject: [PATCH 4/5] Typos --- docs/make.jl | 13 +++++++------ docs/src/alternatives.md | 36 +++++++++++++++++------------------- examples/autodiff.jl | 2 +- 3 files changed, 25 insertions(+), 26 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 8093e1d7..37f1c8bc 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -42,11 +42,7 @@ function literate_title(path) end pages = [ - "First steps" => [ - "Home" => "index.md", - "Alternatives" => "alternatives.md", - "API reference" => "api.md", - ], + "Home" => "index.md", "Tutorials" => [ "Basics" => joinpath("examples", "basics.md"), "Types" => joinpath("examples", "types.md"), @@ -55,7 +51,12 @@ pages = [ "Control dependency" => joinpath("examples", "controlled.md"), "Autodiff" => joinpath("examples", "autodiff.md"), ], - "Advanced" => ["Debugging" => "debugging.md", "Formulas" => "formulas.md"], + "API reference" => "api.md", + "Advanced" => [ + "Alternatives" => "alternatives.md", + "Debugging" => "debugging.md", + "Formulas" => "formulas.md", + ], ] fmt = Documenter.HTML(; diff --git a/docs/src/alternatives.md b/docs/src/alternatives.md index 21aa3792..0369f5ed 100644 --- a/docs/src/alternatives.md +++ b/docs/src/alternatives.md @@ -4,26 +4,25 @@ We compare features among the following Julia packages: -* HiddenMarkovModels.jl (abbreviated to HMMs.jl) +* HiddenMarkovModels.jl * [HMMBase.jl](https://github.com/maxmouchet/HMMBase.jl) * [HMMGradients.jl](https://github.com/idiap/HMMGradients.jl) We discard [MarkovModels.jl](https://github.com/FAST-ASR/MarkovModels.jl) because its focus is GPU computation. There are also more generic packages for probabilistic programming, which are able to perform MCMC or variational inference (eg. [Turing.jl](https://github.com/TuringLang/Turing.jl)) but we leave those aside. -| | HMMs.jl | HMMBase.jl | HMMGradients.jl | -| ------------------------- | ------------------- | ---------------- | --------------- | -| Algorithms[^1] | V, FB, BW | V, FB, BW | FB | -| Number types | anything | `Float64` | `AbstractFloat` | -| Observation types | anything | number or vector | anything | -| Observation distributions | DensityInterface.jl | Distributions.jl | manual | -| Multiple sequences | yes | no | yes | -| Priors / structures | possible | no | possible | -| Temporal dependency | yes | no | no | -| Control dependency | yes | no | no | -| Automatic differentiation | yes | no | yes | -| Linear algebra speedup | yes | yes | no | -| Numerical stability | scaling+ | scaling+ | log | +| | HiddenMarkovModels.jl | HMMBase.jl | HMMGradients.jl | +| ------------------------- | --------------------- | ---------------- | --------------- | +| Algorithms[^1] | V, FB, BW | V, FB, BW | FB | +| Number types | anything | `Float64` | `AbstractFloat` | +| Observation types | anything | number or vector | anything | +| Observation distributions | DensityInterface.jl | Distributions.jl | manual | +| Multiple sequences | yes | no | yes | +| Priors / structures | possible | no | possible | +| Control dependency | yes | no | no | +| Automatic differentiation | yes | no | yes | +| Linear algebra speedup | yes | yes | no | +| Numerical stability | scaling+ | scaling+ | log | !!! info "Very small probabilities" @@ -43,17 +42,16 @@ We compare features among the following Python packages: | | hmmlearn | pomegranate | dynamax | | ------------------------- | -------------------- | --------------------- | -------------------- | -| Algorithms[^1] | V, FB, BW, VI | V, FB, BW | FB, V, BW, GD | -| Number types | NumPy format | PyTorch format | JAX format | +| Algorithms[^1] | V, FB, BW, VI | FB, BW | FB, V, BW, GD | +| Number types | NumPy formats | PyTorch formats | JAX formats | | Observation types | number or vector | number or vector | number or vector | | Observation distributions | discrete or Gaussian | pomegranate catalogue | discrete or Gaussian | | Multiple sequences | yes | yes | yes | -| Priors / structures | yes | no | ? | -| Temporal dependency | no | no | no | +| Priors / structures | yes | no | yes | | Control dependency | no | no | no | | Automatic differentiation | no | yes | yes | | Linear algebra speedup | yes | yes | yes | -| Logarithmic probabilities | scaling / log | log | log | +| Numerical stability | scaling / log | log | log | [^1]: V = Viterbi, FB = Forward-Backward, BW = Baum-Welch, VI = Variational Inference, GD = Gradient Descent diff --git a/examples/autodiff.jl b/examples/autodiff.jl index afff758c..a1d96d10 100644 --- a/examples/autodiff.jl +++ b/examples/autodiff.jl @@ -168,7 +168,7 @@ Enzyme.jl requires preallocated storage for the gradients, which we happily prov ∇parameters_enzyme = Enzyme.make_zero(parameters) ∇obs_enzyme = Enzyme.make_zero(obs_seq) -∇control_enzyme = Enzyme.make_zero(control_seq) +∇control_enzyme = Enzyme.make_zero(control_seq); #= The syntax is a bit more complex, see the Enzyme.jl docs for details. From e92f1a3c26ac9f1b46a045fb3c7d1465e51eb0cc Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 22 Feb 2024 16:02:05 +0100 Subject: [PATCH 5/5] Fix table --- docs/src/alternatives.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/alternatives.md b/docs/src/alternatives.md index 0369f5ed..80882346 100644 --- a/docs/src/alternatives.md +++ b/docs/src/alternatives.md @@ -45,10 +45,10 @@ We compare features among the following Python packages: | Algorithms[^1] | V, FB, BW, VI | FB, BW | FB, V, BW, GD | | Number types | NumPy formats | PyTorch formats | JAX formats | | Observation types | number or vector | number or vector | number or vector | -| Observation distributions | discrete or Gaussian | pomegranate catalogue | discrete or Gaussian | +| Observation distributions | hmmlearn catalogue | pomegranate catalogue | dynamax catalogue | | Multiple sequences | yes | yes | yes | | Priors / structures | yes | no | yes | -| Control dependency | no | no | no | +| Control dependency | no | no | yes | | Automatic differentiation | no | yes | yes | | Linear algebra speedup | yes | yes | yes | | Numerical stability | scaling / log | log | log |