From 9b87ce8972e6128e599515e25a4e965026d608ab Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 28 Sep 2024 09:10:39 +0200 Subject: [PATCH 1/4] Store transposed transition matrix to speed up forward --- src/inference/forward.jl | 4 ++-- src/types/abstract_hmm.jl | 12 ++++++++++++ src/types/hmm.jl | 25 +++++++++++++++++++++---- src/utils/linalg.jl | 3 +++ 4 files changed, 38 insertions(+), 6 deletions(-) diff --git a/src/inference/forward.jl b/src/inference/forward.jl index c7d4883c..36848f43 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -70,9 +70,9 @@ function forward!( logm = maximum(Bₜ₊₁) Bₜ₊₁ .= exp.(Bₜ₊₁ .- logm) - trans = transition_matrix(hmm, control_seq[t]) + transpose_trans = transpose_transition_matrix(hmm, control_seq[t]) αₜ, αₜ₊₁ = view(α, :, t), view(α, :, t + 1) - mul!(αₜ₊₁, transpose(trans), αₜ) + mul!(αₜ₊₁, transpose_trans, αₜ) αₜ₊₁ .*= Bₜ₊₁ c[t + 1] = inv(sum(αₜ₊₁)) lmul!(c[t + 1], αₜ₊₁) diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index 6f3f8e53..bfafb56f 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -75,6 +75,10 @@ Return the matrix of state transition probabilities for `hmm` (possibly when `co """ function transition_matrix end +function transpose_transition_matrix(hmm::AbstractHMM, control) + return transpose(transition_matrix(hmm, control)) +end + """ log_transition_matrix(hmm) log_transition_matrix(hmm, control) @@ -87,6 +91,10 @@ function log_transition_matrix(hmm::AbstractHMM, control) return elementwise_log(transition_matrix(hmm, control)) end +function transpose_log_transition_matrix(hmm::AbstractHMM, control) + return transpose(log_transition_matrix(hmm, control)) +end + """ obs_distributions(hmm) obs_distributions(hmm, control) @@ -104,7 +112,11 @@ function obs_distributions end ## Fallbacks for no control transition_matrix(hmm::AbstractHMM, ::Nothing) = transition_matrix(hmm) +transpose_transition_matrix(hmm::AbstractHMM, ::Nothing) = transpose_transition_matrix(hmm) log_transition_matrix(hmm::AbstractHMM, ::Nothing) = log_transition_matrix(hmm) +function transpose_log_transition_matrix(hmm::AbstractHMM, ::Nothing) + return transpose_log_transition_matrix(hmm) +end obs_distributions(hmm::AbstractHMM, ::Nothing) = obs_distributions(hmm) """ diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 9de3cf2e..2b527af8 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -13,6 +13,8 @@ struct HMM{ VD<:AbstractVector, Vl<:AbstractVector, Ml<:AbstractMatrix, + Mt<:AbstractMatrix, + Mlt<:AbstractMatrix, } <: AbstractHMM "initial state probabilities" init::V @@ -24,14 +26,24 @@ struct HMM{ loginit::Vl "logarithms of state transition probabilities" logtrans::Ml + transpose_trans::Mt + transpose_logtrans::Mlt function HMM(init::AbstractVector, trans::AbstractMatrix, dists::AbstractVector) - log_init = elementwise_log(init) - log_trans = elementwise_log(trans) + loginit = elementwise_log(init) + logtrans = elementwise_log(trans) + transpose_trans = concrete_transpose(trans) + transpose_logtrans = concrete_transpose(logtrans) hmm = new{ - typeof(init),typeof(trans),typeof(dists),typeof(log_init),typeof(log_trans) + typeof(init), + typeof(trans), + typeof(dists), + typeof(loginit), + typeof(logtrans), + typeof(transpose_trans), + typeof(transpose_logtrans), }( - init, trans, dists, log_init, log_trans + init, trans, dists, loginit, logtrans, transpose_trans, transpose_logtrans ) @argcheck valid_hmm(hmm) return hmm @@ -52,7 +64,9 @@ end initialization(hmm::HMM) = hmm.init log_initialization(hmm::HMM) = hmm.loginit transition_matrix(hmm::HMM) = hmm.trans +transpose_transition_matrix(hmm::HMM) = hmm.transpose_trans log_transition_matrix(hmm::HMM) = hmm.logtrans +transpose_log_transition_matrix(hmm::HMM) = hmm.transpose_logtrans obs_distributions(hmm::HMM) = hmm.dists ## Fitting @@ -90,6 +104,9 @@ function StatsAPI.fit!( # Update logs hmm.loginit .= log.(hmm.init) mynonzeros(hmm.logtrans) .= log.(mynonzeros(hmm.trans)) + # Update transposes (could be optimized) + copyto!(hmm.transpose_trans, transpose(hmm.trans)) + copyto!(hmm.transpose_logtrans, transpose(hmm.logtrans)) # Safety check @argcheck valid_hmm(hmm) return nothing diff --git a/src/utils/linalg.jl b/src/utils/linalg.jl index 506b29e4..9ce737dd 100644 --- a/src/utils/linalg.jl +++ b/src/utils/linalg.jl @@ -94,3 +94,6 @@ function argmaxplus_transmul!( end return y end + +concrete_transpose(A::AbstractMatrix) = convert(typeof(A), transpose(A)) +concrete_transpose(A::Transpose) = parent(A) From 0d5ecc2c7a81ec862e9850053c989b009ebcc975 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 28 Sep 2024 09:28:54 +0200 Subject: [PATCH 2/4] Fix --- src/types/abstract_hmm.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index bfafb56f..de41dfc2 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -112,10 +112,10 @@ function obs_distributions end ## Fallbacks for no control transition_matrix(hmm::AbstractHMM, ::Nothing) = transition_matrix(hmm) -transpose_transition_matrix(hmm::AbstractHMM, ::Nothing) = transpose_transition_matrix(hmm) +transpose_transition_matrix(hmm::AbstractHMM, ::Nothing) = transpose(transition_matrix(hmm)) log_transition_matrix(hmm::AbstractHMM, ::Nothing) = log_transition_matrix(hmm) function transpose_log_transition_matrix(hmm::AbstractHMM, ::Nothing) - return transpose_log_transition_matrix(hmm) + return transpose(log_transition_matrix(hmm)) end obs_distributions(hmm::AbstractHMM, ::Nothing) = obs_distributions(hmm) From ced9ecf07da76badf592fe84cbb52f2ae84aae9f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 28 Sep 2024 09:43:36 +0200 Subject: [PATCH 3/4] JET --- test/runtests.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 96b2f0fa..0a303b34 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,8 +21,9 @@ Pkg.develop(; path=joinpath(dirname(@__DIR__), "libs", "HMMTest")) @testset "Code linting" begin using Distributions - using Zygote - JET.test_package(HiddenMarkovModels; target_defined_modules=true) + if VERSION >= v"1.10" + JET.test_package(HiddenMarkovModels; target_defined_modules=true) + end end @testset "Distributions" begin From af9a64f1b3c04b53f93cfcb2704433bb2cadd7a0 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 28 Sep 2024 10:26:00 +0200 Subject: [PATCH 4/4] Zygote --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 0a303b34..0d2bac03 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,6 +21,7 @@ Pkg.develop(; path=joinpath(dirname(@__DIR__), "libs", "HMMTest")) @testset "Code linting" begin using Distributions + using Zygote if VERSION >= v"1.10" JET.test_package(HiddenMarkovModels; target_defined_modules=true) end