From b56d1b253c7a04a5b164cb12fd2eac1b5366295d Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 22 Feb 2024 13:53:41 +0100 Subject: [PATCH] Fix leaky tests --- docs/src/formulas.md | 34 ++++++++++++++++-------------- examples/basics.jl | 4 ++-- examples/interfaces.jl | 2 +- examples/types.jl | 4 ++-- libs/HMMTest/src/allocations.jl | 4 ++-- libs/HMMTest/src/coherence.jl | 37 ++++++++++++++++++++------------- libs/HMMTest/src/hmmbase.jl | 10 ++------- src/types/hmm.jl | 7 +++++++ src/utils/linalg.jl | 1 + test/correctness.jl | 8 +++---- 10 files changed, 61 insertions(+), 50 deletions(-) diff --git a/docs/src/formulas.md b/docs/src/formulas.md index 7c20cb04..082fdfc6 100644 --- a/docs/src/formulas.md +++ b/docs/src/formulas.md @@ -4,9 +4,11 @@ Suppose we are given observations $Y_1, ..., Y_T$, with hidden states $X_1, ..., Following [Rabiner1989](@cite), we use the following notations: * let $\pi \in \mathbb{R}^N$ be the initial state distribution $\pi_i = \mathbb{P}(X_1 = i)$ -* let $A \in \mathbb{R}^{N \times N}$ be the transition matrix $a_{i,j} = \mathbb{P}(X_{t+1}=j | X_t = i)$ +* let $A_t \in \mathbb{R}^{N \times N}$ be the transition matrix $a_{i,j,t} = \mathbb{P}(X_{t+1}=j | X_t = i)$ * let $B \in \mathbb{R}^{N \times T}$ be the matrix of statewise observation likelihoods $b_{i,t} = \mathbb{P}(Y_t | X_t = i)$ +The conditioning on the known controls $U_{1:T}$ is implicit throughout. + ## Vanilla forward-backward ### Recursion @@ -33,8 +35,8 @@ and satisfy the dynamic programming equations ```math \begin{align*} -\alpha_{j,t+1} & = \left(\sum_{i=1}^N \alpha_{i,t} a_{i,j}\right) b_{j,t+1} \\ -\beta_{i,t} & = \sum_{j=1}^N a_{i,j} b_{j,t+1} \beta_{j,t+1} +\alpha_{j,t+1} & = \left(\sum_{i=1}^N \alpha_{i,t} a_{i,j,t}\right) b_{j,t+1} \\ +\beta_{i,t} & = \sum_{j=1}^N a_{i,j,t} b_{j,t+1} \beta_{j,t+1} \end{align*} ``` @@ -53,7 +55,7 @@ We notice that ```math \begin{align*} \alpha_{i,t} \beta_{i,t} & = \mathbb{P}(Y_{1:T}, X_t=i) \\ -\alpha_{i,t} a_{i,j} b_{j,t+1} \beta_{j,t+1} & = \mathbb{P}(Y_{1:T}, X_t=i, X_{t+1}=j) +\alpha_{i,t} a_{i,j,t} b_{j,t+1} \beta_{j,t+1} & = \mathbb{P}(Y_{1:T}, X_t=i, X_{t+1}=j) \end{align*} ``` @@ -62,7 +64,7 @@ Thus we deduce the one-state and two-state marginals ```math \begin{align*} \gamma_{i,t} & = \mathbb{P}(X_t=i | Y_{1:T}) = \frac{1}{\mathcal{L}} \alpha_{i,t} \beta_{i,t} \\ -\xi_{i,j,t} & = \mathbb{P}(X_t=i, X_{t+1}=j | Y_{1:T}) = \frac{1}{\mathcal{L}} \alpha_{i,t} a_{i,j} b_{j,t+1} \beta_{j,t+1} +\xi_{i,j,t} & = \mathbb{P}(X_t=i, X_{t+1}=j | Y_{1:T}) = \frac{1}{\mathcal{L}} \alpha_{i,t} a_{i,j,t} b_{j,t+1} \beta_{j,t+1} \end{align*} ``` @@ -75,7 +77,7 @@ According to [Qin2000](@cite), derivatives of the likelihood can be obtained as \frac{\partial \mathcal{L}}{\partial \pi_i} &= \beta_{i,1} b_{i,1} \\ \frac{\partial \mathcal{L}}{\partial a_{i,j}} &= \sum_{t=1}^{T-1} \alpha_{i,t} b_{j,t+1} \beta_{j,t+1} \\ \frac{\partial \mathcal{L}}{\partial b_{j,1}} &= \pi_j \beta_{j,1} \\ -\frac{\partial \mathcal{L}}{\partial b_{j,t}} &= \left(\sum_{i=1}^N \alpha_{i,t-1} a_{i,j}\right) \beta_{j,t} +\frac{\partial \mathcal{L}}{\partial b_{j,t}} &= \left(\sum_{i=1}^N \alpha_{i,t-1} a_{i,j,t-1}\right) \beta_{j,t} \end{align*} ``` @@ -98,8 +100,8 @@ and satisfy the dynamic programming equations ```math \begin{align*} -\hat{\alpha}_{j,t+1} & = \left(\sum_{i=1}^N \bar{\alpha}_{i,t} a_{i,j}\right) \frac{b_{j,t+1}}{m_{t+1}} & c_{t+1} & = \frac{1}{\sum_j \hat{\alpha}_{j,t+1}} & \bar{\alpha}_{j,t+1} = c_{t+1} \hat{\alpha}_{j,t+1} \\ -\hat{\beta}_{i,t} & = \sum_{j=1}^N a_{i,j} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} & && \bar{\beta}_{j,t} = c_t \hat{\beta}_{j,t} +\hat{\alpha}_{j,t+1} & = \left(\sum_{i=1}^N \bar{\alpha}_{i,t} a_{i,j,t}\right) \frac{b_{j,t+1}}{m_{t+1}} & c_{t+1} & = \frac{1}{\sum_j \hat{\alpha}_{j,t+1}} & \bar{\alpha}_{j,t+1} = c_{t+1} \hat{\alpha}_{j,t+1} \\ +\hat{\beta}_{i,t} & = \sum_{j=1}^N a_{i,j,t} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} & && \bar{\beta}_{j,t} = c_t \hat{\beta}_{j,t} \end{align*} ``` @@ -140,9 +142,9 @@ We can now express the marginals using scaled variables: ```math \begin{align*} \xi_{i,j,t} & = \frac{1}{\mathcal{L}} \alpha_{i,t} a_{i,j} b_{j,t+1} \beta_{j,t+1} \\ -&= \frac{1}{\mathcal{L}} \left(\bar{\alpha}_{i,t} \prod_{s=1}^t \frac{m_s}{c_s}\right) a_{i,j} b_{j,t+1} \left(\bar{\beta}_{j,t+1} \frac{1}{c_{t+1}} \prod_{s=t+2}^T \frac{m_s}{c_s}\right) \\ -&= \frac{1}{\mathcal{L}} \bar{\alpha}_{i,t} a_{i,j} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} \left(\prod_{s=1}^T \frac{m_s}{c_s}\right) \\ -&= \bar{\alpha}_{i,t} a_{i,j} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} +&= \frac{1}{\mathcal{L}} \left(\bar{\alpha}_{i,t} \prod_{s=1}^t \frac{m_s}{c_s}\right) a_{i,j,t} b_{j,t+1} \left(\bar{\beta}_{j,t+1} \frac{1}{c_{t+1}} \prod_{s=t+2}^T \frac{m_s}{c_s}\right) \\ +&= \frac{1}{\mathcal{L}} \bar{\alpha}_{i,t} a_{i,j,t} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} \left(\prod_{s=1}^T \frac{m_s}{c_s}\right) \\ +&= \bar{\alpha}_{i,t} a_{i,j,t} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} \end{align*} ``` @@ -179,10 +181,10 @@ And for the statewise observation likelihoods, ```math \begin{align*} -\frac{\partial \mathcal{L}}{\partial b_{j,t}} &= \left(\sum_{i=1}^N \alpha_{i,t-1} a_{i,j}\right) \beta_{j,t} \\ -&= \sum_{i=1}^N \left(\bar{\alpha}_{i,t-1} \prod_{s=1}^{t-1} \frac{m_s}{c_s}\right) a_{i,j} \left(\bar{\beta}_{j,t} \frac{1}{c_t} \prod_{s=t+1}^T \frac{m_s}{c_s} \right) \\ -&= \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j} \bar{\beta}_{j,t} \frac{1}{m_t} \left(\prod_{s=1}^T \frac{m_s}{c_s}\right) \\ -&= \mathcal{L} \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j} \bar{\beta}_{j,t} \frac{1}{m_t} \\ +\frac{\partial \mathcal{L}}{\partial b_{j,t}} &= \left(\sum_{i=1}^N \alpha_{i,t-1} a_{i,j,t-1}\right) \beta_{j,t} \\ +&= \sum_{i=1}^N \left(\bar{\alpha}_{i,t-1} \prod_{s=1}^{t-1} \frac{m_s}{c_s}\right) a_{i,j,t-1} \left(\bar{\beta}_{j,t} \frac{1}{c_t} \prod_{s=t+1}^T \frac{m_s}{c_s} \right) \\ +&= \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j,t-1} \bar{\beta}_{j,t} \frac{1}{m_t} \left(\prod_{s=1}^T \frac{m_s}{c_s}\right) \\ +&= \mathcal{L} \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j,t-1} \bar{\beta}_{j,t} \frac{1}{m_t} \\ \end{align*} ``` @@ -199,7 +201,7 @@ To sum up, \frac{\partial \log \mathcal{L}}{\partial \pi_i} &= \frac{b_{i,1}}{m_1} \bar{\beta}_{i,1} \\ \frac{\partial \log \mathcal{L}}{\partial a_{i,j}} &= \sum_{t=1}^{T-1} \bar{\alpha}_{i,t} \frac{b_{j,t+1}}{m_{t+1}} \bar{\beta}_{j,t+1} \\ \frac{\partial \log \mathcal{L}}{\partial \log b_{j,1}} &= \pi_j \frac{b_{j,1}}{m_1} \bar{\beta}_{j,1} = \frac{\bar{\alpha}_{j,1} \bar{\beta}_{j,1}}{c_1} = \gamma_{j,1} \\ -\frac{\partial \log \mathcal{L}}{\partial \log b_{j,t}} &= \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j} \frac{b_{j,t}}{m_t} \bar{\beta}_{j,t} = \frac{\bar{\alpha}_{j,t} \bar{\beta}_{j,t}}{c_t} = \gamma_{j,t} +\frac{\partial \log \mathcal{L}}{\partial \log b_{j,t}} &= \sum_{i=1}^N \bar{\alpha}_{i,t-1} a_{i,j,t-1} \frac{b_{j,t}}{m_t} \bar{\beta}_{j,t} = \frac{\bar{\alpha}_{j,t} \bar{\beta}_{j,t}}{c_t} = \gamma_{j,t} \end{align*} ``` diff --git a/examples/basics.jl b/examples/basics.jl index 79168429..034583c7 100644 --- a/examples/basics.jl +++ b/examples/basics.jl @@ -30,7 +30,7 @@ Any scalar- or vector-valued distribution from [Distributions.jl](https://github init = [0.6, 0.4] trans = [0.7 0.3; 0.3 0.7] dists = [MvNormal([-0.5, -0.8], I), MvNormal([0.5, 0.8], I)] -hmm = HMM(init, trans, dists); +hmm = HMM(init, trans, dists) # ## Simulation @@ -143,7 +143,7 @@ Since it is a local optimization procedure, it requires a starting point that is init_guess = [0.5, 0.5] trans_guess = [0.6 0.4; 0.4 0.6] -dists_guess = [MvNormal([-0.6, -0.7], I), MvNormal([0.6, 0.7], I)] +dists_guess = [MvNormal([-0.4, -0.7], I), MvNormal([0.4, 0.7], I)] hmm_guess = HMM(init_guess, trans_guess, dists_guess); #= diff --git a/examples/interfaces.jl b/examples/interfaces.jl index 703f8c36..18e0d7b5 100644 --- a/examples/interfaces.jl +++ b/examples/interfaces.jl @@ -105,7 +105,7 @@ If we implement `fit!`, Baum-Welch also works seamlessly. init_guess = [0.5, 0.5] trans_guess = [0.6 0.4; 0.4 0.6] -dists_guess = [StuffDist(-0.5), StuffDist(+0.5)] +dists_guess = [StuffDist(-0.7), StuffDist(+0.7)] hmm_guess = HMM(init_guess, trans_guess, dists_guess); #- diff --git a/examples/types.jl b/examples/types.jl index dc97c6eb..dc4c074a 100644 --- a/examples/types.jl +++ b/examples/types.jl @@ -51,7 +51,7 @@ Note that uncertainty on the transition parameters would throw an error because =# dists_guess = [Normal(-1.0 ± 0.1), Normal(1.0 ± 0.2)] -hmm_uncertain = HMM(init, trans, dists_guess); +hmm_uncertain = HMM(init, trans, dists_guess) #= Every quantity we compute with this new HMM will have propagated uncertainties around it. @@ -98,7 +98,7 @@ trans = sparse([ init = [0.2, 0.6, 0.2] dists = [Normal(-2.0), Normal(0.0), Normal(+2.0)] -hmm = HMM(init, trans, dists); +hmm = HMM(init, trans, dists) #= When we simulate it, the transitions outside of the nonzero coefficients simply cannot happen. diff --git a/libs/HMMTest/src/allocations.jl b/libs/HMMTest/src/allocations.jl index d3bb21b9..ea3aeef9 100644 --- a/libs/HMMTest/src/allocations.jl +++ b/libs/HMMTest/src/allocations.jl @@ -46,8 +46,8 @@ function test_allocations( ) HMMs.forward_backward!(fb_storage, hmm, obs_seq, control_seq; seq_ends) allocs_bw = @ballocated fit!( - $hmm_guess, $fb_storage, $obs_seq, $control_seq; seq_ends=$seq_ends - ) evals = 1 samples = 1 setup = (hmm_guess = deepcopy($hmm)) + hmm_guess_copy, $fb_storage, $obs_seq, $control_seq; seq_ends=$seq_ends + ) evals = 1 samples = 1 setup = (hmm_guess_copy = deepcopy($hmm_guess)) @test_broken allocs_bw == 0 end end diff --git a/libs/HMMTest/src/coherence.jl b/libs/HMMTest/src/coherence.jl index f5ca9147..c8983eea 100644 --- a/libs/HMMTest/src/coherence.jl +++ b/libs/HMMTest/src/coherence.jl @@ -1,27 +1,31 @@ infnorm(x) = maximum(abs, x) -function are_equal_hmms( +function test_equal_hmms( hmm1::AbstractHMM, hmm2::AbstractHMM, control_seq::AbstractVector; atol::Real, init::Bool, - test::Bool, + flip::Bool=false, ) - equal_check = true - if init init1 = initialization(hmm1) init2 = initialization(hmm2) - test && @test isapprox(init1, init2; atol, norm=infnorm) - equal_check = equal_check && isapprox(init1, init2; atol, norm=infnorm) + if flip + @test !isapprox(init1, init2; atol, norm=infnorm) + else + @test isapprox(init1, init2; atol, norm=infnorm) + end end for control in control_seq trans1 = transition_matrix(hmm1, control) trans2 = transition_matrix(hmm2, control) - test && @test isapprox(trans1, trans2; atol, norm=infnorm) - equal_check = equal_check && isapprox(trans1, trans2; atol, norm=infnorm) + if flip + @test !isapprox(trans1, trans2; atol, norm=infnorm) + else + @test isapprox(trans1, trans2; atol, norm=infnorm) + end end for control in control_seq @@ -29,18 +33,23 @@ function are_equal_hmms( dists2 = obs_distributions(hmm2, control) for (dist1, dist2) in zip(dists1, dists2) for field in fieldnames(typeof(dist1)) - if startswith(string(field), "log") + if startswith(string(field), "log") || + contains("σ", string(field)) || + contains("Σ", string(field)) continue end x1 = getfield(dist1, field) x2 = getfield(dist2, field) - test && @test isapprox(x1, x2; atol, norm=infnorm) - equal_check = equal_check && isapprox(x1, x2; atol, norm=infnorm) + if flip + @test !isapprox(x1, x2; atol, norm=infnorm) + else + @test isapprox(x1, x2; atol, norm=infnorm) + end end end end - return equal_check + return nothing end function test_coherent_algorithms( @@ -81,8 +90,8 @@ function test_coherent_algorithms( if !isnothing(hmm_guess) hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends) @test all(>=(0), diff(logL_evolution)) - @test !are_equal_hmms(hmm, hmm_guess, control_seq[1:2]; atol, init, test=false) - are_equal_hmms(hmm, hmm_est, control_seq[1:2]; atol, init, test=true) + test_equal_hmms(hmm, hmm_guess, control_seq[1:2]; atol, init, flip=true) + test_equal_hmms(hmm, hmm_est, control_seq[1:2]; atol, init) end end end diff --git a/libs/HMMTest/src/hmmbase.jl b/libs/HMMTest/src/hmmbase.jl index 1ecf76fd..fe1e48e3 100644 --- a/libs/HMMTest/src/hmmbase.jl +++ b/libs/HMMTest/src/hmmbase.jl @@ -50,14 +50,8 @@ function test_identical_hmmbase( @test isapprox( logL_evolution[(begin + 1):end], 2 * logL_evolution_base[begin:(end - 1)] ) - are_equal_hmms( - hmm_est, - HMM(hmm_est_base.a, hmm_est_base.A, hmm_est_base.B), - [nothing]; - atol, - init=true, - test=true, - ) + hmm_est_base_converted = HMM(hmm_est_base.a, hmm_est_base.A, hmm_est_base.B) + test_equal_hmms(hmm_est, hmm_est_base_converted, [nothing]; atol, init=true) end end end diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 19c51340..4ebf8e65 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -26,6 +26,13 @@ function Base.copy(hmm::HMM) return HMM(copy(hmm.init), copy(hmm.trans), copy(hmm.dists)) end +function Base.show(io::IO, hmm::HMM) + return print( + io, + "Hidden Markov Model with:\n - initialization: $(hmm.init)\n - transition matrix: $(hmm.trans)\n - observation distributions: $(hmm.dists)", + ) +end + initialization(hmm::HMM) = hmm.init transition_matrix(hmm::HMM) = hmm.trans obs_distributions(hmm::HMM) = hmm.dists diff --git a/src/utils/linalg.jl b/src/utils/linalg.jl index 550c160b..e3fe2b6f 100644 --- a/src/utils/linalg.jl +++ b/src/utils/linalg.jl @@ -18,6 +18,7 @@ function mul_rows_cols!( @argcheck size(B) == size(A) == (length(l), length(r)) @argcheck nnz(B) == nnz(A) for j in axes(B, 2) + @argcheck nzrange(B, j) == nzrange(A, j) for k in nzrange(B, j) i = B.rowval[k] B.nzval[k] = l[i] * A.nzval[k] * r[j] diff --git a/test/correctness.jl b/test/correctness.jl index 639c3b7b..c7dc84a8 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -18,8 +18,8 @@ T, K = 100, 200 init = [0.4, 0.6] init_guess = [0.5, 0.5] -trans = [0.8 0.2; 0.2 0.8] -trans_guess = [0.7 0.3; 0.3 0.7] +trans = [0.7 0.3; 0.3 0.7] +trans_guess = [0.6 0.4; 0.4 0.6] p = [[0.8, 0.2], [0.2, 0.8]] p_guess = [[0.7, 0.3], [0.3, 0.7]] @@ -58,9 +58,7 @@ end hmm_guess = HMM(init_guess, trans_guess, dists_guess) test_identical_hmmbase(rng, hmm, T; hmm_guess) - @test_skip test_coherent_algorithms( - rng, hmm, control_seq; seq_ends, hmm_guess, init=false - ) + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) end