From 4f10b10f76e8325e8cfa9028defbb0ce64cfe193 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 11 Dec 2023 20:44:37 +0100 Subject: [PATCH] Fix tests 1.6 --- Project.toml | 22 +++++++++------------- examples/basics.jl | 2 +- ext/HiddenMarkovModelsHMMBaseExt.jl | 20 -------------------- ext/HiddenMarkovModelsSparseArraysExt.jl | 2 +- libs/HMMTest/src/hmmbase.jl | 16 +++++++++++++--- src/HiddenMarkovModels.jl | 18 ++++-------------- test/correctness.jl | 6 ++++-- test/distributions.jl | 3 ++- 8 files changed, 34 insertions(+), 55 deletions(-) delete mode 100644 ext/HiddenMarkovModelsHMMBaseExt.jl diff --git a/Project.toml b/Project.toml index 85591c30..ddebecce 100644 --- a/Project.toml +++ b/Project.toml @@ -6,24 +6,15 @@ version = "0.4.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" -StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" - -[weakdeps] -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[extensions] -HiddenMarkovModelsDistributionsExt = "Distributions" -HiddenMarkovModelsHMMBaseExt = "HMMBase" -HiddenMarkovModelsSparseArraysExt = "SparseArrays" +StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" [compat] ChainRulesCore = "1.16" @@ -31,13 +22,18 @@ DensityInterface = "0.4" Distributions = "0.25" DocStringExtensions = "0.9" FillArrays = "1" -HMMBase = "1" LinearAlgebra = "1" PrecompileTools = "1.1" Random = "1" -Requires = "1.3" SimpleUnPack = "1.1" SparseArrays = "1" StatsAPI = "1.6" julia = "1.6" +[extensions] +HiddenMarkovModelsDistributionsExt = "Distributions" +HiddenMarkovModelsSparseArraysExt = "SparseArrays" + +[weakdeps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/examples/basics.jl b/examples/basics.jl index a68d6b33..112a1d86 100644 --- a/examples/basics.jl +++ b/examples/basics.jl @@ -215,7 +215,7 @@ map(dist -> dist.μ, hcat(obs_distributions(hmm_est_concat), obs_distributions(h # ## Tests #src -control_seqs = [fill(nothing, rand(rng, 100:200)) for k in 1:200]; #src +control_seqs = [fill(nothing, rand(rng, 100:200)) for k in 1:500]; #src control_seq = reduce(vcat, control_seqs); #src seq_ends = cumsum(length.(control_seqs)); #src diff --git a/ext/HiddenMarkovModelsHMMBaseExt.jl b/ext/HiddenMarkovModelsHMMBaseExt.jl deleted file mode 100644 index fe9e66c8..00000000 --- a/ext/HiddenMarkovModelsHMMBaseExt.jl +++ /dev/null @@ -1,20 +0,0 @@ -module HiddenMarkovModelsHMMBaseExt - -using HiddenMarkovModels: HiddenMarkovModels -using HMMBase: HMMBase - -function HiddenMarkovModels.HMM(hmm_base::HMMBase.HMM) - init = deepcopy(hmm_base.a) - trans = deepcopy(hmm_base.A) - dists = deepcopy(hmm_base.B) - return HiddenMarkovModels.HMM(init, trans, dists) -end - -function HMMBase.HMM(hmm::HiddenMarkovModels.HMM) - a = deepcopy(hmm.init) - A = deepcopy(hmm.trans) - B = deepcopy(hmm.dists) - return HMMBase.HMM(a, A, B) -end - -end diff --git a/ext/HiddenMarkovModelsSparseArraysExt.jl b/ext/HiddenMarkovModelsSparseArraysExt.jl index 58f3fc65..0c9890d3 100644 --- a/ext/HiddenMarkovModelsSparseArraysExt.jl +++ b/ext/HiddenMarkovModelsSparseArraysExt.jl @@ -1,6 +1,6 @@ module HiddenMarkovModelsSparseArraysExt -using HiddenMarkovModels +using HiddenMarkovModels: HiddenMarkovModels using SparseArrays HiddenMarkovModels.mynonzeros(x::AbstractSparseArray) = nonzeros(x) diff --git a/libs/HMMTest/src/hmmbase.jl b/libs/HMMTest/src/hmmbase.jl index bebc7d53..76ba0731 100644 --- a/libs/HMMTest/src/hmmbase.jl +++ b/libs/HMMTest/src/hmmbase.jl @@ -13,8 +13,7 @@ function test_identical_hmmbase( obs_seq = vcat(sim.obs_seq, sim.obs_seq) seq_ends = [length(sim.obs_seq), 2 * length(sim.obs_seq)] - hmm_base = HMMBase.HMM(hmm) - hmm_guess_base = HMMBase.HMM(hmm_guess) + hmm_base = HMMBase.HMM(deepcopy(hmm.init), deepcopy(hmm.trans), deepcopy(hmm.dists)) logL_base = HMMBase.forward(hmm_base, obs_mat)[2] logL = logdensityof(hmm, obs_seq; seq_ends) @@ -35,6 +34,12 @@ function test_identical_hmmbase( @test isapprox(γ[:, 1:T], γ_base') && isapprox(γ[:, (T + 1):(2T)], γ_base') if !isnothing(hmm_guess) + hmm_guess_base = HMMBase.HMM( + deepcopy(hmm_guess.init), + deepcopy(hmm_guess.trans), + deepcopy(hmm_guess.dists), + ) + hmm_est_base, hist_base = HMMBase.fit_mle( hmm_guess_base, obs_mat; maxiter=10, tol=-Inf ) @@ -45,7 +50,12 @@ function test_identical_hmmbase( @test isapprox( logL_evolution[(begin + 1):end], 2 * logL_evolution_base[begin:(end - 1)] ) - test_equal_hmms(hmm_est, HMM(hmm_est_base); atol, init=true) + test_equal_hmms( + hmm_est, + HMM(hmm_est_base.a, hmm_est_base.A, hmm_est_base.B); + atol, + init=true, + ) end end end diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index ae49aa3d..0e481731 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -14,7 +14,6 @@ using FillArrays: Fill using LinearAlgebra: dot, ldiv!, lmul!, mul! using PrecompileTools: @compile_workload using Random: Random, AbstractRNG, default_rng -using Requires: @require using SimpleUnPack: @unpack using StatsAPI: StatsAPI, fit, fit! @@ -48,20 +47,11 @@ include("inference/chainrules.jl") include("types/hmm.jl") +include("precompile.jl") + if !isdefined(Base, :get_extension) - function __init__() - @require Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" begin - include("../ext/HiddenMarkovModelsDistributionsExt.jl") - end - @require HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7" begin - include("../ext/HiddenMarkovModelsHMMBaseExt.jl") - end - @require SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" begin - include("../ext/HiddenMarkovModelsSparseArraysExt.jl") - end - end + include("../ext/HiddenMarkovModelsDistributionsExt.jl") + include("../ext/HiddenMarkovModelsSparseArraysExt.jl") end -include("precompile.jl") - end diff --git a/test/correctness.jl b/test/correctness.jl index bed70147..8860d698 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -52,8 +52,10 @@ seq_ends = cumsum(length.(control_seqs)) end @testset "DiagNormal" begin - dists = [MvNormal(μ[1], I), MvNormal(μ[2], I)] - dists_guess = [MvNormal(μ_guess[1], I), MvNormal(μ_guess[2], I)] + dists = [MvNormal(μ[1], Diagonal(abs2.(σ))), MvNormal(μ[2], Diagonal(abs2.(σ)))] + dists_guess = [ + MvNormal(μ_guess[1], Diagonal(abs2.(σ))), MvNormal(μ_guess[2], Diagonal(abs2.(σ))) + ] hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) diff --git a/test/distributions.jl b/test/distributions.jl index 56a34fe1..dd161218 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -1,5 +1,6 @@ using Distributions using HiddenMarkovModels: LightCategorical, LightDiagNormal, logdensityof, rand_prob_vec +using LinearAlgebra using Statistics using Test @@ -46,5 +47,5 @@ end test_fit_allocs(dist, x, w) # Logdensity @test logdensityof(dist, x[1]) ≈ - logdensityof(MvNormal(μ, σ), x[1]) + length(x[1]) * log(sqrt(2π)) + logdensityof(MvNormal(μ, Diagonal(abs2.(σ))), x[1]) + length(x[1]) * log(sqrt(2π)) end