From 1dc2dfbd7c2a0763ec5a790ed6af17de6e14de66 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 29 Sep 2024 09:19:45 +0200 Subject: [PATCH 01/14] Put HMMBase in extension of HMMTest --- .github/workflows/test.yml | 15 +++++++++------ libs/HMMTest/Project.toml | 7 ++++++- .../{src/hmmbase.jl => ext/HMMTestHMMBaseExt.jl} | 14 +++++++++++++- libs/HMMTest/src/HMMTest.jl | 3 ++- 4 files changed, 30 insertions(+), 9 deletions(-) rename libs/HMMTest/{src/hmmbase.jl => ext/HMMTestHMMBaseExt.jl} (88%) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e4240b49..ace1c011 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -3,7 +3,7 @@ on: push: branches: - main - tags: ['*'] + tags: ["*"] pull_request: concurrency: # Skip intermediate builds: always. @@ -18,10 +18,13 @@ jobs: fail-fast: false matrix: version: - - '1.9' - - '1' - os: - - ubuntu-latest + - "1.9" + - "1" + test_suite: + - "normal" + - "hmmbase" + env: + JULIA_HMM_TEST_SUITE: ${{ matrix.test_suite }} steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -36,4 +39,4 @@ jobs: with: files: lcov.info token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: true \ No newline at end of file + fail_ci_if_error: true diff --git a/libs/HMMTest/Project.toml b/libs/HMMTest/Project.toml index 5e576d83..90312674 100644 --- a/libs/HMMTest/Project.toml +++ b/libs/HMMTest/Project.toml @@ -5,9 +5,14 @@ version = "0.1.0" [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7" HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[weakdeps] +HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7" + +[extensions] +HMMTestHMMBaseExt = "HMMBase" \ No newline at end of file diff --git a/libs/HMMTest/src/hmmbase.jl b/libs/HMMTest/ext/HMMTestHMMBaseExt.jl similarity index 88% rename from libs/HMMTest/src/hmmbase.jl rename to libs/HMMTest/ext/HMMTestHMMBaseExt.jl index 808e2e0c..e397cf8f 100644 --- a/libs/HMMTest/src/hmmbase.jl +++ b/libs/HMMTest/ext/HMMTestHMMBaseExt.jl @@ -1,5 +1,14 @@ +module HMMTestHMMBaseExt -function test_identical_hmmbase( +using HiddenMarkovModels +import HiddenMarkovModels as HMMs +using HMMBase: HMMBase +using HMMTest: HMMTest +using Random: AbstractRNG +using Statistics: mean +using Test: @test, @testset, @test_broken + +function HMMTest.test_identical_hmmbase( rng::AbstractRNG, hmm::AbstractHMM, T::Integer; @@ -54,3 +63,6 @@ function test_identical_hmmbase( end end end + + +end diff --git a/libs/HMMTest/src/HMMTest.jl b/libs/HMMTest/src/HMMTest.jl index 26951ceb..7a9a07ab 100644 --- a/libs/HMMTest/src/HMMTest.jl +++ b/libs/HMMTest/src/HMMTest.jl @@ -9,6 +9,8 @@ using Random: AbstractRNG using Statistics: mean using Test: @test, @testset, @test_broken +function test_identical_hmmbase end # in extension + export transpose_hmm export test_equal_hmms, test_coherent_algorithms export test_identical_hmmbase @@ -18,7 +20,6 @@ export test_type_stability include("utils.jl") include("coherence.jl") include("allocations.jl") -include("hmmbase.jl") include("jet.jl") end From fcc0ad8662ade5a85b3c98fad160715e2d361fb1 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 29 Sep 2024 10:59:27 +0200 Subject: [PATCH 02/14] Split test suite --- .github/workflows/test.yml | 6 +-- libs/HMMTest/ext/HMMTestHMMBaseExt.jl | 1 - test/correctness.jl | 10 ++--- test/runtests.jl | 57 +++++++++++++++------------ 4 files changed, 40 insertions(+), 34 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ace1c011..975ccb9b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -12,7 +12,7 @@ concurrency: cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: test: - name: Julia ${{ matrix.version }} - ${{ github.event_name }} + name: Julia ${{ matrix.version }} - ${{ matrix.test_suite }} runs-on: ubuntu-latest strategy: fail-fast: false @@ -21,8 +21,8 @@ jobs: - "1.9" - "1" test_suite: - - "normal" - - "hmmbase" + - "Standard" + - "HMMBase" env: JULIA_HMM_TEST_SUITE: ${{ matrix.test_suite }} steps: diff --git a/libs/HMMTest/ext/HMMTestHMMBaseExt.jl b/libs/HMMTest/ext/HMMTestHMMBaseExt.jl index e397cf8f..f3c1c379 100644 --- a/libs/HMMTest/ext/HMMTestHMMBaseExt.jl +++ b/libs/HMMTest/ext/HMMTestHMMBaseExt.jl @@ -64,5 +64,4 @@ function HMMTest.test_identical_hmmbase( end end - end diff --git a/test/correctness.jl b/test/correctness.jl index cb139823..7b2706d1 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -42,7 +42,7 @@ seq_ends = cumsum(length.(control_seqs)); hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_identical_hmmbase(rng, hmm, T; hmm_guess) + TEST_SUITE == "HMMBase" && test_identical_hmmbase(rng, hmm, T; hmm_guess) test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) @@ -57,7 +57,7 @@ end hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_identical_hmmbase(rng, hmm, T; hmm_guess) + TEST_SUITE == "HMMBase" && test_identical_hmmbase(rng, hmm, T; hmm_guess) test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) end @@ -93,7 +93,7 @@ end hmm = HMM(init, sparse(trans), dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_identical_hmmbase(rng, hmm, T; hmm_guess) + TEST_SUITE == "HMMBase" && test_identical_hmmbase(rng, hmm, T; hmm_guess) test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) @test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) @@ -106,7 +106,7 @@ end hmm = transpose_hmm(HMM(init, trans, dists)) hmm_guess = transpose_hmm(HMM(init_guess, trans_guess, dists_guess)) - test_identical_hmmbase(rng, hmm, T; hmm_guess) + TEST_SUITE == "HMMBase" && test_identical_hmmbase(rng, hmm, T; hmm_guess) test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) @@ -119,6 +119,6 @@ end hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_identical_hmmbase(rng, hmm, T; hmm_guess) + TEST_SUITE == "HMMBase" && test_identical_hmmbase(rng, hmm, T; hmm_guess) test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) end diff --git a/test/runtests.jl b/test/runtests.jl index 96b2f0fa..662cfc9e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,41 +6,48 @@ using JuliaFormatter: JuliaFormatter using Pkg using Test +TEST_SUITE = get(ENV, "JULIA_HMM_TEST_SUITE", "Standard") +if TEST_SUITE == "HMMBase" + Pkg.add("HMMBase") +end + Pkg.develop(; path=joinpath(dirname(@__DIR__), "libs", "HMMTest")) @testset verbose = true "HiddenMarkovModels.jl" begin - @testset "Code formatting" begin - @test JuliaFormatter.format(HiddenMarkovModels; verbose=false, overwrite=false) - end + if TEST_SUITE == "Standard" + @testset "Code formatting" begin + @test JuliaFormatter.format(HiddenMarkovModels; verbose=false, overwrite=false) + end - @testset "Code quality" begin - Aqua.test_all( - HiddenMarkovModels; ambiguities=false, deps_compat=(check_extras=false,) - ) - end + @testset "Code quality" begin + Aqua.test_all( + HiddenMarkovModels; ambiguities=false, deps_compat=(check_extras=false,) + ) + end - @testset "Code linting" begin - using Distributions - using Zygote - JET.test_package(HiddenMarkovModels; target_defined_modules=true) - end + @testset "Code linting" begin + using Distributions + using Zygote + JET.test_package(HiddenMarkovModels; target_defined_modules=true) + end - @testset "Distributions" begin - include("distributions.jl") - end + @testset "Distributions" begin + include("distributions.jl") + end - @testset "Correctness" begin - include("correctness.jl") - end + examples_path = joinpath(dirname(@__DIR__), "examples") + for file in readdir(examples_path) + @testset "Example - $file" begin + include(joinpath(examples_path, file)) + end + end - examples_path = joinpath(dirname(@__DIR__), "examples") - for file in readdir(examples_path) - @testset "Example - $file" begin - include(joinpath(examples_path, file)) + @testset "Doctests" begin + Documenter.doctest(HiddenMarkovModels) end end - @testset "Doctests" begin - Documenter.doctest(HiddenMarkovModels) + @testset "Correctness - $TEST_SUITE" begin + include("correctness.jl") end end From f1a9d1cdd3469db15980d642d09a3ca033693e6b Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 29 Sep 2024 11:01:46 +0200 Subject: [PATCH 03/14] Using --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index a0557db5..e41cf6ae 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,7 @@ using Test TEST_SUITE = get(ENV, "JULIA_HMM_TEST_SUITE", "Standard") if TEST_SUITE == "HMMBase" Pkg.add("HMMBase") + using HMMBase: HMMBase end Pkg.develop(; path=joinpath(dirname(@__DIR__), "libs", "HMMTest")) From 5c25d714f81d2690686f42c484cf46f61f31a027 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 29 Sep 2024 11:07:17 +0200 Subject: [PATCH 04/14] Fix --- libs/HMMTest/src/HMMTest.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/HMMTest/src/HMMTest.jl b/libs/HMMTest/src/HMMTest.jl index 7a9a07ab..aaeedb40 100644 --- a/libs/HMMTest/src/HMMTest.jl +++ b/libs/HMMTest/src/HMMTest.jl @@ -3,7 +3,6 @@ module HMMTest using BenchmarkTools: @ballocated using HiddenMarkovModels import HiddenMarkovModels as HMMs -using HMMBase: HMMBase using JET: @test_opt, @test_call using Random: AbstractRNG using Statistics: mean From 5c1d5c6599a14dbde5688f3ac311373da093a837 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 29 Sep 2024 11:34:21 +0200 Subject: [PATCH 05/14] Fixes --- examples/basics.jl | 3 +-- examples/types.jl | 3 +-- libs/HMMTest/ext/HMMTestHMMBaseExt.jl | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/basics.jl b/examples/basics.jl index 3594bdab..fbb48bcd 100644 --- a/examples/basics.jl +++ b/examples/basics.jl @@ -189,7 +189,7 @@ This is important to keep in mind when testing new models. In many applications, we have access to various observation sequences of different lengths. =# -nb_seqs = 300 +nb_seqs = 1000 long_obs_seqs = [last(rand(rng, hmm, rand(rng, 100:200))) for k in 1:nb_seqs]; typeof(long_obs_seqs) @@ -258,6 +258,5 @@ hcat(initialization(hmm_est_concat), initialization(hmm)) # ## Tests #src control_seq = fill(nothing, last(seq_ends)); #src -test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess) #src test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/examples/types.jl b/examples/types.jl index 0945be63..628155c7 100644 --- a/examples/types.jl +++ b/examples/types.jl @@ -156,9 +156,8 @@ Another useful array type is [StaticArrays.jl](https://github.com/JuliaArrays/St @test nnz(log_transition_matrix(hmm)) == nnz(transition_matrix(hmm)) #src -seq_ends = cumsum(rand(rng, 100:200, 100)); #src +seq_ends = cumsum(rand(rng, 100:200, 1000)); #src control_seq = fill(nothing, last(seq_ends)); #src -test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false, atol=0.08) #src test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src # https://github.com/JuliaSparse/SparseArrays.jl/issues/469 #src diff --git a/libs/HMMTest/ext/HMMTestHMMBaseExt.jl b/libs/HMMTest/ext/HMMTestHMMBaseExt.jl index f3c1c379..b13e7f64 100644 --- a/libs/HMMTest/ext/HMMTestHMMBaseExt.jl +++ b/libs/HMMTest/ext/HMMTestHMMBaseExt.jl @@ -3,7 +3,7 @@ module HMMTestHMMBaseExt using HiddenMarkovModels import HiddenMarkovModels as HMMs using HMMBase: HMMBase -using HMMTest: HMMTest +using HMMTest using Random: AbstractRNG using Statistics: mean using Test: @test, @testset, @test_broken From 40e3a758b85cf488bd25eec5d282536df49ff01c Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 29 Sep 2024 11:54:24 +0200 Subject: [PATCH 06/14] Increase test case sizes --- examples/controlled.jl | 4 ++-- examples/temporal.jl | 4 ++-- examples/types.jl | 2 +- test/correctness.jl | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/controlled.jl b/examples/controlled.jl index 1f2451b9..764d3939 100644 --- a/examples/controlled.jl +++ b/examples/controlled.jl @@ -66,7 +66,7 @@ Simulation requires a vector of controls, each being a vector itself with the ri Let us build several sequences of variable lengths. =# -control_seqs = [[randn(rng, d) for t in 1:rand(100:200)] for k in 1:100]; +control_seqs = [[randn(rng, d) for t in 1:rand(100:200)] for k in 1:300]; obs_seqs = [rand(rng, hmm, control_seq).obs_seq for control_seq in control_seqs]; obs_seq = reduce(vcat, obs_seqs) @@ -151,5 +151,5 @@ hcat(hmm_est.dist_coeffs[2], hmm.dist_coeffs[2]) @test hmm_est.dist_coeffs[1] ≈ hmm.dist_coeffs[1] atol = 0.05 #src @test hmm_est.dist_coeffs[2] ≈ hmm.dist_coeffs[2] atol = 0.05 #src -test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.08, init=false) #src +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/examples/temporal.jl b/examples/temporal.jl index 9c9549f4..2c59277b 100644 --- a/examples/temporal.jl +++ b/examples/temporal.jl @@ -76,7 +76,7 @@ obs_seq' We now generate several sequences of variable lengths, for inference and learning tasks. =# -control_seqs = [1:rand(rng, 100:200) for k in 1:1000] +control_seqs = [1:rand(rng, 300:500) for k in 1:1000] obs_seqs = [rand(rng, hmm, control_seqs[k]).obs_seq for k in eachindex(control_seqs)]; obs_seq = reduce(vcat, obs_seqs) @@ -184,5 +184,5 @@ map(mean, hcat(obs_distributions(hmm_est, 2), obs_distributions(hmm, 2))) # ## Tests #src @test mean(obs_seq[1:2:end]) < 0 < mean(obs_seq[2:2:end]) #src -test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.09, init=false) #src +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/examples/types.jl b/examples/types.jl index 628155c7..32901c46 100644 --- a/examples/types.jl +++ b/examples/types.jl @@ -158,7 +158,7 @@ Another useful array type is [StaticArrays.jl](https://github.com/JuliaArrays/St seq_ends = cumsum(rand(rng, 100:200, 1000)); #src control_seq = fill(nothing, last(seq_ends)); #src -test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false, atol=0.08) #src +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src # https://github.com/JuliaSparse/SparseArrays.jl/issues/469 #src @test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/test/correctness.jl b/test/correctness.jl index 7b2706d1..d495f460 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -13,7 +13,7 @@ rng = StableRNG(63) ## Settings -T, K = 50, 200 +T, K = 100, 200 init = [0.4, 0.6] init_guess = [0.5, 0.5] From 290035fbf3df09a72330f1bfcccba28b5a4c406a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:08:02 +0200 Subject: [PATCH 07/14] Fixes --- examples/autodiff.jl | 1 + test/correctness.jl | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/autodiff.jl b/examples/autodiff.jl index a1d96d10..fe0e6732 100644 --- a/examples/autodiff.jl +++ b/examples/autodiff.jl @@ -11,6 +11,7 @@ using Enzyme: Enzyme using ForwardDiff: ForwardDiff using HiddenMarkovModels import HiddenMarkovModels as HMMs +using HMMTest #src using LinearAlgebra using Random: Random, AbstractRNG using StableRNGs diff --git a/test/correctness.jl b/test/correctness.jl index d495f460..6195e5ee 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -13,7 +13,7 @@ rng = StableRNG(63) ## Settings -T, K = 100, 200 +T, K = 100, 500 init = [0.4, 0.6] init_guess = [0.5, 0.5] From b9d79a4f4dc6f41aca551643b2af22c9ea01eefe Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 29 Sep 2024 18:24:06 +0200 Subject: [PATCH 08/14] Fix --- examples/temporal.jl | 2 +- libs/HMMTest/src/coherence.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/temporal.jl b/examples/temporal.jl index 2c59277b..5beec75e 100644 --- a/examples/temporal.jl +++ b/examples/temporal.jl @@ -183,6 +183,6 @@ map(mean, hcat(obs_distributions(hmm_est, 2), obs_distributions(hmm, 2))) # ## Tests #src -@test mean(obs_seq[1:2:end]) < 0 < mean(obs_seq[2:2:end]) #src +@test mean(obs_seqs[1][1:2:end]) < 0 < mean(obs_seqs[1][2:2:end]) #src test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/libs/HMMTest/src/coherence.jl b/libs/HMMTest/src/coherence.jl index 2a39e336..0deb9fef 100644 --- a/libs/HMMTest/src/coherence.jl +++ b/libs/HMMTest/src/coherence.jl @@ -88,7 +88,7 @@ function test_coherent_algorithms( if !isnothing(hmm_guess) hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends) - @test all(>=(0), diff(logL_evolution)) + @test all(>=(-0.5e-5), diff(logL_evolution)) test_equal_hmms(hmm, hmm_guess, control_seq[1:2]; atol, init, flip=true) test_equal_hmms(hmm, hmm_est, control_seq[1:2]; atol, init) end From 27b6f3e47a8b71f23819a316e40cd94546d536e2 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 29 Sep 2024 18:54:22 +0200 Subject: [PATCH 09/14] Split better --- examples/controlled.jl | 2 +- examples/temporal.jl | 2 +- libs/HMMTest/src/coherence.jl | 2 +- test/correctness.jl | 65 ++++++++++++++++++++++------------- 4 files changed, 45 insertions(+), 26 deletions(-) diff --git a/examples/controlled.jl b/examples/controlled.jl index 764d3939..28ef8ebc 100644 --- a/examples/controlled.jl +++ b/examples/controlled.jl @@ -66,7 +66,7 @@ Simulation requires a vector of controls, each being a vector itself with the ri Let us build several sequences of variable lengths. =# -control_seqs = [[randn(rng, d) for t in 1:rand(100:200)] for k in 1:300]; +control_seqs = [[randn(rng, d) for t in 1:rand(100:200)] for k in 1:1000]; obs_seqs = [rand(rng, hmm, control_seq).obs_seq for control_seq in control_seqs]; obs_seq = reduce(vcat, obs_seqs) diff --git a/examples/temporal.jl b/examples/temporal.jl index 5beec75e..80216dda 100644 --- a/examples/temporal.jl +++ b/examples/temporal.jl @@ -76,7 +76,7 @@ obs_seq' We now generate several sequences of variable lengths, for inference and learning tasks. =# -control_seqs = [1:rand(rng, 300:500) for k in 1:1000] +control_seqs = [1:rand(rng, 100:200) for k in 1:1000] obs_seqs = [rand(rng, hmm, control_seqs[k]).obs_seq for k in eachindex(control_seqs)]; obs_seq = reduce(vcat, obs_seqs) diff --git a/libs/HMMTest/src/coherence.jl b/libs/HMMTest/src/coherence.jl index 0deb9fef..2a39e336 100644 --- a/libs/HMMTest/src/coherence.jl +++ b/libs/HMMTest/src/coherence.jl @@ -88,7 +88,7 @@ function test_coherent_algorithms( if !isnothing(hmm_guess) hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends) - @test all(>=(-0.5e-5), diff(logL_evolution)) + @test all(>=(0), diff(logL_evolution)) test_equal_hmms(hmm, hmm_guess, control_seq[1:2]; atol, init, flip=true) test_equal_hmms(hmm, hmm_est, control_seq[1:2]; atol, init) end diff --git a/test/correctness.jl b/test/correctness.jl index 6195e5ee..f0bdc651 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -42,10 +42,13 @@ seq_ends = cumsum(length.(control_seqs)); hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - TEST_SUITE == "HMMBase" && test_identical_hmmbase(rng, hmm, T; hmm_guess) - test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) - test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) - test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + if TEST_SUITE == "HMMBase" + test_identical_hmmbase(rng, hmm, T; hmm_guess) + else + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + end end @testset "DiagNormal" begin @@ -57,9 +60,12 @@ end hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - TEST_SUITE == "HMMBase" && test_identical_hmmbase(rng, hmm, T; hmm_guess) - test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) - test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + if TEST_SUITE == "HMMBase" + test_identical_hmmbase(rng, hmm, T; hmm_guess) + else + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + end end @testset "LightCategorical" begin @@ -69,9 +75,11 @@ end hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) - test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) - test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + if TEST_SUITE != "HMMBase" + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + end end @testset "LightDiagNormal" begin @@ -81,9 +89,11 @@ end hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) - test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) - test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + if TEST_SUITE != "HMMBase" + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + end end @testset "Normal (sparse)" begin @@ -93,10 +103,13 @@ end hmm = HMM(init, sparse(trans), dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - TEST_SUITE == "HMMBase" && test_identical_hmmbase(rng, hmm, T; hmm_guess) - test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) - test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) - @test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + if TEST_SUITE == "HMMBase" + test_identical_hmmbase(rng, hmm, T; hmm_guess) + else + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + @test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + end end @testset "Normal transposed" begin # issue 99 @@ -106,10 +119,13 @@ end hmm = transpose_hmm(HMM(init, trans, dists)) hmm_guess = transpose_hmm(HMM(init_guess, trans_guess, dists_guess)) - TEST_SUITE == "HMMBase" && test_identical_hmmbase(rng, hmm, T; hmm_guess) - test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) - test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) - test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + if TEST_SUITE == "HMMBase" + test_identical_hmmbase(rng, hmm, T; hmm_guess) + else + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + end end @testset "Normal and Exponential" begin # issue 101 @@ -119,6 +135,9 @@ end hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - TEST_SUITE == "HMMBase" && test_identical_hmmbase(rng, hmm, T; hmm_guess) - test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + if TEST_SUITE == "HMMBase" + test_identical_hmmbase(rng, hmm, T; hmm_guess) + else + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + end end From 1c6ee9a2662938140e984995aa4e635ae8449e07 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 29 Sep 2024 19:08:46 +0200 Subject: [PATCH 10/14] More fixes --- libs/HMMTest/src/coherence.jl | 1 + src/inference/baum_welch.jl | 2 +- test/correctness.jl | 14 +++++++------- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/libs/HMMTest/src/coherence.jl b/libs/HMMTest/src/coherence.jl index 2a39e336..f0a7fae2 100644 --- a/libs/HMMTest/src/coherence.jl +++ b/libs/HMMTest/src/coherence.jl @@ -88,6 +88,7 @@ function test_coherent_algorithms( if !isnothing(hmm_guess) hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends) + @show diff(logL_evolution) @test all(>=(0), diff(logL_evolution)) test_equal_hmms(hmm, hmm_guess, control_seq[1:2]; atol, init, flip=true) test_equal_hmms(hmm, hmm_est, control_seq[1:2]; atol, init) diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 1bc25665..3bec0ac2 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -73,7 +73,7 @@ function baum_welch( seq_ends, atol, max_iterations, - loglikelihood_increasing=false, + loglikelihood_increasing, ) return hmm, logL_evolution end diff --git a/test/correctness.jl b/test/correctness.jl index f0bdc651..8cbdb7b7 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -35,7 +35,7 @@ seq_ends = cumsum(length.(control_seqs)); ## Uncontrolled -@testset "Normal" begin +@testset verbose = true "Normal" begin dists = [Normal(μ[1][1]), Normal(μ[2][1])] dists_guess = [Normal(μ_guess[1][1]), Normal(μ_guess[2][1])] @@ -51,7 +51,7 @@ seq_ends = cumsum(length.(control_seqs)); end end -@testset "DiagNormal" begin +@testset verbose = true "DiagNormal" begin dists = [MvNormal(μ[1], Diagonal(abs2.(σ))), MvNormal(μ[2], Diagonal(abs2.(σ)))] dists_guess = [ MvNormal(μ_guess[1], Diagonal(abs2.(σ))), MvNormal(μ_guess[2], Diagonal(abs2.(σ))) @@ -68,7 +68,7 @@ end end end -@testset "LightCategorical" begin +@testset verbose = true "LightCategorical" begin dists = [LightCategorical(p[1]), LightCategorical(p[2])] dists_guess = [LightCategorical(p_guess[1]), LightCategorical(p_guess[2])] @@ -82,7 +82,7 @@ end end end -@testset "LightDiagNormal" begin +@testset verbose = true "LightDiagNormal" begin dists = [LightDiagNormal(μ[1], σ), LightDiagNormal(μ[2], σ)] dists_guess = [LightDiagNormal(μ_guess[1], σ), LightDiagNormal(μ_guess[2], σ)] @@ -96,7 +96,7 @@ end end end -@testset "Normal (sparse)" begin +@testset verbose = true "Normal (sparse)" begin dists = [Normal(μ[1][1]), Normal(μ[2][1])] dists_guess = [Normal(μ_guess[1][1]), Normal(μ_guess[2][1])] @@ -112,7 +112,7 @@ end end end -@testset "Normal transposed" begin # issue 99 +@testset verbose = true "Normal transposed" begin # issue 99 dists = [Normal(μ[1][1]), Normal(μ[2][1])] dists_guess = [Normal(μ_guess[1][1]), Normal(μ_guess[2][1])] @@ -128,7 +128,7 @@ end end end -@testset "Normal and Exponential" begin # issue 101 +@testset verbose = true "Normal and Exponential" begin # issue 101 dists = [Normal(μ[1][1]), Exponential(1.0)] dists_guess = [Normal(μ_guess[1][1]), Exponential(0.8)] From b646ee508f366475a932e475021a756a2dba105f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 29 Sep 2024 19:41:03 +0200 Subject: [PATCH 11/14] Rm show --- libs/HMMTest/src/coherence.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/HMMTest/src/coherence.jl b/libs/HMMTest/src/coherence.jl index f0a7fae2..2a39e336 100644 --- a/libs/HMMTest/src/coherence.jl +++ b/libs/HMMTest/src/coherence.jl @@ -88,7 +88,6 @@ function test_coherent_algorithms( if !isnothing(hmm_guess) hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends) - @show diff(logL_evolution) @test all(>=(0), diff(logL_evolution)) test_equal_hmms(hmm, hmm_guess, control_seq[1:2]; atol, init, flip=true) test_equal_hmms(hmm, hmm_est, control_seq[1:2]; atol, init) From b7a54ae4ca6c53f7964271ed68d85fdb1732611b Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 29 Sep 2024 21:04:28 +0200 Subject: [PATCH 12/14] Fix --- src/inference/baum_welch.jl | 2 +- test/correctness.jl | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 3791071c..279e6fa3 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -5,7 +5,7 @@ function baum_welch_has_converged( logL, logL_prev = logL_evolution[end], logL_evolution[end - 1] progress = logL - logL_prev if loglikelihood_increasing && progress < min(0, -atol) - error("Loglikelihood decreased in Baum-Welch") + error("Loglikelihood decreased from $logL_prev to $logL in Baum-Welch") elseif progress < atol return true end diff --git a/test/correctness.jl b/test/correctness.jl index 8cbdb7b7..1e4e1d84 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -9,11 +9,13 @@ using SparseArrays using StableRNGs using Test +TEST_SUITE = get(ENV, "JULIA_HMM_TEST_SUITE", "Standard") + rng = StableRNG(63) ## Settings -T, K = 100, 500 +T, K = 50, 200 init = [0.4, 0.6] init_guess = [0.5, 0.5] @@ -91,8 +93,8 @@ end if TEST_SUITE != "HMMBase" test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) - test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) - test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + # test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + # test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) end end From 26c099fe5b72aa6cbff4e9fb132b66665113bfa4 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 30 Sep 2024 08:59:42 +0200 Subject: [PATCH 13/14] Subtle change --- test/correctness.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/correctness.jl b/test/correctness.jl index 1e4e1d84..1bf560d0 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -15,7 +15,7 @@ rng = StableRNG(63) ## Settings -T, K = 50, 200 +T, K = 100, 200 init = [0.4, 0.6] init_guess = [0.5, 0.5] @@ -93,8 +93,8 @@ end if TEST_SUITE != "HMMBase" test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) - # test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) - # test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) end end From 3bb007d596ef09ff5f7c52114e5a36b6180530e4 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 30 Sep 2024 09:18:47 +0200 Subject: [PATCH 14/14] Fixes --- src/HiddenMarkovModels.jl | 2 +- src/types/abstract_hmm.jl | 2 +- src/utils/lightcategorical.jl | 2 +- src/utils/lightdiagnormal.jl | 8 ++++---- test/correctness.jl | 10 ++++++++-- test/distributions.jl | 4 ++-- test/runtests.jl | 2 +- 7 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index 2cfa2029..ecfd99b2 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -16,7 +16,7 @@ using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, rrule_via_ad using DensityInterface: DensityInterface, DensityKind, HasDensity, NoDensity, logdensityof using DocStringExtensions using FillArrays: Fill -using LinearAlgebra: Transpose, dot, ldiv!, lmul!, mul!, parent +using LinearAlgebra: Transpose, axpy!, dot, ldiv!, lmul!, mul!, parent using Random: Random, AbstractRNG, default_rng using SparseArrays: AbstractSparseArray, SparseMatrixCSC, nonzeros, nnz, nzrange, rowvals using StatsAPI: StatsAPI, fit, fit! diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index 6f3f8e53..c9f0c675 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -125,7 +125,7 @@ function obs_logdensities!( logb::AbstractVector{T}, hmm::AbstractHMM, obs, control ) where {T} dists = obs_distributions(hmm, control) - @inbounds @simd for i in eachindex(logb, dists) + @simd for i in eachindex(logb, dists) logb[i] = logdensityof(dists[i], obs) end @argcheck maximum(logb) < typemax(T) diff --git a/src/utils/lightcategorical.jl b/src/utils/lightcategorical.jl index fd96dd29..605f3c28 100644 --- a/src/utils/lightcategorical.jl +++ b/src/utils/lightcategorical.jl @@ -53,7 +53,7 @@ function StatsAPI.fit!( @argcheck 1 <= minimum(x) <= maximum(x) <= length(dist.p) w_tot = sum(w) fill!(dist.p, zero(T1)) - @inbounds @simd for i in eachindex(x, w) + @simd for i in eachindex(x, w) dist.p[x[i]] += w[i] end dist.p ./= w_tot diff --git a/src/utils/lightdiagnormal.jl b/src/utils/lightdiagnormal.jl index 6c84ac44..05851672 100644 --- a/src/utils/lightdiagnormal.jl +++ b/src/utils/lightdiagnormal.jl @@ -46,7 +46,7 @@ function DensityInterface.logdensityof( ) where {T1,T2,T3} l = zero(promote_type(T1, T2, T3, eltype(x))) l -= sum(dist.logσ) + log2π * length(x) / 2 - @inbounds @simd for i in eachindex(x, dist.μ, dist.σ) + @simd for i in eachindex(x, dist.μ, dist.σ) l -= abs2(x[i] - dist.μ[i]) / (2 * abs2(dist.σ[i])) end return l @@ -58,11 +58,11 @@ function StatsAPI.fit!( w_tot = sum(w) fill!(dist.μ, zero(T1)) fill!(dist.σ, zero(T2)) - @inbounds @simd for i in eachindex(x, w) - dist.μ .+= x[i] .* w[i] + @simd for i in eachindex(x, w) + axpy!(w[i], x[i], dist.μ) end dist.μ ./= w_tot - @inbounds @simd for i in eachindex(x, w) + @simd for i in eachindex(x, w) dist.σ .+= abs2.(x[i] .- dist.μ) .* w[i] end dist.σ .= sqrt.(dist.σ ./ w_tot) diff --git a/test/correctness.jl b/test/correctness.jl index 1bf560d0..716e8552 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -11,8 +11,6 @@ using Test TEST_SUITE = get(ENV, "JULIA_HMM_TEST_SUITE", "Standard") -rng = StableRNG(63) - ## Settings T, K = 100, 200 @@ -31,6 +29,7 @@ p_guess = [[0.7, 0.3], [0.3, 0.7]] σ = ones(2) +rng = StableRNG(63) control_seqs = [fill(nothing, rand(rng, T:(2T))) for k in 1:K]; control_seq = reduce(vcat, control_seqs); seq_ends = cumsum(length.(control_seqs)); @@ -44,6 +43,7 @@ seq_ends = cumsum(length.(control_seqs)); hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) + rng = StableRNG(63) if TEST_SUITE == "HMMBase" test_identical_hmmbase(rng, hmm, T; hmm_guess) else @@ -62,6 +62,7 @@ end hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) + rng = StableRNG(63) if TEST_SUITE == "HMMBase" test_identical_hmmbase(rng, hmm, T; hmm_guess) else @@ -77,6 +78,7 @@ end hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) + rng = StableRNG(63) if TEST_SUITE != "HMMBase" test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) @@ -91,6 +93,7 @@ end hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) + rng = StableRNG(63) if TEST_SUITE != "HMMBase" test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) @@ -105,6 +108,7 @@ end hmm = HMM(init, sparse(trans), dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) + rng = StableRNG(63) if TEST_SUITE == "HMMBase" test_identical_hmmbase(rng, hmm, T; hmm_guess) else @@ -121,6 +125,7 @@ end hmm = transpose_hmm(HMM(init, trans, dists)) hmm_guess = transpose_hmm(HMM(init_guess, trans_guess, dists_guess)) + rng = StableRNG(63) if TEST_SUITE == "HMMBase" test_identical_hmmbase(rng, hmm, T; hmm_guess) else @@ -137,6 +142,7 @@ end hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) + rng = StableRNG(63) if TEST_SUITE == "HMMBase" test_identical_hmmbase(rng, hmm, T; hmm_guess) else diff --git a/test/distributions.jl b/test/distributions.jl index 544ba063..dfbc8822 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -25,7 +25,7 @@ end end @test val_count ./ length(x) ≈ p atol = 2e-2 # Fitting - dist_est = deepcopy(dist) + dist_est = LightCategorical(rand_prob_vec(rng, 10)) w = ones(length(x)) fit!(dist_est, x, w) @test dist_est.p ≈ p atol = 2e-2 @@ -43,7 +43,7 @@ end @test mean(x) ≈ μ atol = 2e-2 @test std(x) ≈ σ atol = 2e-2 # Fitting - dist_est = deepcopy(dist) + dist_est = LightDiagNormal(randn(rng, 10), rand(rng, 10)) w = ones(length(x)) fit!(dist_est, x, w) @test dist_est.μ ≈ μ atol = 2e-2 diff --git a/test/runtests.jl b/test/runtests.jl index e41cf6ae..39a06cbd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -50,7 +50,7 @@ Pkg.develop(; path=joinpath(dirname(@__DIR__), "libs", "HMMTest")) end end - @testset "Correctness - $TEST_SUITE" begin + @testset verbose = true "Correctness - $TEST_SUITE" begin include("correctness.jl") end end