From 82fd4651437911af4944b6df30e46ce6fd66c1ae Mon Sep 17 00:00:00 2001 From: THargreaves Date: Fri, 27 Sep 2024 15:34:36 +0100 Subject: [PATCH 1/3] Decompose forward function into initialize, predict, update --- src/inference/forward.jl | 70 +++++++++++++++++++++++----------------- 1 file changed, 41 insertions(+), 29 deletions(-) diff --git a/src/inference/forward.jl b/src/inference/forward.jl index c7d4883c..3be62097 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -47,43 +47,55 @@ function forward!( t1::Integer, t2::Integer; ) - (; α, B, c) = storage - # Initialization - Bₜ₁ = view(B, :, t1) - obs_logdensities!(Bₜ₁, hmm, obs_seq[t1], control_seq[t1]) - logm = maximum(Bₜ₁) - Bₜ₁ .= exp.(Bₜ₁ .- logm) + _initialize!(storage, hmm, t1) + logL = zero(eltype(storage)) - init = initialization(hmm) - αₜ₁ = view(α, :, t1) - αₜ₁ .= init .* Bₜ₁ - c[t1] = inv(sum(αₜ₁)) - lmul!(c[t1], αₜ₁) - - logL = -log(c[t1]) + logm - - # Loop - for t in t1:(t2 - 1) - Bₜ₊₁ = view(B, :, t + 1) - obs_logdensities!(Bₜ₊₁, hmm, obs_seq[t + 1], control_seq[t + 1]) - logm = maximum(Bₜ₊₁) - Bₜ₊₁ .= exp.(Bₜ₊₁ .- logm) - - trans = transition_matrix(hmm, control_seq[t]) - αₜ, αₜ₊₁ = view(α, :, t), view(α, :, t + 1) - mul!(αₜ₊₁, transpose(trans), αₜ) - αₜ₊₁ .*= Bₜ₊₁ - c[t + 1] = inv(sum(αₜ₊₁)) - lmul!(c[t + 1], αₜ₊₁) - - logL += -log(c[t + 1]) + logm + # Filter step loop + for t in t1:t2 + t > t1 && _predict!(storage, hmm, control_seq, t) + logL = _update!(storage, logL, hmm, obs_seq, control_seq, t) end @argcheck isfinite(logL) return logL end +function _initialize!(storage, hmm, t1) + (; α) = storage + αₜ₁ = view(α, :, t1) + αₜ₁ .= initialization(hmm) + return nothing +end + +function _predict!(storage, hmm, control_seq, t) + (; α) = storage + αₜ₋₁, αₜ = view(α, :, t - 1), view(α, :, t) + + trans = transition_matrix(hmm, control_seq[t]) + mul!(αₜ, transpose(trans), αₜ₋₁) + + return nothing +end + +function _update!(storage, logL, hmm, obs_seq, control_seq, t) + (; α, B, c) = storage + Bₜ = view(B, :, t) + αₜ = view(α, :, t) + + obs_logdensities!(Bₜ, hmm, obs_seq[t], control_seq[t]) + logm = maximum(Bₜ) + Bₜ .= exp.(Bₜ .- logm) + + αₜ .*= Bₜ + c[t] = inv(sum(αₜ)) + lmul!(c[t], αₜ) + + logL += -log(c[t]) + logm + + return logL +end + """ $(SIGNATURES) """ From 21012a6d125c44dab3602d68029244b3a09893b6 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 1 Oct 2024 16:24:00 +0200 Subject: [PATCH 2/3] Fixes --- src/HiddenMarkovModels.jl | 1 + src/inference/forward.jl | 97 ++++++++++++++------------------------- src/inference/predict.jl | 17 +++++++ src/types/abstract_hmm.jl | 6 +++ 4 files changed, 58 insertions(+), 63 deletions(-) create mode 100644 src/inference/predict.jl diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index ecfd99b2..aa319be6 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -38,6 +38,7 @@ include("utils/lightdiagnormal.jl") include("utils/lightcategorical.jl") include("utils/limits.jl") +include("inference/predict.jl") include("inference/forward.jl") include("inference/viterbi.jl") include("inference/forward_backward.jl") diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 197717ea..615c70a8 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -63,6 +63,27 @@ function initialize_forward( return ForwardStorage(α, logL, B, c) end +function _forward_digest_observation!( + current_state_marginals::AbstractVector{<:Real}, + current_obs_likelihoods::AbstractVector{<:Real}, + hmm::AbstractHMM, + obs, + control, +) + a, b = current_state_marginals, current_obs_likelihoods + + obs_logdensities!(b, hmm, obs, control) + logm = maximum(b) + b .= exp.(b .- logm) + + a .*= b + c = inv(sum(a)) + lmul!(c, a) + + logL = -log(c) + logm + return c, logL +end + function _forward!( storage::ForwardOrForwardBackwardStorage, hmm::AbstractHMM, @@ -73,75 +94,25 @@ function _forward!( ) (; α, B, c, logL) = storage t1, t2 = seq_limits(seq_ends, k) - - # Initialization - _initialize!(storage, hmm, t1) - logL = zero(eltype(storage)) - - init = initialization(hmm) - αₜ₁ = view(α, :, t1) - αₜ₁ .= init .* Bₜ₁ - c[t1] = inv(sum(αₜ₁)) - lmul!(c[t1], αₜ₁) - - logL[k] = -log(c[t1]) + logm - - # Loop - for t in t1:(t2 - 1) - Bₜ₊₁ = view(B, :, t + 1) - obs_logdensities!(Bₜ₊₁, hmm, obs_seq[t + 1], control_seq[t + 1]) - logm = maximum(Bₜ₊₁) - Bₜ₊₁ .= exp.(Bₜ₊₁ .- logm) - - trans = transition_matrix(hmm, control_seq[t]) - αₜ, αₜ₊₁ = view(α, :, t), view(α, :, t + 1) - mul!(αₜ₊₁, transpose(trans), αₜ) - αₜ₊₁ .*= Bₜ₊₁ - c[t + 1] = inv(sum(αₜ₊₁)) - lmul!(c[t + 1], αₜ₊₁) - - logL[k] += -log(c[t + 1]) + logm + logL[k] = zero(eltype(logL)) + for t in t1:t2 + αₜ = view(α, :, t) + Bₜ = view(B, :, t) + if t == t1 + copyto!(αₜ, initialization(hmm)) + else + αₜ₋₁ = view(α, :, t - 1) + predict_next_state!(αₜ, hmm, αₜ₋₁, control_seq[t - 1]) + end + cₜ, logLₜ = _forward_digest_observation!(αₜ, Bₜ, hmm, obs_seq[t], control_seq[t]) + c[t] = cₜ + logL[k] += logLₜ end @argcheck isfinite(logL[k]) return nothing end -function _initialize!(storage, hmm, t1) - (; α) = storage - αₜ₁ = view(α, :, t1) - αₜ₁ .= initialization(hmm) - return nothing -end - -function _predict!(storage, hmm, control_seq, t) - (; α) = storage - αₜ₋₁, αₜ = view(α, :, t - 1), view(α, :, t) - - trans = transition_matrix(hmm, control_seq[t]) - mul!(αₜ, transpose(trans), αₜ₋₁) - - return nothing -end - -function _update!(storage, logL, hmm, obs_seq, control_seq, t) - (; α, B, c) = storage - Bₜ = view(B, :, t) - αₜ = view(α, :, t) - - obs_logdensities!(Bₜ, hmm, obs_seq[t], control_seq[t]) - logm = maximum(Bₜ) - Bₜ .= exp.(Bₜ .- logm) - - αₜ .*= Bₜ - c[t] = inv(sum(αₜ)) - lmul!(c[t], αₜ) - - logL += -log(c[t]) + logm - - return logL -end - """ $(SIGNATURES) """ diff --git a/src/inference/predict.jl b/src/inference/predict.jl new file mode 100644 index 00000000..9459fa8b --- /dev/null +++ b/src/inference/predict.jl @@ -0,0 +1,17 @@ +function predict_next_state!( + next_state_marginals::AbstractVector{<:Real}, + hmm::AbstractHMM, + current_state_marginals::AbstractVector{<:Real}, + control=nothing, +) + trans = transition_matrix(hmm, control) + mul!(next_state_marginals, transpose(trans), current_state_marginals) + return next_state_marginals +end + +function predict_next_state( + hmm::AbstractHMM, current_state_marginals::AbstractVector{<:Real}, control=nothing +) + next_state_marginals = similar(current_state_marginals) + return predict_next_state!(next_state_marginals, hmm, current_state_marginals, control) +end diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index c9f0c675..4fe1387c 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -72,6 +72,9 @@ log_initialization(hmm::AbstractHMM) = elementwise_log(initialization(hmm)) transition_matrix(hmm, control) Return the matrix of state transition probabilities for `hmm` (possibly when `control` is applied). + +!!! note + When processing sequences, the control at time `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`). """ function transition_matrix end @@ -82,6 +85,9 @@ function transition_matrix end Return the matrix of state transition log-probabilities for `hmm` (possibly when `control` is applied). Falls back on `transition_matrix`. + +!!! note + When processing sequences, the control at time `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`). """ function log_transition_matrix(hmm::AbstractHMM, control) return elementwise_log(transition_matrix(hmm, control)) From 6b2a3d4952d2c36e79a011f759fa31fa975d1b5c Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 1 Oct 2024 16:32:24 +0200 Subject: [PATCH 3/3] Remove unused predict --- src/inference/predict.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/inference/predict.jl b/src/inference/predict.jl index 9459fa8b..7240af6a 100644 --- a/src/inference/predict.jl +++ b/src/inference/predict.jl @@ -8,10 +8,3 @@ function predict_next_state!( mul!(next_state_marginals, transpose(trans), current_state_marginals) return next_state_marginals end - -function predict_next_state( - hmm::AbstractHMM, current_state_marginals::AbstractVector{<:Real}, control=nothing -) - next_state_marginals = similar(current_state_marginals) - return predict_next_state!(next_state_marginals, hmm, current_state_marginals, control) -end