Skip to content

Commit

Permalink
Force controls everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Feb 24, 2024
1 parent 64262cd commit 58da66e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HiddenMarkovModels"
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
authors = ["Guillaume Dalle"]
version = "0.4.2"
version = "0.5.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
5 changes: 3 additions & 2 deletions examples/controlled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ struct ControlledGaussianHMM{T} <: AbstractHMM
end

#=
In state $i$ with a vector of controls $u$, our observation is given by the linear model $y \sim \mathcal{N}(\beta_i^\top u, 1)$.
In state $i$ with a vector of controls $u$, our observation is given by the linear model $y \sim \mathcal{N}(\beta_i^\top u, 1)$.
Controls must be provided to both `transition_matrix` and `obs_distributions` even if they are only used by one.
=#

function HMMs.initialization(hmm::ControlledGaussianHMM)
return hmm.init
end

function HMMs.transition_matrix(hmm::ControlledGaussianHMM)
function HMMs.transition_matrix(hmm::ControlledGaussianHMM, control::AbstractVector)
return hmm.trans
end

Expand Down
5 changes: 3 additions & 2 deletions src/types/abstract_hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ function obs_distributions end

## Fallbacks for no control

transition_matrix(hmm::AbstractHMM, control) = transition_matrix(hmm)
obs_distributions(hmm::AbstractHMM, control) = obs_distributions(hmm)
transition_matrix(hmm::AbstractHMM, ::Nothing) = transition_matrix(hmm)
log_transition_matrix(hmm::AbstractHMM, ::Nothing) = log_transition_matrix(hmm)
obs_distributions(hmm::AbstractHMM, ::Nothing) = obs_distributions(hmm)

"""
StatsAPI.fit!(
Expand Down

0 comments on commit 58da66e

Please sign in to comment.