Skip to content

Commit

Permalink
Allow different type for elementwise log
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed May 27, 2024
1 parent 6bfb23a commit af80816
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 5 deletions.
1 change: 1 addition & 0 deletions examples/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,5 +259,6 @@ hcat(initialization(hmm_est_concat), initialization(hmm))

control_seq = fill(nothing, last(seq_ends)); #src
test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src
test_identical_hmmbase(rng, transpose_hmm(hmm), 100; transpose_hmm(hmm_guess)) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
2 changes: 2 additions & 0 deletions libs/HMMTest/src/HMMTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ using Random: AbstractRNG
using Statistics: mean
using Test: @test, @testset, @test_broken

export transpose_hmm
export test_equal_hmms, test_coherent_algorithms
export test_identical_hmmbase
export test_allocations
export test_type_stability

include("utils.jl")
include("coherence.jl")
include("allocations.jl")
include("hmmbase.jl")
Expand Down
8 changes: 8 additions & 0 deletions libs/HMMTest/src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
function transpose_hmm(hmm::HMM)
init = initial_distribution(hmm)
trans = transition_matrix(hmm)
dists = obs_distributions(hmm)
trans_transpose = transpose(convert(typeof(trans), transpose(trans)))
@assert trans_transpose == trans
return HMM(init, trans_transpose, dists)
end
20 changes: 15 additions & 5 deletions src/types/hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,31 @@ Basic implementation of an HMM.
$(TYPEDFIELDS)
"""
struct HMM{V<:AbstractVector,M<:AbstractMatrix,VD<:AbstractVector} <: AbstractHMM
struct HMM{
V<:AbstractVector,
M<:AbstractMatrix,
VD<:AbstractVector,
Vl<:AbstractVector,
Ml<:AbstractMatrix,
} <: AbstractHMM
"initial state probabilities"
init::V
"state transition probabilities"
trans::M
"observation distributions"
dists::VD
"logarithms of initial state probabilities"
loginit::V
loginit::Vl
"logarithms of state transition probabilities"
logtrans::M
logtrans::Ml

function HMM(init::AbstractVector, trans::AbstractMatrix, dists::AbstractVector)
hmm = new{typeof(init),typeof(trans),typeof(dists)}(
init, trans, dists, elementwise_log(init), elementwise_log(trans)
log_init = elementwise_log(init)
log_trans = elementwise_log(trans)
hmm = new{
typeof(init),typeof(trans),typeof(dists),typeof(log_init),typeof(log_trans)
}(
init, trans, dists, log_init, log_trans
)
@argcheck valid_hmm(hmm)
return hmm
Expand Down

0 comments on commit af80816

Please sign in to comment.