From d68fcf5cb6c1d504b3f1b58baf7505f55aad7898 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 30 Sep 2024 09:58:45 +0200 Subject: [PATCH 1/2] Improve code coverage --- examples/basics.jl | 2 ++ examples/temporal.jl | 1 + ext/HiddenMarkovModelsDistributionsExt.jl | 5 ++++ src/inference/forward.jl | 1 - src/inference/viterbi.jl | 2 -- src/types/hmm.jl | 4 ---- test/distributions.jl | 28 ++++++++++++++++++++++- 7 files changed, 35 insertions(+), 8 deletions(-) diff --git a/examples/basics.jl b/examples/basics.jl index fbb48bcd..d9d4357b 100644 --- a/examples/basics.jl +++ b/examples/basics.jl @@ -257,6 +257,8 @@ hcat(initialization(hmm_est_concat), initialization(hmm)) # ## Tests #src +@test startswith(string(hmm), "Hidden") #src +@test length.(values(rand(hmm, T))) == (T, T); #src control_seq = fill(nothing, last(seq_ends)); #src test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess) #src test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/examples/temporal.jl b/examples/temporal.jl index 1cad38f6..3343ac80 100644 --- a/examples/temporal.jl +++ b/examples/temporal.jl @@ -184,5 +184,6 @@ map(mean, hcat(obs_distributions(hmm_est, 2), obs_distributions(hmm, 2))) # ## Tests #src @test mean(obs_seqs[1][1:2:end]) < 0 < mean(obs_seqs[1][2:2:end]) #src +@test length.(values(rand(hmm, control_seq))) == (10, 10); #src test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/ext/HiddenMarkovModelsDistributionsExt.jl b/ext/HiddenMarkovModelsDistributionsExt.jl index 7bb16d01..8581ce84 100644 --- a/ext/HiddenMarkovModelsDistributionsExt.jl +++ b/ext/HiddenMarkovModelsDistributionsExt.jl @@ -27,6 +27,10 @@ function HiddenMarkovModels.fit_in_sequence!( return dists[i] = fit(typeof(dists[i]), reduce(hcat, x_vecs), w) end +#= + +# Matrix distribution fitting not supported by Distributions.jl at the moment + function HiddenMarkovModels.fit_in_sequence!( dists::AbstractVector{<:MatrixDistribution}, i::Integer, @@ -37,5 +41,6 @@ function HiddenMarkovModels.fit_in_sequence!( end dcat(M1, M2) = cat(M1, M2; dims=3) +=# end diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 290248fa..9b2ee492 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -39,7 +39,6 @@ struct ForwardBackwardStorage{R,M<:AbstractMatrix{R}} Bβ::Matrix{R} end -Base.eltype(::ForwardStorage{R}) where {R} = R Base.eltype(::ForwardBackwardStorage{R}) where {R} = R const ForwardOrForwardBackwardStorage{R} = Union{ diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index c212c9bb..f1758097 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -17,8 +17,6 @@ struct ViterbiStorage{R} ψ::Matrix{Int} end -Base.eltype(::ViterbiStorage{R}) where {R} = R - """ $(SIGNATURES) """ diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 6a71135a..60d7bf12 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -38,10 +38,6 @@ struct HMM{ end end -function Base.copy(hmm::HMM) - return HMM(copy(hmm.init), copy(hmm.trans), copy(hmm.dists)) -end - function Base.show(io::IO, hmm::HMM) return print( io, diff --git a/test/distributions.jl b/test/distributions.jl index dfbc8822..3c660c0d 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -1,5 +1,6 @@ using Distributions -using HiddenMarkovModels: LightCategorical, LightDiagNormal, logdensityof, rand_prob_vec +using HiddenMarkovModels: + LightCategorical, LightDiagNormal, logdensityof, rand_prob_vec, rand_trans_mat using LinearAlgebra using Statistics using StatsAPI: fit! @@ -8,6 +9,29 @@ using Test rng = StableRNG(63) +function test_randprobvec(p) + @test all(>=(0), p) + @test sum(p) ≈ 1 +end + +function test_randtransmat(A) + foreach(eachrow(A)) do p + test_randprobvec(p) + end +end + +@testset "Rand prob" begin + n = 10 + test_randprobvec(rand_prob_vec(n)) + test_randprobvec(rand_prob_vec(rng, n)) + test_randprobvec(rand_prob_vec(Float32, n)) + test_randprobvec(rand_prob_vec(rng, Float32, n)) + test_randtransmat(rand_trans_mat(n)) + test_randtransmat(rand_trans_mat(rng, n)) + test_randtransmat(rand_trans_mat(Float32, n)) + test_randtransmat(rand_trans_mat(rng, Float32, n)) +end + function test_fit_allocs(dist, x, w) dist_copy = deepcopy(dist) allocs = @allocated fit!(dist_copy, x, w) @@ -17,6 +41,7 @@ end @testset "LightCategorical" begin p = rand_prob_vec(rng, 10) dist = LightCategorical(p) + @test startswith(string(dist), "LightCategorical") x = [(@inferred rand(rng, dist)) for _ in 1:100_000] # Simulation val_count = zeros(Int, length(p)) @@ -38,6 +63,7 @@ end μ = randn(rng, 10) σ = rand(rng, 10) dist = LightDiagNormal(μ, σ) + @test startswith(string(dist), "LightDiagNormal") x = [(@inferred rand(rng, dist)) for _ in 1:100_000] # Simulation @test mean(x) ≈ μ atol = 2e-2 From 1cfbfadaad42dfec506a2afbd28892200ba47e1f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 30 Sep 2024 10:08:58 +0200 Subject: [PATCH 2/2] Fix --- examples/temporal.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/temporal.jl b/examples/temporal.jl index 3343ac80..1cad38f6 100644 --- a/examples/temporal.jl +++ b/examples/temporal.jl @@ -184,6 +184,5 @@ map(mean, hcat(obs_distributions(hmm_est, 2), obs_distributions(hmm, 2))) # ## Tests #src @test mean(obs_seqs[1][1:2:end]) < 0 < mean(obs_seqs[1][2:2:end]) #src -@test length.(values(rand(hmm, control_seq))) == (10, 10); #src test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src