Skip to content

Commit

Permalink
Allow different types for elementwise log (gdalle#100)
Browse files Browse the repository at this point in the history
* Allow different type for elementwise log

* Bump

* Fix kwarg

* Fix init

* Fix test and CI
  • Loading branch information
gdalle authored May 27, 2024
1 parent 6bfb23a commit 67934b1
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 15 deletions.
15 changes: 7 additions & 8 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,23 @@ concurrency:
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
name: Julia ${{ matrix.version }} - ${{ github.event_name }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
version:
- '1.9'
- '1'
os:
- ubuntu-latest
arch:
- x64
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
arch: x64
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HiddenMarkovModels"
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
authors = ["Guillaume Dalle"]
version = "0.5.1"
version = "0.5.2"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
3 changes: 2 additions & 1 deletion examples/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ This is important to keep in mind when testing new models.
In many applications, we have access to various observation sequences of different lengths.
=#

nb_seqs = 100
nb_seqs = 300
long_obs_seqs = [last(rand(rng, hmm, rand(rng, 100:200))) for k in 1:nb_seqs];
typeof(long_obs_seqs)

Expand Down Expand Up @@ -261,3 +261,4 @@ control_seq = fill(nothing, last(seq_ends)); #src
test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
test_identical_hmmbase(rng, transpose_hmm(hmm), 100; hmm_guess=transpose_hmm(hmm_guess)) #src
2 changes: 2 additions & 0 deletions libs/HMMTest/src/HMMTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ using Random: AbstractRNG
using Statistics: mean
using Test: @test, @testset, @test_broken

export transpose_hmm
export test_equal_hmms, test_coherent_algorithms
export test_identical_hmmbase
export test_allocations
export test_type_stability

include("utils.jl")
include("coherence.jl")
include("allocations.jl")
include("hmmbase.jl")
Expand Down
8 changes: 8 additions & 0 deletions libs/HMMTest/src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
function transpose_hmm(hmm::HMM)
init = initialization(hmm)
trans = transition_matrix(hmm)
dists = obs_distributions(hmm)
trans_transpose = transpose(convert(typeof(trans), transpose(trans)))
@assert trans_transpose == trans
return HMM(init, trans_transpose, dists)
end
20 changes: 15 additions & 5 deletions src/types/hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,31 @@ Basic implementation of an HMM.
$(TYPEDFIELDS)
"""
struct HMM{V<:AbstractVector,M<:AbstractMatrix,VD<:AbstractVector} <: AbstractHMM
struct HMM{
V<:AbstractVector,
M<:AbstractMatrix,
VD<:AbstractVector,
Vl<:AbstractVector,
Ml<:AbstractMatrix,
} <: AbstractHMM
"initial state probabilities"
init::V
"state transition probabilities"
trans::M
"observation distributions"
dists::VD
"logarithms of initial state probabilities"
loginit::V
loginit::Vl
"logarithms of state transition probabilities"
logtrans::M
logtrans::Ml

function HMM(init::AbstractVector, trans::AbstractMatrix, dists::AbstractVector)
hmm = new{typeof(init),typeof(trans),typeof(dists)}(
init, trans, dists, elementwise_log(init), elementwise_log(trans)
log_init = elementwise_log(init)
log_trans = elementwise_log(trans)
hmm = new{
typeof(init),typeof(trans),typeof(dists),typeof(log_init),typeof(log_trans)
}(
init, trans, dists, log_init, log_trans
)
@argcheck valid_hmm(hmm)
return hmm
Expand Down

0 comments on commit 67934b1

Please sign in to comment.