Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store transposed transition matrix to speed up forward #107

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/inference/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ function forward!(
logm = maximum(Bₜ₊₁)
Bₜ₊₁ .= exp.(Bₜ₊₁ .- logm)

trans = transition_matrix(hmm, control_seq[t])
transpose_trans = transpose_transition_matrix(hmm, control_seq[t])
αₜ, αₜ₊₁ = view(α, :, t), view(α, :, t + 1)
mul!(αₜ₊₁, transpose(trans), αₜ)
mul!(αₜ₊₁, transpose_trans, αₜ)
αₜ₊₁ .*= Bₜ₊₁
c[t + 1] = inv(sum(αₜ₊₁))
lmul!(c[t + 1], αₜ₊₁)
Expand Down
12 changes: 12 additions & 0 deletions src/types/abstract_hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
"""
function transition_matrix end

function transpose_transition_matrix(hmm::AbstractHMM, control)
return transpose(transition_matrix(hmm, control))
end

"""
log_transition_matrix(hmm)
log_transition_matrix(hmm, control)
Expand All @@ -87,6 +91,10 @@
return elementwise_log(transition_matrix(hmm, control))
end

function transpose_log_transition_matrix(hmm::AbstractHMM, control)
return transpose(log_transition_matrix(hmm, control))

Check warning on line 95 in src/types/abstract_hmm.jl

View check run for this annotation

Codecov / codecov/patch

src/types/abstract_hmm.jl#L94-L95

Added lines #L94 - L95 were not covered by tests
end

"""
obs_distributions(hmm)
obs_distributions(hmm, control)
Expand All @@ -104,7 +112,11 @@
## Fallbacks for no control

transition_matrix(hmm::AbstractHMM, ::Nothing) = transition_matrix(hmm)
transpose_transition_matrix(hmm::AbstractHMM, ::Nothing) = transpose(transition_matrix(hmm))
log_transition_matrix(hmm::AbstractHMM, ::Nothing) = log_transition_matrix(hmm)
function transpose_log_transition_matrix(hmm::AbstractHMM, ::Nothing)
return transpose(log_transition_matrix(hmm))

Check warning on line 118 in src/types/abstract_hmm.jl

View check run for this annotation

Codecov / codecov/patch

src/types/abstract_hmm.jl#L117-L118

Added lines #L117 - L118 were not covered by tests
end
obs_distributions(hmm::AbstractHMM, ::Nothing) = obs_distributions(hmm)

"""
Expand Down
25 changes: 21 additions & 4 deletions src/types/hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
VD<:AbstractVector,
Vl<:AbstractVector,
Ml<:AbstractMatrix,
Mt<:AbstractMatrix,
Mlt<:AbstractMatrix,
} <: AbstractHMM
"initial state probabilities"
init::V
Expand All @@ -24,14 +26,24 @@
loginit::Vl
"logarithms of state transition probabilities"
logtrans::Ml
transpose_trans::Mt
transpose_logtrans::Mlt

function HMM(init::AbstractVector, trans::AbstractMatrix, dists::AbstractVector)
log_init = elementwise_log(init)
log_trans = elementwise_log(trans)
loginit = elementwise_log(init)
logtrans = elementwise_log(trans)
transpose_trans = concrete_transpose(trans)
transpose_logtrans = concrete_transpose(logtrans)
hmm = new{
typeof(init),typeof(trans),typeof(dists),typeof(log_init),typeof(log_trans)
typeof(init),
typeof(trans),
typeof(dists),
typeof(loginit),
typeof(logtrans),
typeof(transpose_trans),
typeof(transpose_logtrans),
}(
init, trans, dists, log_init, log_trans
init, trans, dists, loginit, logtrans, transpose_trans, transpose_logtrans
)
@argcheck valid_hmm(hmm)
return hmm
Expand All @@ -52,7 +64,9 @@
initialization(hmm::HMM) = hmm.init
log_initialization(hmm::HMM) = hmm.loginit
transition_matrix(hmm::HMM) = hmm.trans
transpose_transition_matrix(hmm::HMM) = hmm.transpose_trans

Check warning on line 67 in src/types/hmm.jl

View check run for this annotation

Codecov / codecov/patch

src/types/hmm.jl#L67

Added line #L67 was not covered by tests
log_transition_matrix(hmm::HMM) = hmm.logtrans
transpose_log_transition_matrix(hmm::HMM) = hmm.transpose_logtrans

Check warning on line 69 in src/types/hmm.jl

View check run for this annotation

Codecov / codecov/patch

src/types/hmm.jl#L69

Added line #L69 was not covered by tests
obs_distributions(hmm::HMM) = hmm.dists

## Fitting
Expand Down Expand Up @@ -90,6 +104,9 @@
# Update logs
hmm.loginit .= log.(hmm.init)
mynonzeros(hmm.logtrans) .= log.(mynonzeros(hmm.trans))
# Update transposes (could be optimized)
copyto!(hmm.transpose_trans, transpose(hmm.trans))
copyto!(hmm.transpose_logtrans, transpose(hmm.logtrans))
# Safety check
@argcheck valid_hmm(hmm)
return nothing
Expand Down
3 changes: 3 additions & 0 deletions src/utils/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,6 @@ function argmaxplus_transmul!(
end
return y
end

concrete_transpose(A::AbstractMatrix) = convert(typeof(A), transpose(A))
concrete_transpose(A::Transpose) = parent(A)
4 changes: 3 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ Pkg.develop(; path=joinpath(dirname(@__DIR__), "libs", "HMMTest"))
@testset "Code linting" begin
using Distributions
using Zygote
JET.test_package(HiddenMarkovModels; target_defined_modules=true)
if VERSION >= v"1.10"
JET.test_package(HiddenMarkovModels; target_defined_modules=true)
end
end

@testset "Distributions" begin
Expand Down
Loading