Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First pass for ARHMM #123

Open
wants to merge 8 commits into
base: gd/ar
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ on:
push:
branches:
- main
- gd/ar
tags: ["*"]
pull_request:
concurrency:
Expand Down
3 changes: 2 additions & 1 deletion examples/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ rng = StableRNG(63);
To play around with automatic differentiation, we define a simple controlled HMM.
=#

struct DiffusionHMM{V1<:AbstractVector,M2<:AbstractMatrix,V3<:AbstractVector} <: AbstractHMM
struct DiffusionHMM{V1<:AbstractVector,M2<:AbstractMatrix,V3<:AbstractVector} <:
AbstractHMM{false}
init::V1
trans::M2
means::V3
Expand Down
4 changes: 2 additions & 2 deletions examples/controlled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ rng = StableRNG(63);

#=
A Markov switching regression is like a classical regression, except that the weights depend on the unobserved state of an HMM.
We can represent it with the following subtype of `AbstractHMM` (see [Custom HMM structures](@ref)), which has one vector of coefficients $\beta_i$ per state.
We can represent it with the following subtype of `AbstractHMM{false}` (see [Custom HMM structures](@ref)), which has one vector of coefficients $\beta_i$ per state.
=#

struct ControlledGaussianHMM{T} <: AbstractHMM
struct ControlledGaussianHMM{T} <: AbstractHMM{false}
init::Vector{T}
trans::Matrix{T}
dist_coeffs::Vector{Vector{T}}
Expand Down
6 changes: 3 additions & 3 deletions examples/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,22 +138,22 @@ test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) #src
#=
In some scenarios, the vanilla Baum-Welch algorithm is not exactly what we want.
For instance, we might have a prior on the parameters of our model, which we want to apply during the fitting step of the iterative procedure.
Then we need to create a new type that satisfies the `AbstractHMM` interface.
Then we need to create a new type that satisfies the `AbstractHMM{ar}` interface.

Let's make a simpler version of the built-in `HMM`, with a prior saying that each transition has already been observed a certain number of times.
Such a prior can be very useful to regularize estimation and avoid numerical instabilities.
It amounts to drawing every row of the transition matrix from a Dirichlet distribution, where each Dirichlet parameter is one plus the number of times the corresponding transition has been observed.
=#

struct PriorHMM{T,D} <: AbstractHMM
struct PriorHMM{T,D} <: AbstractHMM{false}
init::Vector{T}
trans::Matrix{T}
dists::Vector{D}
trans_prior_count::Int
end

#=
The basic requirements for `AbstractHMM` are the following three functions: [`initialization`](@ref), [`transition_matrix`](@ref) and [`obs_distributions`](@ref).
The basic requirements for `AbstractHMM{false}` are the following three functions: [`initialization`](@ref), [`transition_matrix`](@ref) and [`obs_distributions`](@ref).
=#

HiddenMarkovModels.initialization(hmm::PriorHMM) = hmm.init
Expand Down
6 changes: 3 additions & 3 deletions examples/temporal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ rng = StableRNG(63);
#=
We focus on the particular case of a periodic HMM with period `L`.
It has only one initialization vector, but `L` transition matrices and `L` vectors of observation distributions.
As in [Custom HMM structures](@ref), we need to subtype `AbstractHMM`.
As in [Custom HMM structures](@ref), we need to subtype `AbstractHMM{ar}`.
=#

struct PeriodicHMM{T<:Number,D,L} <: AbstractHMM
struct PeriodicHMM{T<:Number,D,L} <: AbstractHMM{false}
init::Vector{T}
trans_per::NTuple{L,Matrix{T}}
dists_per::NTuple{L,Vector{D}}
Expand Down Expand Up @@ -100,7 +100,7 @@ vcat(obs_seq', best_state_seq')
# ## Learning

#=
When estimating parameters for a custom subtype of `AbstractHMM`, we have to override the fitting procedure after forward-backward, with an additional `control_seq` positional argument.
When estimating parameters for a custom subtype of `AbstractHMM{false}`, we have to override the fitting procedure after forward-backward, with an additional `control_seq` positional argument.
The key is to split the observations according to which periodic parameter they belong to.
=#

Expand Down
2 changes: 1 addition & 1 deletion src/HiddenMarkovModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, rrule_via_ad
using DensityInterface: DensityInterface, DensityKind, HasDensity, NoDensity, logdensityof
using DocStringExtensions
using FillArrays: Fill
using LinearAlgebra: Transpose, axpy!, dot, ldiv!, lmul!, mul!, parent
using LinearAlgebra: Transpose, axpy!, dot, ldiv!, lmul!, mul!, parent, diagm
using Random: Random, AbstractRNG, default_rng
using SparseArrays: AbstractSparseArray, SparseMatrixCSC, nonzeros, nnz, nzrange, rowvals
using StatsAPI: StatsAPI, fit, fit!
Expand Down
8 changes: 6 additions & 2 deletions src/inference/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,11 @@ function _forward_digest_observation!(
hmm::AbstractHMM,
obs,
control,
prev_obs,
)
a, b = current_state_marginals, current_obs_likelihoods

