Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First pass for ARHMM #123

Open
wants to merge 8 commits into
base: gd/ar
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ on:
push:
branches:
- main
- gd/ar
tags: ["*"]
pull_request:
concurrency:
Expand Down
8 changes: 5 additions & 3 deletions examples/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ rng = StableRNG(63);
To play around with automatic differentiation, we define a simple controlled HMM.
=#

struct DiffusionHMM{V1<:AbstractVector,M2<:AbstractMatrix,V3<:AbstractVector} <: AbstractHMM
struct DiffusionHMM{V1<:AbstractVector,M2<:AbstractMatrix,V3<:AbstractVector} <:
AbstractHMM{false}
init::V1
trans::M2
means::V3
Expand All @@ -40,15 +41,16 @@ Both its transition matrix and its vector of observation means result from a con
The coefficient $\lambda$ of this convex combination is given as a control.
=#

HMMs.initialization(hmm::DiffusionHMM, λ::Number) = hmm.init
HMMs.initialization(hmm::DiffusionHMM) = hmm.init

function HMMs.transition_matrix(hmm::DiffusionHMM, λ::Number)
N = length(hmm)
N = size(hmm.trans, 2)
return (1 - λ) * hmm.trans + λ * ones(N, N) / N
end

function HMMs.obs_distributions(hmm::DiffusionHMM, λ::Number)
return [Normal((1 - λ) * hmm.means[i] + λ * 0) for i in 1:length(hmm)]
return [Normal((1 - λ) * hmm.means[i] + λ * 0) for i in 1:size(hmm, λ)]
end

#=
Expand Down
15 changes: 9 additions & 6 deletions examples/controlled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ rng = StableRNG(63);

#=
A Markov switching regression is like a classical regression, except that the weights depend on the unobserved state of an HMM.
We can represent it with the following subtype of `AbstractHMM` (see [Custom HMM structures](@ref)), which has one vector of coefficients $\beta_i$ per state.
We can represent it with the following subtype of `AbstractHMM{false}` (see [Custom HMM structures](@ref)), which has one vector of coefficients $\beta_i$ per state.
=#

struct ControlledGaussianHMM{T} <: AbstractHMM
struct ControlledGaussianHMM{T} <: AbstractHMM{false}
init::Vector{T}
trans::Matrix{T}
dist_coeffs::Vector{Vector{T}}
Expand All @@ -36,16 +36,19 @@ In state $i$ with a vector of controls $u$, our observation is given by the line
Controls must be provided to both `transition_matrix` and `obs_distributions` even if they are only used by one.
=#

function HMMs.initialization(hmm::ControlledGaussianHMM, control)
return hmm.init
end
function HMMs.initialization(hmm::ControlledGaussianHMM)
return hmm.init
end

function HMMs.transition_matrix(hmm::ControlledGaussianHMM, control::AbstractVector)
function HMMs.transition_matrix(hmm::ControlledGaussianHMM, control)
return hmm.trans
end

function HMMs.obs_distributions(hmm::ControlledGaussianHMM, control::AbstractVector)
return [Normal(dot(hmm.dist_coeffs[i], control), 1.0) for i in 1:length(hmm)]
function HMMs.obs_distributions(hmm::ControlledGaussianHMM, control)
return [Normal(dot(hmm.dist_coeffs[i], control), 1.0) for i in 1:size(hmm, control)]
end

#=
Expand Down Expand Up @@ -97,7 +100,7 @@ function StatsAPI.fit!(
seq_ends,
) where {T}
(; γ, ξ) = fb_storage
N = length(hmm)
N = size(hmm, control_seq[1])

hmm.init .= 0
hmm.trans .= 0
Expand Down
10 changes: 5 additions & 5 deletions examples/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,22 +138,22 @@ test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) #src
#=
In some scenarios, the vanilla Baum-Welch algorithm is not exactly what we want.
For instance, we might have a prior on the parameters of our model, which we want to apply during the fitting step of the iterative procedure.
Then we need to create a new type that satisfies the `AbstractHMM` interface.
Then we need to create a new type that satisfies the `AbstractHMM{ar}` interface.

Let's make a simpler version of the built-in `HMM`, with a prior saying that each transition has already been observed a certain number of times.
Such a prior can be very useful to regularize estimation and avoid numerical instabilities.
It amounts to drawing every row of the transition matrix from a Dirichlet distribution, where each Dirichlet parameter is one plus the number of times the corresponding transition has been observed.
=#

struct PriorHMM{T,D} <: AbstractHMM
struct PriorHMM{T,D} <: AbstractHMM{false}
init::Vector{T}
trans::Matrix{T}
dists::Vector{D}
trans_prior_count::Int
end

