From 06215d13bdacd2bef8783a529f47b4b597a8ef1f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 15 Nov 2024 22:31:35 +0100 Subject: [PATCH 1/8] Start autoregressive prototype --- src/inference/forward.jl | 8 ++++++-- src/inference/logdensity.jl | 2 +- src/inference/viterbi.jl | 6 ++++-- src/types/abstract_hmm.jl | 16 ++++++++++------ src/types/hmm.jl | 2 +- test/autoregressive.jl | 20 ++++++++++++++++++++ 6 files changed, 42 insertions(+), 12 deletions(-) create mode 100644 test/autoregressive.jl diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 5e6e2518..59ea410d 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -69,10 +69,11 @@ function _forward_digest_observation!( hmm::AbstractHMM, obs, control, + prev_obs, ) a, b = current_state_marginals, current_obs_likelihoods - obs_logdensities!(b, hmm, obs, control) + obs_logdensities!(b, hmm, obs, control, prev_obs) logm = maximum(b) b .= exp.(b .- logm) @@ -104,7 +105,10 @@ function _forward!( αₜ₋₁ = view(α, :, t - 1) predict_next_state!(αₜ, hmm, αₜ₋₁, control_seq[t - 1]) end - cₜ, logLₜ = _forward_digest_observation!(αₜ, Bₜ, hmm, obs_seq[t], control_seq[t]) + prev_obs = t == t1 ? missing : previous_obs(hmm, obs_seq, t) + cₜ, logLₜ = _forward_digest_observation!( + αₜ, Bₜ, hmm, obs_seq[t], control_seq[t], prev_obs + ) c[t] = cₜ logL[k] += logLₜ end diff --git a/src/inference/logdensity.jl b/src/inference/logdensity.jl index ce153ff2..7a174be9 100644 --- a/src/inference/logdensity.jl +++ b/src/inference/logdensity.jl @@ -39,7 +39,7 @@ function joint_logdensityof( end # Observations for t in t1:t2 - dists = obs_distributions(hmm, control_seq[t]) + dists = obs_distributions(hmm, control_seq[t], previous_obs(hmm, obs_seq, t)) logL += logdensityof(dists[state_seq[t]], obs_seq[t]) end end diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index f4f52701..58c9a887 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -48,13 +48,15 @@ function _viterbi!( t1, t2 = seq_limits(seq_ends, k) logBₜ₁ = view(logB, :, t1) - obs_logdensities!(logBₜ₁, hmm, obs_seq[t1], control_seq[t1]) + obs_logdensities!(logBₜ₁, hmm, obs_seq[t1], control_seq[t1], missing) loginit = log_initialization(hmm) ϕ[:, t1] .= loginit .+ logBₜ₁ for t in (t1 + 1):t2 logBₜ = view(logB, :, t) - obs_logdensities!(logBₜ, hmm, obs_seq[t], control_seq[t]) + obs_logdensities!( + logBₜ, hmm, obs_seq[t], control_seq[t], previous_obs(hmm, obs_seq, t) + ) logtrans = log_transition_matrix(hmm, control_seq[t - 1]) ϕₜ, ϕₜ₋₁ = view(ϕ, :, t), view(ϕ, :, t - 1) ψₜ = view(ψ, :, t) diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index 4fe1387c..dbacbb5e 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -23,7 +23,7 @@ Any `AbstractHMM` which satisfies the interface can be given to the following fu - [`forward_backward`](@ref) - [`baum_welch`](@ref) (if `[fit!](@ref)` is implemented) """ -abstract type AbstractHMM end +abstract type AbstractHMM{ar} end @inline DensityInterface.DensityKind(::AbstractHMM) = HasDensity() @@ -46,7 +46,7 @@ It is typically a promotion between the element type of the initialization, the function Base.eltype(hmm::AbstractHMM, obs, control) init_type = eltype(initialization(hmm)) trans_type = eltype(transition_matrix(hmm, control)) - dist = obs_distributions(hmm, control)[1] + dist = obs_distributions(hmm, control, obs)[1] logdensity_type = typeof(logdensityof(dist, obs)) return promote_type(init_type, trans_type, logdensity_type) end @@ -112,6 +112,10 @@ function obs_distributions end transition_matrix(hmm::AbstractHMM, ::Nothing) = transition_matrix(hmm) log_transition_matrix(hmm::AbstractHMM, ::Nothing) = log_transition_matrix(hmm) obs_distributions(hmm::AbstractHMM, ::Nothing) = obs_distributions(hmm) +obs_distributions(hmm::AbstractHMM, control, ::Missing) = obs_distributions(hmm, control) + +previous_obs(::AbstractHMM{false}, obs_seq::AbstractVector, t::Integer) = missing +previous_obs(::AbstractHMM{true}, obs_seq::AbstractVector, t::Integer) = obs_seq[t - 1] """ StatsAPI.fit!( @@ -128,9 +132,9 @@ StatsAPI.fit! ## Fill logdensities function obs_logdensities!( - logb::AbstractVector{T}, hmm::AbstractHMM, obs, control + logb::AbstractVector{T}, hmm::AbstractHMM, obs, control, prev_obs ) where {T} - dists = obs_distributions(hmm, control) + dists = obs_distributions(hmm, control, prev_obs) @simd for i in eachindex(logb, dists) logb[i] = logdensityof(dists[i], obs) end @@ -164,13 +168,13 @@ function Random.rand(rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVe ) end - dists1 = obs_distributions(hmm, control_seq[1]) + dists1 = obs_distributions(hmm, control_seq[1], missing) obs1 = rand(rng, dists1[state1]) obs_seq = Vector{typeof(obs1)}(undef, T) obs_seq[1] = obs1 for t in 2:T - dists = obs_distributions(hmm, control_seq[t]) + dists = obs_distributions(hmm, control_seq[t], previous_obs(hmm, obs_seq, t)) obs_seq[t] = rand(rng, dists[state_seq[t]]) end return (; state_seq=state_seq, obs_seq=obs_seq) diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 9c1f47b4..9ecbe1ad 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -13,7 +13,7 @@ struct HMM{ VD<:AbstractVector, Vl<:AbstractVector, Ml<:AbstractMatrix, -} <: AbstractHMM +} <: AbstractHMM{false} "initial state probabilities" init::V "state transition probabilities" diff --git a/test/autoregressive.jl b/test/autoregressive.jl new file mode 100644 index 00000000..9b697a6c --- /dev/null +++ b/test/autoregressive.jl @@ -0,0 +1,20 @@ +using Distributions +using HiddenMarkovModels +const HMMs = HiddenMarkovModels + +struct AutoRegressiveGaussianHMM{T} <: AbstractHMM{true} + init::Vector{T} + trans::Matrix{T} + a::Vector{T} + b::Vector{T} +end + +const ARGHMM = AutoRegressiveGaussianHMM + +HMMs.initialization(hmm::ARGHMM) = hmm.init +HMMs.transition_matrix(hmm::ARGHMM) = hmm.trans + +function HMMs.obs_distributions(hmm::ARGHMM, _control, prev_obs) + (; a, b) = hmm + return [Normal(a[i] * prev_obs + b[i], 1.0) for i in 1:length(hmm)] +end From dd09a01a8fa8aca86b4bf0abc008d1cf8e3979e1 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 15 Nov 2024 22:33:09 +0100 Subject: [PATCH 2/8] Nothing or missing --- src/types/abstract_hmm.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index dbacbb5e..13a64b8a 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -112,9 +112,11 @@ function obs_distributions end transition_matrix(hmm::AbstractHMM, ::Nothing) = transition_matrix(hmm) log_transition_matrix(hmm::AbstractHMM, ::Nothing) = log_transition_matrix(hmm) obs_distributions(hmm::AbstractHMM, ::Nothing) = obs_distributions(hmm) -obs_distributions(hmm::AbstractHMM, control, ::Missing) = obs_distributions(hmm, control) +function obs_distributions(hmm::AbstractHMM, control, ::Union{Nothing,Missing}) + return obs_distributions(hmm, control) +end -previous_obs(::AbstractHMM{false}, obs_seq::AbstractVector, t::Integer) = missing +previous_obs(::AbstractHMM{false}, obs_seq::AbstractVector, t::Integer) = nothing previous_obs(::AbstractHMM{true}, obs_seq::AbstractVector, t::Integer) = obs_seq[t - 1] """ From 0b82437351fd8d6b8e648034ea26f6814edda542 Mon Sep 17 00:00:00 2001 From: Fausto Marques Pinheiro Junior Date: Fri, 22 Nov 2024 21:26:14 -0300 Subject: [PATCH 3/8] Second pass on ARHMM --- src/inference/chainrules.jl | 6 +++--- src/inference/forward.jl | 6 +++--- src/inference/forward_backward.jl | 8 ++++---- src/inference/logdensity.jl | 15 +++++++++------ src/inference/viterbi.jl | 6 +++--- src/types/abstract_hmm.jl | 30 +++++++++++++++--------------- test/discretecontrolarhmm.jl | 22 ++++++++++++++++++++++ 7 files changed, 59 insertions(+), 34 deletions(-) create mode 100644 test/discretecontrolarhmm.jl diff --git a/src/inference/chainrules.jl b/src/inference/chainrules.jl index 8816120f..b5a66f57 100644 --- a/src/inference/chainrules.jl +++ b/src/inference/chainrules.jl @@ -6,9 +6,9 @@ function _params_and_loglikelihoods( control_seq::AbstractVector=Fill(nothing, length(obs_seq)); seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), ) - init = initialization(hmm) + init = initialization(hmm, control) trans_by_time = mapreduce(_dcat, eachindex(control_seq)) do t - transition_matrix(hmm, control_seq[t]) + t == 1 ? diagm(ones(size(hmm, t))) : transition_matrix(hmm, control_seq[t]) # I did't understand what this is doing, but my best guess is that it returns the transition matrix for each moment `t` to `t+1`. If this is the case, then, like forward.jl, line 106, the control variable matches `t+1`. To avoid messing up the logic, I just made the first matrix to be the identity matrix, and the following matrices are P(X_{t+1}|X_{t},U_{t+1}). end logB = mapreduce(hcat, eachindex(obs_seq, control_seq)) do t logdensityof.(obs_distributions(hmm, control_seq[t]), (obs_seq[t],)) @@ -30,7 +30,7 @@ function ChainRulesCore.rrule( fb_storage = initialize_forward_backward(hmm, obs_seq, control_seq; seq_ends) forward_backward!(fb_storage, hmm, obs_seq, control_seq; seq_ends) (; logL, α, γ, Bβ) = fb_storage - N, T = length(hmm), length(obs_seq) + N, T = size(hmm, control_seq[1]), length(obs_seq) R = eltype(α) Δinit = zeros(R, N) diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 59ea410d..ff024e2b 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -54,7 +54,7 @@ function initialize_forward( control_seq::AbstractVector; seq_ends::AbstractVectorOrNTuple{Int}, ) - N, T, K = length(hmm), length(obs_seq), length(seq_ends) + N, T, K = size(hmm, control_seq[1]), length(obs_seq), length(seq_ends) R = eltype(hmm, obs_seq[1], control_seq[1]) α = Matrix{R}(undef, N, T) logL = Vector{R}(undef, K) @@ -100,10 +100,10 @@ function _forward!( αₜ = view(α, :, t) Bₜ = view(B, :, t) if t == t1 - copyto!(αₜ, initialization(hmm)) + copyto!(αₜ, initialization(hmm, control_seq[t])) else αₜ₋₁ = view(α, :, t - 1) - predict_next_state!(αₜ, hmm, αₜ₋₁, control_seq[t - 1]) + predict_next_state!(αₜ, hmm, αₜ₋₁, control_seq[t]) # If `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`), then the associated control must be at `t+1`, right? If `control_seq[t-1]`, then we're using the control associated with the previous state and not the correct control, aren't we? The transition matrix would be P(X_{t}|X_{t-1},U_{t-1}) and not P(X_{t}|X_{t-1},U_{t}) as it should be. E.g., if `t == t1 + 1`, then `αₜ₋₁ = view(α, :, t1)` and the function would use the transition matrix P(X_{t1+1}|X_{t1},U_{t1}) instead of P(X_{t1+1}|X_{t1},U_{t1+1}). Same at logdensity.jl, line 37; forward_backward.jl, line 53. end prev_obs = t == t1 ? missing : previous_obs(hmm, obs_seq, t) cₜ, logLₜ = _forward_digest_observation!( diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index ba51ad46..a6cefa4e 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -8,9 +8,9 @@ function initialize_forward_backward( seq_ends::AbstractVectorOrNTuple{Int}, transition_marginals=true, ) - N, T, K = length(hmm), length(obs_seq), length(seq_ends) + N, T, K = size(hmm, control_seq[1]), length(obs_seq), length(seq_ends) R = eltype(hmm, obs_seq[1], control_seq[1]) - trans = transition_matrix(hmm, control_seq[1]) + trans = transition_matrix(hmm, control_seq[2]) M = typeof(similar(trans, R)) γ = Matrix{R}(undef, N, T) @@ -50,7 +50,7 @@ function _forward_backward!( Bβ[:, t + 1] .= view(B, :, t + 1) .* view(β, :, t + 1) βₜ = view(β, :, t) Bβₜ₊₁ = view(Bβ, :, t + 1) - predict_previous_state!(βₜ, hmm, Bβₜ₊₁, control_seq[t]) + predict_previous_state!(βₜ, hmm, Bβₜ₊₁, control_seq[t+1]) # See forward.jl, line 106. lmul!(c[t], βₜ) end Bβ[:, t1] .= view(B, :, t1) .* view(β, :, t1) @@ -61,7 +61,7 @@ function _forward_backward!( # Transition marginals if transition_marginals for t in t1:(t2 - 1) - trans = transition_matrix(hmm, control_seq[t]) + trans = transition_matrix(hmm, control_seq[t+1]) # See forward.jl, line 106. mul_rows_cols!(ξ[t], view(α, :, t), trans, view(Bβ, :, t + 1)) end ξ[t2] .= zero(R) diff --git a/src/inference/logdensity.jl b/src/inference/logdensity.jl index 7a174be9..5a33877b 100644 --- a/src/inference/logdensity.jl +++ b/src/inference/logdensity.jl @@ -29,16 +29,19 @@ function joint_logdensityof( logL = zero(R) for k in eachindex(seq_ends) t1, t2 = seq_limits(seq_ends, k) - # Initialization - init = initialization(hmm) + # Initialization: P(X_{1}|U_{1}) + init = initialization(hmm, control_seq[t1]) logL += log(init[state_seq[t1]]) - # Transitions + # Transitions: P(X_{t+1}|X_{t},U_{t+1}) for t in t1:(t2 - 1) - trans = transition_matrix(hmm, control_seq[t]) + trans = transition_matrix(hmm, control_seq[t+1]) # See forward.jl, line 106. logL += log(trans[state_seq[t], state_seq[t + 1]]) end - # Observations - for t in t1:t2 + # Priori: P(Y_{1}|X_{1},U_{1}) + dists = obs_distributions(hmm, control_seq[t1], missing) + logL += logdensityof(dists[state_seq[t1]], obs_seq[t1]) + # Observations: P(Y_{t}|Y_{t-1},X_{t},U_{t}) + for t in (t1+1):t2 dists = obs_distributions(hmm, control_seq[t], previous_obs(hmm, obs_seq, t)) logL += logdensityof(dists[state_seq[t]], obs_seq[t]) end diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index 58c9a887..1a21f609 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -26,7 +26,7 @@ function initialize_viterbi( control_seq::AbstractVector; seq_ends::AbstractVectorOrNTuple{Int}, ) - N, T, K = length(hmm), length(obs_seq), length(seq_ends) + N, T, K = size(hmm, control_seq[1]), length(obs_seq), length(seq_ends) R = eltype(hmm, obs_seq[1], control_seq[1]) q = Vector{Int}(undef, T) logL = Vector{R}(undef, K) @@ -49,7 +49,7 @@ function _viterbi!( logBₜ₁ = view(logB, :, t1) obs_logdensities!(logBₜ₁, hmm, obs_seq[t1], control_seq[t1], missing) - loginit = log_initialization(hmm) + loginit = log_initialization(hmm, control_seq[t1]) ϕ[:, t1] .= loginit .+ logBₜ₁ for t in (t1 + 1):t2 @@ -57,7 +57,7 @@ function _viterbi!( obs_logdensities!( logBₜ, hmm, obs_seq[t], control_seq[t], previous_obs(hmm, obs_seq, t) ) - logtrans = log_transition_matrix(hmm, control_seq[t - 1]) + logtrans = log_transition_matrix(hmm, control_seq[t]) # See forward.jl, line 106. ϕₜ, ϕₜ₋₁ = view(ϕ, :, t), view(ϕ, :, t - 1) ψₜ = view(ψ, :, t) argmaxplus_transmul!(ϕₜ, ψₜ, logtrans, ϕₜ₋₁) diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index 13a64b8a..2eeab438 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -30,11 +30,11 @@ abstract type AbstractHMM{ar} end ## Interface """ - length(hmm) + size(hmm, control) Return the number of states of `hmm`. """ -Base.length(hmm::AbstractHMM) = length(initialization(hmm)) +Base.size(hmm::AbstractHMM, control) = size(transition_matrix(hmm, control), 2) """ eltype(hmm, obs, control) @@ -53,19 +53,23 @@ end """ initialization(hmm) + initialization(hmm, control) -Return the vector of initial state probabilities for `hmm`. +Return the vector of initial state probabilities for `hmm` (possibly when `control` is applied). """ function initialization end +initialization(hmm::AbstractHMM, ::Nothing) = initialization(hmm) + """ log_initialization(hmm) + log_initialization(hmm, control) -Return the vector of initial state log-probabilities for `hmm`. +Return the vector of initial state log-probabilities for `hmm` (possibly when `control` is applied). Falls back on `initialization`. """ -log_initialization(hmm::AbstractHMM) = elementwise_log(initialization(hmm)) +log_initialization(hmm::AbstractHMM, control) = elementwise_log(initialization(hmm, control)) """ transition_matrix(hmm) @@ -89,9 +93,7 @@ Falls back on `transition_matrix`. !!! note When processing sequences, the control at time `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`). """ -function log_transition_matrix(hmm::AbstractHMM, control) - return elementwise_log(transition_matrix(hmm, control)) -end +log_transition_matrix(hmm::AbstractHMM, control) = elementwise_log(transition_matrix(hmm, control)) """ obs_distributions(hmm) @@ -110,11 +112,9 @@ function obs_distributions end ## Fallbacks for no control transition_matrix(hmm::AbstractHMM, ::Nothing) = transition_matrix(hmm) -log_transition_matrix(hmm::AbstractHMM, ::Nothing) = log_transition_matrix(hmm) +log_transition_matrix(hmm::AbstractHMM, ::Nothing) = log_transition_matrix(hmm) # Is this function needed? If `log_transition_matrix(hmm, nothing)`, then `transition_matrix(hmm, nothing)` returns `transition_matrix(hmm)`. obs_distributions(hmm::AbstractHMM, ::Nothing) = obs_distributions(hmm) -function obs_distributions(hmm::AbstractHMM, control, ::Union{Nothing,Missing}) - return obs_distributions(hmm, control) -end +obs_distributions(hmm::AbstractHMM, control, ::Union{Nothing,Missing}) = obs_distributions(hmm, control) previous_obs(::AbstractHMM{false}, obs_seq::AbstractVector, t::Integer) = nothing previous_obs(::AbstractHMM{true}, obs_seq::AbstractVector, t::Integer) = obs_seq[t - 1] @@ -156,9 +156,9 @@ Return a named tuple `(; state_seq, obs_seq)`. """ function Random.rand(rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVector) T = length(control_seq) - dummy_log_probas = fill(-Inf, length(hmm)) + dummy_log_probas = fill(-Inf, size(hmm, control_seq[1])) - init = initialization(hmm) + init = initialization(hmm, control) state_seq = Vector{Int}(undef, T) state1 = rand(rng, LightCategorical(init, dummy_log_probas)) state_seq[1] = state1 @@ -172,7 +172,7 @@ function Random.rand(rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVe dists1 = obs_distributions(hmm, control_seq[1], missing) obs1 = rand(rng, dists1[state1]) - obs_seq = Vector{typeof(obs1)}(undef, T) + obs_seq = Vector{typeof(obs1)}(undef, T) # If the `typeof(obs1)` is only known at runtime, does it makes any difference over Vector{Any}? obs_seq[1] = obs1 for t in 2:T diff --git a/test/discretecontrolarhmm.jl b/test/discretecontrolarhmm.jl new file mode 100644 index 00000000..bf4193f1 --- /dev/null +++ b/test/discretecontrolarhmm.jl @@ -0,0 +1,22 @@ +using Distributions +using HiddenMarkovModels +const HMMs = HiddenMarkovModels + +struct DiscreteCARHMM{T<:Number} <: AbstractHMM{true} + # Initial distribution P(X_{1}|U_{1}), one vector for each control + init::Vector{Vector{T}} + # Transition matrix P(X_{t}|X_{t-1}, U_{t}), one matrix for each control + trans::Vector{Matrix{T}} + # Emission matriz P(Y_{t}|X_{t}, U_{t}), one matriz for each control and each possible observation + dists::Vector{Vector{Matrix{T}}} + # Prior Distribution for P(Y_{1}|X_{1}, U_{1}), one matriz for each control + prior::Vector{Matrix{T}} +end + +HMMs.initialization(hmm::DiscreteCARHMM, control) = hmm.init[control] + +HMMs.transition_matrix(hmm::DiscreteCARHMM, control) = hmm.trans[control] + +HMMs.obs_distributions(hmm::DiscreteCARHMM, control, prev_obs) = [Categorical(hmm.dists[control][prev_obs][i,:]) for i in 1:length(hmm, control)] + +HMMs.obs_distributions(hmm::DiscreteCARHMM, control, ::Missing) = [Categorical(hmm.prior[control][i,:]) for i in 1:length(hmm, control)] \ No newline at end of file From 631d838bf17adae48480fc71cdc28441ae5c8b3b Mon Sep 17 00:00:00 2001 From: Fausto Marques Pinheiro Junior Date: Mon, 25 Nov 2024 17:49:41 -0300 Subject: [PATCH 4/8] Finishing touches in the second pass for ARHMM --- examples/autodiff.jl | 8 +++++--- examples/controlled.jl | 15 +++++++++------ examples/interfaces.jl | 10 +++++----- examples/temporal.jl | 19 +++++++++++-------- src/HiddenMarkovModels.jl | 2 +- src/inference/chainrules.jl | 2 +- src/inference/forward_backward.jl | 4 ++-- src/inference/logdensity.jl | 6 +++--- src/types/abstract_hmm.jl | 22 +++++++++++++--------- src/types/hmm.jl | 2 +- test/discretecontrolarhmm.jl | 20 ++++++++++++-------- 11 files changed, 63 insertions(+), 47 deletions(-) diff --git a/examples/autodiff.jl b/examples/autodiff.jl index 75905b17..1f9f4e05 100644 --- a/examples/autodiff.jl +++ b/examples/autodiff.jl @@ -29,7 +29,8 @@ rng = StableRNG(63); To play around with automatic differentiation, we define a simple controlled HMM. =# -struct DiffusionHMM{V1<:AbstractVector,M2<:AbstractMatrix,V3<:AbstractVector} <: AbstractHMM +struct DiffusionHMM{V1<:AbstractVector,M2<:AbstractMatrix,V3<:AbstractVector} <: + AbstractHMM{false} init::V1 trans::M2 means::V3 @@ -40,15 +41,16 @@ Both its transition matrix and its vector of observation means result from a con The coefficient $\lambda$ of this convex combination is given as a control. =# +HMMs.initialization(hmm::DiffusionHMM, λ::Number) = hmm.init HMMs.initialization(hmm::DiffusionHMM) = hmm.init function HMMs.transition_matrix(hmm::DiffusionHMM, λ::Number) - N = length(hmm) + N = size(hmm.trans, 2) return (1 - λ) * hmm.trans + λ * ones(N, N) / N end function HMMs.obs_distributions(hmm::DiffusionHMM, λ::Number) - return [Normal((1 - λ) * hmm.means[i] + λ * 0) for i in 1:length(hmm)] + return [Normal((1 - λ) * hmm.means[i] + λ * 0) for i in 1:size(hmm, λ)] end #= diff --git a/examples/controlled.jl b/examples/controlled.jl index ffd0ab98..4d910efc 100644 --- a/examples/controlled.jl +++ b/examples/controlled.jl @@ -22,10 +22,10 @@ rng = StableRNG(63); #= A Markov switching regression is like a classical regression, except that the weights depend on the unobserved state of an HMM. -We can represent it with the following subtype of `AbstractHMM` (see [Custom HMM structures](@ref)), which has one vector of coefficients $\beta_i$ per state. +We can represent it with the following subtype of `AbstractHMM{false}` (see [Custom HMM structures](@ref)), which has one vector of coefficients $\beta_i$ per state. =# -struct ControlledGaussianHMM{T} <: AbstractHMM +struct ControlledGaussianHMM{T} <: AbstractHMM{false} init::Vector{T} trans::Matrix{T} dist_coeffs::Vector{Vector{T}} @@ -36,16 +36,19 @@ In state $i$ with a vector of controls $u$, our observation is given by the line Controls must be provided to both `transition_matrix` and `obs_distributions` even if they are only used by one. =# +function HMMs.initialization(hmm::ControlledGaussianHMM, control) + return hmm.init +end function HMMs.initialization(hmm::ControlledGaussianHMM) return hmm.init end -function HMMs.transition_matrix(hmm::ControlledGaussianHMM, control::AbstractVector) +function HMMs.transition_matrix(hmm::ControlledGaussianHMM, control) return hmm.trans end -function HMMs.obs_distributions(hmm::ControlledGaussianHMM, control::AbstractVector) - return [Normal(dot(hmm.dist_coeffs[i], control), 1.0) for i in 1:length(hmm)] +function HMMs.obs_distributions(hmm::ControlledGaussianHMM, control) + return [Normal(dot(hmm.dist_coeffs[i], control), 1.0) for i in 1:size(hmm, control)] end #= @@ -97,7 +100,7 @@ function StatsAPI.fit!( seq_ends, ) where {T} (; γ, ξ) = fb_storage - N = length(hmm) + N = size(hmm, control_seq[1]) hmm.init .= 0 hmm.trans .= 0 diff --git a/examples/interfaces.jl b/examples/interfaces.jl index 4498478f..cfcfa3eb 100644 --- a/examples/interfaces.jl +++ b/examples/interfaces.jl @@ -138,14 +138,14 @@ test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) #src #= In some scenarios, the vanilla Baum-Welch algorithm is not exactly what we want. For instance, we might have a prior on the parameters of our model, which we want to apply during the fitting step of the iterative procedure. -Then we need to create a new type that satisfies the `AbstractHMM` interface. +Then we need to create a new type that satisfies the `AbstractHMM{ar}` interface. Let's make a simpler version of the built-in `HMM`, with a prior saying that each transition has already been observed a certain number of times. Such a prior can be very useful to regularize estimation and avoid numerical instabilities. It amounts to drawing every row of the transition matrix from a Dirichlet distribution, where each Dirichlet parameter is one plus the number of times the corresponding transition has been observed. =# -struct PriorHMM{T,D} <: AbstractHMM +struct PriorHMM{T,D} <: AbstractHMM{false} init::Vector{T} trans::Matrix{T} dists::Vector{D} @@ -153,7 +153,7 @@ struct PriorHMM{T,D} <: AbstractHMM end #= -The basic requirements for `AbstractHMM` are the following three functions: [`initialization`](@ref), [`transition_matrix`](@ref) and [`obs_distributions`](@ref). +The basic requirements for `AbstractHMM{false}` are the following three functions: [`initialization`](@ref), [`transition_matrix`](@ref) and [`obs_distributions`](@ref). =# HiddenMarkovModels.initialization(hmm::PriorHMM) = hmm.init @@ -166,7 +166,7 @@ If we forget to implement this, the loglikelihood computed in Baum-Welch will be =# function DensityInterface.logdensityof(hmm::PriorHMM) - prior = Dirichlet(fill(hmm.trans_prior_count + 1, length(hmm))) + prior = Dirichlet(fill(hmm.trans_prior_count + 1, size(hmm, nothing))) return sum(logdensityof(prior, row) for row in eachrow(transition_matrix(hmm))) end @@ -204,7 +204,7 @@ function StatsAPI.fit!( hmm.init ./= sum(hmm.init) hmm.trans ./= sum(hmm.trans; dims=2) - for i in 1:length(hmm) + for i in 1:size(hmm, nothing) ## weigh each sample by the marginal probability of being in state i weight_seq = fb_storage.γ[i, :] ## fit observation distribution i using those weights diff --git a/examples/temporal.jl b/examples/temporal.jl index 1cad38f6..ef0e545f 100644 --- a/examples/temporal.jl +++ b/examples/temporal.jl @@ -23,10 +23,10 @@ rng = StableRNG(63); #= We focus on the particular case of a periodic HMM with period `L`. It has only one initialization vector, but `L` transition matrices and `L` vectors of observation distributions. -As in [Custom HMM structures](@ref), we need to subtype `AbstractHMM`. +As in [Custom HMM structures](@ref), we need to subtype `AbstractHMM{ar}`. =# -struct PeriodicHMM{T<:Number,D,L} <: AbstractHMM +struct PeriodicHMM{T<:Number,D,L} <: AbstractHMM{false} init::Vector{T} trans_per::NTuple{L,Matrix{T}} dists_per::NTuple{L,Vector{D}} @@ -41,14 +41,17 @@ period(::PeriodicHMM{T,D,L}) where {T,D,L} = L function HMMs.initialization(hmm::PeriodicHMM) return hmm.init end +function HMMs.initialization(hmm::PeriodicHMM, control::Integer) + return hmm.init +end -function HMMs.transition_matrix(hmm::PeriodicHMM, t::Integer) - l = (t - 1) % period(hmm) + 1 +function HMMs.transition_matrix(hmm::PeriodicHMM, control::Integer) + l = (control - 1) % period(hmm) + 1 return hmm.trans_per[l] end -function HMMs.obs_distributions(hmm::PeriodicHMM, t::Integer) - l = (t - 1) % period(hmm) + 1 +function HMMs.obs_distributions(hmm::PeriodicHMM, control::Integer) + l = (control - 1) % period(hmm) + 1 return hmm.dists_per[l] end @@ -100,7 +103,7 @@ vcat(obs_seq', best_state_seq') # ## Learning #= -When estimating parameters for a custom subtype of `AbstractHMM`, we have to override the fitting procedure after forward-backward, with an additional `control_seq` positional argument. +When estimating parameters for a custom subtype of `AbstractHMM{false}`, we have to override the fitting procedure after forward-backward, with an additional `control_seq` positional argument. The key is to split the observations according to which periodic parameter they belong to. =# @@ -112,7 +115,7 @@ function StatsAPI.fit!( seq_ends, ) where {T} (; γ, ξ) = fb_storage - L, N = period(hmm), length(hmm) + L, N = period(hmm), size(hmm, control_seq[1]) hmm.init .= zero(T) for l in 1:L diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index aa319be6..46b24da8 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -16,7 +16,7 @@ using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, rrule_via_ad using DensityInterface: DensityInterface, DensityKind, HasDensity, NoDensity, logdensityof using DocStringExtensions using FillArrays: Fill -using LinearAlgebra: Transpose, axpy!, dot, ldiv!, lmul!, mul!, parent +using LinearAlgebra: Transpose, axpy!, dot, ldiv!, lmul!, mul!, parent, diagm using Random: Random, AbstractRNG, default_rng using SparseArrays: AbstractSparseArray, SparseMatrixCSC, nonzeros, nnz, nzrange, rowvals using StatsAPI: StatsAPI, fit, fit! diff --git a/src/inference/chainrules.jl b/src/inference/chainrules.jl index b5a66f57..c8a1bb10 100644 --- a/src/inference/chainrules.jl +++ b/src/inference/chainrules.jl @@ -6,7 +6,7 @@ function _params_and_loglikelihoods( control_seq::AbstractVector=Fill(nothing, length(obs_seq)); seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), ) - init = initialization(hmm, control) + init = initialization(hmm, control_seq[1]) trans_by_time = mapreduce(_dcat, eachindex(control_seq)) do t t == 1 ? diagm(ones(size(hmm, t))) : transition_matrix(hmm, control_seq[t]) # I did't understand what this is doing, but my best guess is that it returns the transition matrix for each moment `t` to `t+1`. If this is the case, then, like forward.jl, line 106, the control variable matches `t+1`. To avoid messing up the logic, I just made the first matrix to be the identity matrix, and the following matrices are P(X_{t+1}|X_{t},U_{t+1}). end diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index a6cefa4e..afe224f4 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -50,7 +50,7 @@ function _forward_backward!( Bβ[:, t + 1] .= view(B, :, t + 1) .* view(β, :, t + 1) βₜ = view(β, :, t) Bβₜ₊₁ = view(Bβ, :, t + 1) - predict_previous_state!(βₜ, hmm, Bβₜ₊₁, control_seq[t+1]) # See forward.jl, line 106. + predict_previous_state!(βₜ, hmm, Bβₜ₊₁, control_seq[t + 1]) # See forward.jl, line 106. lmul!(c[t], βₜ) end Bβ[:, t1] .= view(B, :, t1) .* view(β, :, t1) @@ -61,7 +61,7 @@ function _forward_backward!( # Transition marginals if transition_marginals for t in t1:(t2 - 1) - trans = transition_matrix(hmm, control_seq[t+1]) # See forward.jl, line 106. + trans = transition_matrix(hmm, control_seq[t + 1]) # See forward.jl, line 106. mul_rows_cols!(ξ[t], view(α, :, t), trans, view(Bβ, :, t + 1)) end ξ[t2] .= zero(R) diff --git a/src/inference/logdensity.jl b/src/inference/logdensity.jl index 5a33877b..b2a2f43a 100644 --- a/src/inference/logdensity.jl +++ b/src/inference/logdensity.jl @@ -34,14 +34,14 @@ function joint_logdensityof( logL += log(init[state_seq[t1]]) # Transitions: P(X_{t+1}|X_{t},U_{t+1}) for t in t1:(t2 - 1) - trans = transition_matrix(hmm, control_seq[t+1]) # See forward.jl, line 106. + trans = transition_matrix(hmm, control_seq[t + 1]) # See forward.jl, line 106. logL += log(trans[state_seq[t], state_seq[t + 1]]) end # Priori: P(Y_{1}|X_{1},U_{1}) - dists = obs_distributions(hmm, control_seq[t1], missing) + dists = obs_distributions(hmm, control_seq[t1], missing) logL += logdensityof(dists[state_seq[t1]], obs_seq[t1]) # Observations: P(Y_{t}|Y_{t-1},X_{t},U_{t}) - for t in (t1+1):t2 + for t in (t1 + 1):t2 dists = obs_distributions(hmm, control_seq[t], previous_obs(hmm, obs_seq, t)) logL += logdensityof(dists[state_seq[t]], obs_seq[t]) end diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index 2eeab438..8bdf5644 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -46,8 +46,8 @@ It is typically a promotion between the element type of the initialization, the function Base.eltype(hmm::AbstractHMM, obs, control) init_type = eltype(initialization(hmm)) trans_type = eltype(transition_matrix(hmm, control)) - dist = obs_distributions(hmm, control, obs)[1] - logdensity_type = typeof(logdensityof(dist, obs)) + dists = obs_distributions(hmm, control, obs) + logdensity_type = typeof(logdensityof(dists[1], obs)) return promote_type(init_type, trans_type, logdensity_type) end @@ -59,8 +59,6 @@ Return the vector of initial state probabilities for `hmm` (possibly when `contr """ function initialization end -initialization(hmm::AbstractHMM, ::Nothing) = initialization(hmm) - """ log_initialization(hmm) log_initialization(hmm, control) @@ -69,7 +67,8 @@ Return the vector of initial state log-probabilities for `hmm` (possibly when `c Falls back on `initialization`. """ -log_initialization(hmm::AbstractHMM, control) = elementwise_log(initialization(hmm, control)) +log_initialization(hmm::AbstractHMM, control) = + elementwise_log(initialization(hmm, control)) """ transition_matrix(hmm) @@ -93,11 +92,13 @@ Falls back on `transition_matrix`. !!! note When processing sequences, the control at time `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`). """ -log_transition_matrix(hmm::AbstractHMM, control) = elementwise_log(transition_matrix(hmm, control)) +log_transition_matrix(hmm::AbstractHMM, control) = + elementwise_log(transition_matrix(hmm, control)) """ obs_distributions(hmm) obs_distributions(hmm, control) + obs_distributions(hmm, control, obs) Return a vector of observation distributions, one for each state of `hmm` (possibly when `control` is applied). @@ -111,10 +112,13 @@ function obs_distributions end ## Fallbacks for no control +initialization(hmm::AbstractHMM, ::Nothing) = initialization(hmm) transition_matrix(hmm::AbstractHMM, ::Nothing) = transition_matrix(hmm) log_transition_matrix(hmm::AbstractHMM, ::Nothing) = log_transition_matrix(hmm) # Is this function needed? If `log_transition_matrix(hmm, nothing)`, then `transition_matrix(hmm, nothing)` returns `transition_matrix(hmm)`. obs_distributions(hmm::AbstractHMM, ::Nothing) = obs_distributions(hmm) -obs_distributions(hmm::AbstractHMM, control, ::Union{Nothing,Missing}) = obs_distributions(hmm, control) +function obs_distributions(hmm::AbstractHMM, control, ::Any) + return obs_distributions(hmm, control) +end previous_obs(::AbstractHMM{false}, obs_seq::AbstractVector, t::Integer) = nothing previous_obs(::AbstractHMM{true}, obs_seq::AbstractVector, t::Integer) = obs_seq[t - 1] @@ -158,13 +162,13 @@ function Random.rand(rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVe T = length(control_seq) dummy_log_probas = fill(-Inf, size(hmm, control_seq[1])) - init = initialization(hmm, control) + init = initialization(hmm, control_seq[1]) state_seq = Vector{Int}(undef, T) state1 = rand(rng, LightCategorical(init, dummy_log_probas)) state_seq[1] = state1 @views for t in 1:(T - 1) - trans = transition_matrix(hmm, control_seq[t]) + trans = transition_matrix(hmm, control_seq[t + 1]) # See forward.jl, line 106. state_seq[t + 1] = rand( rng, LightCategorical(trans[state_seq[t], :], dummy_log_probas) ) diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 9ecbe1ad..f9d89f16 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -90,7 +90,7 @@ function StatsAPI.fit!( sum_to_one!(hmm.init) foreach(sum_to_one!, eachrow(hmm.trans)) # Fit observations - for i in 1:length(hmm) + for i in 1:size(hmm, nothing) fit_in_sequence!(hmm.dists, i, obs_seq, view(γ, i, :)) end # Update logs diff --git a/test/discretecontrolarhmm.jl b/test/discretecontrolarhmm.jl index bf4193f1..1a5b5233 100644 --- a/test/discretecontrolarhmm.jl +++ b/test/discretecontrolarhmm.jl @@ -3,20 +3,24 @@ using HiddenMarkovModels const HMMs = HiddenMarkovModels struct DiscreteCARHMM{T<:Number} <: AbstractHMM{true} - # Initial distribution P(X_{1}|U_{1}), one vector for each control + # Initial distribution P(X_{1}|U_{1}), one vector for each control init::Vector{Vector{T}} - # Transition matrix P(X_{t}|X_{t-1}, U_{t}), one matrix for each control + # Transition matrix P(X_{t}|X_{t-1}, U_{t}), one matrix for each control trans::Vector{Matrix{T}} - # Emission matriz P(Y_{t}|X_{t}, U_{t}), one matriz for each control and each possible observation - dists::Vector{Vector{Matrix{T}}} - # Prior Distribution for P(Y_{1}|X_{1}, U_{1}), one matriz for each control - prior::Vector{Matrix{T}} + # Emission matriz P(Y_{t}|X_{t}, U_{t}), one matriz for each control and each possible observation + dists::Vector{Vector{Matrix{T}}} + # Prior Distribution for P(Y_{1}|X_{1}, U_{1}), one matriz for each control + prior::Vector{Matrix{T}} end HMMs.initialization(hmm::DiscreteCARHMM, control) = hmm.init[control] HMMs.transition_matrix(hmm::DiscreteCARHMM, control) = hmm.trans[control] -HMMs.obs_distributions(hmm::DiscreteCARHMM, control, prev_obs) = [Categorical(hmm.dists[control][prev_obs][i,:]) for i in 1:length(hmm, control)] +function HMMs.obs_distributions(hmm::DiscreteCARHMM, control, prev_obs) + return [Categorical(hmm.dists[control][prev_obs][i, :]) for i in 1:size(hmm, control)] +end -HMMs.obs_distributions(hmm::DiscreteCARHMM, control, ::Missing) = [Categorical(hmm.prior[control][i,:]) for i in 1:length(hmm, control)] \ No newline at end of file +function HMMs.obs_distributions(hmm::DiscreteCARHMM, control, ::Missing) + return [Categorical(hmm.prior[control][i, :]) for i in 1:size(hmm, control)] +end From c2ee9d491717a2d46ad7c2c4776a831e391342c5 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 6 Dec 2024 07:38:12 +0100 Subject: [PATCH 5/8] Test on gd/ar --- .github/workflows/test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 975ccb9b..e75f5890 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -3,6 +3,7 @@ on: push: branches: - main + - gd/ar tags: ["*"] pull_request: concurrency: From b6eb4ccbe7a7d452ca8115fc5aca2e8e4da0db9d Mon Sep 17 00:00:00 2001 From: Fausto Marques Pinheiro Junior Date: Mon, 9 Dec 2024 17:33:39 -0300 Subject: [PATCH 6/8] Third second pass for ARHMM --- examples/autodiff.jl | 5 ++--- examples/controlled.jl | 11 ++++------- examples/interfaces.jl | 2 +- examples/temporal.jl | 13 +++++-------- src/inference/chainrules.jl | 6 +++--- src/inference/forward.jl | 6 +++--- src/inference/forward_backward.jl | 6 +++--- src/inference/logdensity.jl | 8 ++++---- src/inference/viterbi.jl | 4 ++-- src/types/abstract_hmm.jl | 20 ++++++++++---------- src/types/hmm.jl | 2 +- 11 files changed, 38 insertions(+), 45 deletions(-) diff --git a/examples/autodiff.jl b/examples/autodiff.jl index 1f9f4e05..e498ed19 100644 --- a/examples/autodiff.jl +++ b/examples/autodiff.jl @@ -41,16 +41,15 @@ Both its transition matrix and its vector of observation means result from a con The coefficient $\lambda$ of this convex combination is given as a control. =# -HMMs.initialization(hmm::DiffusionHMM, λ::Number) = hmm.init HMMs.initialization(hmm::DiffusionHMM) = hmm.init function HMMs.transition_matrix(hmm::DiffusionHMM, λ::Number) - N = size(hmm.trans, 2) + N = length(hmm) return (1 - λ) * hmm.trans + λ * ones(N, N) / N end function HMMs.obs_distributions(hmm::DiffusionHMM, λ::Number) - return [Normal((1 - λ) * hmm.means[i] + λ * 0) for i in 1:size(hmm, λ)] + return [Normal((1 - λ) * hmm.means[i] + λ * 0) for i in 1:length(hmm)] end #= diff --git a/examples/controlled.jl b/examples/controlled.jl index 4d910efc..a0698632 100644 --- a/examples/controlled.jl +++ b/examples/controlled.jl @@ -36,19 +36,16 @@ In state $i$ with a vector of controls $u$, our observation is given by the line Controls must be provided to both `transition_matrix` and `obs_distributions` even if they are only used by one. =# -function HMMs.initialization(hmm::ControlledGaussianHMM, control) - return hmm.init -end function HMMs.initialization(hmm::ControlledGaussianHMM) return hmm.init end -function HMMs.transition_matrix(hmm::ControlledGaussianHMM, control) +function HMMs.transition_matrix(hmm::ControlledGaussianHMM, control::AbstractVector) return hmm.trans end -function HMMs.obs_distributions(hmm::ControlledGaussianHMM, control) - return [Normal(dot(hmm.dist_coeffs[i], control), 1.0) for i in 1:size(hmm, control)] +function HMMs.obs_distributions(hmm::ControlledGaussianHMM, control::AbstractVector) + return [Normal(dot(hmm.dist_coeffs[i], control), 1.0) for i in 1:length(hmm)] end #= @@ -100,7 +97,7 @@ function StatsAPI.fit!( seq_ends, ) where {T} (; γ, ξ) = fb_storage - N = size(hmm, control_seq[1]) + N = length(hmm) hmm.init .= 0 hmm.trans .= 0 diff --git a/examples/interfaces.jl b/examples/interfaces.jl index cfcfa3eb..bfd03997 100644 --- a/examples/interfaces.jl +++ b/examples/interfaces.jl @@ -166,7 +166,7 @@ If we forget to implement this, the loglikelihood computed in Baum-Welch will be =# function DensityInterface.logdensityof(hmm::PriorHMM) - prior = Dirichlet(fill(hmm.trans_prior_count + 1, size(hmm, nothing))) + prior = Dirichlet(fill(hmm.trans_prior_count + 1, length(hmm))) return sum(logdensityof(prior, row) for row in eachrow(transition_matrix(hmm))) end diff --git a/examples/temporal.jl b/examples/temporal.jl index ef0e545f..cf5a1d1f 100644 --- a/examples/temporal.jl +++ b/examples/temporal.jl @@ -41,17 +41,14 @@ period(::PeriodicHMM{T,D,L}) where {T,D,L} = L function HMMs.initialization(hmm::PeriodicHMM) return hmm.init end -function HMMs.initialization(hmm::PeriodicHMM, control::Integer) - return hmm.init -end -function HMMs.transition_matrix(hmm::PeriodicHMM, control::Integer) - l = (control - 1) % period(hmm) + 1 +function HMMs.transition_matrix(hmm::PeriodicHMM, t::Integer) + l = (t - 1) % period(hmm) + 1 return hmm.trans_per[l] end -function HMMs.obs_distributions(hmm::PeriodicHMM, control::Integer) - l = (control - 1) % period(hmm) + 1 +function HMMs.obs_distributions(hmm::PeriodicHMM, t::Integer) + l = (t - 1) % period(hmm) + 1 return hmm.dists_per[l] end @@ -115,7 +112,7 @@ function StatsAPI.fit!( seq_ends, ) where {T} (; γ, ξ) = fb_storage - L, N = period(hmm), size(hmm, control_seq[1]) + L, N = period(hmm), length(hmm) hmm.init .= zero(T) for l in 1:L diff --git a/src/inference/chainrules.jl b/src/inference/chainrules.jl index c8a1bb10..8816120f 100644 --- a/src/inference/chainrules.jl +++ b/src/inference/chainrules.jl @@ -6,9 +6,9 @@ function _params_and_loglikelihoods( control_seq::AbstractVector=Fill(nothing, length(obs_seq)); seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), ) - init = initialization(hmm, control_seq[1]) + init = initialization(hmm) trans_by_time = mapreduce(_dcat, eachindex(control_seq)) do t - t == 1 ? diagm(ones(size(hmm, t))) : transition_matrix(hmm, control_seq[t]) # I did't understand what this is doing, but my best guess is that it returns the transition matrix for each moment `t` to `t+1`. If this is the case, then, like forward.jl, line 106, the control variable matches `t+1`. To avoid messing up the logic, I just made the first matrix to be the identity matrix, and the following matrices are P(X_{t+1}|X_{t},U_{t+1}). + transition_matrix(hmm, control_seq[t]) end logB = mapreduce(hcat, eachindex(obs_seq, control_seq)) do t logdensityof.(obs_distributions(hmm, control_seq[t]), (obs_seq[t],)) @@ -30,7 +30,7 @@ function ChainRulesCore.rrule( fb_storage = initialize_forward_backward(hmm, obs_seq, control_seq; seq_ends) forward_backward!(fb_storage, hmm, obs_seq, control_seq; seq_ends) (; logL, α, γ, Bβ) = fb_storage - N, T = size(hmm, control_seq[1]), length(obs_seq) + N, T = length(hmm), length(obs_seq) R = eltype(α) Δinit = zeros(R, N) diff --git a/src/inference/forward.jl b/src/inference/forward.jl index ff024e2b..59ea410d 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -54,7 +54,7 @@ function initialize_forward( control_seq::AbstractVector; seq_ends::AbstractVectorOrNTuple{Int}, ) - N, T, K = size(hmm, control_seq[1]), length(obs_seq), length(seq_ends) + N, T, K = length(hmm), length(obs_seq), length(seq_ends) R = eltype(hmm, obs_seq[1], control_seq[1]) α = Matrix{R}(undef, N, T) logL = Vector{R}(undef, K) @@ -100,10 +100,10 @@ function _forward!( αₜ = view(α, :, t) Bₜ = view(B, :, t) if t == t1 - copyto!(αₜ, initialization(hmm, control_seq[t])) + copyto!(αₜ, initialization(hmm)) else αₜ₋₁ = view(α, :, t - 1) - predict_next_state!(αₜ, hmm, αₜ₋₁, control_seq[t]) # If `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`), then the associated control must be at `t+1`, right? If `control_seq[t-1]`, then we're using the control associated with the previous state and not the correct control, aren't we? The transition matrix would be P(X_{t}|X_{t-1},U_{t-1}) and not P(X_{t}|X_{t-1},U_{t}) as it should be. E.g., if `t == t1 + 1`, then `αₜ₋₁ = view(α, :, t1)` and the function would use the transition matrix P(X_{t1+1}|X_{t1},U_{t1}) instead of P(X_{t1+1}|X_{t1},U_{t1+1}). Same at logdensity.jl, line 37; forward_backward.jl, line 53. + predict_next_state!(αₜ, hmm, αₜ₋₁, control_seq[t - 1]) end prev_obs = t == t1 ? missing : previous_obs(hmm, obs_seq, t) cₜ, logLₜ = _forward_digest_observation!( diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index afe224f4..973626cf 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -8,7 +8,7 @@ function initialize_forward_backward( seq_ends::AbstractVectorOrNTuple{Int}, transition_marginals=true, ) - N, T, K = size(hmm, control_seq[1]), length(obs_seq), length(seq_ends) + N, T, K = length(hmm), length(obs_seq), length(seq_ends) R = eltype(hmm, obs_seq[1], control_seq[1]) trans = transition_matrix(hmm, control_seq[2]) M = typeof(similar(trans, R)) @@ -50,7 +50,7 @@ function _forward_backward!( Bβ[:, t + 1] .= view(B, :, t + 1) .* view(β, :, t + 1) βₜ = view(β, :, t) Bβₜ₊₁ = view(Bβ, :, t + 1) - predict_previous_state!(βₜ, hmm, Bβₜ₊₁, control_seq[t + 1]) # See forward.jl, line 106. + predict_previous_state!(βₜ, hmm, Bβₜ₊₁, control_seq[t]) lmul!(c[t], βₜ) end Bβ[:, t1] .= view(B, :, t1) .* view(β, :, t1) @@ -61,7 +61,7 @@ function _forward_backward!( # Transition marginals if transition_marginals for t in t1:(t2 - 1) - trans = transition_matrix(hmm, control_seq[t + 1]) # See forward.jl, line 106. + trans = transition_matrix(hmm, control_seq[t]) mul_rows_cols!(ξ[t], view(α, :, t), trans, view(Bβ, :, t + 1)) end ξ[t2] .= zero(R) diff --git a/src/inference/logdensity.jl b/src/inference/logdensity.jl index b2a2f43a..6a73a024 100644 --- a/src/inference/logdensity.jl +++ b/src/inference/logdensity.jl @@ -29,12 +29,12 @@ function joint_logdensityof( logL = zero(R) for k in eachindex(seq_ends) t1, t2 = seq_limits(seq_ends, k) - # Initialization: P(X_{1}|U_{1}) - init = initialization(hmm, control_seq[t1]) + # Initialization + init = initialization(hmm) logL += log(init[state_seq[t1]]) - # Transitions: P(X_{t+1}|X_{t},U_{t+1}) + # Transitions for t in t1:(t2 - 1) - trans = transition_matrix(hmm, control_seq[t + 1]) # See forward.jl, line 106. + trans = transition_matrix(hmm, control_seq[t]) logL += log(trans[state_seq[t], state_seq[t + 1]]) end # Priori: P(Y_{1}|X_{1},U_{1}) diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index 1a21f609..6d6c5e8c 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -49,7 +49,7 @@ function _viterbi!( logBₜ₁ = view(logB, :, t1) obs_logdensities!(logBₜ₁, hmm, obs_seq[t1], control_seq[t1], missing) - loginit = log_initialization(hmm, control_seq[t1]) + loginit = log_initialization(hmm) ϕ[:, t1] .= loginit .+ logBₜ₁ for t in (t1 + 1):t2 @@ -57,7 +57,7 @@ function _viterbi!( obs_logdensities!( logBₜ, hmm, obs_seq[t], control_seq[t], previous_obs(hmm, obs_seq, t) ) - logtrans = log_transition_matrix(hmm, control_seq[t]) # See forward.jl, line 106. + logtrans = log_transition_matrix(hmm, control_seq[t - 1]) ϕₜ, ϕₜ₋₁ = view(ϕ, :, t), view(ϕ, :, t - 1) ψₜ = view(ψ, :, t) argmaxplus_transmul!(ϕₜ, ψₜ, logtrans, ϕₜ₋₁) diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index 8bdf5644..94049976 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -30,11 +30,12 @@ abstract type AbstractHMM{ar} end ## Interface """ - size(hmm, control) + length(hmm) Return the number of states of `hmm`. """ -Base.size(hmm::AbstractHMM, control) = size(transition_matrix(hmm, control), 2) +Base.length(hmm::AbstractHMM) = length(initialization(hmm)) + """ eltype(hmm, obs, control) @@ -61,14 +62,13 @@ function initialization end """ log_initialization(hmm) - log_initialization(hmm, control) Return the vector of initial state log-probabilities for `hmm` (possibly when `control` is applied). Falls back on `initialization`. """ -log_initialization(hmm::AbstractHMM, control) = - elementwise_log(initialization(hmm, control)) +log_initialization(hmm::AbstractHMM) = + elementwise_log(initialization(hmm)) """ transition_matrix(hmm) @@ -114,7 +114,7 @@ function obs_distributions end initialization(hmm::AbstractHMM, ::Nothing) = initialization(hmm) transition_matrix(hmm::AbstractHMM, ::Nothing) = transition_matrix(hmm) -log_transition_matrix(hmm::AbstractHMM, ::Nothing) = log_transition_matrix(hmm) # Is this function needed? If `log_transition_matrix(hmm, nothing)`, then `transition_matrix(hmm, nothing)` returns `transition_matrix(hmm)`. +log_transition_matrix(hmm::AbstractHMM, ::Nothing) = log_transition_matrix(hmm) obs_distributions(hmm::AbstractHMM, ::Nothing) = obs_distributions(hmm) function obs_distributions(hmm::AbstractHMM, control, ::Any) return obs_distributions(hmm, control) @@ -160,15 +160,15 @@ Return a named tuple `(; state_seq, obs_seq)`. """ function Random.rand(rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVector) T = length(control_seq) - dummy_log_probas = fill(-Inf, size(hmm, control_seq[1])) + dummy_log_probas = fill(-Inf, length(hmm)) - init = initialization(hmm, control_seq[1]) + init = initialization(hmm) state_seq = Vector{Int}(undef, T) state1 = rand(rng, LightCategorical(init, dummy_log_probas)) state_seq[1] = state1 @views for t in 1:(T - 1) - trans = transition_matrix(hmm, control_seq[t + 1]) # See forward.jl, line 106. + trans = transition_matrix(hmm, control_seq[t]) state_seq[t + 1] = rand( rng, LightCategorical(trans[state_seq[t], :], dummy_log_probas) ) @@ -176,7 +176,7 @@ function Random.rand(rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVe dists1 = obs_distributions(hmm, control_seq[1], missing) obs1 = rand(rng, dists1[state1]) - obs_seq = Vector{typeof(obs1)}(undef, T) # If the `typeof(obs1)` is only known at runtime, does it makes any difference over Vector{Any}? + obs_seq = Vector{typeof(obs1)}(undef, T) obs_seq[1] = obs1 for t in 2:T diff --git a/src/types/hmm.jl b/src/types/hmm.jl index f9d89f16..9ecbe1ad 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -90,7 +90,7 @@ function StatsAPI.fit!( sum_to_one!(hmm.init) foreach(sum_to_one!, eachrow(hmm.trans)) # Fit observations - for i in 1:size(hmm, nothing) + for i in 1:length(hmm) fit_in_sequence!(hmm.dists, i, obs_seq, view(γ, i, :)) end # Update logs From cf3c5f70b9b150c078f7307e3538c58396de4c26 Mon Sep 17 00:00:00 2001 From: Fausto Marques Pinheiro Junior Date: Mon, 9 Dec 2024 17:43:25 -0300 Subject: [PATCH 7/8] Third pass for ARHMM --- src/inference/viterbi.jl | 2 +- src/types/abstract_hmm.jl | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index 6d6c5e8c..58c9a887 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -26,7 +26,7 @@ function initialize_viterbi( control_seq::AbstractVector; seq_ends::AbstractVectorOrNTuple{Int}, ) - N, T, K = size(hmm, control_seq[1]), length(obs_seq), length(seq_ends) + N, T, K = length(hmm), length(obs_seq), length(seq_ends) R = eltype(hmm, obs_seq[1], control_seq[1]) q = Vector{Int}(undef, T) logL = Vector{R}(undef, K) diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index 94049976..7e56d206 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -36,7 +36,6 @@ Return the number of states of `hmm`. """ Base.length(hmm::AbstractHMM) = length(initialization(hmm)) - """ eltype(hmm, obs, control) @@ -54,21 +53,19 @@ end """ initialization(hmm) - initialization(hmm, control) -Return the vector of initial state probabilities for `hmm` (possibly when `control` is applied). +Return the vector of initial state probabilities for `hmm`. """ function initialization end """ log_initialization(hmm) -Return the vector of initial state log-probabilities for `hmm` (possibly when `control` is applied). +Return the vector of initial state log-probabilities for `hmm`. Falls back on `initialization`. """ -log_initialization(hmm::AbstractHMM) = - elementwise_log(initialization(hmm)) +log_initialization(hmm::AbstractHMM) = elementwise_log(initialization(hmm)) """ transition_matrix(hmm) From 4c1b89093b82f0c5f712f14b2a40e7f0eb5a50bb Mon Sep 17 00:00:00 2001 From: Fausto Marques Pinheiro Junior Date: Mon, 9 Dec 2024 17:51:23 -0300 Subject: [PATCH 8/8] Fix a `size` -> `length` that I've missed --- examples/interfaces.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/interfaces.jl b/examples/interfaces.jl index bfd03997..6bd7f311 100644 --- a/examples/interfaces.jl +++ b/examples/interfaces.jl @@ -204,7 +204,7 @@ function StatsAPI.fit!( hmm.init ./= sum(hmm.init) hmm.trans ./= sum(hmm.trans; dims=2) - for i in 1:size(hmm, nothing) + for i in 1:length(hmm) ## weigh each sample by the marginal probability of being in state i weight_seq = fb_storage.γ[i, :] ## fit observation distribution i using those weights