obs_logdensities!(b, hmm, obs, control)
obs_logdensities!(b, hmm, obs, control, prev_obs)
logm = maximum(b)
b .= exp.(b .- logm)

Expand Down Expand Up @@ -104,7 +105,10 @@ function _forward!(
αₜ₋₁ = view(α, :, t - 1)
predict_next_state!(αₜ, hmm, αₜ₋₁, control_seq[t - 1])
end
cₜ, logLₜ = _forward_digest_observation!(αₜ, Bₜ, hmm, obs_seq[t], control_seq[t])
prev_obs = t == t1 ? missing : previous_obs(hmm, obs_seq, t)
cₜ, logLₜ = _forward_digest_observation!(
αₜ, Bₜ, hmm, obs_seq[t], control_seq[t], prev_obs
)
c[t] = cₜ
logL[k] += logLₜ
end
Expand Down
2 changes: 1 addition & 1 deletion src/inference/forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function initialize_forward_backward(
)
N, T, K = length(hmm), length(obs_seq), length(seq_ends)
R = eltype(hmm, obs_seq[1], control_seq[1])
trans = transition_matrix(hmm, control_seq[1])
trans = transition_matrix(hmm, control_seq[2])
M = typeof(similar(trans, R))

γ = Matrix{R}(undef, N, T)
Expand Down
9 changes: 6 additions & 3 deletions src/inference/logdensity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ function joint_logdensityof(
trans = transition_matrix(hmm, control_seq[t])
logL += log(trans[state_seq[t], state_seq[t + 1]])
end
# Observations
for t in t1:t2
dists = obs_distributions(hmm, control_seq[t])
# Priori: P(Y_{1}|X_{1},U_{1})
dists = obs_distributions(hmm, control_seq[t1], missing)
logL += logdensityof(dists[state_seq[t1]], obs_seq[t1])
# Observations: P(Y_{t}|Y_{t-1},X_{t},U_{t})
for t in (t1 + 1):t2
dists = obs_distributions(hmm, control_seq[t], previous_obs(hmm, obs_seq, t))
logL += logdensityof(dists[state_seq[t]], obs_seq[t])
end
end
Expand Down
6 changes: 4 additions & 2 deletions src/inference/viterbi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ function _viterbi!(
t1, t2 = seq_limits(seq_ends, k)

logBₜ₁ = view(logB, :, t1)
obs_logdensities!(logBₜ₁, hmm, obs_seq[t1], control_seq[t1])
obs_logdensities!(logBₜ₁, hmm, obs_seq[t1], control_seq[t1], missing)
loginit = log_initialization(hmm)
ϕ[:, t1] .= loginit .+ logBₜ₁

for t in (t1 + 1):t2
logBₜ = view(logB, :, t)
obs_logdensities!(logBₜ, hmm, obs_seq[t], control_seq[t])
obs_logdensities!(
logBₜ, hmm, obs_seq[t], control_seq[t], previous_obs(hmm, obs_seq, t)
)
logtrans = log_transition_matrix(hmm, control_seq[t - 1])
ϕₜ, ϕₜ₋₁ = view(ϕ, :, t), view(ϕ, :, t - 1)
ψₜ = view(ψ, :, t)
Expand Down
27 changes: 17 additions & 10 deletions src/types/abstract_hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
- [`forward_backward`](@ref)
- [`baum_welch`](@ref) (if `[fit!](@ref)` is implemented)
"""
abstract type AbstractHMM end
abstract type AbstractHMM{ar} end

@inline DensityInterface.DensityKind(::AbstractHMM) = HasDensity()

Expand All @@ -46,8 +46,8 @@
function Base.eltype(hmm::AbstractHMM, obs, control)
init_type = eltype(initialization(hmm))
trans_type = eltype(transition_matrix(hmm, control))
dist = obs_distributions(hmm, control)[1]
logdensity_type = typeof(logdensityof(dist, obs))
dists = obs_distributions(hmm, control, obs)
logdensity_type = typeof(logdensityof(dists[1], obs))
return promote_type(init_type, trans_type, logdensity_type)
end

Expand Down Expand Up @@ -89,13 +89,13 @@
!!! note
When processing sequences, the control at time `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`).
"""
function log_transition_matrix(hmm::AbstractHMM, control)
return elementwise_log(transition_matrix(hmm, control))
end
log_transition_matrix(hmm::AbstractHMM, control) =
elementwise_log(transition_matrix(hmm, control))

"""
obs_distributions(hmm)
obs_distributions(hmm, control)
obs_distributions(hmm, control, obs)

Return a vector of observation distributions, one for each state of `hmm` (possibly when `control` is applied).

Expand All @@ -109,9 +109,16 @@

## Fallbacks for no control

initialization(hmm::AbstractHMM, ::Nothing) = initialization(hmm)

Check warning on line 112 in src/types/abstract_hmm.jl

View check run for this annotation

Codecov / codecov/patch

src/types/abstract_hmm.jl#L112

Added line #L112 was not covered by tests
transition_matrix(hmm::AbstractHMM, ::Nothing) = transition_matrix(hmm)
log_transition_matrix(hmm::AbstractHMM, ::Nothing) = log_transition_matrix(hmm)
obs_distributions(hmm::AbstractHMM, ::Nothing) = obs_distributions(hmm)
function obs_distributions(hmm::AbstractHMM, control, ::Any)
return obs_distributions(hmm, control)
end

previous_obs(::AbstractHMM{false}, obs_seq::AbstractVector, t::Integer) = nothing
previous_obs(::AbstractHMM{true}, obs_seq::AbstractVector, t::Integer) = obs_seq[t - 1]

Check warning on line 121 in src/types/abstract_hmm.jl

View check run for this annotation

Codecov / codecov/patch

src/types/abstract_hmm.jl#L121

Added line #L121 was not covered by tests

"""
StatsAPI.fit!(
Expand All @@ -128,9 +135,9 @@
## Fill logdensities

function obs_logdensities!(
logb::AbstractVector{T}, hmm::AbstractHMM, obs, control
logb::AbstractVector{T}, hmm::AbstractHMM, obs, control, prev_obs
) where {T}
dists = obs_distributions(hmm, control)
dists = obs_distributions(hmm, control, prev_obs)
@simd for i in eachindex(logb, dists)
logb[i] = logdensityof(dists[i], obs)
end
Expand Down Expand Up @@ -164,13 +171,13 @@
)
end

dists1 = obs_distributions(hmm, control_seq[1])
dists1 = obs_distributions(hmm, control_seq[1], missing)
obs1 = rand(rng, dists1[state1])
obs_seq = Vector{typeof(obs1)}(undef, T)
obs_seq[1] = obs1

for t in 2:T
dists = obs_distributions(hmm, control_seq[t])
dists = obs_distributions(hmm, control_seq[t], previous_obs(hmm, obs_seq, t))
obs_seq[t] = rand(rng, dists[state_seq[t]])
end
return (; state_seq=state_seq, obs_seq=obs_seq)
Expand Down
2 changes: 1 addition & 1 deletion src/types/hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct HMM{
VD<:AbstractVector,
Vl<:AbstractVector,
Ml<:AbstractMatrix,
} <: AbstractHMM
} <: AbstractHMM{false}
"initial state probabilities"
init::V
"state transition probabilities"
Expand Down
20 changes: 20 additions & 0 deletions test/autoregressive.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using Distributions
using HiddenMarkovModels
const HMMs = HiddenMarkovModels

struct AutoRegressiveGaussianHMM{T} <: AbstractHMM{true}
init::Vector{T}
trans::Matrix{T}
a::Vector{T}
b::Vector{T}
end

const ARGHMM = AutoRegressiveGaussianHMM

HMMs.initialization(hmm::ARGHMM) = hmm.init
HMMs.transition_matrix(hmm::ARGHMM) = hmm.trans

function HMMs.obs_distributions(hmm::ARGHMM, _control, prev_obs)
(; a, b) = hmm
return [Normal(a[i] * prev_obs + b[i], 1.0) for i in 1:length(hmm)]
end
26 changes: 26 additions & 0 deletions test/discretecontrolarhmm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using Distributions
using HiddenMarkovModels
const HMMs = HiddenMarkovModels

struct DiscreteCARHMM{T<:Number} <: AbstractHMM{true}
# Initial distribution P(X_{1}|U_{1}), one vector for each control
init::Vector{Vector{T}}
# Transition matrix P(X_{t}|X_{t-1}, U_{t}), one matrix for each control
trans::Vector{Matrix{T}}
# Emission matriz P(Y_{t}|X_{t}, U_{t}), one matriz for each control and each possible observation
dists::Vector{Vector{Matrix{T}}}
# Prior Distribution for P(Y_{1}|X_{1}, U_{1}), one matriz for each control
prior::Vector{Matrix{T}}
end

HMMs.initialization(hmm::DiscreteCARHMM, control) = hmm.init[control]

HMMs.transition_matrix(hmm::DiscreteCARHMM, control) = hmm.trans[control]

function HMMs.obs_distributions(hmm::DiscreteCARHMM, control, prev_obs)
return [Categorical(hmm.dists[control][prev_obs][i, :]) for i in 1:size(hmm, control)]
end

function HMMs.obs_distributions(hmm::DiscreteCARHMM, control, ::Missing)
return [Categorical(hmm.prior[control][i, :]) for i in 1:size(hmm, control)]
end
Loading