#=
The basic requirements for `AbstractHMM` are the following three functions: [`initialization`](@ref), [`transition_matrix`](@ref) and [`obs_distributions`](@ref).
The basic requirements for `AbstractHMM{false}` are the following three functions: [`initialization`](@ref), [`transition_matrix`](@ref) and [`obs_distributions`](@ref).
=#

HiddenMarkovModels.initialization(hmm::PriorHMM) = hmm.init
Expand All @@ -166,7 +166,7 @@ If we forget to implement this, the loglikelihood computed in Baum-Welch will be
=#

function DensityInterface.logdensityof(hmm::PriorHMM)
prior = Dirichlet(fill(hmm.trans_prior_count + 1, length(hmm)))
prior = Dirichlet(fill(hmm.trans_prior_count + 1, size(hmm, nothing)))
return sum(logdensityof(prior, row) for row in eachrow(transition_matrix(hmm)))
end

Expand Down Expand Up @@ -204,7 +204,7 @@ function StatsAPI.fit!(
hmm.init ./= sum(hmm.init)
hmm.trans ./= sum(hmm.trans; dims=2)

for i in 1:length(hmm)
for i in 1:size(hmm, nothing)
## weigh each sample by the marginal probability of being in state i
weight_seq = fb_storage.γ[i, :]
## fit observation distribution i using those weights
Expand Down
19 changes: 11 additions & 8 deletions examples/temporal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ rng = StableRNG(63);
#=
We focus on the particular case of a periodic HMM with period `L`.
It has only one initialization vector, but `L` transition matrices and `L` vectors of observation distributions.
As in [Custom HMM structures](@ref), we need to subtype `AbstractHMM`.
As in [Custom HMM structures](@ref), we need to subtype `AbstractHMM{ar}`.
=#

struct PeriodicHMM{T<:Number,D,L} <: AbstractHMM
struct PeriodicHMM{T<:Number,D,L} <: AbstractHMM{false}
init::Vector{T}
trans_per::NTuple{L,Matrix{T}}
dists_per::NTuple{L,Vector{D}}
Expand All @@ -41,14 +41,17 @@ period(::PeriodicHMM{T,D,L}) where {T,D,L} = L
function HMMs.initialization(hmm::PeriodicHMM)
return hmm.init
end
function HMMs.initialization(hmm::PeriodicHMM, control::Integer)
return hmm.init
end

function HMMs.transition_matrix(hmm::PeriodicHMM, t::Integer)
l = (t - 1) % period(hmm) + 1
function HMMs.transition_matrix(hmm::PeriodicHMM, control::Integer)
l = (control - 1) % period(hmm) + 1
return hmm.trans_per[l]
end

function HMMs.obs_distributions(hmm::PeriodicHMM, t::Integer)
l = (t - 1) % period(hmm) + 1
function HMMs.obs_distributions(hmm::PeriodicHMM, control::Integer)
l = (control - 1) % period(hmm) + 1
return hmm.dists_per[l]
end

Expand Down Expand Up @@ -100,7 +103,7 @@ vcat(obs_seq', best_state_seq')
# ## Learning

#=
When estimating parameters for a custom subtype of `AbstractHMM`, we have to override the fitting procedure after forward-backward, with an additional `control_seq` positional argument.
When estimating parameters for a custom subtype of `AbstractHMM{false}`, we have to override the fitting procedure after forward-backward, with an additional `control_seq` positional argument.
The key is to split the observations according to which periodic parameter they belong to.
=#

Expand All @@ -112,7 +115,7 @@ function StatsAPI.fit!(
seq_ends,
) where {T}
(; γ, ξ) = fb_storage
L, N = period(hmm), length(hmm)
L, N = period(hmm), size(hmm, control_seq[1])

hmm.init .= zero(T)
for l in 1:L
Expand Down
2 changes: 1 addition & 1 deletion src/HiddenMarkovModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, rrule_via_ad
using DensityInterface: DensityInterface, DensityKind, HasDensity, NoDensity, logdensityof
using DocStringExtensions
using FillArrays: Fill
using LinearAlgebra: Transpose, axpy!, dot, ldiv!, lmul!, mul!, parent
using LinearAlgebra: Transpose, axpy!, dot, ldiv!, lmul!, mul!, parent, diagm
using Random: Random, AbstractRNG, default_rng
using SparseArrays: AbstractSparseArray, SparseMatrixCSC, nonzeros, nnz, nzrange, rowvals
using StatsAPI: StatsAPI, fit, fit!
Expand Down
6 changes: 3 additions & 3 deletions src/inference/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
)
init = initialization(hmm)
init = initialization(hmm, control_seq[1])

Check warning on line 9 in src/inference/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/inference/chainrules.jl#L9

