Skip to content

Commit

Permalink
Split better
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Sep 29, 2024
1 parent b9d79a4 commit 27b6f3e
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 26 deletions.
2 changes: 1 addition & 1 deletion examples/controlled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Simulation requires a vector of controls, each being a vector itself with the ri
Let us build several sequences of variable lengths.
=#

control_seqs = [[randn(rng, d) for t in 1:rand(100:200)] for k in 1:300];
control_seqs = [[randn(rng, d) for t in 1:rand(100:200)] for k in 1:1000];
obs_seqs = [rand(rng, hmm, control_seq).obs_seq for control_seq in control_seqs];

obs_seq = reduce(vcat, obs_seqs)
Expand Down
2 changes: 1 addition & 1 deletion examples/temporal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ obs_seq'
We now generate several sequences of variable lengths, for inference and learning tasks.
=#

control_seqs = [1:rand(rng, 300:500) for k in 1:1000]
control_seqs = [1:rand(rng, 100:200) for k in 1:1000]
obs_seqs = [rand(rng, hmm, control_seqs[k]).obs_seq for k in eachindex(control_seqs)];

obs_seq = reduce(vcat, obs_seqs)
Expand Down
2 changes: 1 addition & 1 deletion libs/HMMTest/src/coherence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ function test_coherent_algorithms(

if !isnothing(hmm_guess)
hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends)
@test all(>=(-0.5e-5), diff(logL_evolution))
@test all(>=(0), diff(logL_evolution))
test_equal_hmms(hmm, hmm_guess, control_seq[1:2]; atol, init, flip=true)
test_equal_hmms(hmm, hmm_est, control_seq[1:2]; atol, init)
end
Expand Down
65 changes: 42 additions & 23 deletions test/correctness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ seq_ends = cumsum(length.(control_seqs));
hmm = HMM(init, trans, dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

TEST_SUITE == "HMMBase" && test_identical_hmmbase(rng, hmm, T; hmm_guess)
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
if TEST_SUITE == "HMMBase"
test_identical_hmmbase(rng, hmm, T; hmm_guess)
else
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
end
end

@testset "DiagNormal" begin
Expand All @@ -57,9 +60,12 @@ end
hmm = HMM(init, trans, dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

TEST_SUITE == "HMMBase" && test_identical_hmmbase(rng, hmm, T; hmm_guess)
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
if TEST_SUITE == "HMMBase"
test_identical_hmmbase(rng, hmm, T; hmm_guess)
else
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
end
end

@testset "LightCategorical" begin
Expand All @@ -69,9 +75,11 @@ end
hmm = HMM(init, trans, dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
if TEST_SUITE != "HMMBase"
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
end
end

@testset "LightDiagNormal" begin
Expand All @@ -81,9 +89,11 @@ end
hmm = HMM(init, trans, dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
if TEST_SUITE != "HMMBase"
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
end
end

@testset "Normal (sparse)" begin
Expand All @@ -93,10 +103,13 @@ end
hmm = HMM(init, sparse(trans), dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

TEST_SUITE == "HMMBase" && test_identical_hmmbase(rng, hmm, T; hmm_guess)
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
@test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
if TEST_SUITE == "HMMBase"
test_identical_hmmbase(rng, hmm, T; hmm_guess)
else
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
@test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
end
end

@testset "Normal transposed" begin # issue 99
Expand All @@ -106,10 +119,13 @@ end
hmm = transpose_hmm(HMM(init, trans, dists))
hmm_guess = transpose_hmm(HMM(init_guess, trans_guess, dists_guess))

TEST_SUITE == "HMMBase" && test_identical_hmmbase(rng, hmm, T; hmm_guess)
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
if TEST_SUITE == "HMMBase"
test_identical_hmmbase(rng, hmm, T; hmm_guess)
else
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
end
end

@testset "Normal and Exponential" begin # issue 101
Expand All @@ -119,6 +135,9 @@ end
hmm = HMM(init, trans, dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

TEST_SUITE == "HMMBase" && test_identical_hmmbase(rng, hmm, T; hmm_guess)
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
if TEST_SUITE == "HMMBase"
test_identical_hmmbase(rng, hmm, T; hmm_guess)
else
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
end
end

0 comments on commit 27b6f3e

Please sign in to comment.