diff --git a/Project.toml b/Project.toml index 29a61cce..690b4dc2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,18 +1,19 @@ name = "HiddenMarkovModels" uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47" authors = ["Guillaume Dalle"] -version = "0.3.1" +version = "0.4.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -RequiredInterfaces = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" [weakdeps] @@ -27,6 +28,7 @@ HiddenMarkovModelsHMMBaseExt = "HMMBase" ChainRulesCore = "1.16" DensityInterface = "0.4" Distributions = "0.25" +DocStringExtensions = "0.9" LinearAlgebra = "1.6" PrecompileTools = "1.1" Random = "1.6" diff --git a/benchmark/utils/hmms.jl b/benchmark/utils/hmms.jl index 027505c1..972e1981 100644 --- a/benchmark/utils/hmms.jl +++ b/benchmark/utils/hmms.jl @@ -1,5 +1,5 @@ using BenchmarkTools -using HiddenMarkovModels: HMMs +using HiddenMarkovModels using SimpleUnPack function rand_params_hmms(; N, D) @@ -15,9 +15,9 @@ function rand_model_hmms(; N, D) if D == 1 dists = [Normal(μ[n, 1], σ[n, 1]) for n in 1:N] else - dists = [HMMs.LightDiagNormal(μ[n, :], σ[n, :]) for n in 1:N] + dists = [HiddenMarkovModels.LightDiagNormal(μ[n, :], σ[n, :]) for n in 1:N] end - model = HMMs.HMM(p, A, dists) + model = HiddenMarkovModels.HMM(p, A, dists) return model end @@ -30,22 +30,22 @@ function benchmarkables_hmms(; algos, N, D, T, K, I) end benchs = Dict() if "logdensity" in algos - benchs["logdensity"] = @benchmarkable HMMs.logdensityof(model, $obs_seqs, $K) setup = ( - model = rand_model_hmms(; N=$N, D=$D) - ) + benchs["logdensity"] = @benchmarkable HiddenMarkovModels.logdensityof( + model, $obs_seqs, $K + ) setup = (model = rand_model_hmms(; N=$N, D=$D)) end if "viterbi" in algos - benchs["viterbi"] = @benchmarkable HMMs.viterbi(model, $obs_seqs, $K) setup = ( + benchs["viterbi"] = @benchmarkable HiddenMarkovModels.viterbi(model, $obs_seqs, $K) setup = ( model = rand_model_hmms(; N=$N, D=$D) ) end if "forward_backward" in algos - benchs["forward_backward"] = @benchmarkable HMMs.forward_backward( + benchs["forward_backward"] = @benchmarkable HiddenMarkovModels.forward_backward( model, $obs_seqs, $K ) setup = (model = rand_model_hmms(; N=$N, D=$D)) end if "baum_welch" in algos - benchs["baum_welch"] = @benchmarkable HMMs.baum_welch( + benchs["baum_welch"] = @benchmarkable HiddenMarkovModels.baum_welch( model, $obs_seqs, $K; max_iterations=$I, atol=-Inf ) setup = (model = rand_model_hmms(; N=$N, D=$D)) end diff --git a/docs/make.jl b/docs/make.jl index 083fc774..85085118 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -31,12 +31,15 @@ pages = [ "Home" => "index.md", "Essentials" => ["Background" => "background.md", "API reference" => "api.md"], "Tutorials" => [ - "Built-in HMM" => "tuto_builtin.md", - "Custom HMM" => "tuto_custom.md", + "Built-in HMM" => "builtin.md", + "Custom HMM" => "custom.md", "Debugging" => "debugging.md", ], - "Alternatives" => - ["Features" => "alt_features.md", "Performance" => "alt_performance.md"], + "Alternatives" => if benchmarks_done + ["Features" => "features.md", "Benchmarks" => "benchmarks.md"] + else + ["Features" => "features.md"] + end, "Advanced" => ["Formulas" => "formulas.md", "Roadmap" => "roadmap.md"], ] @@ -54,7 +57,7 @@ makedocs(; format=fmt, pages=pages, plugins=[bib], - warnonly=!benchmarks_done, + pagesonly=true, ) deploydocs(; repo="github.com/gdalle/HiddenMarkovModels.jl", devbranch="main") diff --git a/docs/src/api.md b/docs/src/api.md index 3f58652e..c027e18c 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -6,22 +6,10 @@ HiddenMarkovModels ## Types -### Markov chains - -```@docs -AbstractMarkovChain -MarkovChain -AbstractMC -MC -``` - -### Hidden Markov Models - ```@docs AbstractHiddenMarkovModel HiddenMarkovModel AbstractHMM -HiddenMarkovModels.PermutedHMM HMM ``` @@ -30,9 +18,10 @@ HMM ```@docs rand length -initial_distribution +eltype +initialization transition_matrix -obs_distribution +obs_distributions ``` ## Inference @@ -42,22 +31,28 @@ logdensityof forward viterbi forward_backward +baum_welch +fit! ``` -## Learning +## Misc ```@docs -fit! -fit -baum_welch +check_hmm +rand_prob_vec +rand_trans_mat ``` ## Internals ```@docs -HMMs.ForwardBackwardStorage -HMMs.fit_element_from_sequence! -HMMs.LightDiagNormal +HiddenMarkovModels.ForwardStorage +HiddenMarkovModels.ViterbiStorage +HiddenMarkovModels.ForwardBackwardStorage +HiddenMarkovModels.BaumWelchStorage +HiddenMarkovModels.fit_element_from_sequence! +HiddenMarkovModels.LightDiagNormal +HiddenMarkovModels.PermutedHMM ``` ## Notations @@ -71,12 +66,13 @@ HMMs.LightDiagNormal ### Models and simulations -- `p` or `init`: initial_distribution (vector of state probabilities) +- `p` or `init`: initialization (vector of state probabilities) - `A` or `trans`: transition_matrix (matrix of transition probabilities) -- `dists`: observation distribution (vector of `rand`-able and `logdensityof`-able objects) +- `d` or `dists`: observation distribution (vector of `rand`-able and `logdensityof`-able objects) - `state_seq`: a sequence of states (vector of integers) - `obs_seq`: a sequence of observations (vector of individual observations) - `obs_seqs`: several sequences of observations +- `nb_seqs`: number of observation sequences ### Forward backward @@ -84,9 +80,9 @@ HMMs.LightDiagNormal - `(log)B`: matrix of observation (log)likelihoods by state for a sequence of observations - `α`: scaled forward variables - `β`: scaled backward variables -- `γ`: one-state marginals -- `ξ`: two-state marginals -- `logL`: loglikelihood of a sequence of observations +- `γ`: state marginals +- `ξ`: transition marginals +- `logL`: posterior loglikelihood of a sequence of observations ## Index diff --git a/docs/src/alt_performance.md b/docs/src/benchmarks.md similarity index 92% rename from docs/src/alt_performance.md rename to docs/src/benchmarks.md index 88165e1b..bcf618e3 100644 --- a/docs/src/alt_performance.md +++ b/docs/src/benchmarks.md @@ -15,10 +15,6 @@ The test case is an HMM with diagonal multivariate normal observations. - `K`: number of trajectories - `I`: number of Baum-Welch iterations -!!! danger "Missing benchmarks?" - The benchmarks are computationally expensive and we only run them once for each new release. - If you don't see any plots below and the links are broken, you are probably on the dev documentation: go to the [stable documentation](https://gdalle.github.io/HiddenMarkovModels.jl/stable/) instead. - ## Reproducibility These benchmarks were generated in the following environment: [`setup.txt`](./assets/benchmark/results/setup.txt). diff --git a/docs/src/tuto_builtin.md b/docs/src/builtin.md similarity index 81% rename from docs/src/tuto_builtin.md rename to docs/src/builtin.md index d8801142..afa429f1 100644 --- a/docs/src/tuto_builtin.md +++ b/docs/src/builtin.md @@ -15,8 +15,8 @@ Creating a model: function gaussian_hmm(N; noise=0) p = ones(N) / N # initial distribution A = rand_trans_mat(N) # transition matrix - dists = [Normal(i + noise * randn(), 0.5) for i in 1:N] # observation distributions - return HMM(p, A, dists) + d = [Normal(i + noise * randn(), 0.5) for i in 1:N] # observation distributions + return HMM(p, A, d) end ``` @@ -29,7 +29,7 @@ transition_matrix(hmm) ``` ```@example tuto -[obs_distribution(hmm, i) for i in 1:N] +obs_distributions(hmm) ``` Simulating a sequence: @@ -70,15 +70,15 @@ first(logL_evolution), last(logL_evolution) Correcting state order because we know observation means are increasing in the true model: ```@example tuto -[obs_distribution(hmm_est, i) for i in 1:N] +d_est = obs_distributions(hmm_est) ``` ```@example tuto -perm = sortperm(1:3, by=i->obs_distribution(hmm_est, i).μ) +perm = sortperm(1:3, by=i->d_est[i].μ) ``` ```@example tuto -hmm_est = PermutedHMM(hmm_est, perm) +hmm_est = HiddenMarkovModels.PermutedHMM(hmm_est, perm) ``` Evaluating errors: diff --git a/docs/src/custom.md b/docs/src/custom.md new file mode 100644 index 00000000..c21154f1 --- /dev/null +++ b/docs/src/custom.md @@ -0,0 +1,13 @@ +# Tutorial - custom HMM + +```@example tuto +using Distributions +using HiddenMarkovModels + +using Random; Random.seed!(63) +``` + +Here we demonstrate how to build your own HMM structure satisfying the interface. + +!!! danger + Work in progress. \ No newline at end of file diff --git a/docs/src/debugging.md b/docs/src/debugging.md index 88a3e66d..0f759552 100644 --- a/docs/src/debugging.md +++ b/docs/src/debugging.md @@ -9,4 +9,4 @@ This can happen for a variety of reasons, so here are a few leads worth investig * Reduce the number of states (to make every one of them useful) * Add a prior to your transition matrix / observation distributions (to avoid degenerate behavior like zero variance in a Gaussian) * Pick a better initialization (to start closer to the supposed ground truth) -* Use [LogarithmicNumbers.jl](https://github.com/cjdoris/LogarithmicNumbers.jl) in strategic places (to guarantee numerical stability). Note that these numbers don't play nicely with Distributions.jl, so you may have to roll out your own observation distribution. \ No newline at end of file +* Use [LogarithmicNumbers.jl](https://github.com/cjdoris/LogarithmicNumbers.jl) in strategic places (to guarantee numerical stability). Note that these numbers don't play nicely with Distributions.jl, so you may have to roll out your own observation distribution. diff --git a/docs/src/alt_features.md b/docs/src/features.md similarity index 100% rename from docs/src/alt_features.md rename to docs/src/features.md diff --git a/docs/src/tuto_custom.md b/docs/src/tuto_custom.md deleted file mode 100644 index 7a6edc95..00000000 --- a/docs/src/tuto_custom.md +++ /dev/null @@ -1,135 +0,0 @@ -# Tutorial - custom HMM - -The built-in HMM is perfect when the initial state distribution, transition matrix and emission distributions can be updated independently with Maximum Likelihood Estimation. -But some of these parameters might be correlated, or fixed. -Or they might come with a prior, which forces you to use Maximum A Posteriori instead. - -In such cases, it is necessary to implement a new subtype of [`AbstractHMM`](@ref) with all its required methods. - -```@example tuto -using Distributions -using HiddenMarkovModels -using LinearAlgebra -using RequiredInterfaces -using StatsAPI - -using Random; Random.seed!(63) -``` - -## Interface - -To ascertain that a type indeed satisfies the interface, you can use [RequiredInterfaces.jl](https://github.com/Seelengrab/RequiredInterfaces.jl). - -```@example tuto -RequiredInterfaces.check_interface_implemented(AbstractHMM, HMM) -``` - -If your implementation is insufficient, the test will list missing methods. - -```@example tuto -struct EmptyHMM end -RequiredInterfaces.check_interface_implemented(AbstractHMM, EmptyHMM) -``` - -Note that this test does not check the `StatsAPI.fit!` method. -Since it is only used in the Baum-Welch algorithm, it is an optional part of the `AbstractHMM` interface. - -## Example - -We show how to implement an HMM whose initial distribution is always the equilibrium distribution of the underlying Markov chain. -The code that follows is not efficient (it leads to a lot of allocations), but it would be fairly easy to optimize if needed. - -The equilibrium distribution of a Markov chain is the (only) left eigenvector associated with the left eigenvalue $1$. - -```@example tuto -function markov_equilibrium(A) - p = real.(eigvecs(A')[:, end]) - return p ./ sum(p) -end -``` - -We now define our custom HMM by taking inspiration from `src/types/hmm.jl` and making a few modifications: - -```@example tuto -struct EquilibriumHMM{R,D} <: AbstractHMM - trans::Matrix{R} - dists::Vector{D} -end -``` - -The interface is only different as far as the initialization is concerned. - -```@example tuto -Base.length(hmm::EquilibriumHMM) = length(hmm.dists) -HMMs.initial_distribution(hmm::EquilibriumHMM) = markov_equilibrium(hmm.trans) # this is new -HMMs.transition_matrix(hmm::EquilibriumHMM) = hmm.trans -HMMs.obs_distribution(hmm::EquilibriumHMM, i::Integer) = hmm.dists[i] -``` - -As for fitting, we simply ignore the initialization count and copy the rest of the original code (with a few simplifications): - -```@example tuto -function StatsAPI.fit!( - hmm::EquilibriumHMM{R,D}, init_count, trans_count, obs_seq, state_marginals -) where {R,D} - hmm.trans .= trans_count ./ sum(trans_count, dims=2) - for i in 1:N - hmm.dists[i] = fit(D, obs_seq, state_marginals[i, :]) - end -end -``` - -Let's take our new model for a spin: - -```@example tuto -function gaussian_equilibrium_hmm(N; noise=0) - A = rand_trans_mat(N) - dists = [Normal(i + noise * randn(), 0.5) for i in 1:N] - return EquilibriumHMM(A, dists) -end; -``` - -```@example tuto -N = 3 -hmm = gaussian_equilibrium_hmm(N); -transition_matrix(hmm) -``` - -```@example tuto -[obs_distribution(hmm, i) for i in 1:N] -``` - -We can estimate parameters based on several observation sequences. -Note that as soon as we tamper with the re-estimation procedure, the loglikelihood is no longer guaranteed to increase during Baum-Welch, which is why we turn off the corresponding check. - -```@example tuto -T = 1000 -nb_seqs = 10 -obs_seqs = [rand(hmm, T).obs_seq for _ in 1:nb_seqs] - -hmm_init = gaussian_equilibrium_hmm(N; noise=1) -hmm_est, logL_evolution = baum_welch( - hmm_init, obs_seqs, nb_seqs; check_loglikelihood_increasing=false -); -first(logL_evolution), last(logL_evolution) -``` - -Let's correct the state order: - -```@example tuto -[obs_distribution(hmm_est, i) for i in 1:N] -``` - -```@example tuto -perm = sortperm(1:3, by=i->obs_distribution(hmm_est, i).μ) -``` - -```@example tuto -hmm_est = PermutedHMM(hmm_est, perm) -``` - -And finally evaluate the errors: - -```@example tuto -cat(transition_matrix(hmm_est), transition_matrix(hmm), dims=3) -``` \ No newline at end of file diff --git a/ext/HiddenMarkovModelsChainRulesCoreExt.jl b/ext/HiddenMarkovModelsChainRulesCoreExt.jl index c5a8c451..a0bfd1b1 100644 --- a/ext/HiddenMarkovModelsChainRulesCoreExt.jl +++ b/ext/HiddenMarkovModelsChainRulesCoreExt.jl @@ -4,12 +4,14 @@ using ChainRulesCore: ChainRulesCore, NoTangent, ZeroTangent, RuleConfig, rrule_via_ad, @not_implemented using DensityInterface: logdensityof using HiddenMarkovModels +import HiddenMarkovModels as HMMs using SimpleUnPack function _params_and_loglikelihoods(hmm::AbstractHMM, obs_seq) - p = initial_distribution(hmm) + p = initialization(hmm) A = transition_matrix(hmm) - logB = HiddenMarkovModels.loglikelihoods(hmm, obs_seq) + d = obs_distributions(hmm) + logB = reduce(hcat, logdensityof.(d, (obs,)) for obs in obs_seq) return p, A, logB end @@ -17,16 +19,16 @@ function ChainRulesCore.rrule( rc::RuleConfig, ::typeof(logdensityof), hmm::AbstractHMM, obs_seq ) (p, A, logB), pullback = rrule_via_ad(rc, _params_and_loglikelihoods, hmm, obs_seq) - fb = forward_backward(p, A, logB) - logL = HiddenMarkovModels.loglikelihood(fb) - @unpack α, β, γ, c, Bscaled, Bβscaled = fb + fb = HMMs.initialize_forward_backward(hmm, obs_seq) + HMMs.forward_backward!(fb, hmm, obs_seq) + @unpack α, β, γ, c, B̃β = fb T = length(obs_seq) function logdensityof_hmm_pullback(ΔlogL) - Δp = ΔlogL .* Bβscaled[:, 1] - ΔA = ΔlogL .* α[:, 1] .* Bβscaled[:, 2]' + Δp = ΔlogL .* B̃β[:, 1] + ΔA = ΔlogL .* α[:, 1] .* B̃β[:, 2]' @views for t in 2:(T - 1) - ΔA .+= ΔlogL .* α[:, t] .* Bβscaled[:, t + 1]' + ΔA .+= ΔlogL .* α[:, t] .* B̃β[:, t + 1]' end ΔlogB = ΔlogL .* γ @@ -35,7 +37,7 @@ function ChainRulesCore.rrule( return Δlogdensityof, Δhmm, Δobs_seq end - return logL, logdensityof_hmm_pullback + return fb.logL[], logdensityof_hmm_pullback end end diff --git a/ext/HiddenMarkovModelsHMMBaseExt.jl b/ext/HiddenMarkovModelsHMMBaseExt.jl index 397ba3bc..9488107f 100644 --- a/ext/HiddenMarkovModelsHMMBaseExt.jl +++ b/ext/HiddenMarkovModelsHMMBaseExt.jl @@ -11,9 +11,9 @@ function HiddenMarkovModels.HMM(hmm_base::HMMBase.HMM) end function HMMBase.HMM(hmm::HiddenMarkovModels.HMM) - a = deepcopy(initial_distribution(hmm)) + a = deepcopy(initialization(hmm)) A = deepcopy(transition_matrix(hmm)) - B = deepcopy(obs_distribution.(Ref(hmm), 1:length(hmm))) + B = deepcopy(obs_distributions(hmm)) return HMMBase.HMM(a, A, B) end diff --git a/src/HMMTest.jl b/src/HMMTest.jl new file mode 100644 index 00000000..d218f1a5 --- /dev/null +++ b/src/HMMTest.jl @@ -0,0 +1,49 @@ +module HMMTest + +using Distributions +using Distributions: PDiagMat +using HiddenMarkovModels +using HiddenMarkovModels: LightDiagNormal, sum_to_one! +using LinearAlgebra +using SparseArrays + +export rand_categorical_hmm +export rand_gaussian_hmm_1d +export rand_gaussian_hmm_2d +export rand_gaussian_hmm_2d_light + +function sparse_trans_mat(N) + A = sparse(SymTridiagonal(ones(N), ones(N - 1))) + foreach(sum_to_one!, eachrow(A)) + return A +end + +function rand_categorical_hmm(N, D; sparse_trans=false) + p = ones(N) / N + A = sparse_trans ? sparse_trans_mat(N) : rand_trans_mat(N) + d = [Categorical(rand_prob_vec(D)) for i in 1:N] + return HMM(p, A, d) +end + +function rand_gaussian_hmm_1d(N; sparse_trans=false) + p = ones(N) / N + A = sparse_trans ? sparse_trans_mat(N) : rand_trans_mat(N) + d = [Normal(randn(), 1) for i in 1:N] + return HMM(p, A, d) +end + +function rand_gaussian_hmm_2d(N, D; sparse_trans=false) + p = ones(N) / N + A = sparse_trans ? sparse_trans_mat(N) : rand_trans_mat(N) + d = [DiagNormal(randn(D), PDiagMat(ones(D) .^ 2)) for i in 1:N] + return HMM(p, A, d) +end + +function rand_gaussian_hmm_2d_light(N, D; sparse_trans=false) + p = ones(N) / N + A = sparse_trans ? sparse_trans_mat(N) : rand_trans_mat(N) + d = [LightDiagNormal(randn(D), ones(D)) for i in 1:N] + return HMM(p, A, d) +end + +end diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index c40acf3e..1fd68ac7 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -3,12 +3,13 @@ A Julia package for HMM modeling, simulation, inference and learning. -The alias `HMMs` is exported for the package name. +# Exports + +$(EXPORTS) """ module HiddenMarkovModels -const HMMs = HiddenMarkovModels - +using Base: RefValue using Base.Threads: @threads using DensityInterface: DensityInterface, DensityKind, HasDensity, NoDensity, densityof, logdensityof @@ -19,43 +20,42 @@ using Distributions: UnivariateDistribution, MultivariateDistribution, MatrixDistribution +using DocStringExtensions using LinearAlgebra: Diagonal, dot, mul! using PrecompileTools: @compile_workload, @setup_workload -using Random: AbstractRNG, default_rng -using RequiredInterfaces: @required +using Random: Random, AbstractRNG, default_rng using Requires: @require using SimpleUnPack: @unpack +using SparseArrays: SparseMatrixCSC, nzrange, nnz using StatsAPI: StatsAPI, fit, fit! -export HMMs -export AbstractMarkovChain, AbstractMC -export MarkovChain, MC -export AbstractHiddenMarkovModel, AbstractHMM, PermutedHMM +export AbstractHiddenMarkovModel, AbstractHMM export HiddenMarkovModel, HMM export rand_prob_vec, rand_trans_mat -export initial_distribution, transition_matrix, obs_distribution +export initialization, transition_matrix, obs_distributions export logdensityof, viterbi, forward, forward_backward, baum_welch -export fit, fit! -export LightDiagNormal +export fit! +export check_hmm -include("types/abstract_mc.jl") -include("types/mc.jl") include("types/abstract_hmm.jl") -include("types/hmm.jl") +include("types/permuted_hmm.jl") include("utils/check.jl") include("utils/probvec.jl") include("utils/transmat.jl") include("utils/fit.jl") include("utils/lightdiagnormal.jl") +include("utils/mul.jl") -include("inference/loglikelihoods.jl") include("inference/forward.jl") include("inference/viterbi.jl") include("inference/forward_backward.jl") -include("inference/sufficient_stats.jl") include("inference/baum_welch.jl") +include("types/hmm.jl") + +include("HMMTest.jl") + if !isdefined(Base, :get_extension) function __init__() @require HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7" include( @@ -68,19 +68,18 @@ if !isdefined(Base, :get_extension) end @compile_workload begin - N, D, T = 5, 3, 100 + N, D, T = 3, 2, 100 p = rand_prob_vec(N) A = rand_trans_mat(N) dists = [LightDiagNormal(randn(D), ones(D)) for i in 1:N] hmm = HMM(p, A, dists) + obs_seq = rand(hmm, T).obs_seq - obs_seqs = [last(rand(hmm, T)) for _ in 1:3] - nb_seqs = 3 - logdensityof(hmm, obs_seqs, nb_seqs) - forward(hmm, obs_seqs, nb_seqs) - viterbi(hmm, obs_seqs, nb_seqs) - forward_backward(hmm, obs_seqs, nb_seqs) - baum_welch(hmm, obs_seqs, nb_seqs; max_iterations=2, atol=-Inf) + logdensityof(hmm, obs_seq) + forward(hmm, obs_seq) + viterbi(hmm, obs_seq) + forward_backward(hmm, obs_seq) + baum_welch(hmm, obs_seq; max_iterations=2, atol=-Inf) end end diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 87bc8ac3..c4ef69e8 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -1,58 +1,119 @@ -function baum_welch!( - hmm::AbstractHMM, obs_seqs; atol, max_iterations, check_loglikelihood_increasing +""" +$(TYPEDEF) + +Store Baum-Welch quantities with element type `R`. + +This storage is relative to a several sequences. + +# Fields + +These fields (except `limits`) are passed to the `fit!(hmm, ...)` method along with `obs_seqs_concat`. + +$(TYPEDFIELDS) +""" +struct BaumWelchStorage{R,M<:AbstractMatrix{R}} + "posterior initialization counts for each state" + init_count::Vector{R} + "posterior transition counts for each state" + trans_count::M + "concatenation along time of the state marginal matrices `γ[i,t] = ℙ(X[t]=i | Y[1:T])` for all observation sequences" + state_marginals_concat::Matrix{R} + "temporal separations between observation sequences: `state_marginals_concat[limits[k]+1:limits[k+1]]` refers to sequence `k`" + limits::Vector{Int} +end + +function initialize_baum_welch( + hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer ) - # Pre-allocate nearly all necessary memory - logB = loglikelihoods(hmm, obs_seqs[1]) - fb = initialize_forward_backward(hmm, logB) - - logBs = Vector{typeof(logB)}(undef, length(obs_seqs)) - fbs = Vector{typeof(fb)}(undef, length(obs_seqs)) - @threads for k in eachindex(obs_seqs) - logBs[k] = loglikelihoods(hmm, obs_seqs[k]) - fbs[k] = forward_backward(hmm, logBs[k]) - end + check_lengths(obs_seqs, nb_seqs) + N, T_concat = length(hmm), sum(length, obs_seqs) + A = transition_matrix(hmm) + R = eltype(hmm, obs_seqs[1][1]) + init_count = Vector{R}(undef, N) + trans_count = similar(A, R) + state_marginals_concat = Matrix{R}(undef, N, T_concat) + limits = vcat(0, cumsum(length.(obs_seqs))) + return BaumWelchStorage(init_count, trans_count, state_marginals_concat, limits) +end - init_count, trans_count = initialize_states_stats(fbs) - state_marginals_concat = initialize_observations_stats(fbs) - obs_seqs_concat = reduce(vcat, obs_seqs) - logL = loglikelihood(fbs) - logL_evolution = [logL] - - for iteration in 1:max_iterations - # E step - if iteration > 1 - @threads for k in eachindex(obs_seqs, logBs, fbs) - loglikelihoods!(logBs[k], hmm, obs_seqs[k]) - forward_backward!(fbs[k], hmm, logBs[k]) - end - logL = loglikelihood(fbs) - push!(logL_evolution, logL) +function initialize_logL_evolution( + hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer; max_iterations::Integer +) + check_lengths(obs_seqs, nb_seqs) + R = eltype(hmm, obs_seqs[1][1]) + logL_evolution = R[] + sizehint!(logL_evolution, max_iterations) + return logL_evolution +end + +function update_sufficient_statistics!( + bw::BaumWelchStorage{R}, fbs::Vector{<:ForwardBackwardStorage} +) where {R} + @unpack init_count, trans_count, state_marginals_concat, limits = bw + init_count .= zero(R) + trans_count .= zero(R) + state_marginals_concat .= zero(R) + for k in eachindex(fbs) + @unpack γ, ξ, B̃β = fbs[k] + init_count .+= view(γ, :, 1) + for t in eachindex(ξ) + trans_count .+= ξ[t] end + state_marginals_concat[:, (limits[k] + 1):limits[k + 1]] .= γ + end + return nothing +end - # M step - update_states_stats!(init_count, trans_count, fbs) - update_observations_stats!(state_marginals_concat, fbs) - fit!(hmm, init_count, trans_count, obs_seqs_concat, state_marginals_concat) - - # Stopping criterion - if iteration > 1 - progress = logL_evolution[end] - logL_evolution[end - 1] - if check_loglikelihood_increasing && progress < 0 - error("Loglikelihood decreased in Baum-Welch") - elseif progress < atol - break - end +function baum_welch_has_converged( + logL_evolution::Vector; atol::Real, loglikelihood_increasing::Bool +) + if length(logL_evolution) >= 2 + logL, logL_prev = logL_evolution[end], logL_evolution[end - 1] + progress = logL - logL_prev + if loglikelihood_increasing && progress < 0 + error("Loglikelihood decreased in Baum-Welch") + elseif progress < atol + return true end end + return false +end - return logL_evolution +function StatsAPI.fit!(hmm::AbstractHMM, bw::BaumWelchStorage, obs_seqs_concat::Vector) + return fit!( + hmm, bw.init_count, bw.trans_count, obs_seqs_concat, bw.state_marginals_concat + ) +end + +function baum_welch!( + fbs::Vector{<:ForwardBackwardStorage}, + bw::BaumWelchStorage, + logL_evolution::Vector, + hmm::AbstractHMM, + obs_seqs::Vector{<:Vector}, + obs_seqs_concat::Vector; + atol::Real, + max_iterations::Integer, + loglikelihood_increasing::Bool, +) + for _ in 1:max_iterations + @threads for k in eachindex(obs_seqs, fbs) + forward_backward!(fbs[k], hmm, obs_seqs[k]) + end + update_sufficient_statistics!(bw, fbs) + push!(logL_evolution, sum(fb.logL[] for fb in fbs)) + fit!(hmm, bw, obs_seqs_concat) + check_hmm(hmm) + if baum_welch_has_converged(logL_evolution; atol, loglikelihood_increasing) + break + end + end + return nothing end """ - baum_welch( - hmm_init, obs_seq; - atol, max_iterations, check_loglikelihood_increasing - ) + baum_welch(hmm_init, obs_seq; kwargs...) + baum_welch(hmm_init, obs_seqs, nb_seqs; kwargs...) Apply the Baum-Welch algorithm to estimate the parameters of an HMM starting from `hmm_init`. @@ -60,57 +121,46 @@ Return a tuple `(hmm_est, logL_evolution)`. # Keyword arguments -- `atol`: Minimum loglikelihood increase at an iteration of the algorithm (otherwise the algorithm is deemed to have converged) -- `max_iterations`: Maximum number of iterations of the algorithm -- `check_loglikelihood_increasing`: Whether to throw an error if the loglikelihood decreases +- `atol`: minimum loglikelihood increase at an iteration of the algorithm (otherwise the algorithm is deemed to have converged) +- `max_iterations`: maximum number of iterations of the algorithm +- `loglikelihood_increasing`: whether to throw an error if the loglikelihood decreases """ function baum_welch( hmm_init::AbstractHMM, - obs_seq; + obs_seqs::Vector{<:Vector}, + nb_seqs::Integer; atol=1e-5, max_iterations=100, - check_loglikelihood_increasing=true, + loglikelihood_increasing=true, ) + check_lengths(obs_seqs, nb_seqs) hmm = deepcopy(hmm_init) - logL_evolution = baum_welch!( - hmm, [obs_seq]; atol, max_iterations, check_loglikelihood_increasing + fbs = [initialize_forward_backward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] + bw = initialize_baum_welch(hmm, obs_seqs, nb_seqs) + logL_evolution = initialize_logL_evolution(hmm, obs_seqs, nb_seqs; max_iterations) + obs_seqs_concat = reduce(vcat, obs_seqs) + baum_welch!( + fbs, + bw, + logL_evolution, + hmm, + obs_seqs, + obs_seqs_concat; + atol, + max_iterations, + loglikelihood_increasing, ) return hmm, logL_evolution end -""" - baum_welch( - hmm_init, obs_seqs, nb_seqs; - atol, max_iterations, check_loglikelihood_increasing - ) - -Apply the Baum-Welch algorithm to estimate the parameters of an HMM starting from `hmm_init`, based on `nb_seqs` observation sequences. - -Return a tuple `(hmm_est, logL_evolution)`. - -!!! warning "Multithreading" - This function is parallelized across sequences. - -# Keyword arguments - -- `atol`: Minimum loglikelihood increase at an iteration of the algorithm (otherwise the algorithm is deemed to have converged) -- `max_iterations`: Maximum number of iterations of the algorithm -- `check_loglikelihood_increasing`: Whether to throw an error if the loglikelihood decreases -""" function baum_welch( hmm_init::AbstractHMM, - obs_seqs, - nb_seqs::Integer; + obs_seq::Vector; atol=1e-5, max_iterations=100, - check_loglikelihood_increasing=true, + loglikelihood_increasing=true, ) - if nb_seqs != length(obs_seqs) - throw(ArgumentError("nb_seqs != length(obs_seqs)")) - end - hmm = deepcopy(hmm_init) - logL_evolution = baum_welch!( - hmm, obs_seqs; atol, max_iterations, check_loglikelihood_increasing + return baum_welch( + hmm_init, [obs_seq], 1; atol, max_iterations, loglikelihood_increasing ) - return hmm, logL_evolution end diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 3815c77d..02f1a3d3 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -1,103 +1,117 @@ -function forward!(αₜ, αₜ₊₁, logb, p, A, hmm::AbstractHMM, obs_seq) +""" +$(TYPEDEF) + +Store forward quantities with element type `R`. + +This storage is relative to a single sequence. + +# Fields + +The only fields useful outside of the algorithm are `αₜ` and `logL`. + +$(TYPEDFIELDS) +""" +struct ForwardStorage{R} + "total loglikelihood" + logL::RefValue{R} + "observation loglikelihoods `logbₜ[i] = ℙ(Y[t] | X[t]=i)`" + logb::Vector{R} + "scaled forward messsages for a given time step" + α::Vector{R} + "same as `α` but for the next time step" + α_next::Vector{R} +end + +function initialize_forward(hmm::AbstractHMM, obs_seq::Vector) + N = length(hmm) + R = eltype(hmm, obs_seq[1]) + + logL = RefValue{R}(zero(R)) + logb = Vector{R}(undef, N) + α = Vector{R}(undef, N) + α_next = Vector{R}(undef, N) + f = ForwardStorage(logL, logb, α, α_next) + return f +end + +function forward!(f::ForwardStorage, hmm::AbstractHMM, obs_seq::Vector) T = length(obs_seq) - loglikelihoods_vec!(logb, hmm, obs_seq[1]) + p = initialization(hmm) + A = transition_matrix(hmm) + d = obs_distributions(hmm) + @unpack logL, logb, α, α_next = f + + logb .= logdensityof.(d, (obs_seq[1],)) logm = maximum(logb) - αₜ .= p .* exp.(logb .- logm) - c = inv(sum(αₜ)) - αₜ .*= c - logL = -log(c) + logm + α .= p .* exp.(logb .- logm) + c = inv(sum(α)) + α .*= c + logL[] = -log(c) + logm for t in 1:(T - 1) - loglikelihoods_vec!(logb, hmm, obs_seq[t + 1]) + logb .= logdensityof.(d, (obs_seq[t + 1],)) logm = maximum(logb) - mul!(αₜ₊₁, A', αₜ) - αₜ₊₁ .*= exp.(logb .- logm) - c = inv(sum(αₜ₊₁)) - αₜ₊₁ .*= c - αₜ .= αₜ₊₁ - logL += -log(c) + logm + mul!(α_next, A', α) + α_next .*= exp.(logb .- logm) + c = inv(sum(α_next)) + α_next .*= c + α .= α_next + logL[] += -log(c) + logm + end + return nothing +end + +function forward!( + fs::Vector{<:ForwardStorage}, + hmm::AbstractHMM, + obs_seqs::Vector{<:Vector}, + nb_seqs::Integer, +) + check_lengths(obs_seqs, nb_seqs) + @threads for k in eachindex(fs, obs_seqs) + forward!(fs[k], hmm, obs_seqs[k]) end - return logL + return nothing end """ forward(hmm, obs_seq) + forward(hmm, obs_seqs, nb_seqs) -Apply the forward algorithm to an HMM. +Run the forward algorithm to infer the current state of an HMM. -Return a tuple `(α, logL)` where +When applied on a single sequence, this function returns a tuple `(α, logL)` where +- `α[i]` is the posterior probability of state `i` at the end of the sequence - `logL` is the loglikelihood of the sequence -- `α[i]` is the posterior probability of state `i` at the end of the sequence. -""" -function forward(hmm::AbstractHMM, obs_seq) - N = length(hmm) - p = initial_distribution(hmm) - A = transition_matrix(hmm) - logb = loglikelihoods_vec(hmm, obs_seq[1]) - - R = promote_type(eltype(p), eltype(A), eltype(logb)) - αₜ = Vector{R}(undef, N) - αₜ₊₁ = Vector{R}(undef, N) - logL = forward!(αₜ, αₜ₊₁, logb, p, A, hmm, obs_seq) - return αₜ, logL -end +When applied on multiple sequences, this function returns a vector of tuples. """ - forward(hmm, obs_seqs, nb_seqs) - -Apply the forward algorithm to an HMM, based on multiple observation sequences. - -Return a vector of tuples `(αₖ, logLₖ)`, where - -- `logLₖ` is the loglikelihood of sequence `k` -- `αₖ[i]` is the posterior probability of state `i` at the end of sequence `k` +function forward(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer) + check_lengths(obs_seqs, nb_seqs) + fs = [initialize_forward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] + forward!(fs, hmm, obs_seqs, nb_seqs) + return [(f.α, f.logL[]) for f in fs] +end -!!! warning "Multithreading" - This function is parallelized across sequences. -""" -function forward(hmm::AbstractHMM, obs_seqs, nb_seqs::Integer) - if nb_seqs != length(obs_seqs) - throw(ArgumentError("nb_seqs != length(obs_seqs)")) - end - f1 = forward(hmm, first(obs_seqs)) - fs = Vector{typeof(f1)}(undef, nb_seqs) - fs[1] = f1 - @threads for k in 2:nb_seqs - fs[k] = forward(hmm, obs_seqs[k]) - end - return fs +function forward(hmm::AbstractHMM, obs_seq::Vector) + return only(forward(hmm, [obs_seq], 1)) end """ logdensityof(hmm, obs_seq) + logdensityof(hmm, obs_seqs, nb_seqs) -Apply the forward algorithm to compute the loglikelihood of a single observation sequence for an HMM. +Run the forward algorithm to compute the posterior loglikelihood of observations for an HMM. -Return a number. +Whether it is applied on one or multiple sequences, this function returns a number. """ -function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seq) - return last(forward(hmm, obs_seq)) +function DensityInterface.logdensityof( + hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer +) + logαs_and_logLs = forward(hmm, obs_seqs, nb_seqs) + return sum(last, logαs_and_logLs) end -""" - logdensityof(hmm, obs_seqs, nb_seqs) - -Apply the forward algorithm to compute the total loglikelihood of multiple observation sequences for an HMM. - -Return a number. - -!!! warning "Multithreading" - This function is parallelized across sequences. -""" -function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seqs, nb_seqs::Integer) - if nb_seqs != length(obs_seqs) - throw(ArgumentError("nb_seqs != length(obs_seqs)")) - end - logL1 = logdensityof(hmm, first(obs_seqs)) - logLs = Vector{typeof(logL1)}(undef, nb_seqs) - logLs[1] = logL1 - @threads for k in 2:nb_seqs - logLs[k] = logdensityof(hmm, obs_seqs[k]) - end - return sum(logLs) +function DensityInterface.logdensityof(hmm::AbstractHMM, obs_seq::Vector) + return logdensityof(hmm, [obs_seq], 1) end diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index 165b7b54..3aefcdbf 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -1,173 +1,169 @@ """ - ForwardBackwardStorage{R} +$(TYPEDEF) Store forward-backward quantities with element type `R`. -# Fields - -Let `X` denote the vector of hidden states and `Y` denote the vector of observations. The following fields are part of the API: +This storage is relative to a single sequence. -- `γ::Matrix{R}`: posterior one-state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])` -- `ξ::Array{R,3}`: posterior two-state marginals `ξ[i,j,t] = ℙ(X[t:t+1]=(i,j) | Y[1:T])` +# Fields -The following fields are internals and subject to change: +The only fields useful outside of the algorithm are `γ`, `ξ` and `logL`. -- `α::Matrix{R}`: scaled forward variables `α[i,t]` proportional to `ℙ(Y[1:t], X[t]=i)` (up to a function of `t`) -- `β::Matrix{R}`: scaled backward variables `β[i,t]` proportional to `ℙ(Y[t+1:T] | X[t]=i)` (up to a function of `t`) -- `c::Vector{R}`: forward variable inverse normalizations `c[t] = 1 / sum(α[:,t])` -- `logm::Vector{R}`: maximum of the observation loglikelihoods `logm[t] = maximum(logB[:, t])` -- `Bscaled::Matrix{R}`: numerically stabilized observation likelihoods `Bscaled[i,t] = exp.(logB[i,t] - logm[t])` -- `Bβscaled::Matrix{R}`: numerically stabilized product `Bβscaled[i,t] = Bscaled[i,t] * β[i,t]` +$(TYPEDFIELDS) """ -struct ForwardBackwardStorage{R} +struct ForwardBackwardStorage{R,M<:AbstractMatrix{R}} + "total loglikelihood" + logL::RefValue{R} + "scaled forward messsages `α[i,t]` proportional to `ℙ(Y[1:t], X[t]=i)` (up to a function of `t`)" α::Matrix{R} + "scaled backward messsages `β[i,t]` proportional to `ℙ(Y[t+1:T] | X[t]=i)` (up to a function of `t`)" β::Matrix{R} + "posterior state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`" γ::Matrix{R} - ξ::Array{R,3} + "posterior transition marginals `ξ[t][i,j] = ℙ(X[t:t+1]=(i,j) | Y[1:T])`" + ξ::Vector{M} + "forward message inverse normalizations `c[t] = 1 / sum(α[:,t])`" c::Vector{R} + "observation loglikelihoods `logB[i,t] = ℙ(Y[t] | X[t]=i)`" + logB::Matrix{R} + "maximum of the observation loglikelihoods `logm[t] = maximum(logB[:, t])`" logm::Vector{R} - Bscaled::Matrix{R} - Bβscaled::Matrix{R} + "numerically stabilized observation likelihoods `B̃[i,t] = exp.(logB[i,t] - logm[t])`" + B̃::Matrix{R} + "product `B̃β[i,t] = B̃[i,t] * β[i,t]`" + B̃β::Matrix{R} end -Base.length(fb::ForwardBackwardStorage) = size(fb.α, 1) +Base.eltype(::ForwardBackwardStorage{R}) where {R} = R duration(fb::ForwardBackwardStorage) = size(fb.α, 2) -function loglikelihood(fb::ForwardBackwardStorage{R}) where {R} - logL = -sum(log, fb.c) + sum(fb.logm) - return logL -end - -function loglikelihood(fbs::Vector{ForwardBackwardStorage{R}}) where {R} - logL = zero(R) - for fb in fbs - logL += loglikelihood(fb) - end - return logL -end +function initialize_forward_backward(hmm::AbstractHMM, obs_seq::Vector) + N, T = length(hmm), length(obs_seq) + A = transition_matrix(hmm) + R = eltype(hmm, obs_seq[1]) + M = typeof(similar(A, R)) -function initialize_forward_backward(p, A, logB) - N, T = size(logB) - R = promote_type(eltype(p), eltype(A), eltype(logB)) + logL = RefValue{R}(zero(R)) α = Matrix{R}(undef, N, T) β = Matrix{R}(undef, N, T) γ = Matrix{R}(undef, N, T) - ξ = Array{R,3}(undef, N, N, T - 1) + ξ = Vector{M}(undef, T - 1) + for t in 1:(T - 1) + ξ[t] = similar(A, R) + end c = Vector{R}(undef, T) + logB = Matrix{R}(undef, N, T) logm = Vector{R}(undef, T) - Bscaled = Matrix{R}(undef, N, T) - Bβscaled = Matrix{R}(undef, N, T) - return ForwardBackwardStorage(α, β, γ, ξ, c, logm, Bscaled, Bβscaled) -end + B̃ = Matrix{R}(undef, N, T) + B̃β = Matrix{R}(undef, N, T) -function initialize_forward_backward(hmm::AbstractHMM, logB) - p = initial_distribution(hmm) - A = transition_matrix(hmm) - return initialize_forward_backward(p, A, logB) + return ForwardBackwardStorage{R,M}(logL, α, β, γ, ξ, c, logB, logm, B̃, B̃β) end -function forward!(fb::ForwardBackwardStorage, p, A, logB) - @unpack α, c, logm, Bscaled = fb - T = size(α, 2) +function update_likelihoods!(fb::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq::Vector) + d = obs_distributions(hmm) + @unpack logB, logm, B̃ = fb + N, T = length(hmm), duration(fb) + + for t in 1:T + logB[:, t] .= logdensityof.(d, (obs_seq[t],)) + end + check_no_nan(logB) maximum!(logm', logB) - Bscaled .= exp.(logB .- logm') + B̃ .= exp.(logB .- logm') + return nothing +end + +function forward!(fb::ForwardBackwardStorage, hmm::AbstractHMM) + p = initialization(hmm) + A = transition_matrix(hmm) + @unpack α, c, B̃ = fb + N, T = length(hmm), duration(fb) + @views begin - α[:, 1] .= p .* Bscaled[:, 1] + α[:, 1] .= p .* B̃[:, 1] c[1] = inv(sum(α[:, 1])) α[:, 1] .*= c[1] end @views for t in 1:(T - 1) mul!(α[:, t + 1], A', α[:, t]) - α[:, t + 1] .*= Bscaled[:, t + 1] + α[:, t + 1] .*= B̃[:, t + 1] c[t + 1] = inv(sum(α[:, t + 1])) α[:, t + 1] .*= c[t + 1] end - check_no_nan(α) + fb.logL[] = -sum(log, fb.c) + sum(fb.logm) return nothing end -function backward!(fb::ForwardBackwardStorage{R}, A, logB) where {R} - @unpack β, c, Bscaled, Bβscaled = fb - T = size(β, 2) +function backward!(fb::ForwardBackwardStorage{R}, hmm::AbstractHMM) where {R} + A = transition_matrix(hmm) + @unpack β, c, B̃, B̃β = fb + N, T = length(hmm), duration(fb) + β[:, T] .= c[T] @views for t in (T - 1):-1:1 - Bβscaled[:, t + 1] .= Bscaled[:, t + 1] .* β[:, t + 1] - mul!(β[:, t], A, Bβscaled[:, t + 1]) + B̃β[:, t + 1] .= B̃[:, t + 1] .* β[:, t + 1] + mul!(β[:, t], A, B̃β[:, t + 1]) β[:, t] .*= c[t] end - @views Bβscaled[:, 1] .= Bscaled[:, 1] .* β[:, 1] - check_no_nan(β) + @views B̃β[:, 1] .= B̃[:, 1] .* β[:, 1] return nothing end -function marginals!(fb::ForwardBackwardStorage, A) - @unpack α, β, c, Bβscaled, γ, ξ = fb - N, T = size(γ) +function marginals!(fb::ForwardBackwardStorage, hmm::AbstractHMM) + A = transition_matrix(hmm) + @unpack α, β, c, B̃β, γ, ξ = fb + N, T = length(hmm), duration(fb) + γ .= α .* β ./ c' check_no_nan(γ) @views for t in 1:(T - 1) - ξ[:, :, t] .= α[:, t] .* A .* Bβscaled[:, t + 1]' + mul_rows_cols!(ξ[t], α[:, t], A, B̃β[:, t + 1]) end - check_no_nan(ξ) return nothing end -function forward_backward!(fb::ForwardBackwardStorage, p, A, logB) - forward!(fb, p, A, logB) - backward!(fb, A, logB) - marginals!(fb, A) +function forward_backward!(fb::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq::Vector) + update_likelihoods!(fb, hmm, obs_seq) + forward!(fb, hmm) + backward!(fb, hmm) + marginals!(fb, hmm) return nothing end -function forward_backward!(fb::ForwardBackwardStorage, hmm::AbstractHMM, logB) - p = initial_distribution(hmm) - A = transition_matrix(hmm) - return forward_backward!(fb, p, A, logB) -end - -function forward_backward(p, A, logB) - fb = initialize_forward_backward(p, A, logB) - forward_backward!(fb, p, A, logB) - return fb -end - -function forward_backward(hmm::AbstractHMM, logB::Matrix) - p = initial_distribution(hmm) - A = transition_matrix(hmm) - return forward_backward(p, A, logB) +function forward_backward!( + fbs::Vector{<:ForwardBackwardStorage}, + hmm::AbstractHMM, + obs_seqs::Vector{<:Vector}, + nb_seqs::Integer, +) + check_lengths(obs_seqs, nb_seqs) + @threads for k in eachindex(fbs, obs_seqs) + forward_backward!(fbs[k], hmm, obs_seqs[k]) + end + return nothing end """ forward_backward(hmm, obs_seq) - -Apply the forward-backward algorithm to estimate the posterior state marginals of an HMM. - -Return a [`ForwardBackwardStorage`](@ref). -""" -function forward_backward(hmm::AbstractHMM, obs_seq) - logB = loglikelihoods(hmm, obs_seq) - return forward_backward(hmm, logB) -end - -""" forward_backward(hmm, obs_seqs, nb_seqs) -Apply the forward-backward algorithm to estimate the posterior state marginals of an HMM, based on multiple observation sequences. +Run the forward-backward algorithm to infer the posterior state and transition marginals of an HMM. -Return a vector of [`ForwardBackwardStorage`](@ref) objects. +When applied on a single sequence, this function returns a tuple `(γ, ξ, logL)` where -!!! warning "Multithreading" - This function is parallelized across sequences. +- `γ` is a matrix containing the posterior state marginals `γ[i, t]` +- `logL` is the loglikelihood of the sequence + +WHen applied on multiple sequences, it returns a vector of tuples. """ -function forward_backward(hmm::AbstractHMM, obs_seqs, nb_seqs::Integer) - if nb_seqs != length(obs_seqs) - throw(ArgumentError("nb_seqs != length(obs_seqs)")) - end - fb1 = forward_backward(hmm, first(obs_seqs)) - fbs = Vector{typeof(fb1)}(undef, nb_seqs) - fbs[1] = fb1 - @threads for k in 2:nb_seqs - fbs[k] = forward_backward(hmm, obs_seqs[k]) - end - return fbs +function forward_backward(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer) + check_lengths(obs_seqs, nb_seqs) + fbs = [initialize_forward_backward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] + forward_backward!(fbs, hmm, obs_seqs, nb_seqs) + return [(fb.γ, fb.logL[]) for fb in fbs] +end + +function forward_backward(hmm::AbstractHMM, obs_seq::Vector) + return only(forward_backward(hmm, [obs_seq], 1)) end diff --git a/src/inference/loglikelihoods.jl b/src/inference/loglikelihoods.jl deleted file mode 100644 index 26592b73..00000000 --- a/src/inference/loglikelihoods.jl +++ /dev/null @@ -1,38 +0,0 @@ -## Vector - -function loglikelihoods_vec!(logb, hmm::AbstractHMM, obs) - for i in 1:length(hmm) - logb[i] = logdensityof(obs_distribution(hmm, i), obs) - end - check_no_nan(logb) - check_no_inf(logb) - return nothing -end - -function loglikelihoods_vec(hmm::AbstractHMM, obs) - logb = [logdensityof(obs_distribution(hmm, i), obs) for i in 1:length(hmm)] - check_no_nan(logb) - check_no_inf(logb) - return logb -end - -## Matrix - -function loglikelihoods!(logB, hmm::AbstractHMM, obs_seq) - T, N = length(obs_seq), length(hmm) - for t in 1:T, i in 1:N - logB[i, t] = logdensityof(obs_distribution(hmm, i), obs_seq[t]) - end - check_no_nan(logB) - check_no_inf(logB) - return nothing -end - -function loglikelihoods(hmm::AbstractHMM, obs_seq) - T, N = length(obs_seq), length(hmm) - dists = obs_distribution.(Ref(hmm), 1:N) - logB = [logdensityof(dists[i], obs_seq[t]) for i in 1:N, t in 1:T] - check_no_nan(logB) - check_no_inf(logB) - return logB -end diff --git a/src/inference/sufficient_stats.jl b/src/inference/sufficient_stats.jl deleted file mode 100644 index 9bf55c58..00000000 --- a/src/inference/sufficient_stats.jl +++ /dev/null @@ -1,40 +0,0 @@ - -function initialize_states_stats(fbs::Vector{ForwardBackwardStorage{R}}) where {R} - N = length(first(fbs)) - init_count = Vector{R}(undef, N) - trans_count = Matrix{R}(undef, N, N) - return init_count, trans_count -end - -function initialize_observations_stats(fbs::Vector{ForwardBackwardStorage{R}}) where {R} - N = length(first(fbs)) - T_total = sum(duration, fbs) - state_marginals_concat = Matrix{R}(undef, N, T_total) - return state_marginals_concat -end - -function update_states_stats!( - init_count, trans_count, fbs::Vector{ForwardBackwardStorage{R}} -) where {R} - init_count .= zero(R) - for k in eachindex(fbs) - @views init_count .+= fbs[k].γ[:, 1] - end - trans_count .= zero(R) - for k in eachindex(fbs) - sum!(trans_count, fbs[k].ξ; init=false) - end - return nothing -end - -function update_observations_stats!( - state_marginals_concat, fbs::Vector{ForwardBackwardStorage{R}} -) where {R} - T = 1 - for k in eachindex(fbs) - Tk = duration(fbs[k]) - @views state_marginals_concat[:, T:(T + Tk - 1)] .= fbs[k].γ - T += Tk - end - return nothing -end diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index 06b744cf..6a65f92c 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -1,69 +1,103 @@ -function viterbi!(q, δₜ, δₜ₋₁, δA_tmp, ψ, logb, p, A, hmm::AbstractHMM, obs_seq) +""" +$(TYPEDEF) + +Store Viterbi quantities with element type `R`. + +This storage is relative to a single sequence. + +# Fields + +The only field useful outside of the algorithm is `q`. + +$(TYPEDFIELDS) +""" +struct ViterbiStorage{R} + "observation loglikelihoods at a given time step" + logb::Vector{R} + "highest path scores when accounting for the first `t` observations and ending at a given state" + δ::Vector{R} + "same as `δ` but for the previous time step" + δ_prev::Vector{R} + "temporary variable used to store products `δ_prev .* A[:, j]`" + δA::Vector{R} + "penultimate state maximizing the path score" + ψ::Matrix{Int} + "most likely state at each time `q[t] = argmaxᵢ ℙ(X[t]=i | Y[1:T])`" + q::Vector{Int} +end + +function initialize_viterbi(hmm::AbstractHMM, obs_seq::Vector) + T, N = length(obs_seq), length(hmm) + R = eltype(hmm, obs_seq[1]) + + logb = Vector{R}(undef, N) + δ = Vector{R}(undef, N) + δ_prev = Vector{R}(undef, N) + δA = Vector{R}(undef, N) + ψ = Matrix{Int}(undef, N, T) + q = Vector{Int}(undef, T) + return ViterbiStorage(logb, δ, δ_prev, δA, ψ, q) +end + +function viterbi!(v::ViterbiStorage, hmm::AbstractHMM, obs_seq::Vector) N, T = length(hmm), length(obs_seq) - loglikelihoods_vec!(logb, hmm, obs_seq[1]) + p = initialization(hmm) + A = transition_matrix(hmm) + d = obs_distributions(hmm) + @unpack logb, δ, δ_prev, δA, ψ, q = v + + logb .= logdensityof.(d, (obs_seq[1],)) logm = maximum(logb) - δₜ .= p .* exp.(logb .- logm) - δₜ₋₁ .= δₜ + δ .= p .* exp.(logb .- logm) + δ_prev .= δ @views ψ[:, 1] .= zero(eltype(ψ)) for t in 2:T - loglikelihoods_vec!(logb, hmm, obs_seq[t]) + logb .= logdensityof.(d, (obs_seq[t],)) logm = maximum(logb) for j in 1:N - @views δA_tmp .= δₜ₋₁ .* A[:, j] - i_max = argmax(δA_tmp) + @views δA .= δ_prev .* A[:, j] + i_max = argmax(δA) ψ[j, t] = i_max - δₜ[j] = δA_tmp[i_max] * exp(logb[j] - logm) + δ[j] = δA[i_max] * exp(logb[j] - logm) end - δₜ₋₁ .= δₜ + δ_prev .= δ end - @views q[T] = argmax(δₜ) + q[T] = argmax(δ) for t in (T - 1):-1:1 q[t] = ψ[q[t + 1], t + 1] end return nothing end -""" - viterbi(hmm, obs_seq) - -Apply the Viterbi algorithm to compute the most likely state sequence of an HMM. - -Return a vector of integers. -""" -function viterbi(hmm::AbstractHMM, obs_seq) - T, N = length(obs_seq), length(hmm) - p = initial_distribution(hmm) - A = transition_matrix(hmm) - logb = loglikelihoods_vec(hmm, obs_seq[1]) - - R = promote_type(eltype(p), eltype(A), eltype(logb)) - δₜ = Vector{R}(undef, N) - δₜ₋₁ = Vector{R}(undef, N) - δA_tmp = Vector{R}(undef, N) - ψ = Matrix{Int}(undef, N, T) - q = Vector{Int}(undef, T) - - viterbi!(q, δₜ, δₜ₋₁, δA_tmp, ψ, logb, p, A, hmm, obs_seq) - return q +function viterbi!( + vs::Vector{<:ViterbiStorage}, + hmm::AbstractHMM, + obs_seqs::Vector{<:Vector}, + nb_seqs::Integer, +) + check_lengths(obs_seqs, nb_seqs) + @threads for k in eachindex(vs, obs_seqs) + viterbi!(vs[k], hmm, obs_seqs[k]) + end + return nothing end """ + viterbi(hmm, obs_seq) viterbi(hmm, obs_seqs, nb_seqs) -Apply the Viterbi algorithm to compute the most likely state sequences of an HMM, based on multiple observation sequences. - -Return a vector of vectors of integers. +Apply the Viterbi algorithm to infer the most likely state sequence of an HMM. -!!! warning "Multithreading" - This function is parallelized across sequences. +When applied on a single sequence, this function returns a vector of integers. +When applied on multiple sequences, it returns a vector of vectors of integers. """ -function viterbi(hmm::AbstractHMM, obs_seqs, nb_seqs::Integer) - if nb_seqs != length(obs_seqs) - throw(ArgumentError("nb_seqs != length(obs_seqs)")) - end - qs = Vector{Vector{Int}}(undef, nb_seqs) - @threads for k in 1:nb_seqs - qs[k] = viterbi(hmm, obs_seqs[k]) - end - return qs +function viterbi(hmm::AbstractHMM, obs_seqs::Vector{<:Vector}, nb_seqs::Integer) + check_lengths(obs_seqs, nb_seqs) + vs = [initialize_viterbi(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] + viterbi!(vs, hmm, obs_seqs, nb_seqs) + return [v.q for v in vs] +end + +function viterbi(hmm::AbstractHMM, obs_seq::Vector) + return only(viterbi(hmm, [obs_seq], 1)) end diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index 18c2f558..658dabbb 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -1,25 +1,31 @@ """ - AbstractHiddenMarkovModel <: AbstractMarkovChain + AbstractHiddenMarkovModel Abstract supertype for an HMM amenable to simulation, inference and learning. -# Required interface +# Interface -- `initial_distribution(hmm)` -- `transition_matrix(hmm)` -- `obs_distribution(hmm, i)` -- `fit!(hmm, init_count, trans_count, obs_seq, state_marginals)` (optional) +To create your own subtype of `AbstractHiddenMarkovModel`, you need to implement the following methods: -# Applicable methods +- [`length(hmm)`](@ref) +- [`eltype(hmm, obs)`](@ref) +- [`initialization(hmm)`](@ref) +- [`transition_matrix(hmm)`](@ref) +- [`obs_distributions(hmm)`](@ref) +- [`fit!(hmm, init_count, trans_count, obs_seq, state_marginals)`](@ref) (optional) -- `rand([rng,] hmm, T)` -- `logdensityof(hmm, obs_seq)` / `logdensityof(hmm, obs_seqs, nb_seqs)` -- `forward(hmm, obs_seq)` / `forward(hmm, obs_seqs, nb_seqs)` -- `viterbi(hmm, obs_seq)` / `viterbi(hmm, obs_seqs, nb_seqs)` -- `forward_backward(hmm, obs_seq)` / `forward_backward(hmm, obs_seqs, nb_seqs)` -- `baum_welch(hmm, obs_seq)` / `baum_welch(hmm, obs_seqs, nb_seqs)` if `fit!` is implemented +# Applicable functions + +Any HMM object which satisfies the interface can be given as input to the following functions: + +- [`rand(rng, hmm, T)`](@ref) +- [`logdensityof(hmm, obs_seq)`](@ref) +- [`forward(hmm, obs_seq)`](@ref) +- [`viterbi(hmm, obs_seq)`](@ref) +- [`forward_backward(hmm, obs_seq)`](@ref) +- [`baum_welch(hmm, obs_seq)`](@ref) (if the optional `fit!` is implemented) """ -abstract type AbstractHiddenMarkovModel <: AbstractMarkovChain end +abstract type AbstractHiddenMarkovModel end """ AbstractHMM @@ -30,65 +36,100 @@ const AbstractHMM = AbstractHiddenMarkovModel @inline DensityInterface.DensityKind(::AbstractHMM) = HasDensity() -@required AbstractHMM begin - Base.length(::AbstractHMM) - initial_distribution(::AbstractHMM) - transition_matrix(::AbstractHMM) - obs_distribution(::AbstractHMM, ::Integer) -end +## Interface """ - obs_distribution(hmm::AbstractHMM, i) + length(hmm) -Return the observation distribution of `hmm` associated with state `i`. +Return the number of states of `hmm`. +""" +Base.length -The returned object `dist` must implement -- `rand(rng, dist)` -- `DensityInterface.logdensityof(dist, x)` """ -function obs_distribution end + eltype(hmm, obs) -function Base.rand(rng::AbstractRNG, hmm::AbstractHMM, T::Integer) - mc = MarkovChain(hmm) - state_seq = rand(rng, mc, T) - first_obs = rand(rng, obs_distribution(hmm, first(state_seq))) - obs_seq = Vector{typeof(first_obs)}(undef, T) - obs_seq[1] = first_obs - for t in 2:T - obs_seq[t] = rand(rng, obs_distribution(hmm, state_seq[t])) - end - return (; state_seq=state_seq, obs_seq=obs_seq) +Return a type that can accommodate forward-backward computations on observations similar to `obs`. +It is typicall a promotion between the element type of the initialization, the element type of the transition matrix, and the type of an observation logdensity evaluated at `obs`. +""" +function Base.eltype(hmm::AbstractHMM, obs) + init_type = eltype(initialization(hmm)) + trans_type = eltype(transition_matrix(hmm)) + logdensity_type = typeof(logdensityof(obs_distributions(hmm)[1], obs)) + return promote_type(init_type, trans_type, logdensity_type) end -function MarkovChain(hmm::AbstractHMM) - return MarkovChain(initial_distribution(hmm), transition_matrix(hmm)) -end +""" + initialization(hmm) + +Return the vector of initial state probabilities for `hmm`. +""" +function initialization end """ - PermutedHMM{H<:AbstractHMM} + transition_matrix(hmm) -Wrapper around an `AbstractHMM` that permutes its states. +Return the matrix of state transition probabilities for `hmm`. +""" +function transition_matrix end -This is computationally inefficient and mostly useful for evaluation. +""" + obs_distributions(hmm) -# Fields +Return a vector of observation distributions for `hmm`. -- `hmm:H`: the old HMM -- `perm::Vector{Int}`: a permutation such that state `i` in the new HMM corresponds to state `perm[i]` in the old. +Each element `dist` of this vector must implement +- `rand(rng, dist)` +- `DensityInterface.logdensityof(dist, obs)` """ -struct PermutedHMM{H<:AbstractHMM} <: AbstractHMM - hmm::H - perm::Vector{Int} -end +function obs_distributions end -Base.length(p::PermutedHMM) = length(p.hmm) +""" + fit!(hmm, init_count, trans_count, obs_seq, state_marginals) -HMMs.initial_distribution(p::PermutedHMM) = initial_distribution(p.hmm)[p.perm] +Update `hmm` in-place based on information generated during forward-backward. -function HMMs.transition_matrix(p::PermutedHMM) - return transition_matrix(p.hmm)[p.perm, :][:, p.perm] -end +This method is only necessary for the Baum-Welch algorithm. + +# Arguments + +- `init_count::Vector`: posterior initialization counts for each state (size `N`) +- `trans_count::AbstractMatrix`: posterior transition counts for each state (size `(N, N)`) +- `obs_seq::Vector`: sequence of observation, possibly concatenated (size `T`) +- `state_marginals::Matrix`: posterior probabilities of being in each state at each time, to be used as weights during maximum likelihood fitting of the observation distributions (size `(N, T)`). + +# See also + +- [`BaumWelchStorage`](@ref) +- [`ForwardBackwardStorage`](@ref) +""" +StatsAPI.fit! # TODO: complete + +## Sampling + +""" + rand(hmm, T) + rand(rng, hmm, T) -function HMMs.obs_distribution(p::PermutedHMM, i::Integer) - return obs_distribution(p.hmm, p.perm[i]) +Simulate `hmm` for `T` time steps. +""" +function Base.rand(rng::AbstractRNG, hmm::AbstractHMM, T::Integer) + p = initialization(hmm) + A = transition_matrix(hmm) + d = obs_distributions(hmm) + + first_state = rand(rng, Categorical(p; check_args=false)) + state_seq = Vector{Int}(undef, T) + state_seq[1] = first_state + @views for t in 2:T + state_seq[t] = rand(rng, Categorical(A[state_seq[t - 1], :]; check_args=false)) + end + first_obs = rand(rng, d[state_seq[1]]) + obs_seq = Vector{typeof(first_obs)}(undef, T) + obs_seq[1] = first_obs + for t in 2:T + obs_seq[t] = rand(rng, d[state_seq[t]]) + end + return (; state_seq=state_seq, obs_seq=obs_seq) end + +Base.rand(hmm::AbstractHMM, T::Integer) = rand(default_rng(), hmm, T) diff --git a/src/types/abstract_mc.jl b/src/types/abstract_mc.jl deleted file mode 100644 index afebdc54..00000000 --- a/src/types/abstract_mc.jl +++ /dev/null @@ -1,118 +0,0 @@ -""" - AbstractMarkovChain - -Abstract supertype for a Markov chain amenable to simulation, inference and learning. - -# Required interface - -- `initial_distribution(mc)` -- `transition_matrix(mc)` -- `fit!(mc, init_count, trans_count)` (optional) - -# Applicable methods - -- `rand([rng,] mc, T)` -- `logdensityof(mc, state_seq)` -- `fit(mc, state_seq_or_seqs)` (if `fit!` is implemented) -""" -abstract type AbstractMarkovChain end - -""" - AbstractMC - -Alias for the type `AbstractMarkovChain`. -""" -const AbstractMC = AbstractMarkovChain - -@inline DensityInterface.DensityKind(::AbstractMC) = HasDensity() - -@required AbstractMC begin - Base.length(::AbstractMC) - initial_distribution(::AbstractMC) - transition_matrix(::AbstractMC) -end - -""" - length(mc::AbstractMarkovChain) - -Return the number of states of `model`. -""" -Base.length - -""" - initial_distribution(mc::AbstractMarkovChain) - -Return the initial state probabilities of `mc`. -""" -function initial_distribution end - -""" - transition_matrix(mc::AbstractMarkovChain) - -Return the state transition probabilities of `mc`. -""" -function transition_matrix end - -""" - rand([rng=default_rng(),] mc::AbstractMarkovChain, T) - -Simulate `mc` for `T` time steps with a specified `rng`. -""" -Base.rand(mc::AbstractMarkovChain, T::Integer) = rand(default_rng(), mc, T) - -function StatsAPI.fit!(mc::AbstractMC, state_seq::Vector{<:Integer}) - return fit!(mc, [state_seq]) -end - -function StatsAPI.fit!(mc::AbstractMC, state_seqs::Vector{<:Vector{<:Integer}}) - N = length(mc) - init_count = zeros(Int, N) - trans_count = zeros(Int, N, N) - for state_seq in state_seqs - init_count[first(state_seq)] += 1 - for t in 1:(length(state_seq) - 1) - trans_count[state_seq[t], state_seq[t + 1]] += 1 - end - end - return fit!(mc, init_count, trans_count) -end - -""" - fit(mc, state_seq_or_seqs) - -Fit a Markov chain of the same type as `mc` to one or several state sequence(s). - -Beware that `mc` must be an actual object of type `MarkovChain`, and not the type itself as is usually done eg. in Distributions.jl. -""" -function StatsAPI.fit(mc::AbstractMC, state_seq_or_seqs) - mc_est = deepcopy(mc) - fit!(mc_est, state_seq_or_seqs) - return mc_est -end - -function Base.rand(rng::AbstractRNG, mc::AbstractMC, T::Integer) - init = initial_distribution(mc) - trans = transition_matrix(mc) - first_state = rand(rng, Categorical(init; check_args=false)) - state_seq = Vector{Int}(undef, T) - state_seq[1] = first_state - @views for t in 2:T - state_seq[t] = rand(rng, Categorical(trans[state_seq[t - 1], :]; check_args=false)) - end - return state_seq -end - -""" - logdensityof(mc, state_seq) - -Compute the loglikelihood of a single state sequence for a Markov chain. -""" -function DensityInterface.logdensityof(mc::AbstractMC, state_seq::Vector{<:Integer}) - init = initial_distribution(mc) - trans = transition_matrix(mc) - logL = log(init[first(state_seq)]) - for t in 1:(length(state_seq) - 1) - logL += log(trans[state_seq[t], state_seq[t + 1]]) - end - return logL -end diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 17845bb9..2a27f0d7 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -1,24 +1,23 @@ """ - HiddenMarkovModel{D} <: AbstractHiddenMarkovModel +$(TYPEDEF) Basic implementation of an HMM. # Fields -- `init::AbstractVector`: initial state probabilities -- `trans::AbstractMatrix`: state transition matrix -- `dists::AbstractVector{D}`: observation distributions +$(TYPEDFIELDS) """ -struct HiddenMarkovModel{D,U<:AbstractVector,M<:AbstractMatrix,V<:AbstractVector{D}} <: +struct HiddenMarkovModel{I<:AbstractVector,T<:AbstractMatrix,D<:AbstractVector} <: AbstractHMM - init::U - trans::M - dists::V - - function HiddenMarkovModel( - init::U, trans::M, dists::V - ) where {D,U<:AbstractVector,M<:AbstractMatrix,V<:AbstractVector{D}} - hmm = new{D,U,M,V}(init, trans, dists) + "initial state probabilities" + init::I + "state transition matrix" + trans::T + "observation distributions" + dists::D + + function HiddenMarkovModel(init::I, trans::T, dists::D) where {I,T,D} + hmm = new{I,T,D}(init, trans, dists) check_hmm(hmm) return hmm end @@ -36,22 +35,26 @@ function Base.copy(hmm::HMM) end Base.length(hmm::HMM) = length(hmm.init) -initial_distribution(hmm::HMM) = hmm.init +initialization(hmm::HMM) = hmm.init transition_matrix(hmm::HMM) = hmm.trans -obs_distribution(hmm::HMM, i::Integer) = hmm.dists[i] - -""" - fit!(hmm::HMM, init_count, trans_count, obs_seq, state_marginals) - -Update `hmm` in-place based on information generated during forward-backward. -""" -function StatsAPI.fit!(hmm::HMM, init_count, trans_count, obs_seq, state_marginals) +obs_distributions(hmm::HMM) = hmm.dists + +function StatsAPI.fit!( + hmm::HMM, + init_count::Vector, + trans_count::AbstractMatrix, + obs_seq::Vector, + state_marginals::Matrix, +) + # Initialization hmm.init .= init_count sum_to_one!(hmm.init) + # Transition matrix hmm.trans .= trans_count foreach(sum_to_one!, eachrow(hmm.trans)) - @views for i in eachindex(hmm.dists) - fit_element_from_sequence!(hmm.dists, i, obs_seq, state_marginals[i, :]) + # Observation distributions + for i in eachindex(hmm.dists) + fit_element_from_sequence!(hmm.dists, i, obs_seq, view(state_marginals, i, :)) end check_hmm(hmm) return nothing diff --git a/src/types/mc.jl b/src/types/mc.jl deleted file mode 100644 index 9c35bd56..00000000 --- a/src/types/mc.jl +++ /dev/null @@ -1,44 +0,0 @@ -""" - MarkovChain <: AbstractMarkovChain - -Basic implementation of a discrete-state Markov chain. - -# Fields - -- `init::AbstractVector`: initial state probabilities -- `trans::AbstractMatrix`: state transition matrix -""" -struct MarkovChain{U<:AbstractVector,M<:AbstractMatrix} <: AbstractMarkovChain - init::U - trans::M - - function MarkovChain(init::U, trans::M) where {U<:AbstractVector,M<:AbstractMatrix} - mc = new{U,M}(init, trans) - check_mc(mc) - return mc - end -end - -""" - MC - -Alias for the type `MarkovChain`. -""" -const MC = MarkovChain - -Base.length(mc::MC) = length(mc.init) -initial_distribution(mc::MC) = mc.init -transition_matrix(mc::MC) = mc.trans - -""" - fit!(mc::MC, init_count, trans_count) - -Update `mc` in-place based on information generated from a state sequence. -""" -function StatsAPI.fit!(mc::MC, init_count, trans_count) - mc.init .= init_count - sum_to_one!(mc.init) - mc.trans .= trans_count - foreach(sum_to_one!, eachrow(mc.trans)) - return nothing -end diff --git a/src/types/permuted_hmm.jl b/src/types/permuted_hmm.jl new file mode 100644 index 00000000..4d1f53de --- /dev/null +++ b/src/types/permuted_hmm.jl @@ -0,0 +1,29 @@ +""" +$(TYPEDEF) + +Wrapper around an `AbstractHMM` that permutes its states. + +This is computationally inefficient and mostly useful for evaluation. + +# Fields + +$(TYPEDFIELDS) +""" +struct PermutedHMM{H<:AbstractHMM} <: AbstractHMM + "the old HMM" + hmm::H + "a permutation such that state `i` in the new HMM corresponds to state `perm[i]` in the old" + perm::Vector{Int} +end + +Base.length(p::PermutedHMM) = length(p.hmm) + +initialization(p::PermutedHMM) = initialization(p.hmm)[p.perm] + +function transition_matrix(p::PermutedHMM) + return transition_matrix(p.hmm)[p.perm, :][:, p.perm] +end + +function obs_distributions(p::PermutedHMM) + return obs_distributions(p.hmm)[p.perm] +end diff --git a/src/utils/check.jl b/src/utils/check.jl index c72bad29..a0c34838 100644 --- a/src/utils/check.jl +++ b/src/utils/check.jl @@ -40,32 +40,34 @@ function check_coherent_sizes(p::AbstractVector, A::AbstractMatrix) end end -function check_dists(dists) - for i in eachindex(dists) - if DensityKind(dists[i]) == NoDensity() +function check_dists(d) + for i in eachindex(d) + if DensityKind(d[i]) == NoDensity() throw(ArgumentError("Observation is not a density")) end end end -function check_mc(mc::MarkovChain) - init = initial_distribution(mc) - trans = transition_matrix(mc) - if !(length(init) == size(trans, 1) == size(trans, 2)) +""" + check_hmm(hmm::AbstractHMM) + +Verify that `hmm` satisfies basic assumptions. +""" +function check_hmm(hmm::AbstractHMM) + p = initialization(hmm) + A = transition_matrix(hmm) + d = obs_distributions(hmm) + if !all(==(length(hmm)), (length(p), size(A, 1), size(A, 2), length(d))) throw(DimensionMismatch("Incoherent sizes")) end - check_prob_vec(init) - check_trans_mat(trans) + check_prob_vec(p) + check_trans_mat(A) + check_dists(d) return nothing end -function check_hmm(hmm::AbstractHMM) - mc = MarkovChain(hmm) - dists = [obs_distribution(hmm, i) for i in 1:length(hmm)] - if length(mc) != length(dists) - throw(DimensionMismatch("Incoherent sizes")) +function check_lengths(obs_seqs::Vector{<:Vector}, nb_seqs::Integer) + if nb_seqs != length(obs_seqs) + throw(ArgumentError("nb_seqs != length(obs_seqs)")) end - check_mc(mc) - check_dists(dists) - return nothing end diff --git a/src/utils/fit.jl b/src/utils/fit.jl index 6e749b99..7321af9b 100644 --- a/src/utils/fit.jl +++ b/src/utils/fit.jl @@ -14,23 +14,23 @@ end function fit_element_from_sequence!( dists::AbstractVector{D}, i::Integer, x::AbstractVector, w::AbstractVector ) where {D<:Distribution} - dists[i] = _fit_from_sequence(D, x, w) + dists[i] = fit_from_sequence(D, x, w) return nothing end -function _fit_from_sequence( +function fit_from_sequence( ::Type{D}, x_nums::AbstractVector, w::AbstractVector ) where {D<:UnivariateDistribution} return fit(D, x_nums, w) end -function _fit_from_sequence( +function fit_from_sequence( ::Type{D}, x_vecs::AbstractVector{<:AbstractVector}, w::AbstractVector ) where {D<:MultivariateDistribution} return fit(D, reduce(hcat, x_vecs), w) end -function _fit_from_sequence( +function fit_from_sequence( ::Type{D}, x_mats::AbstractVector{<:AbstractMatrix}, w::AbstractVector ) where {D<:MatrixDistribution} return fit(D, reduce(dcat, x_mats), w) diff --git a/src/utils/lightdiagnormal.jl b/src/utils/lightdiagnormal.jl index 850f0089..a2b36a31 100644 --- a/src/utils/lightdiagnormal.jl +++ b/src/utils/lightdiagnormal.jl @@ -1,15 +1,22 @@ """ - LightDiagNormal +$(TYPEDEF) An HMMs-compatible implementation of a multivariate normal distribution with diagonal covariance, enabling allocation-free estimation. This is not part of the public API and is expected to change. + +# Fields + +$(TYPEDFIELDS) """ struct LightDiagNormal{ T1,T2,T3,V1<:AbstractVector{T1},V2<:AbstractVector{T2},V3<:AbstractVector{T3} } + "vector of means" μ::V1 + "vector of standard deviations" σ::V2 + "vector of log standard deviations" logσ::V3 end @@ -19,6 +26,10 @@ function LightDiagNormal(μ, σ) return LightDiagNormal(μ, σ, log.(σ)) end +function Base.show(io::IO, dist::LightDiagNormal) + return print(io, "LightDiagNormal($(dist.μ), $(dist.σ))") +end + @inline DensityInterface.DensityKind(::LightDiagNormal) = HasDensity() Base.length(dist::LightDiagNormal) = length(dist.μ) diff --git a/src/utils/mul.jl b/src/utils/mul.jl new file mode 100644 index 00000000..5b66fc9e --- /dev/null +++ b/src/utils/mul.jl @@ -0,0 +1,20 @@ +function mul_rows_cols!( + B::AbstractMatrix, l::AbstractVector, A::AbstractMatrix, r::AbstractVector +) + B .= l .* A .* r' + return nothing +end + +function mul_rows_cols!( + B::SparseMatrixCSC, l::AbstractVector, A::SparseMatrixCSC, r::AbstractVector +) + @assert size(B) == size(A) == (length(l), length(r)) + B .= A + for j in axes(B, 2) + for k in nzrange(B, j) + i = B.rowval[k] + B.nzval[k] *= l[i] * r[j] + end + end + return nothing +end diff --git a/src/utils/probvec.jl b/src/utils/probvec.jl index 50c6d368..7925b56c 100644 --- a/src/utils/probvec.jl +++ b/src/utils/probvec.jl @@ -4,6 +4,12 @@ end sum_to_one!(x) = x ./= sum(x) +""" + rand_prob_vec(N) + rand_prob_vec(rng, N) + +Generate a random probability distribution of size `N`. +""" function rand_prob_vec(rng::AbstractRNG, N) p = rand(rng, N) sum_to_one!(p) diff --git a/src/utils/transmat.jl b/src/utils/transmat.jl index 71cb86d5..1f30e5f8 100644 --- a/src/utils/transmat.jl +++ b/src/utils/transmat.jl @@ -15,6 +15,12 @@ function is_trans_mat(A::AbstractMatrix; atol=1e-2) end end +""" + rand_trans_mat(N) + rand_trans_mat(rng, N) + +Generate a random transition matrix of size `(N, N)`. +""" function rand_trans_mat(rng::AbstractRNG, N) A = rand(rng, N, N) foreach(sum_to_one!, eachrow(A)) diff --git a/test/allocations.jl b/test/allocations.jl index 5e75177f..667e4306 100644 --- a/test/allocations.jl +++ b/test/allocations.jl @@ -1,49 +1,53 @@ using BenchmarkTools -using Distributions -using Distributions: PDiagMat using HiddenMarkovModels -using SimpleUnPack +using HiddenMarkovModels.HMMTest +import HiddenMarkovModels as HMMs using Test function test_allocations(hmm; T) - p = initial_distribution(hmm) - A = transition_matrix(hmm) - @unpack state_seq, obs_seq = rand(hmm, T) + obs_seq = rand(hmm, T).obs_seq + nb_seqs = 2 + obs_seqs = [rand(hmm, T).obs_seq for _ in 1:nb_seqs] ## Forward - logb = HMMs.loglikelihoods_vec(hmm, obs_seq[1]) - αₜ = zeros(N) - αₜ₊₁ = zeros(N) - allocs = @ballocated HMMs.forward!($αₜ, $αₜ₊₁, $logb, $p, $A, $hmm, $obs_seq) + f = HMMs.initialize_forward(hmm, obs_seq) + allocs = @ballocated HiddenMarkovModels.forward!($f, $hmm, $obs_seq) @test allocs == 0 ## Viterbi - logb = HMMs.loglikelihoods_vec(hmm, obs_seq[1]) - δₜ = zeros(N) - δₜ₋₁ = zeros(N) - δA_tmp = zeros(N) - ψ = zeros(Int, N, T) - q = zeros(Int, T) - allocs = @ballocated HMMs.viterbi!( - $q, $δₜ, $δₜ₋₁, $δA_tmp, $ψ, $logb, $p, $A, $hmm, $obs_seq - ) + v = HMMs.initialize_viterbi(hmm, obs_seq) + allocs = @ballocated HMMs.viterbi!($v, $hmm, $obs_seq) @test allocs == 0 ## Forward-backward - logB = HMMs.loglikelihoods(hmm, obs_seq) - fb = HMMs.initialize_forward_backward(p, A, logB) - allocs = @ballocated HMMs.forward_backward!($fb, $p, $A, $logB) + fb = HMMs.initialize_forward_backward(hmm, obs_seq) + allocs = @ballocated HMMs.forward_backward!($fb, $hmm, $obs_seq) + @test allocs == 0 + + ## Baum-Welch + fbs = [HMMs.initialize_forward_backward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs)] + bw = HMMs.initialize_baum_welch(hmm, obs_seqs, nb_seqs) + obs_seqs_concat = reduce(vcat, obs_seqs) + HMMs.forward_backward!(fbs, hmm, obs_seqs, nb_seqs) + HMMs.update_sufficient_statistics!(bw, fbs) + fit!(hmm, bw, obs_seqs_concat) + allocs = @ballocated HMMs.update_sufficient_statistics!($bw, $fbs) + @test allocs == 0 + allocs = @ballocated fit!($hmm, $bw, $obs_seqs_concat) @test allocs == 0 end -N = 5 -D = 3 -T = 100 +N, D, T = 3, 2, 100 -p = rand_prob_vec(N) -A = rand_trans_mat(N) -dists = [Normal(randn(), 1.0) for i in 1:N] +@testset "Normal" begin + test_allocations(rand_gaussian_hmm_1d(N); T) +end -hmm = HMM(p, A, dists) +@testset "Normal sparse" begin + # see https://discourse.julialang.org/t/why-does-mul-u-a-v-allocate-when-a-is-sparse-and-u-v-are-views/105995 + @test_skip test_allocations(rand_gaussian_hmm_1d(N; sparse_trans=true); T) +end -test_allocations(hmm; T) +@testset "LightDiagNormal" begin + test_allocations(rand_gaussian_hmm_2d_light(N, D); T) +end diff --git a/test/arrays.jl b/test/arrays.jl new file mode 100644 index 00000000..316f019a --- /dev/null +++ b/test/arrays.jl @@ -0,0 +1,49 @@ +using Distributions +using HiddenMarkovModels +using HiddenMarkovModels: sum_to_one! +using LinearAlgebra +using SparseArrays +using StaticArrays +using SimpleUnPack +using Test + +N, T = 3, 1000 + +## Sparse + +p = ones(N) / N; +A = SparseMatrixCSC(SymTridiagonal(ones(N), ones(N - 1))); +foreach(sum_to_one!, eachrow(A)); +d = [Normal(i + randn(), 1.0) for i in 1:N]; +d_init = [Normal(i + randn(), 1.0) for i in 1:N]; + +hmm = HMM(p, A, d); +hmm_init = HMM(p, A, d_init); + +obs_seq = rand(hmm, T).obs_seq; + +γ, logL = forward_backward(hmm, obs_seq); +hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq); + +@testset "Sparse" begin + @test typeof(hmm_est) == typeof(hmm_init) + @test nnz(transition_matrix(hmm_est)) <= nnz(transition_matrix(hmm)) +end + +## Static + +p = MVector{N}(ones(N) / N); +A = MMatrix{N,N}(rand_trans_mat(N)); +d = MVector{N}([Normal(randn(), 1.0) for i in 1:N]); +d_init = MVector{N}([Normal(randn(), 1.0) for i in 1:N]); + +hmm = HMM(p, A, d); +hmm_init = HMM(p, A, d_init); +obs_seq = rand(hmm, T).obs_seq; + +γ, logL = forward_backward(hmm, obs_seq); +hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq); + +@testset "Static" begin + @test typeof(hmm_est) == typeof(hmm_init) +end diff --git a/test/autodiff.jl b/test/autodiff.jl index e7ba7f7d..eca642ca 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -12,13 +12,13 @@ T = 100 p = rand_prob_vec(N) A = rand_trans_mat(N) μ = rand(N) -dists = [Normal(μ[i], 1.0) for i in 1:N] -hmm = HMM(p, A, dists) +d = [Normal(μ[i], 1.0) for i in 1:N] +hmm = HMM(p, A, d) -@unpack state_seq, obs_seq = rand(hmm, T); +obs_seq = rand(hmm, T).obs_seq; function f_init(_p) - hmm = HMM(_p, A, dists) + hmm = HMM(_p, A, d) return logdensityof(hmm, obs_seq) end @@ -27,7 +27,7 @@ g2 = Zygote.gradient(f_init, p)[1] @test isapprox(g1, g2) function f_trans(_A) - hmm = HMM(p, _A, dists) + hmm = HMM(p, _A, d) return logdensityof(hmm, obs_seq) end @@ -35,13 +35,13 @@ g1 = ForwardDiff.gradient(f_trans, A) g2 = Zygote.gradient(f_trans, A)[1] @test isapprox(g1, g2) -function f_dists(_μ) +function f_d(_μ) hmm = HMM(p, A, [Normal(_μ[i], 1.0) for i in 1:N]) return logdensityof(hmm, obs_seq) end -g0 = FiniteDifferences.grad(central_fdm(5, 1), f_dists, μ)[1] -g1 = ForwardDiff.gradient(f_dists, μ) -g2 = Zygote.gradient(f_dists, μ)[1] +g0 = FiniteDifferences.grad(central_fdm(5, 1), f_d, μ)[1] +g1 = ForwardDiff.gradient(f_d, μ) +g2 = Zygote.gradient(f_d, μ)[1] @test isapprox(g0, g1) @test isapprox(g0, g2) diff --git a/test/correctness.jl b/test/correctness.jl index 1fd9fa7e..29b1b55a 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -1,92 +1,90 @@ using Distributions -using Distributions: PDiagMat using HMMBase: HMMBase using HiddenMarkovModels +using HiddenMarkovModels.HMMTest using SimpleUnPack using Test function test_correctness(hmm, hmm_init; T) - @unpack state_seq, obs_seq = rand(hmm, T) - obs_mat = collect(reduce(hcat, obs_seq)') + obs_seq1 = rand(hmm, T).obs_seq + obs_seq2 = rand(hmm, T).obs_seq + obs_mat1 = collect(reduce(hcat, obs_seq1)') + obs_mat2 = collect(reduce(hcat, obs_seq2)') + + nb_seqs = 2 + obs_seqs = [obs_seq1, obs_seq2] hmm_base = HMMBase.HMM(deepcopy(hmm)) hmm_init_base = HMMBase.HMM(deepcopy(hmm_init)) @testset "Logdensity" begin - _, logL_base = HMMBase.forward(hmm_base, obs_mat) - logL = @inferred logdensityof(hmm, obs_seq) - @test logL ≈ logL_base + logL1_base = HMMBase.forward(hmm_base, obs_mat1)[2] + logL2_base = HMMBase.forward(hmm_base, obs_mat2)[2] + logL = logdensityof(hmm, obs_seqs, nb_seqs) + @test logL ≈ logL1_base + logL2_base end @testset "Forward" begin - α_base, logL_base = HMMBase.forward(hmm_base, obs_mat) - α, logL = @inferred forward(hmm, obs_seq) - @test isapprox(α, α_base[end, :]) - @test logL ≈ logL_base + (α1_base, logL1_base), (α2_base, logL2_base) = [ + HMMBase.forward(hmm_base, obs_mat1), HMMBase.forward(hmm_base, obs_mat2) + ] + (α1, logL1), (α2, logL2) = forward(hmm, obs_seqs, nb_seqs) + @test isapprox(α1, α1_base[end, :]) + @test isapprox(α2, α2_base[end, :]) + @test logL1 ≈ logL1_base + @test logL2 ≈ logL2_base end @testset "Viterbi" begin - best_state_seq_base = HMMBase.viterbi(hmm_base, obs_mat) - best_state_seq = @inferred viterbi(hmm, obs_seq) - @test isequal(best_state_seq, best_state_seq_base) + q1_base = HMMBase.viterbi(hmm_base, obs_mat1) + q2_base = HMMBase.viterbi(hmm_base, obs_mat2) + q1, q2 = viterbi(hmm, obs_seqs, nb_seqs) + @test isequal(q1, q1_base) + @test isequal(q2, q2_base) end @testset "Forward-backward" begin - γ_base = HMMBase.posteriors(hmm_base, obs_mat) - fb = @inferred forward_backward(hmm, obs_seq) - @test isapprox(fb.γ, γ_base') + γ1_base = HMMBase.posteriors(hmm_base, obs_mat1) + γ2_base = HMMBase.posteriors(hmm_base, obs_mat2) + (γ1, _), (γ2, _) = forward_backward(hmm, obs_seqs, nb_seqs) + @test isapprox(γ1, γ1_base') + @test isapprox(γ2, γ2_base') end @testset "Baum-Welch" begin hmm_est_base, hist_base = HMMBase.fit_mle( - hmm_init_base, obs_mat; maxiter=10, tol=-Inf + hmm_init_base, obs_mat1; maxiter=10, tol=-Inf ) logL_evolution_base = hist_base.logtots - hmm_est, logL_evolution = @inferred baum_welch( - hmm_init, obs_seq; max_iterations=10, atol=-Inf + hmm_est, logL_evolution = baum_welch( + hmm_init, [obs_seq1, obs_seq1], 2; max_iterations=10, atol=-Inf ) @test isapprox( - logL_evolution[(begin + 1):end], logL_evolution_base[begin:(end - 1)] + logL_evolution[(begin + 1):end], 2 * logL_evolution_base[begin:(end - 1)] ) - @test isapprox(initial_distribution(hmm_est), hmm_est_base.a) + @test isapprox(initialization(hmm_est), hmm_est_base.a) @test isapprox(transition_matrix(hmm_est), hmm_est_base.A) for (dist, dist_base) in zip(hmm.dists, hmm_base.B) - @test isapprox(dist.μ, dist_base.μ) + if hasfield(typeof(dist), :μ) + @test isapprox(dist.μ, dist_base.μ) + elseif hasfield(typeof(dist), :p) + @test isapprox(dist.p, dist_base.p) + end end end end -N = 5 -D = 3 -T = 100 - -p = rand_prob_vec(N) -p_init = rand_prob_vec(N) - -A = rand_trans_mat(N) -A_init = rand_trans_mat(N) - -# Normal - -dists_norm = [Normal(randn(), 1.0) for i in 1:N] -dists_norm_init = [Normal(randn(), 1) for i in 1:N] +N, D, T = 3, 2, 100 -hmm_norm = HMM(p, A, dists_norm) -hmm_norm_init = HMM(p_init, A_init, dists_norm_init) +@testset "Categorical" begin + test_correctness(rand_categorical_hmm(N, 10), rand_categorical_hmm(N, 10); T) +end @testset "Normal" begin - test_correctness(hmm_norm, hmm_norm_init; T) + test_correctness(rand_gaussian_hmm_1d(N), rand_gaussian_hmm_1d(N); T) end -# DiagNormal - -dists_diagnorm = [DiagNormal(randn(D), PDiagMat(ones(D))) for i in 1:N] -dists_diagnorm_init = [DiagNormal(randn(D), PDiagMat(ones(D) .^ 2)) for i in 1:N] - -hmm_diagnorm = HMM(p, A, dists_diagnorm) -hmm_diagnorm_init = HMM(p, A, dists_diagnorm_init) - @testset "DiagNormal" begin - test_correctness(hmm_diagnorm, hmm_diagnorm_init; T) + test_correctness(rand_gaussian_hmm_2d(N, D), rand_gaussian_hmm_2d(N, D); T) end diff --git a/test/dna.jl b/test/dna.jl index b6956ed5..6279d122 100644 --- a/test/dna.jl +++ b/test/dna.jl @@ -1,7 +1,6 @@ using DensityInterface using HiddenMarkovModels using Random: AbstractRNG -using RequiredInterfaces: check_interface_implemented using SimpleUnPack using StatsAPI using Test @@ -74,11 +73,11 @@ get_state(coding, nucleotide) = 4(coding - 1) + nucleotide Base.length(dchmm::DNACodingHMM) = 8 -function HMMs.initial_distribution(dchmm::DNACodingHMM) +function HiddenMarkovModels.initialization(dchmm::DNACodingHMM) return repeat(dchmm.cod_init; inner=4) .* repeat(dchmm.nuc_init; outer=2) end -function HMMs.transition_matrix(dchmm::DNACodingHMM) +function HiddenMarkovModels.transition_matrix(dchmm::DNACodingHMM) @unpack cod_trans, nuc_trans = dchmm A = Matrix{Float64}(undef, 8, 8) for c1 in 1:2, n1 in 1:4, c2 in 1:2, n2 in 1:4 @@ -88,12 +87,16 @@ function HMMs.transition_matrix(dchmm::DNACodingHMM) return A end -function HMMs.obs_distribution(::DNACodingHMM, s::Integer) - return Dirac(get_nucleotide(s)) +function HiddenMarkovModels.obs_distributions(hmm::DNACodingHMM) + return [Dirac(get_nucleotide(s)) for s in 1:length(hmm)] end function StatsAPI.fit!( - dchmm::DNACodingHMM, init_count, trans_count, obs_seq, state_marginals + dchmm::DNACodingHMM, + init_count::Vector, + trans_count::Matrix, + obs_seq::Vector, + state_marginals::Matrix, ) # Initializations for c in 1:2 @@ -102,8 +105,8 @@ function StatsAPI.fit!( for n in 1:4 dchmm.nuc_init[n] = sum(init_count[get_state(c, n)] for c in 1:2) end - HMMs.sum_to_one!(dchmm.cod_init) - HMMs.sum_to_one!(dchmm.nuc_init) + HiddenMarkovModels.sum_to_one!(dchmm.cod_init) + HiddenMarkovModels.sum_to_one!(dchmm.nuc_init) # Transitions for c1 in 1:2, c2 in 1:2 @@ -116,15 +119,13 @@ function StatsAPI.fit!( trans_count[get_state(c1, n1), get_state(c2, n2)] for c2 in 1:2 ) end - foreach(HMMs.sum_to_one!, eachrow(dchmm.cod_trans)) - foreach(HMMs.sum_to_one!, eachrow(@view dchmm.nuc_trans[1, :, :])) - foreach(HMMs.sum_to_one!, eachrow(@view dchmm.nuc_trans[2, :, :])) + foreach(HiddenMarkovModels.sum_to_one!, eachrow(dchmm.cod_trans)) + foreach(HiddenMarkovModels.sum_to_one!, eachrow(@view dchmm.nuc_trans[1, :, :])) + foreach(HiddenMarkovModels.sum_to_one!, eachrow(@view dchmm.nuc_trans[2, :, :])) return nothing end -@test check_interface_implemented(AbstractHMM, DNACodingHMM) - dchmm = DNACodingHMM(; cod_init=rand_prob_vec(2), nuc_init=rand_prob_vec(4), @@ -136,14 +137,10 @@ dchmm = DNACodingHMM(; most_likely_coding_seq = get_coding.(viterbi(dchmm, obs_seq)); -mc = MarkovChain(ones(4) ./ 4, rand_trans_mat(4)) -fit!(mc, obs_seq) - dchmm_init = DNACodingHMM(; cod_init=rand(2), nuc_init=rand(4), cod_trans=rand_trans_mat(2), - # using transition_matrix(mc) as initialization below seems to be worse nuc_trans=permutedims(cat(rand_trans_mat(4), rand_trans_mat(4); dims=3), (3, 1, 2)), ); diff --git a/test/doctests.jl b/test/doctests.jl deleted file mode 100644 index 15594db4..00000000 --- a/test/doctests.jl +++ /dev/null @@ -1,4 +0,0 @@ -using Documenter -using HiddenMarkovModels - -doctest(HiddenMarkovModels) diff --git a/test/formatting.jl b/test/formatting.jl deleted file mode 100644 index b3fe015d..00000000 --- a/test/formatting.jl +++ /dev/null @@ -1,5 +0,0 @@ -using HiddenMarkovModels -using JuliaFormatter: format -using Test - -@test format(HiddenMarkovModels; verbose=false, overwrite=false) diff --git a/test/interface.jl b/test/interface.jl deleted file mode 100644 index 0681ca94..00000000 --- a/test/interface.jl +++ /dev/null @@ -1,17 +0,0 @@ -using HiddenMarkovModels -using RequiredInterfaces: check_interface_implemented -using Suppressor -using Test - -struct Empty end - -@test check_interface_implemented(AbstractMC, MarkovChain) -@test check_interface_implemented(AbstractHMM, HMM) -@test check_interface_implemented(AbstractMC, HMM) -@test check_interface_implemented(AbstractHMM, PermutedHMM) - -@suppress begin - @test check_interface_implemented(AbstractMC, Empty) != true - @test check_interface_implemented(AbstractHMM, Empty) != true - @test check_interface_implemented(AbstractHMM, MarkovChain) != true -end diff --git a/test/linting.jl b/test/linting.jl deleted file mode 100644 index aaa7986a..00000000 --- a/test/linting.jl +++ /dev/null @@ -1,4 +0,0 @@ -using HiddenMarkovModels -using JET - -JET.test_package(HiddenMarkovModels; target_defined_modules=true) diff --git a/test/logarithmic.jl b/test/logarithmic.jl deleted file mode 100644 index 4129e355..00000000 --- a/test/logarithmic.jl +++ /dev/null @@ -1,31 +0,0 @@ -using Distributions -using HiddenMarkovModels -using LinearAlgebra -using LogarithmicNumbers -using SimpleUnPack -using Test - -N = 3 -D = 2 -T = 100 - -p = ones(N) / N -A = rand_trans_mat(N) -dists = [LightDiagNormal(randn(D), ones(D)) for i in 1:N]; -dists_init = [LightDiagNormal(randn(D), ones(D)) for i in 1:N]; -dists_init_log = [LightDiagNormal(randn(D), LogFloat64.(ones(D))) for i in 1:N]; - -hmm = HMM(p, A, dists); -@unpack state_seq, obs_seq = rand(hmm, T); - -hmm_init = HMM(LogFloat64.(p), A, dists_init); -hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq); -@test typeof(hmm_est) == typeof(hmm_init) - -hmm_init = HMM(p, LogFloat64.(A), dists_init); -hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq); -@test typeof(hmm_est) == typeof(hmm_init) - -hmm_init = HMM(p, A, dists_init_log); -hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq); -@test typeof(hmm_est) == typeof(hmm_init) diff --git a/test/mc.jl b/test/mc.jl deleted file mode 100644 index 2773e087..00000000 --- a/test/mc.jl +++ /dev/null @@ -1,22 +0,0 @@ -using HiddenMarkovModels -using Test - -N = 5 -T = 100 - -p = rand_prob_vec(N) -p_rand = rand_prob_vec(N) - -A = rand_trans_mat(N) -A_rand = rand_trans_mat(N) - -mc = MC(p, A) -mc_rand = MC(p_rand, A_rand) - -state_seq = rand(mc, T) - -mc_est = fit(mc_rand, state_seq) - -@test logdensityof(mc_est, state_seq) > - logdensityof(mc, state_seq) > - logdensityof(mc_rand, state_seq) diff --git a/test/misc.jl b/test/misc.jl new file mode 100644 index 00000000..e035dc46 --- /dev/null +++ b/test/misc.jl @@ -0,0 +1,29 @@ +using HiddenMarkovModels +using HiddenMarkovModels.HMMTest +using HiddenMarkovModels: PermutedHMM +using Distributions +using Test + +## Permuted + +perm = [3, 1, 2] +hmm = rand_gaussian_hmm_1d(3) +hmm_perm = PermutedHMM(hmm, perm) + +p = initialization(hmm) +A = transition_matrix(hmm) +d = obs_distributions(hmm) + +p_perm = initialization(hmm_perm) +A_perm = transition_matrix(hmm_perm) +d_perm = obs_distributions(hmm_perm) + +@testset "PermutedHMM" begin + for i in 1:3 + @test p_perm[i] ≈ p[perm[i]] + @test d_perm[i] ≈ d[perm[i]] + end + for i in 1:3, j in 1:3 + @test A_perm[i, j] ≈ A[perm[i], perm[j]] + end +end diff --git a/test/numbers.jl b/test/numbers.jl new file mode 100644 index 00000000..35e906cc --- /dev/null +++ b/test/numbers.jl @@ -0,0 +1,34 @@ +using HiddenMarkovModels +using HiddenMarkovModels: LightDiagNormal +using LinearAlgebra +using LogarithmicNumbers +using SimpleUnPack +using Test + +N, D, T = 3, 2, 1000 + +## LogarithmicNumbers + +p = ones(N) / N; +A = rand_trans_mat(N); +d = [LightDiagNormal(randn(D), ones(D)) for i in 1:N]; +d_init = [LightDiagNormal(randn(D), ones(D)) for i in 1:N]; +d_init_log = [LightDiagNormal(randn(D), LogFloat64.(ones(D))) for i in 1:N]; + +hmm = HMM(p, A, d); +obs_seq = rand(hmm, T).obs_seq; + +hmm_init1 = HMM(LogFloat64.(p), A, d_init); +hmm_est1, logL_evolution1 = @inferred baum_welch(hmm_init1, obs_seq); + +hmm_init2 = HMM(p, LogFloat64.(A), d_init); +hmm_est2, logL_evolution2 = @inferred baum_welch(hmm_init2, obs_seq); + +hmm_init3 = HMM(p, A, d_init_log); +hmm_est3, logL_evolution3 = @inferred baum_welch(hmm_init3, obs_seq); + +@testset "Logarithmic" begin + @test typeof(hmm_est1) == typeof(hmm_init1) + @test typeof(hmm_est2) == typeof(hmm_init2) + @test typeof(hmm_est3) == typeof(hmm_init3) +end diff --git a/test/permuted.jl b/test/permuted.jl deleted file mode 100644 index b8056704..00000000 --- a/test/permuted.jl +++ /dev/null @@ -1,24 +0,0 @@ -using HiddenMarkovModels -using Distributions -using Test - -p = rand_prob_vec(3) -A = rand_trans_mat(3) -dists = [Normal(i) for i in 1:3] - -hmm = HMM(p, A, dists) -perm = [3, 1, 2] - -hmm_perm = PermutedHMM(hmm, perm) -p_perm = initial_distribution(hmm_perm) -A_perm = transition_matrix(hmm_perm) -dists_perm = [obs_distribution(hmm_perm, i) for i in 1:3] - -for i in 1:3 - @test p_perm[i] ≈ p[perm[i]] - @test dists_perm[i] ≈ dists[perm[i]] -end - -for i in 1:3, j in 1:3 - @test A_perm[i, j] ≈ A[perm[i], perm[j]] -end diff --git a/test/quality.jl b/test/quality.jl deleted file mode 100644 index 0a6a547a..00000000 --- a/test/quality.jl +++ /dev/null @@ -1,4 +0,0 @@ -using Aqua: test_all -using HiddenMarkovModels - -test_all(HiddenMarkovModels; ambiguities=false) diff --git a/test/runtests.jl b/test/runtests.jl index 31c4c16a..a49bbe92 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,65 +1,58 @@ +using Aqua: Aqua +using Documenter: Documenter +using HiddenMarkovModels +using JuliaFormatter: JuliaFormatter +using JET: JET using Test @testset verbose = true "HiddenMarkovModels.jl" begin @testset "Code formatting" begin - include("formatting.jl") + @test JuliaFormatter.format(HiddenMarkovModels; verbose=false, overwrite=false) end if VERSION >= v"1.9" @testset "Code quality" begin - include("quality.jl") + Aqua.test_all(HiddenMarkovModels; ambiguities=false) end @testset "Code linting" begin - include("linting.jl") + JET.test_package(HiddenMarkovModels; target_defined_modules=true) end - @testset verbose = true "Type stability" begin + @testset "Type stability" begin include("type_stability.jl") end - @testset verbose = true "Allocations" begin + @testset "Allocations" begin include("allocations.jl") end end - @testset "Interface" begin - include("interface.jl") - end - - @testset "Markov chain" begin - include("mc.jl") + @testset "Doctests" begin + Documenter.doctest(HiddenMarkovModels) end - @testset verbose = true "Correctness" begin + @testset "Correctness" begin include("correctness.jl") end - @testset verbose = true "Sparse" begin - include("sparse.jl") - end - - @testset verbose = true "Static" begin - include("static.jl") + @testset "Array types" begin + include("arrays.jl") end - @testset verbose = true "Logarithmic" begin - include("logarithmic.jl") + @testset "Number types" begin + include("numbers.jl") end - @testset verbose = true "Autodiff" begin + @testset "Autodiff" begin include("autodiff.jl") end - @testset verbose = true "DNA" begin + @testset "DNA" begin include("dna.jl") end - @testset verbose = true "Permuted" begin - include("permuted.jl") - end - - @testset "Doctests" begin - include("doctests.jl") + @testset "Misc" begin + include("misc.jl") end end diff --git a/test/sparse.jl b/test/sparse.jl deleted file mode 100644 index da6e14df..00000000 --- a/test/sparse.jl +++ /dev/null @@ -1,25 +0,0 @@ -using Distributions -using HiddenMarkovModels -using HiddenMarkovModels: sum_to_one! -using LinearAlgebra -using SparseArrays -using SimpleUnPack -using Test - -N = 4 -T = 2000 - -p = ones(N) / N -A = SparseMatrixCSC(SymTridiagonal(ones(N), ones(N - 1))) -foreach(sum_to_one!, eachrow(A)) -dists = [Normal(i + randn(), 1) for i in 1:N] -dists_init = [Normal(i + randn(), 1) for i in 1:N] - -hmm = HMM(p, A, dists) -hmm_init = HMM(p, A, dists_init) - -@unpack state_seq, obs_seq = rand(hmm, T) -hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq) - -@test typeof(hmm_est) == typeof(hmm_init) -@test nnz(transition_matrix(hmm_est)) <= nnz(transition_matrix(hmm)) diff --git a/test/static.jl b/test/static.jl deleted file mode 100644 index 0a30d67c..00000000 --- a/test/static.jl +++ /dev/null @@ -1,22 +0,0 @@ -using Distributions -using HiddenMarkovModels -using LinearAlgebra -using StaticArrays -using SimpleUnPack -using Test - -N = 5 -T = 1000 - -p = MVector{N}(ones(N) / N) -A = MMatrix{N,N}(rand_trans_mat(N)) -dists = MVector{N}([Normal(randn(), 1.0) for i in 1:N]) -dists_init = MVector{N}([Normal(randn(), 1.0) for i in 1:N]) - -hmm = HMM(p, A, dists) -hmm_init = HMM(p, A, dists_init) - -@unpack state_seq, obs_seq = rand(hmm, T) -hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq) - -@test typeof(hmm_est) == typeof(hmm_init) diff --git a/test/type_stability.jl b/test/type_stability.jl index ed730606..642ae0ef 100644 --- a/test/type_stability.jl +++ b/test/type_stability.jl @@ -1,88 +1,68 @@ -using Distributions -using Distributions: PDiagMat using HiddenMarkovModels -using HiddenMarkovModels: LightDiagNormal +using HiddenMarkovModels.HMMTest +import HiddenMarkovModels as HMMs using JET using SimpleUnPack using Test -function test_type_stability(hmm, hmm_init; T, K) - obs_seqs = [rand(hmm, T).obs_seq for k in 1:K] +function test_type_stability(hmm, hmm_init; T) + obs_seq = rand(hmm, T).obs_seq @testset "Logdensity" begin - @inferred logdensityof(hmm, obs_seqs, K) - @test_opt target_modules = (HiddenMarkovModels,) logdensityof(hmm, obs_seqs, K) - @test_call target_modules = (HiddenMarkovModels,) logdensityof(hmm, obs_seqs, K) + @inferred logdensityof(hmm, obs_seq) + @test_opt target_modules = (HMMs,) logdensityof(hmm, obs_seq) + @test_call target_modules = (HMMs,) logdensityof(hmm, obs_seq) end @testset "Forward" begin - @inferred forward(hmm, obs_seqs, K) - @test_opt target_modules = (HiddenMarkovModels,) forward(hmm, obs_seqs, K) - @test_call target_modules = (HiddenMarkovModels,) forward(hmm, obs_seqs, K) + @inferred forward(hmm, obs_seq) + @test_opt target_modules = (HMMs,) forward(hmm, obs_seq) + @test_call target_modules = (HMMs,) forward(hmm, obs_seq) end @testset "Viterbi" begin - @inferred viterbi(hmm, obs_seqs, K) - @test_opt target_modules = (HiddenMarkovModels,) viterbi(hmm, obs_seqs, K) - @test_call target_modules = (HiddenMarkovModels,) viterbi(hmm, obs_seqs, K) + @inferred viterbi(hmm, obs_seq) + @test_opt target_modules = (HMMs,) viterbi(hmm, obs_seq) + @test_call target_modules = (HMMs,) viterbi(hmm, obs_seq) end @testset "Forward-backward" begin - @inferred forward_backward(hmm, obs_seqs, K) - @test_opt target_modules = (HiddenMarkovModels,) forward_backward(hmm, obs_seqs, K) - @test_call target_modules = (HiddenMarkovModels,) forward_backward(hmm, obs_seqs, K) + @inferred forward_backward(hmm, obs_seq) + @test_opt target_modules = (HMMs,) forward_backward(hmm, obs_seq) + @test_call target_modules = (HMMs,) forward_backward(hmm, obs_seq) end @testset "Baum-Welch" begin - @inferred baum_welch(hmm_init, obs_seqs, K) - @test_opt target_modules = (HiddenMarkovModels,) baum_welch(hmm_init, obs_seqs, K) - @test_call target_modules = (HiddenMarkovModels,) baum_welch(hmm_init, obs_seqs, K) + @inferred baum_welch(hmm_init, obs_seq) + @test_opt target_modules = (HMMs,) baum_welch(hmm_init, obs_seq; max_iterations=2) + @test_call target_modules = (HMMs,) baum_welch(hmm_init, obs_seq; max_iterations=2) end end -N = 5 -D = 3 -T = 100 -K = 4 +N, D, T = 3, 2, 100 -p = rand_prob_vec(N) -p_init = rand_prob_vec(N) - -A = rand_trans_mat(N) -A_init = rand_trans_mat(N) - -# Normal - -dists_norm = [Normal(randn(), 1.0) for i in 1:N] -dists_norm_init = [Normal(randn(), 1) for i in 1:N] - -hmm_norm = HMM(p, A, dists_norm) -hmm_norm_init = HMM(p_init, A_init, dists_norm_init) +@testset "Categorical" begin + test_type_stability(rand_categorical_hmm(N, D), rand_categorical_hmm(N, D); T) +end @testset "Normal" begin - test_type_stability(hmm_norm, hmm_norm_init; T, K) + test_type_stability(rand_gaussian_hmm_1d(N), rand_gaussian_hmm_1d(N); T) end -# DiagNormal - -dists_diagnorm = [DiagNormal(randn(D), PDiagMat(ones(D) .^ 2)) for i in 1:N] -dists_diagnorm_init = [DiagNormal(randn(D), PDiagMat(ones(D) .^ 2)) for i in 1:N] - -hmm_diagnorm = HMM(p, A, dists_diagnorm) -hmm_diagnorm_init = HMM(p, A, dists_diagnorm_init) +@testset "Normal sparse" begin + test_type_stability( + rand_gaussian_hmm_1d(N; sparse_trans=true), + rand_gaussian_hmm_1d(N; sparse_trans=true); + T, + ) +end @testset "DiagNormal" begin - test_type_stability(hmm_diagnorm, hmm_diagnorm_init; T, K) + test_type_stability(rand_gaussian_hmm_2d(N, D), rand_gaussian_hmm_2d(N, D); T) end -## LightDiagNormal - -dists_lightdiagnorm = [LightDiagNormal(randn(D), ones(D)) for i in 1:N] -dists_lightdiagnorm_init = [LightDiagNormal(randn(D), ones(D)) for i in 1:N] - -hmm_lightdiagnorm = HMM(p, A, dists_lightdiagnorm) -hmm_lightdiagnorm_init = HMM(p, A, dists_lightdiagnorm_init) - @testset "LightDiagNormal" begin - test_type_stability(hmm_lightdiagnorm, hmm_lightdiagnorm_init; T, K) + test_type_stability( + rand_gaussian_hmm_2d_light(N, D), rand_gaussian_hmm_2d_light(N, D); T + ) end