Skip to content

Commit

Permalink
Fix FB output
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Nov 9, 2023
1 parent 0142f6d commit 0681ec2
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions test/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ hmm_init = HMM(p, A, d_init);

obs_seq = rand(hmm, T).obs_seq;

γ, ξ, logL = forward_backward(hmm, obs_seq);
γ, logL = forward_backward(hmm, obs_seq);
hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq);

@testset "Sparse" begin
@test eltype(ξ) <: AbstractSparseArray
@test typeof(hmm_est) == typeof(hmm_init)
@test nnz(transition_matrix(hmm_est)) <= nnz(transition_matrix(hmm))
end
Expand All @@ -42,10 +41,9 @@ hmm = HMM(p, A, d);
hmm_init = HMM(p, A, d_init);
obs_seq = rand(hmm, T).obs_seq;

γ, ξ, logL = forward_backward(hmm, obs_seq);
γ, logL = forward_backward(hmm, obs_seq);
hmm_est, logL_evolution = @inferred baum_welch(hmm_init, obs_seq);

@testset "Static" begin
@test eltype(ξ) <: StaticArray
@test typeof(hmm_est) == typeof(hmm_init)
end

0 comments on commit 0681ec2

Please sign in to comment.