Skip to content

Commit

Permalink
Allow heterogeneous distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed May 29, 2024
1 parent 67934b1 commit b0603d1
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HiddenMarkovModels"
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
authors = ["Guillaume Dalle"]
version = "0.5.2"
version = "0.5.3"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
1 change: 0 additions & 1 deletion examples/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,4 +261,3 @@ control_seq = fill(nothing, last(seq_ends)); #src
test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
test_identical_hmmbase(rng, transpose_hmm(hmm), 100; hmm_guess=transpose_hmm(hmm_guess)) #src
21 changes: 12 additions & 9 deletions ext/HiddenMarkovModelsDistributionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,30 @@ using Distributions:
fit

function HiddenMarkovModels.fit_in_sequence!(
dists::AbstractVector{D}, i::Integer, x_nums::AbstractVector, w::AbstractVector
) where {D<:UnivariateDistribution}
return dists[i] = fit(D, x_nums, w)
dists::AbstractVector{<:UnivariateDistribution},
i::Integer,
x_nums::AbstractVector,
w::AbstractVector,
)
return dists[i] = fit(typeof(dists[i]), x_nums, w)
end

function HiddenMarkovModels.fit_in_sequence!(
dists::AbstractVector{D},
dists::AbstractVector{<:MultivariateDistribution},
i::Integer,
x_vecs::AbstractVector{<:AbstractVector},
w::AbstractVector,
) where {D<:MultivariateDistribution}
return dists[i] = fit(D, reduce(hcat, x_vecs), w)
)
return dists[i] = fit(typeof(dists[i]), reduce(hcat, x_vecs), w)
end

function HiddenMarkovModels.fit_in_sequence!(
dists::AbstractVector{D},
dists::AbstractVector{<:MatrixDistribution},
i::Integer,
x_mats::AbstractVector{<:AbstractMatrix},
w::AbstractVector,
) where {D<:MatrixDistribution}
return dists[i] = fit(D, reduce(dcat, x_mats), w)
)
return dists[i] = fit(typeof(dists[i]), reduce(dcat, x_mats), w)
end

dcat(M1, M2) = cat(M1, M2; dims=3)
Expand Down
26 changes: 26 additions & 0 deletions test/correctness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,29 @@ end
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
@test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
end

@testset "Normal transposed" begin # issue 99
dists = [Normal(μ[1][1]), Normal(μ[2][1])]
dists_guess = [Normal(μ_guess[1][1]), Normal(μ_guess[2][1])]

hmm = transpose_hmm(HMM(init, trans, dists))
hmm_guess = transpose_hmm(HMM(init_guess, trans_guess, dists_guess))

test_identical_hmmbase(rng, hmm, T; hmm_guess)
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
end

@testset "Normal and Laplace" begin # issue 101
dists = [Normal(μ[1][1]), Laplace(μ[2][1])]
dists_guess = [Normal(μ_guess[1][1]), Laplace(μ_guess[2][1])]

hmm = HMM(init, trans, dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

test_identical_hmmbase(rng, hmm, T; hmm_guess)
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
end

0 comments on commit b0603d1

Please sign in to comment.