Added line #L9 was not covered by tests
trans_by_time = mapreduce(_dcat, eachindex(control_seq)) do t
transition_matrix(hmm, control_seq[t])
t == 1 ? diagm(ones(size(hmm, t))) : transition_matrix(hmm, control_seq[t]) # I did't understand what this is doing, but my best guess is that it returns the transition matrix for each moment `t` to `t+1`. If this is the case, then, like forward.jl, line 106, the control variable matches `t+1`. To avoid messing up the logic, I just made the first matrix to be the identity matrix, and the following matrices are P(X_{t+1}|X_{t},U_{t+1}).

Check warning on line 11 in src/inference/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/inference/chainrules.jl#L11

Added line #L11 was not covered by tests
end
logB = mapreduce(hcat, eachindex(obs_seq, control_seq)) do t
logdensityof.(obs_distributions(hmm, control_seq[t]), (obs_seq[t],))
Expand All @@ -30,7 +30,7 @@
fb_storage = initialize_forward_backward(hmm, obs_seq, control_seq; seq_ends)
forward_backward!(fb_storage, hmm, obs_seq, control_seq; seq_ends)
(; logL, α, γ, Bβ) = fb_storage
N, T = length(hmm), length(obs_seq)
N, T = size(hmm, control_seq[1]), length(obs_seq)

Check warning on line 33 in src/inference/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/inference/chainrules.jl#L33

Added line #L33 was not covered by tests
R = eltype(α)

Δinit = zeros(R, N)
Expand Down
14 changes: 9 additions & 5 deletions src/inference/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function initialize_forward(
control_seq::AbstractVector;
seq_ends::AbstractVectorOrNTuple{Int},
)
N, T, K = length(hmm), length(obs_seq), length(seq_ends)
N, T, K = size(hmm, control_seq[1]), length(obs_seq), length(seq_ends)
R = eltype(hmm, obs_seq[1], control_seq[1])
α = Matrix{R}(undef, N, T)
logL = Vector{R}(undef, K)
Expand All @@ -69,10 +69,11 @@ function _forward_digest_observation!(
hmm::AbstractHMM,
obs,
control,
prev_obs,
)
a, b = current_state_marginals, current_obs_likelihoods

obs_logdensities!(b, hmm, obs, control)
obs_logdensities!(b, hmm, obs, control, prev_obs)
logm = maximum(b)
b .= exp.(b .- logm)

