diff --git a/.gitignore b/.gitignore index 52bc2a27..8f82be1f 100644 --- a/.gitignore +++ b/.gitignore @@ -12,5 +12,6 @@ docs/**/*.svg *.csv *.txt /docs/src/index.md +/docs/src/examples/*.md .vscode/ .benchmarkci/ \ No newline at end of file diff --git a/docs/Project.toml b/docs/Project.toml index a3d01c80..7843672c 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" @@ -6,6 +7,7 @@ HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" diff --git a/docs/make.jl b/docs/make.jl index c96fae19..ede90e3f 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -23,15 +23,42 @@ open(joinpath(joinpath(@__DIR__, "src"), "index.md"), "w") do io end end +examples_jl_path = joinpath(dirname(@__DIR__), "examples") +examples_md_path = joinpath(@__DIR__, "src", "examples") + +for file in readdir(examples_md_path) + if endswith(file, ".md") + rm(joinpath(examples_md_path, file)) + end +end + +for file in readdir(examples_jl_path) + Literate.markdown(joinpath(examples_jl_path, file), examples_md_path) +end + +function literate_title(path) + l = first(readlines(path)) + return l[3:end] +end + pages = [ - "Home" => "index.md", "Essentials" => [ + "Home" => "index.md", "Background" => "background.md", - "API reference" => "api.md", "Alternatives" => "alternatives.md", ], - "Tutorials" => ["Debugging" => "debugging.md"], - "Advanced" => ["Formulas" => "formulas.md", "Roadmap" => "roadmap.md"], + "Tutorials" => [ + "Basics" => joinpath("examples", "basics.md"), + "Distributions" => joinpath("examples", "distributions.md"), + "Controlled" => joinpath("examples", "controlled.md"), + "Periodic" => joinpath("examples", "periodic.md"), + ], + "API reference" => "api.md", + "Advanced" => [ + "Debugging" => "debugging.md", + "Formulas" => "formulas.md", + "Roadmap" => "roadmap.md", + ], ] fmt = Documenter.HTML(; diff --git a/docs/src/examples/.gitkeep b/docs/src/examples/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/examples/hmm.jl b/examples/basics.jl similarity index 70% rename from examples/hmm.jl rename to examples/basics.jl index f662f13e..5ca6c152 100644 --- a/examples/hmm.jl +++ b/examples/basics.jl @@ -1,4 +1,4 @@ -# # Built-in HMM +# # Basics # ## Setup @@ -6,7 +6,9 @@ using Distributions using HiddenMarkovModels +#md using Plots using Random +using Test #src # Random seed @@ -18,7 +20,7 @@ Random.seed!(rng, 63) N = 2 init = rand_prob_vec(N) trans = rand_trans_mat(N) -dists = [Normal(i, 0.5) for i in 1:N] +dists = [Normal(i, 1) for i in 1:N] hmm = HMM(init, trans, dists) # ## Simulation @@ -44,20 +46,26 @@ forward_backward(hmm, obs_seq) # ## Learning from several sequences -K = 3 -obs_seqs = [rand(rng, hmm, k * T).obs_seq for k in 1:K] +nb_seqs = 3 +obs_seqs = [rand(rng, hmm, k * T).obs_seq for k in 1:nb_seqs] # Baum-Welch needs an initial guess init_guess = ones(N) / N trans_guess = ones(N, N) / N -dists_guess = [Normal(i, 1) for i in 1:N] +dists_guess = [Normal(i + randn() / 10, 1) for i in 1:N] hmm_guess = HMM(init_guess, trans_guess, dists_guess) #- -hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seqs, length(obs_seqs)) +hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seqs, nb_seqs) + +#md plot(logL_evolution) + +#- first(logL_evolution), last(logL_evolution) +#- + cat(hmm_est.trans, hmm.trans; dims=3) diff --git a/examples/controlled.jl b/examples/controlled.jl index 3b4ed074..2485a589 100644 --- a/examples/controlled.jl +++ b/examples/controlled.jl @@ -1,6 +1 @@ -struct ControlledHMM{R,C} - init::Vector{R} - trans::Matrix{R} - means::Vector{R} - control_seq::Vector{C} -end +# # Controlled HMM diff --git a/examples/distributions.jl b/examples/distributions.jl new file mode 100644 index 00000000..e6751d6b --- /dev/null +++ b/examples/distributions.jl @@ -0,0 +1 @@ +# # Distributions diff --git a/examples/dna.jl b/examples/dna.jl deleted file mode 100644 index 6279d122..00000000 --- a/examples/dna.jl +++ /dev/null @@ -1,155 +0,0 @@ -using DensityInterface -using HiddenMarkovModels -using Random: AbstractRNG -using SimpleUnPack -using StatsAPI -using Test - -struct Dirac{T} - val::T -end - -Base.rand(::AbstractRNG, d::Dirac) = d.val -DensityInterface.DensityKind(::Dirac) = HasDensity() -DensityInterface.logdensityof(d::Dirac, x) = x == d.val ? 0.0 : -Inf - -""" - DNACodingHMM <: AbstractHMM - -Custom implementation of an autoregressive HMM based on a standard HMM. - -This describes the behavior of DNA as it moves from coding to noncoding segments. -In theory, the state is a character `coding` and the observation is a character `nucleotide`. -In practice, the state is a character couple `(coding, nucleotide)` and the observation is the exact same `nucleotide`. - -# Notations - -Coding: - -| 1 | 2 | -|----|----| -| C | N | - -Emissions: - -| 1 | 2 | 3 | 4 | -|---|---|---|---| -| A | T | G | C | - -States: - -| 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -|-------|-------|-------|-------|-------|-------|-------|-------| -| (C,A) | (C,T) | (C,G) | (C,C) | (N,A) | (N,T) | (N,G) | (N,C) | - -# Fields - -- `cod_init::Vector{Float64}`: initial coding distribution -- `nuc_init::Vector{Float64}`: initial nucleotide distribution -- `cod_trans::Matrix{Float64}`: transition matrix between coding and noncoding -- `nuc_trans::Array{Float64,2}`: pair of transition matrices between nucleotides in coding and noncoding state -""" -struct DNACodingHMM <: AbstractHMM - cod_init::Vector{Float64} - nuc_init::Vector{Float64} - cod_trans::Matrix{Float64} - nuc_trans::Array{Float64,3} - function DNACodingHMM(; cod_init, nuc_init, cod_trans, nuc_trans) - @assert length(cod_init) == 2 - @assert length(nuc_init) == 4 - @assert size(cod_trans) == (2, 2) - @assert size(nuc_trans) == (2, 4, 4) - return new(cod_init, nuc_init, cod_trans, nuc_trans) - end -end - -get_coding(state) = 1 + (state - 1) ÷ 4 -get_nucleotide(state) = 1 + (state - 1) % 4 -get_state(coding, nucleotide) = 4(coding - 1) + nucleotide - -@test get_coding.(1:8) == repeat(1:2; inner=4) -@test get_nucleotide.(1:8) == repeat(1:4; outer=2) -@test get_state.(get_coding.(1:8), get_nucleotide.(1:8)) == collect(1:8) - -Base.length(dchmm::DNACodingHMM) = 8 - -function HiddenMarkovModels.initialization(dchmm::DNACodingHMM) - return repeat(dchmm.cod_init; inner=4) .* repeat(dchmm.nuc_init; outer=2) -end - -function HiddenMarkovModels.transition_matrix(dchmm::DNACodingHMM) - @unpack cod_trans, nuc_trans = dchmm - A = Matrix{Float64}(undef, 8, 8) - for c1 in 1:2, n1 in 1:4, c2 in 1:2, n2 in 1:4 - s1, s2 = get_state(c1, n1), get_state(c2, n2) - A[s1, s2] = cod_trans[c1, c2] * nuc_trans[c1, n1, n2] - end - return A -end - -function HiddenMarkovModels.obs_distributions(hmm::DNACodingHMM) - return [Dirac(get_nucleotide(s)) for s in 1:length(hmm)] -end - -function StatsAPI.fit!( - dchmm::DNACodingHMM, - init_count::Vector, - trans_count::Matrix, - obs_seq::Vector, - state_marginals::Matrix, -) - # Initializations - for c in 1:2 - dchmm.cod_init[c] = sum(init_count[get_state(c, n)] for n in 1:4) - end - for n in 1:4 - dchmm.nuc_init[n] = sum(init_count[get_state(c, n)] for c in 1:2) - end - HiddenMarkovModels.sum_to_one!(dchmm.cod_init) - HiddenMarkovModels.sum_to_one!(dchmm.nuc_init) - - # Transitions - for c1 in 1:2, c2 in 1:2 - dchmm.cod_trans[c1, c2] = sum( - trans_count[get_state(c1, n1), get_state(c2, n2)] for n1 in 1:4, n2 in 1:4 - ) - end - for c1 in 1:2, n1 in 1:4, n2 in 1:4 - dchmm.nuc_trans[c1, n1, n2] = sum( - trans_count[get_state(c1, n1), get_state(c2, n2)] for c2 in 1:2 - ) - end - foreach(HiddenMarkovModels.sum_to_one!, eachrow(dchmm.cod_trans)) - foreach(HiddenMarkovModels.sum_to_one!, eachrow(@view dchmm.nuc_trans[1, :, :])) - foreach(HiddenMarkovModels.sum_to_one!, eachrow(@view dchmm.nuc_trans[2, :, :])) - - return nothing -end - -dchmm = DNACodingHMM(; - cod_init=rand_prob_vec(2), - nuc_init=rand_prob_vec(4), - cod_trans=rand_trans_mat(2), - nuc_trans=permutedims(cat(rand_trans_mat(4), rand_trans_mat(4); dims=3), (3, 1, 2)), -); - -@unpack state_seq, obs_seq = rand(dchmm, 10_000); - -most_likely_coding_seq = get_coding.(viterbi(dchmm, obs_seq)); - -dchmm_init = DNACodingHMM(; - cod_init=rand(2), - nuc_init=rand(4), - cod_trans=rand_trans_mat(2), - nuc_trans=permutedims(cat(rand_trans_mat(4), rand_trans_mat(4); dims=3), (3, 1, 2)), -); - -dchmm_est, logL_evolution = baum_welch(dchmm_init, obs_seq; atol=1e-7, max_iterations=100); - -logL_evolution - -sum(abs, dchmm_init.cod_trans - dchmm.cod_trans) / (2 * 2) -sum(abs, dchmm_est.cod_trans - dchmm.cod_trans) / (2 * 2) - -sum(abs, dchmm_init.nuc_trans - dchmm.nuc_trans) / (2 * 4 * 4) -sum(abs, dchmm_est.nuc_trans - dchmm.nuc_trans) / (2 * 4 * 4) diff --git a/examples/periodic.jl b/examples/periodic.jl index b7e78789..d9c6c449 100644 --- a/examples/periodic.jl +++ b/examples/periodic.jl @@ -3,23 +3,20 @@ using Distributions using HiddenMarkovModels import HiddenMarkovModels as HMMs -using Plots +#md using Plots using SimpleUnPack using StatsAPI # ## Structure -""" - PeriodicHMM{L} - -Basic implementation of a periodic HMM with time-dependent transition matrices and observation distributions, repeating every `L` time steps. -""" struct PeriodicHMM{L,V<:AbstractVector,M<:AbstractMatrix,VD<:AbstractVector} <: AbstractHMM init::V trans_periodic::NTuple{L,M} dists_periodic::NTuple{L,VD} end +#- + period(::PeriodicHMM{L}) where {L} = L Base.length(phmm::PeriodicHMM) = length(phmm.init) @@ -36,18 +33,19 @@ end ## Fitting struct BaumWelchStoragePeriodicHMM <: HMMs.AbstractBaumWelchStorage end + function HMMs.initialize_baum_welch(::PeriodicHMM, fb_storages, obs_seqs) return BaumWelchStoragePeriodicHMM() end +#- + function fit_states!(hmm::PeriodicHMM, fb_storages::Vector{<:HMMs.ForwardBackwardStorage}) L = period(hmm) - # Reset hmm.init .= 0 for l in 1:L hmm.trans_periodic[l] .= 0 end - # Accumulate sufficient stats for k in eachindex(fb_storages) @unpack γ, ξ = fb_storages[k] hmm.init .+= view(γ, :, 1) @@ -56,7 +54,6 @@ function fit_states!(hmm::PeriodicHMM, fb_storages::Vector{<:HMMs.ForwardBackwar hmm.trans_periodic[l] .+= ξ[t] end end - # Normalize hmm.init ./= sum(hmm.init) for l in 1:L hmm.trans_periodic[l] ./= sum(hmm.trans_periodic[l]; dims=2) @@ -64,26 +61,31 @@ function fit_states!(hmm::PeriodicHMM, fb_storages::Vector{<:HMMs.ForwardBackwar return nothing end +#- + function fit_observations!( hmm::PeriodicHMM, fb_storages::Vector{<:HMMs.ForwardBackwardStorage}, obs_seqs::Vector{<:Vector}, ) + L = period(hmm) for l in 1:L - obs_seq_periodic = reduce(vcat, obs_seqs[k][l:L:end] for k in eachindex(obs_seqs)) - state_marginals_periodic = reduce( - hcat, fb_storages[k].γ[:, l:L:end] for k in eachindex(fb_storages) - ) for i in 1:length(hmm) + obs_seq_periodic = reduce( + vcat, obs_seqs[k][l:L:end] for k in eachindex(obs_seqs) + ) + state_marginals_periodic = reduce( + vcat, fb_storages[k].γ[i, l:L:end] for k in eachindex(fb_storages) + ) D = typeof(hmm.dists_periodic[l][i]) - x = obs_seq_periodic - w = view(state_marginals_periodic, i, :) - hmm.dists_periodic[l][i] = fit(D, x, w) + hmm.dists_periodic[l][i] = fit(D, obs_seq_periodic, state_marginals_periodic) end end return nothing end +#- + function StatsAPI.fit!( hmm::PeriodicHMM, ::BaumWelchStoragePeriodicHMM, @@ -97,76 +99,42 @@ end # ## Example -N = 2 # Number of hidden states -L = 10 # Period of the HMM -T = 50_000 # Number of observation - -function make_trans(l, L) - A = Matrix{Float64}(undef, 2, 2) - A[1, 1] = 0.25 + 0.1 + 0.5cos(2π / L * l + 1)^2 - A[1, 2] = 0.25 - 0.1 + 0.5sin(2π / L * l + 1)^2 - A[2, 2] = 0.25 + 0.2 + 0.5cos(2π / L * (l - L / 3))^2 - A[2, 1] = 0.25 - 0.2 + 0.5sin(2π / L * (l - L / 3))^2 - return A -end - -function make_dists(l, L, N) - dists = [Normal(2i * cos(2π * l / L), i + cos(2π / L * (l - i / 2 + 1))^2) for i in 1:N] - return dists -end +N = 2 +T = 1000 init = ones(N) / N; -trans_periodic = ntuple(l -> make_trans(l, L), L); -dists_periodic = ntuple(l -> make_dists(l, L, N), L); +trans_periodic = ( + [0.9 0.1; 0.1 0.9], # + [0.8 0.2; 0.2 0.8], # + [0.7 0.3; 0.3 0.7], +); +dists_periodic = ( + [Normal(0), Normal(4)], # + [Normal(2), Normal(6)], # + [Normal(4), Normal(8)], +); hmm = PeriodicHMM(init, trans_periodic, dists_periodic); +#- + state_seq, obs_seq = rand(hmm, T); +hmm_est, logL_evolution = baum_welch(hmm, obs_seq); -hmm_est, logL_evolution = baum_welch(hmm, obs_seq; max_iterations=100); -length(logL_evolution) - -## Plotting - -p = [plot(; xlabel="l", title="transitions from state $i") for i in 1:N] -for i in 1:N, j in 1:N - plot!( - p[i], - 1:L, - [transition_matrix(hmm, l)[i, j] for l in 1:L]; - label="p$((i,j)) - true", - c=j, - ) - plot!( - p[i], - 1:L, - [transition_matrix(hmm_est, l)[i, j] for l in 1:L]; - label="p$((i,j)) - est", - c=j, - s=:dash, - ) -end -plot(p...; size=(1000, 500)) - -p = [plot(; xlabel="l", title="emissions from state $i") for i in 1:N] -for i in 1:N - plot!(p[i], 1:L, [obs_distributions(hmm, l)[i].μ for l in 1:L]; label="μ - true", c=1) - plot!( - p[i], - 1:L, - [obs_distributions(hmm_est, l)[i].μ for l in 1:L]; - label="μ - est", - c=1, - s=:dash, - ) - plot!(p[i], 1:L, [obs_distributions(hmm, l)[i].σ for l in 1:L]; label="σ - true", c=2) - plot!( - p[i], - 1:L, - [obs_distributions(hmm_est, l)[i].σ for l in 1:L]; - label="σ - est", - c=2, - s=:dash, - ) -end -plot(p...; size=(1000, 500)) +#md plot(logL_evolution) + +#- + +cat(hmm_est.init, hmm.init; dims=3) + +#- + +cat(hmm_est.trans_periodic[1], hmm.trans_periodic[1]; dims=3) +cat(hmm_est.trans_periodic[2], hmm.trans_periodic[2]; dims=3) +cat(hmm_est.trans_periodic[3], hmm.trans_periodic[3]; dims=3) + +#- + +cat(hmm_est.dists_periodic[1], hmm.dists_periodic[1]; dims=3) +cat(hmm_est.dists_periodic[2], hmm.dists_periodic[2]; dims=3) +cat(hmm_est.dists_periodic[3], hmm.dists_periodic[3]; dims=3) diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 77a88c1b..74c004df 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -18,7 +18,7 @@ function baum_welch_has_converged( if length(logL_evolution) >= 2 logL, logL_prev = logL_evolution[end], logL_evolution[end - 1] progress = logL - logL_prev - if loglikelihood_increasing && progress < 0 + if loglikelihood_increasing && progress < min(0, -atol) error("Loglikelihood decreased in Baum-Welch") elseif abs(progress) < atol return true diff --git a/src/precompile.jl b/src/precompile.jl index 2f8beafb..ffe4da05 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -1,14 +1,15 @@ @compile_workload begin N, D, T = 3, 2, 100 - p = rand_prob_vec(N) - A = rand_trans_mat(N) + init = rand_prob_vec(N) + trans = rand_trans_mat(N) dists = [LightDiagNormal(randn(D), ones(D)) for i in 1:N] - hmm = HMM(p, A, dists) - obs_seq = rand(hmm, T).obs_seq + hmm = HMM(init, trans, dists) + state_seq, obs_seq = rand(hmm, T) + logdensityof(hmm, obs_seq, state_seq) logdensityof(hmm, obs_seq) forward(hmm, obs_seq) viterbi(hmm, obs_seq) forward_backward(hmm, obs_seq) - baum_welch(hmm, obs_seq; max_iterations=2, atol=-Inf) + baum_welch(hmm, obs_seq; max_iterations=1) end diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index 00ee2acd..3f4a7b4d 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -80,7 +80,8 @@ function obs_logdensities!(logb::AbstractVector, hmm::AbstractHMM, t::Integer, o @inbounds for i in eachindex(logb, dists) logb[i] = logdensityof(dists[i], obs) end - return check_right_finite(logb) + check_right_finite(logb) + return nothing end """ @@ -105,17 +106,20 @@ function Base.rand(rng::AbstractRNG, hmm::AbstractHMM, T::Integer) state_seq = Vector{typeof(state1)}(undef, T) state_seq[1] = state1 - dists = obs_distributions(hmm, 1) - obs1 = rand(rng, dists[state1]) + @views for t in 1:(T - 1) + trans = transition_matrix(hmm, t) + state_seq[t + 1] = rand( + rng, LightCategorical(trans[state_seq[t], :], dummy_log_probas) + ) + end + + dists1 = obs_distributions(hmm, 1) + obs1 = rand(rng, dists1[state1]) obs_seq = Vector{typeof(obs1)}(undef, T) obs_seq[1] = obs1 - @views for t in 2:T - trans = transition_matrix(hmm, t) + for t in 2:T dists = obs_distributions(hmm, t) - state_seq[t] = rand( - rng, LightCategorical(trans[state_seq[t - 1], :], dummy_log_probas) - ) obs_seq[t] = rand(rng, dists[state_seq[t]]) end return (; state_seq=state_seq, obs_seq=obs_seq) diff --git a/src/utils/lightcategorical.jl b/src/utils/lightcategorical.jl index f2391f06..82ea7a05 100644 --- a/src/utils/lightcategorical.jl +++ b/src/utils/lightcategorical.jl @@ -34,7 +34,7 @@ function Base.rand(rng::AbstractRNG, dist::LightCategorical{T1}) where {T1} s = zero(T1) for k in eachindex(dist.p) s += dist.p[k] - if u < s + if u <= s return k end end diff --git a/test/correctness.jl b/test/correctness.jl index 18881868..e2c49326 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -26,14 +26,14 @@ function Base.isapprox(hmm::HMM, hmm_base::HMMBase.HMM) end function Base.isapprox(hmm1::AbstractHMM, hmm2::AbstractHMM; atol) - #= init1 = initialization(hmm1) init2 = initialization(hmm2) maximum(abs, init1 - init2) < atol || return false - =# + trans1 = transition_matrix(hmm1, 1) trans2 = transition_matrix(hmm2, 1) maximum(abs, trans1 - trans2) < atol || return false + dists1 = obs_distributions(hmm1, 1) dists2 = obs_distributions(hmm2, 1) for (dist1, dist2) in zip(dists1, dists2) @@ -105,6 +105,7 @@ function test_correctness_baum_welch( permuted_hmm_est = PermutedHMM(hmm_est, perm) push!(success_by_perm, isapprox(permuted_hmm_est, hmm; atol)) end + @test last(logL_evolution) > first(logL_evolution) @test sum(success_by_perm) == 1 end @@ -113,8 +114,8 @@ end @testset "Categorical" begin N = 2 - init = rand_prob_vec(N) - trans = rand_trans_mat(N) + init = [0.3, 0.7] + trans = [0.8 0.2; 0.2 0.8] dists = [Categorical([0.2, 0.8]), Categorical([0.8, 0.2])] hmm = HMM(init, trans, dists) @@ -123,95 +124,101 @@ end dists_guess = [Categorical([0.3, 0.7]), Categorical([0.7, 0.3])] hmm_guess = HMM(init_guess, trans_guess, dists_guess) + test_correctness_baum_welch(hmm, hmm_guess; T=200, nb_seqs=100, atol=0.1) test_comparison_hmmbase(hmm, hmm_guess; T=100) - test_correctness_baum_welch(hmm, hmm_guess; T=100, nb_seqs=10, atol=0.2) end @testset "Normal" begin N = 2 - init = rand_prob_vec(N) - trans = rand_trans_mat(N) - dists = [Normal(i, 1) for i in 1:N] + init = [0.3, 0.7] + trans = [0.8 0.2; 0.2 0.8] + dists = [Normal(randn(), 1) for i in 1:N] hmm = HMM(init, trans, dists) init_guess = ones(N) / N trans_guess = ones(N, N) / N - dists_guess = [Normal(i + 0.3, 1) for i in 1:N] + dists_guess = [Normal(randn(), 1) for i in 1:N] hmm_guess = HMM(init_guess, trans_guess, dists_guess) + test_correctness_baum_welch(hmm, hmm_guess; T=100, nb_seqs=100, atol=0.1) test_comparison_hmmbase(hmm, hmm_guess; T=100) - test_correctness_baum_welch(hmm, hmm_guess; T=100, nb_seqs=10, atol=0.2) end @testset "DiagNormal" begin N, D = 2, 2 - init = rand_prob_vec(N) - trans = rand_trans_mat(N) - dists = [DiagNormal(i .* ones(D), PDiagMat(ones(D) .^ 2)) for i in 1:N] + init = [0.3, 0.7] + trans = [0.8 0.2; 0.2 0.8] + dists = [DiagNormal(randn(D), PDiagMat(ones(D) .^ 2)) for i in 1:N] hmm = HMM(init, trans, dists) init_guess = ones(N) / N trans_guess = ones(N, N) / N - dists_guess = [DiagNormal((i + 0.3) .* ones(D), PDiagMat(ones(D) .^ 2)) for i in 1:N] + dists_guess = [DiagNormal(randn(D), PDiagMat(ones(D) .^ 2)) for i in 1:N] hmm_guess = HMM(init_guess, trans_guess, dists_guess) + test_correctness_baum_welch(hmm, hmm_guess; T=100, nb_seqs=100, atol=0.1) test_comparison_hmmbase(hmm, hmm_guess; T=100) - test_correctness_baum_welch(hmm, hmm_guess; T=500, nb_seqs=10, atol=0.2) end -## Sparse arrays +## Light distributions -@testset "Normal sparse" begin +@testset "LightCategorical" begin N = 2 - init = rand_prob_vec(N) - trans = SparseMatrixCSC(SymTridiagonal(rand(N), rand(N - 1))) - foreach(HiddenMarkovModels.sum_to_one!, eachrow(trans)) - dists = [Normal(i, 1) for i in 1:N] + init = [0.3, 0.7] + trans = [0.8 0.2; 0.2 0.8] + dists = [LightCategorical([0.2, 0.8]), LightCategorical([0.8, 0.2])] hmm = HMM(init, trans, dists) init_guess = ones(N) / N - trans_guess = SparseMatrixCSC(SymTridiagonal(ones(N), ones(N - 1))) - foreach(HiddenMarkovModels.sum_to_one!, eachrow(trans_guess)) - dists_guess = [Normal(i + 0.3, 1) for i in 1:N] + trans_guess = ones(N, N) / N + dists_guess = [LightCategorical([0.3, 0.7]), LightCategorical([0.7, 0.3])] hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_comparison_hmmbase(hmm, hmm_guess; T=100) - test_correctness_baum_welch(hmm, hmm_guess; T=100, nb_seqs=10, atol=0.2) + test_correctness_baum_welch(hmm, hmm_guess; T=200, nb_seqs=100, atol=0.1) end -## Light distributions - -@testset "LightCategorical" begin - N = 2 +@testset "LightDiagNormal" begin + N, D = 2, 2 - init = rand_prob_vec(N) - trans = rand_trans_mat(N) - dists = [LightCategorical([0.2, 0.8]), LightCategorical([0.8, 0.2])] + init = [0.3, 0.7] + trans = [0.8 0.2; 0.2 0.8] + dists = [LightDiagNormal(randn(D), ones(D)) for i in 1:N] hmm = HMM(init, trans, dists) init_guess = ones(N) / N trans_guess = ones(N, N) / N - dists_guess = [LightCategorical([0.3, 0.7]), LightCategorical([0.7, 0.3])] + dists_guess = [LightDiagNormal(randn(D), ones(D)) for i in 1:N] hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_correctness_baum_welch(hmm, hmm_guess; T=100, nb_seqs=10, atol=0.2) + test_correctness_baum_welch(hmm, hmm_guess; T=100, nb_seqs=100, atol=0.1) end -@testset "LightDiagNormal" begin - N, D = 2, 2 +## Weird arrays - init = rand_prob_vec(N) - trans = rand_trans_mat(N) - dists = [LightDiagNormal(i .* ones(D), ones(D)) for i in 1:N] +@testset "Normal sparse" begin + N = 3 + + init = [0.3, 0.5, 0.2] + trans = sparse([ + 0.8 0.2 0.0 + 0.0 0.8 0.2 + 0.2 0.0 0.8 + ]) + dists = [Normal(i, 1) for i in 1:N] hmm = HMM(init, trans, dists) init_guess = ones(N) / N - trans_guess = ones(N, N) / N - dists_guess = [LightDiagNormal((i + 0.3) .* ones(D), ones(D)) for i in 1:N] + trans_guess = sparse([ + 0.5 0.5 0.0 + 0.0 0.5 0.5 + 0.5 0.0 0.5 + ]) + dists_guess = [Normal(i + 0.3, 1) for i in 1:N] hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_correctness_baum_welch(hmm, hmm_guess; T=500, nb_seqs=10, atol=0.2) + test_correctness_baum_welch(hmm, hmm_guess; T=100, nb_seqs=100, atol=0.1) + test_comparison_hmmbase(hmm, hmm_guess; T=100) end diff --git a/test/distributions.jl b/test/distributions.jl index 434c0930..44973773 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -2,10 +2,16 @@ using HiddenMarkovModels: LightCategorical, LightDiagNormal using Statistics using Test +function test_fit_allocs(dist, x, w) + dist_copy = deepcopy(dist) + allocs = @allocated fit!(dist_copy, x, w) + @test allocs == 0 +end + @testset "LightCategorical" begin p = rand_prob_vec(10) dist = LightCategorical(p) - x = [rand(dist) for _ in 1:1_000_000] + x = [rand(dist) for _ in 1:100_000] # Simulation val_count = zeros(Int, length(p)) for k in x @@ -15,24 +21,24 @@ using Test # Fitting dist_est = deepcopy(dist) w = ones(length(x)) - allocs = @allocated fit!(dist_est, x, w) - @test_broken allocs == 0 + fit!(dist_est, x, w) @test dist_est.p ≈ p atol = 1e-2 + test_fit_allocs(dist, x, w) end @testset "LightDiagNormal" begin μ = randn(10) σ = rand(10) dist = LightDiagNormal(μ, σ) - x = [rand(dist) for _ in 1:1_000_000] + x = [rand(dist) for _ in 1:100_000] # Simulation @test mean(x) ≈ μ atol = 1e-2 @test std(x) ≈ σ atol = 1e-2 # Fitting dist_est = deepcopy(dist) w = ones(length(x)) - allocs = @allocated fit!(dist_est, x, w) - @test_broken allocs == 0 + fit!(dist_est, x, w) @test dist_est.μ ≈ μ atol = 1e-2 @test dist_est.σ ≈ σ atol = 1e-2 + test_fit_allocs(dist, x, w) end diff --git a/test/runtests.jl b/test/runtests.jl index 7b5e95ee..3a1588b9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,12 +5,14 @@ using JuliaFormatter: JuliaFormatter using JET: JET using Test -@testset verbose = true "HiddenMarkovModels.jl" begin - @testset "Code formatting" begin - @test JuliaFormatter.format(HiddenMarkovModels; verbose=false, overwrite=false) - end +examples_path = joinpath(dirname(@__DIR__), "examples") +@testset verbose = true "HiddenMarkovModels.jl" begin if VERSION >= v"1.9" + @testset "Code formatting" begin + @test JuliaFormatter.format(HiddenMarkovModels; verbose=false, overwrite=false) + end + @testset "Code quality" begin Aqua.test_all(HiddenMarkovModels; deps_compat=(check_extras=false,)) end @@ -36,6 +38,12 @@ using Test include("autodiff.jl") end + for file in readdir(examples_path) + @testset "Example - $file" begin + include(joinpath(examples_path, file)) + end + end + @testset "Doctests" begin Documenter.doctest(HiddenMarkovModels) end diff --git a/test/types_allocations.jl b/test/types_allocations.jl index 2ccaffb1..74f67585 100644 --- a/test/types_allocations.jl +++ b/test/types_allocations.jl @@ -38,8 +38,8 @@ function test_type_stability(hmm::AbstractHMM; T::Integer) end @testset "Baum-Welch" begin - @test_opt target_modules = (HMMs,) baum_welch(hmm, obs_seq) - @test_call target_modules = (HMMs,) baum_welch(hmm, obs_seq) + @test_opt target_modules = (HMMs,) baum_welch(hmm, obs_seq; max_iterations=1) + @test_call target_modules = (HMMs,) baum_welch(hmm, obs_seq; max_iterations=1) end end @@ -70,7 +70,7 @@ function test_allocations(hmm::AbstractHMM; T::Integer, nb_seqs::Integer) HMMs.initialize_forward_backward(hmm, obs_seqs[k]) for k in eachindex(obs_seqs) ] bw_storage = HMMs.initialize_baum_welch(hmm, fb_storages, obs_seqs) - logL_evolution = HMMs.initialize_logL_evolution(hmm, obs_seqs; max_iterations=2) + logL_evolution = HMMs.initialize_logL_evolution(hmm, obs_seqs; max_iterations=1) allocs = @allocated HMMs.baum_welch!( hmm, fb_storages, @@ -78,14 +78,13 @@ function test_allocations(hmm::AbstractHMM; T::Integer, nb_seqs::Integer) logL_evolution, obs_seqs; atol=-Inf, - max_iterations=2, + max_iterations=1, loglikelihood_increasing=false, ) @test allocs == 0 end -N, D, T, nb_seqs = 2, 2, 100, 3 -R = Float32 +N, D, T, nb_seqs, R = 3, 2, 100, 5, Float32 ## Distributions