From af80816d9c9cd5fcb5c7a0f68e49389f59bcd4e7 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 27 May 2024 08:55:22 +0200 Subject: [PATCH] Allow different type for elementwise log --- examples/basics.jl | 1 + libs/HMMTest/src/HMMTest.jl | 2 ++ libs/HMMTest/src/utils.jl | 8 ++++++++ src/types/hmm.jl | 20 +++++++++++++++----- 4 files changed, 26 insertions(+), 5 deletions(-) create mode 100644 libs/HMMTest/src/utils.jl diff --git a/examples/basics.jl b/examples/basics.jl index 13f1b45f..4e9a8768 100644 --- a/examples/basics.jl +++ b/examples/basics.jl @@ -259,5 +259,6 @@ hcat(initialization(hmm_est_concat), initialization(hmm)) control_seq = fill(nothing, last(seq_ends)); #src test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src +test_identical_hmmbase(rng, transpose_hmm(hmm), 100; transpose_hmm(hmm_guess)) #src test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess) #src test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/libs/HMMTest/src/HMMTest.jl b/libs/HMMTest/src/HMMTest.jl index fb1ebdb1..26951ceb 100644 --- a/libs/HMMTest/src/HMMTest.jl +++ b/libs/HMMTest/src/HMMTest.jl @@ -9,11 +9,13 @@ using Random: AbstractRNG using Statistics: mean using Test: @test, @testset, @test_broken +export transpose_hmm export test_equal_hmms, test_coherent_algorithms export test_identical_hmmbase export test_allocations export test_type_stability +include("utils.jl") include("coherence.jl") include("allocations.jl") include("hmmbase.jl") diff --git a/libs/HMMTest/src/utils.jl b/libs/HMMTest/src/utils.jl new file mode 100644 index 00000000..bfe7a5f6 --- /dev/null +++ b/libs/HMMTest/src/utils.jl @@ -0,0 +1,8 @@ +function transpose_hmm(hmm::HMM) + init = initial_distribution(hmm) + trans = transition_matrix(hmm) + dists = obs_distributions(hmm) + trans_transpose = transpose(convert(typeof(trans), transpose(trans))) + @assert trans_transpose == trans + return HMM(init, trans_transpose, dists) +end diff --git a/src/types/hmm.jl b/src/types/hmm.jl index a9310798..9de3cf2e 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -7,7 +7,13 @@ Basic implementation of an HMM. $(TYPEDFIELDS) """ -struct HMM{V<:AbstractVector,M<:AbstractMatrix,VD<:AbstractVector} <: AbstractHMM +struct HMM{ + V<:AbstractVector, + M<:AbstractMatrix, + VD<:AbstractVector, + Vl<:AbstractVector, + Ml<:AbstractMatrix, +} <: AbstractHMM "initial state probabilities" init::V "state transition probabilities" @@ -15,13 +21,17 @@ struct HMM{V<:AbstractVector,M<:AbstractMatrix,VD<:AbstractVector} <: AbstractHM "observation distributions" dists::VD "logarithms of initial state probabilities" - loginit::V + loginit::Vl "logarithms of state transition probabilities" - logtrans::M + logtrans::Ml function HMM(init::AbstractVector, trans::AbstractMatrix, dists::AbstractVector) - hmm = new{typeof(init),typeof(trans),typeof(dists)}( - init, trans, dists, elementwise_log(init), elementwise_log(trans) + log_init = elementwise_log(init) + log_trans = elementwise_log(trans) + hmm = new{ + typeof(init),typeof(trans),typeof(dists),typeof(log_init),typeof(log_trans) + }( + init, trans, dists, log_init, log_trans ) @argcheck valid_hmm(hmm) return hmm