Expand All @@ -99,12 +100,15 @@ function _forward!(
αₜ = view(α, :, t)
Bₜ = view(B, :, t)
if t == t1
copyto!(αₜ, initialization(hmm))
copyto!(αₜ, initialization(hmm, control_seq[t]))
else
αₜ₋₁ = view(α, :, t - 1)
predict_next_state!(αₜ, hmm, αₜ₋₁, control_seq[t - 1])
predict_next_state!(αₜ, hmm, αₜ₋₁, control_seq[t]) # If `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`), then the associated control must be at `t+1`, right? If `control_seq[t-1]`, then we're using the control associated with the previous state and not the correct control, aren't we? The transition matrix would be P(X_{t}|X_{t-1},U_{t-1}) and not P(X_{t}|X_{t-1},U_{t}) as it should be. E.g., if `t == t1 + 1`, then `αₜ₋₁ = view(α, :, t1)` and the function would use the transition matrix P(X_{t1+1}|X_{t1},U_{t1}) instead of P(X_{t1+1}|X_{t1},U_{t1+1}). Same at logdensity.jl, line 37; forward_backward.jl, line 53.
fausto-mpj marked this conversation as resolved.
Show resolved Hide resolved
end
cₜ, logLₜ = _forward_digest_observation!(αₜ, Bₜ, hmm, obs_seq[t], control_seq[t])
prev_obs = t == t1 ? missing : previous_obs(hmm, obs_seq, t)
cₜ, logLₜ = _forward_digest_observation!(
αₜ, Bₜ, hmm, obs_seq[t], control_seq[t], prev_obs
)
c[t] = cₜ
logL[k] += logLₜ
end
Expand Down
8 changes: 4 additions & 4 deletions src/inference/forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ function initialize_forward_backward(
seq_ends::AbstractVectorOrNTuple{Int},
transition_marginals=true,
)
N, T, K = length(hmm), length(obs_seq), length(seq_ends)
N, T, K = size(hmm, control_seq[1]), length(obs_seq), length(seq_ends)
R = eltype(hmm, obs_seq[1], control_seq[1])
trans = transition_matrix(hmm, control_seq[1])
trans = transition_matrix(hmm, control_seq[2])
M = typeof(similar(trans, R))

γ = Matrix{R}(undef, N, T)
Expand Down Expand Up @@ -50,7 +50,7 @@ function _forward_backward!(
Bβ[:, t + 1] .= view(B, :, t + 1) .* view(β, :, t + 1)
βₜ = view(β, :, t)
Bβₜ₊₁ = view(Bβ, :, t + 1)
predict_previous_state!(βₜ, hmm, Bβₜ₊₁, control_seq[t])
predict_previous_state!(βₜ, hmm, Bβₜ₊₁, control_seq[t + 1]) # See forward.jl, line 106.
lmul!(c[t], βₜ)
end
Bβ[:, t1] .= view(B, :, t1) .* view(β, :, t1)
Expand All @@ -61,7 +61,7 @@ function _forward_backward!(
# Transition marginals
if transition_marginals
for t in t1:(t2 - 1)
trans = transition_matrix(hmm, control_seq[t])
trans = transition_matrix(hmm, control_seq[t + 1]) # See forward.jl, line 106.
mul_rows_cols!(ξ[t], view(α, :, t), trans, view(Bβ, :, t + 1))
end
ξ[t2] .= zero(R)
Expand Down
17 changes: 10 additions & 7 deletions src/inference/logdensity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,20 @@
logL = zero(R)
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
# Initialization
init = initialization(hmm)
# Initialization: P(X_{1}|U_{1})
init = initialization(hmm, control_seq[t1])

Check warning on line 33 in src/inference/logdensity.jl

View check run for this annotation

Codecov / codecov/patch

src/inference/logdensity.jl#L33

Added line #L33 was not covered by tests
logL += log(init[state_seq[t1]])
# Transitions
# Transitions: P(X_{t+1}|X_{t},U_{t+1})
for t in t1:(t2 - 1)
trans = transition_matrix(hmm, control_seq[t])
trans = transition_matrix(hmm, control_seq[t + 1]) # See forward.jl, line 106.

Check warning on line 37 in src/inference/logdensity.jl

View check run for this annotation

Codecov / codecov/patch

src/inference/logdensity.jl#L37

Added line #L37 was not covered by tests
logL += log(trans[state_seq[t], state_seq[t + 1]])
end
# Observations
for t in t1:t2
dists = obs_distributions(hmm, control_seq[t])
# Priori: P(Y_{1}|X_{1},U_{1})
dists = obs_distributions(hmm, control_seq[t1], missing)
logL += logdensityof(dists[state_seq[t1]], obs_seq[t1])

Check warning on line 42 in src/inference/logdensity.jl

View check run for this annotation

Codecov / codecov/patch

src/inference/logdensity.jl#L41-L42

Added lines #L41 - L42 were not covered by tests
# Observations: P(Y_{t}|Y_{t-1},X_{t},U_{t})
for t in (t1 + 1):t2
dists = obs_distributions(hmm, control_seq[t], previous_obs(hmm, obs_seq, t))

Check warning on line 45 in src/inference/logdensity.jl

View check run for this annotation

Codecov / codecov/patch

src/inference/logdensity.jl#L44-L45

Added lines #L44 - L45 were not covered by tests
logL += logdensityof(dists[state_seq[t]], obs_seq[t])
end
end
Expand Down
12 changes: 7 additions & 5 deletions src/inference/viterbi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function initialize_viterbi(
control_seq::AbstractVector;
seq_ends::AbstractVectorOrNTuple{Int},
)
N, T, K = length(hmm), length(obs_seq), length(seq_ends)
N, T, K = size(hmm, control_seq[1]), length(obs_seq), length(seq_ends)
R = eltype(hmm, obs_seq[1], control_seq[1])
q = Vector{Int}(undef, T)
logL = Vector{R}(undef, K)
Expand All @@ -48,14 +48,16 @@ function _viterbi!(
t1, t2 = seq_limits(seq_ends, k)

logBₜ₁ = view(logB, :, t1)
obs_logdensities!(logBₜ₁, hmm, obs_seq[t1], control_seq[t1])
loginit = log_initialization(hmm)
obs_logdensities!(logBₜ₁, hmm, obs_seq[t1], control_seq[t1], missing)
loginit = log_initialization(hmm, control_seq[t1])
ϕ[:, t1] .= loginit .+ logBₜ₁

for t in (t1 + 1):t2
logBₜ = view(logB, :, t)
obs_logdensities!(logBₜ, hmm, obs_seq[t], control_seq[t])
logtrans = log_transition_matrix(hmm, control_seq[t - 1])
obs_logdensities!(
logBₜ, hmm, obs_seq[t], control_seq[t], previous_obs(hmm, obs_seq, t)
)
logtrans = log_transition_matrix(hmm, control_seq[t]) # See forward.jl, line 106.
ϕₜ, ϕₜ₋₁ = view(ϕ, :, t), view(ϕ, :, t - 1)
ψₜ = view(ψ, :, t)
argmaxplus_transmul!(ϕₜ, ψₜ, logtrans, ϕₜ₋₁)
Expand Down
Loading
Loading