From b0603d1acc492e64bcb1304a65ee16a32e195c9a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 29 May 2024 07:10:53 +0200 Subject: [PATCH] Allow heterogeneous distributions --- Project.toml | 2 +- examples/basics.jl | 1 - ext/HiddenMarkovModelsDistributionsExt.jl | 21 ++++++++++-------- test/correctness.jl | 26 +++++++++++++++++++++++ 4 files changed, 39 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index 8c7d3e68..525b5d01 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "HiddenMarkovModels" uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47" authors = ["Guillaume Dalle"] -version = "0.5.2" +version = "0.5.3" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/examples/basics.jl b/examples/basics.jl index 17f1eb18..3594bdab 100644 --- a/examples/basics.jl +++ b/examples/basics.jl @@ -261,4 +261,3 @@ control_seq = fill(nothing, last(seq_ends)); #src test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess) #src test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src -test_identical_hmmbase(rng, transpose_hmm(hmm), 100; hmm_guess=transpose_hmm(hmm_guess)) #src diff --git a/ext/HiddenMarkovModelsDistributionsExt.jl b/ext/HiddenMarkovModelsDistributionsExt.jl index d0ebe78e..7bb16d01 100644 --- a/ext/HiddenMarkovModelsDistributionsExt.jl +++ b/ext/HiddenMarkovModelsDistributionsExt.jl @@ -10,27 +10,30 @@ using Distributions: fit function HiddenMarkovModels.fit_in_sequence!( - dists::AbstractVector{D}, i::Integer, x_nums::AbstractVector, w::AbstractVector -) where {D<:UnivariateDistribution} - return dists[i] = fit(D, x_nums, w) + dists::AbstractVector{<:UnivariateDistribution}, + i::Integer, + x_nums::AbstractVector, + w::AbstractVector, +) + return dists[i] = fit(typeof(dists[i]), x_nums, w) end function HiddenMarkovModels.fit_in_sequence!( - dists::AbstractVector{D}, + dists::AbstractVector{<:MultivariateDistribution}, i::Integer, x_vecs::AbstractVector{<:AbstractVector}, w::AbstractVector, -) where {D<:MultivariateDistribution} - return dists[i] = fit(D, reduce(hcat, x_vecs), w) +) + return dists[i] = fit(typeof(dists[i]), reduce(hcat, x_vecs), w) end function HiddenMarkovModels.fit_in_sequence!( - dists::AbstractVector{D}, + dists::AbstractVector{<:MatrixDistribution}, i::Integer, x_mats::AbstractVector{<:AbstractMatrix}, w::AbstractVector, -) where {D<:MatrixDistribution} - return dists[i] = fit(D, reduce(dcat, x_mats), w) +) + return dists[i] = fit(typeof(dists[i]), reduce(dcat, x_mats), w) end dcat(M1, M2) = cat(M1, M2; dims=3) diff --git a/test/correctness.jl b/test/correctness.jl index 3ebe2b5d..5e8b9901 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -98,3 +98,29 @@ end test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) @test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) end + +@testset "Normal transposed" begin # issue 99 + dists = [Normal(μ[1][1]), Normal(μ[2][1])] + dists_guess = [Normal(μ_guess[1][1]), Normal(μ_guess[2][1])] + + hmm = transpose_hmm(HMM(init, trans, dists)) + hmm_guess = transpose_hmm(HMM(init_guess, trans_guess, dists_guess)) + + test_identical_hmmbase(rng, hmm, T; hmm_guess) + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) +end + +@testset "Normal and Laplace" begin # issue 101 + dists = [Normal(μ[1][1]), Laplace(μ[2][1])] + dists_guess = [Normal(μ_guess[1][1]), Laplace(μ_guess[2][1])] + + hmm = HMM(init, trans, dists) + hmm_guess = HMM(init_guess, trans_guess, dists_guess) + + test_identical_hmmbase(rng, hmm, T; hmm_guess) + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) +end