diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 59ea410..31dbee5 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -54,7 +54,7 @@ function initialize_forward( control_seq::AbstractVector; seq_ends::AbstractVectorOrNTuple{Int}, ) - N, T, K = length(hmm), length(obs_seq), length(seq_ends) + N, T, K = length(hmm, control_seq[1]), length(obs_seq), length(seq_ends) R = eltype(hmm, obs_seq[1], control_seq[1]) α = Matrix{R}(undef, N, T) logL = Vector{R}(undef, K) @@ -100,7 +100,7 @@ function _forward!( αₜ = view(α, :, t) Bₜ = view(B, :, t) if t == t1 - copyto!(αₜ, initialization(hmm)) + copyto!(αₜ, initialization(hmm, control_seq[t])) else αₜ₋₁ = view(α, :, t - 1) predict_next_state!(αₜ, hmm, αₜ₋₁, control_seq[t - 1]) diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index ba51ad4..07efd71 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -8,7 +8,7 @@ function initialize_forward_backward( seq_ends::AbstractVectorOrNTuple{Int}, transition_marginals=true, ) - N, T, K = length(hmm), length(obs_seq), length(seq_ends) + N, T, K = length(hmm, control_seq[1]), length(obs_seq), length(seq_ends) R = eltype(hmm, obs_seq[1], control_seq[1]) trans = transition_matrix(hmm, control_seq[1]) M = typeof(similar(trans, R)) diff --git a/src/inference/logdensity.jl b/src/inference/logdensity.jl index 7a174be..b38a350 100644 --- a/src/inference/logdensity.jl +++ b/src/inference/logdensity.jl @@ -30,7 +30,7 @@ function joint_logdensityof( for k in eachindex(seq_ends) t1, t2 = seq_limits(seq_ends, k) # Initialization - init = initialization(hmm) + init = initialization(hmm, control_seq[t1]) logL += log(init[state_seq[t1]]) # Transitions for t in t1:(t2 - 1) diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index 58c9a88..61e3391 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -26,7 +26,7 @@ function initialize_viterbi( control_seq::AbstractVector; seq_ends::AbstractVectorOrNTuple{Int}, ) - N, T, K = length(hmm), length(obs_seq), length(seq_ends) + N, T, K = length(hmm, control_seq[1]), length(obs_seq), length(seq_ends) R = eltype(hmm, obs_seq[1], control_seq[1]) q = Vector{Int}(undef, T) logL = Vector{R}(undef, K) @@ -49,7 +49,7 @@ function _viterbi!( logBₜ₁ = view(logB, :, t1) obs_logdensities!(logBₜ₁, hmm, obs_seq[t1], control_seq[t1], missing) - loginit = log_initialization(hmm) + loginit = log_initialization(hmm, control_seq[t1]) ϕ[:, t1] .= loginit .+ logBₜ₁ for t in (t1 + 1):t2 diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index 13a64b8..da1f9b3 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -1,3 +1,6 @@ +############################################################################ +# 1. TYPE # +############################################################################ """ AbstractHMM @@ -23,19 +26,29 @@ Any `AbstractHMM` which satisfies the interface can be given to the following fu - [`forward_backward`](@ref) - [`baum_welch`](@ref) (if `[fit!](@ref)` is implemented) """ -abstract type AbstractHMM{ar} end +abstract type AbstractHMM{T} end @inline DensityInterface.DensityKind(::AbstractHMM) = HasDensity() -## Interface +############################################################################ +# 2. INTERFACE # +############################################################################ + +#------------------------------ 2.1. length -------------------------------# """ length(hmm) Return the number of states of `hmm`. """ Base.length(hmm::AbstractHMM) = length(initialization(hmm)) + +Base.length(hmm::AbstractHMM, control) = length(initialization(hmm, control)) + +Base.length(hmm::AbstractHMM, ::Nothing) = length(initialization(hmm)) + +#------------------------------ 2.2. eltype -------------------------------# """ eltype(hmm, obs, control) @@ -44,13 +57,15 @@ Return a type that can accommodate forward-backward computations for `hmm` on ob It is typically a promotion between the element type of the initialization, the element type of the transition matrix, and the type of an observation logdensity evaluated at `obs`. """ function Base.eltype(hmm::AbstractHMM, obs, control) - init_type = eltype(initialization(hmm)) - trans_type = eltype(transition_matrix(hmm, control)) - dist = obs_distributions(hmm, control, obs)[1] - logdensity_type = typeof(logdensityof(dist, obs)) - return promote_type(init_type, trans_type, logdensity_type) + init_type = eltype(initialization(hmm, control)) + trans_type = eltype(transition_matrix(hmm, control)) + dist = obs_distributions(hmm, control, obs)[1] + logdensity_type = typeof(logdensityof(dist, obs)) + return promote_type(init_type, trans_type, logdensity_type) end + +#--------------------------- 2.3. initialization --------------------------# """ initialization(hmm) @@ -58,6 +73,14 @@ Return the vector of initial state probabilities for `hmm`. """ function initialization end +initialization(hmm::AbstractHMM) = hmm.init + +initialization(hmm::AbstractHMM, control) = hmm.init + +initialization(hmm::AbstractHMM, ::Nothing) = hmm.init + + +#------------------------ 2.4. log_initialization -------------------------# """ log_initialization(hmm) @@ -66,7 +89,13 @@ Return the vector of initial state log-probabilities for `hmm`. Falls back on `initialization`. """ log_initialization(hmm::AbstractHMM) = elementwise_log(initialization(hmm)) + +log_initialization(hmm::AbstractHMM, control) = elementwise_log(initialization(hmm, control)) + +log_initialization(hmm::AbstractHMM, ::Nothing) = elementwise_log(initialization(hmm)) + +#------------------------- 2.5. transition_matrix -------------------------# """ transition_matrix(hmm) transition_matrix(hmm, control) @@ -74,10 +103,18 @@ log_initialization(hmm::AbstractHMM) = elementwise_log(initialization(hmm)) Return the matrix of state transition probabilities for `hmm` (possibly when `control` is applied). !!! note - When processing sequences, the control at time `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`). + When processing sequences, the control at time `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`). """ function transition_matrix end +transition_matrix(hmm::AbstractHMM) = hmm.trans + +transition_matrix(hmm::AbstractHMM, control) = transition_matrix(hmm) + +transition_matrix(hmm::AbstractHMM, ::Nothing) = transition_matrix(hmm) + + +#----------------------- 2.6. log_transition_matrix -----------------------# """ log_transition_matrix(hmm) log_transition_matrix(hmm, control) @@ -89,10 +126,15 @@ Falls back on `transition_matrix`. !!! note When processing sequences, the control at time `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`). """ -function log_transition_matrix(hmm::AbstractHMM, control) - return elementwise_log(transition_matrix(hmm, control)) -end +log_transition_matrix(hmm::AbstractHMM) = elementwise_log(transition_matrix(hmm)) + +log_transition_matrix(hmm::AbstractHMM, control) = elementwise_log(transition_matrix(hmm, control)) + +log_transition_matrix(hmm::AbstractHMM, ::Nothing) = elementwise_log(transition_matrix(hmm)) + + +#------------------------- 2.7. obs_distributions -------------------------# """ obs_distributions(hmm) obs_distributions(hmm, control) @@ -107,18 +149,27 @@ These distribution objects should implement """ function obs_distributions end -## Fallbacks for no control +obs_distributions(hmm::AbstractHMM) = [hmm.dists[i] for i ∈ 1:length(hmm)] -transition_matrix(hmm::AbstractHMM, ::Nothing) = transition_matrix(hmm) -log_transition_matrix(hmm::AbstractHMM, ::Nothing) = log_transition_matrix(hmm) +obs_distributions(hmm::AbstractHMM, control, prev_obs) = obs_distributions(hmm, control, prev_obs) + +### Fallback when it is not autoregressive and there is no control obs_distributions(hmm::AbstractHMM, ::Nothing) = obs_distributions(hmm) -function obs_distributions(hmm::AbstractHMM, control, ::Union{Nothing,Missing}) - return obs_distributions(hmm, control) -end +### Fallback when it is autoregressive, but previous observation is Missing +obs_distributions(hmm::AbstractHMM, control, ::Missing) = obs_distributions(hmm, control) + +### Fallback when it is autoregressive and there is no control, but observation is missing +obs_distributions(hmm::AbstractHMM, ::Nothing, ::Missing) = obs_distributions(hmm) + + +#---------------------------- 2.8. previous_obs ---------------------------# previous_obs(::AbstractHMM{false}, obs_seq::AbstractVector, t::Integer) = nothing + previous_obs(::AbstractHMM{true}, obs_seq::AbstractVector, t::Integer) = obs_seq[t - 1] + +#--------------------------- 2.9. StatsAPI.fit! ---------------------------# """ StatsAPI.fit!( hmm, fb_storage::ForwardBackwardStorage, @@ -131,21 +182,25 @@ This function is allowed to reuse `fb_storage` as a scratch space, so its conten """ StatsAPI.fit! -## Fill logdensities - + +#------------------------- 2.10. obs_logdensities! ------------------------# function obs_logdensities!( - logb::AbstractVector{T}, hmm::AbstractHMM, obs, control, prev_obs -) where {T} - dists = obs_distributions(hmm, control, prev_obs) - @simd for i in eachindex(logb, dists) - logb[i] = logdensityof(dists[i], obs) - end - @argcheck maximum(logb) < typemax(T) - return nothing + logb::AbstractVector{T}, hmm::AbstractHMM, obs, control, prev_obs + ) where {T} + dists = obs_distributions(hmm, control, prev_obs) + @simd for i in eachindex(logb, dists) + logb[i] = logdensityof(dists[i], obs) + end + @argcheck maximum(logb) < typemax(T) + return nothing end -## Sampling - + +############################################################################ +# 3. SAMPLING # +############################################################################ + +# <------------- Didn't touch it yet! """ rand([rng,] hmm, T) rand([rng,] hmm, control_seq) @@ -155,50 +210,93 @@ Simulate `hmm` for `T` time steps, or when the sequence `control_seq` is applied Return a named tuple `(; state_seq, obs_seq)`. """ function Random.rand(rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVector) - T = length(control_seq) - dummy_log_probas = fill(-Inf, length(hmm)) - - init = initialization(hmm) - state_seq = Vector{Int}(undef, T) - state1 = rand(rng, LightCategorical(init, dummy_log_probas)) - state_seq[1] = state1 - - @views for t in 1:(T - 1) - trans = transition_matrix(hmm, control_seq[t]) - state_seq[t + 1] = rand( - rng, LightCategorical(trans[state_seq[t], :], dummy_log_probas) - ) - end - - dists1 = obs_distributions(hmm, control_seq[1], missing) - obs1 = rand(rng, dists1[state1]) - obs_seq = Vector{typeof(obs1)}(undef, T) - obs_seq[1] = obs1 - - for t in 2:T - dists = obs_distributions(hmm, control_seq[t], previous_obs(hmm, obs_seq, t)) - obs_seq[t] = rand(rng, dists[state_seq[t]]) - end - return (; state_seq=state_seq, obs_seq=obs_seq) + T = length(control_seq) + dummy_log_probas = fill(-Inf, length(hmm)) + + init = initialization(hmm, control) + state_seq = Vector{Int}(undef, T) + state1 = rand(rng, LightCategorical(init, dummy_log_probas)) + state_seq[1] = state1 + + @views for t in 1:(T - 1) + trans = transition_matrix(hmm, control_seq[t]) + state_seq[t + 1] = rand( + rng, LightCategorical(trans[state_seq[t], :], dummy_log_probas) + ) + end + + dists1 = obs_distributions(hmm, control_seq[1], missing) + obs1 = rand(rng, dists1[state1]) + obs_seq = Vector{typeof(obs1)}(undef, T) + obs_seq[1] = obs1 + + for t in 2:T + dists = obs_distributions(hmm, control_seq[t], previous_obs(hmm, obs_seq, t)) + obs_seq[t] = rand(rng, dists[state_seq[t]]) + end + return (; state_seq=state_seq, obs_seq=obs_seq) end function Random.rand(hmm::AbstractHMM, control_seq::AbstractVector) - return rand(default_rng(), hmm, control_seq) + return rand(default_rng(), hmm, control_seq) end function Random.rand(rng::AbstractRNG, hmm::AbstractHMM, T::Integer) - return rand(rng, hmm, Fill(nothing, T)) + return rand(rng, hmm, Fill(nothing, T)) end function Random.rand(hmm::AbstractHMM, T::Integer) - return rand(hmm, Fill(nothing, T)) + return rand(hmm, Fill(nothing, T)) end -## Prior - + +############################################################################ +# 4. PRIOR # +############################################################################ """ logdensityof(hmm) Return the prior loglikelihood associated with the parameters of `hmm`. """ DensityInterface.logdensityof(hmm::AbstractHMM) = false + + +############################################################################ +# 5. ARHMM EXAMPLE # +############################################################################ +## Test scruct for Discrete ARHMM with control +struct DiscreteCARHMM{T<:Number} <: AbstractHMM{true} + # Initial distribution P(X_{1}|U_{1}), one vector for each control + init::Vector{Vector{T}} + # Transition matrix P(X_{t}|X_{t-1}, U_{t}), one matrix for each control + trans::Vector{Matrix{T}} + # Emission matriz P(Y_{t}|X_{t}, U_{t}), one matriz for each control and each possible observation + dists::Vector{Vector{Matrix{T}}} + # Prior Distribution for P(Y_{1}|X_{1}, U_{1}), one matriz for each control + prior::Vector{Matrix{T}} +end + +initialization(hmm::DiscreteCARHMM, control) = hmm.init[control] + +transition_matrix(hmm::DiscreteCARHMM, control) = hmm.trans[control] + +obs_distributions(hmm::DiscreteCARHMM, control, prev_obs) = [Categorical(hmm.dists[control][prev_obs][i,:]) for i in 1:length(hmm, control)] + +obs_distributions(hmm::DiscreteCARHMM, control, ::Missing) = [Categorical(hmm.prior[control][i,:]) for i in 1:length(hmm, control)] + + +## Test scruct for Discrete ARHMM without control +struct DiscreteARHMM{T<:Number} <: AbstractHMM{true} + # Initial distribution P(X_{1}) + init::Vector{T} + # Transition matrix P(X_{t}|X_{t-1}) + trans::Matrix{T} + # Emission matriz P(Y_{t}|X_{t}) + dists::Vector{Matrix{T}} + # Prior Distribution for P(Y_{1}|X_{1}) + prior::Matrix{T} +end + +obs_distributions(hmm::DiscreteARHMM, ::Nothing, prev_obs) = [Categorical(hmm.dists[prev_obs][i,:]) for i in 1:length(hmm)] + +obs_distributions(hmm::DiscreteARHMM, ::Nothing, ::Missing) = [Categorical(hmm.prior[i,:]) for i in 1:length(hmm)] diff --git a/test/arhmm_testing_plutonotebook.jl b/test/arhmm_testing_plutonotebook.jl new file mode 100644 index 0000000..4e4ab82 --- /dev/null +++ b/test/arhmm_testing_plutonotebook.jl @@ -0,0 +1,1638 @@ +### A Pluto.jl notebook ### +# v0.20.3 + +using Markdown +using InteractiveUtils + +# ╔═╡ 455866a4-a5ef-11ef-04ff-f347c85da0d8 +begin + using ArgCheck: @argcheck + using Base: RefValue + using Base.Threads: @threads + using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, rrule_via_ad + using DensityInterface: DensityInterface, DensityKind, HasDensity, NoDensity, logdensityof + using DocStringExtensions + using FillArrays: Fill + using LinearAlgebra: Transpose, axpy!, dot, ldiv!, lmul!, mul!, parent + using Random: Random, AbstractRNG, default_rng + using SparseArrays: AbstractSparseArray, SparseMatrixCSC, nonzeros, nnz, nzrange, rowvals + using StatsAPI: StatsAPI, fit, fit! + using StatsFuns: log2π +end + +# ╔═╡ b8b4c6ec-a8f3-4609-a0c0-15652c0d15be +begin + using Distributions + using PlutoUI + TableOfContents() +end + +# ╔═╡ a5c89117-9642-4c28-bf5e-7c93184934b4 +md""" +## 1. abstract_hmm.jl +""" + +# ╔═╡ 5d513c86-0595-4fee-8312-9927ea731fde +md""" +## 2. linalg.jl +""" + +# ╔═╡ b6abd844-345b-4c36-b6c0-699fa5c45fd9 +md""" +## 3. limits.jl +""" + +# ╔═╡ ebfa27fc-bdb0-4aaa-8930-85025df33b7c +md""" +## 4. forward.jl +""" + +# ╔═╡ f2599655-855a-4c41-a6ed-c22c16d792a1 +md""" +## 5. logdensity.jl +""" + +# ╔═╡ e35f3103-c5ec-4a6f-bfc0-76326f8eed54 +md""" +## 6. forward_backward.jl.jl +""" + +# ╔═╡ 71193717-7f4e-4db6-9669-149e7d6405af +md""" +## 7. viterbi.jl +""" + +# ╔═╡ ffb943ac-b403-4f71-a19b-0875748674ed +md""" +## 8. Test +""" + +# ╔═╡ 3a00c710-e6e6-4b5e-875e-7bbcb173d8a3 +# Supposing control Uₜ, |Uₜ| = 2, state Xₜ, |Xₜ| = 2, and observation Yₜ, |Yₜ| = 3 + +# ╔═╡ 8d99e582-fed1-42bd-8196-5f73cffe5a05 +# Discrete ARHMM without control + +# ╔═╡ c3ed40cc-0921-411e-aa8b-547d35fa1842 +uᵢ = [0.5, 0.5] + +# ╔═╡ c24abed8-65bc-4bb8-9444-64377c90be21 +tᵢⱼ = [0.3 0.7; 0.4 0.6] + +# ╔═╡ 4b6ecf0d-6330-4fc5-b31a-f49bbc03d397 +eʳᵢⱼ = [[0.25 0.5 0.25; 0.4 0.4 0.2], [0.7 0.2 0.1; 0.3 0.3 0.4], [0.4 0.2 0.4; 0.1 0.1 0.8]] + +# ╔═╡ 1c825864-2332-4e99-8d3a-4583f4fb56e2 +pᵢⱼ = [0.4 0.4 0.2; 0.3 0.3 0.4] + +# ╔═╡ 3b669bd8-21e1-42d5-b2d6-45340c57e1e2 +# Discrete ARHMM with control + +# ╔═╡ dfa9f848-b928-4299-8347-2ac492958df7 +uᵏᵢ = [[0.1, 0.9], [0.5, 0.5]] + +# ╔═╡ b454b317-29a9-41ec-82a7-9c48a88f2576 +tᵏᵢⱼ = [[0.3 0.7; 0.4 0.6], [0.75 0.25; 0.5 0.5]] + +# ╔═╡ b1377b4b-ce15-4eb8-83cb-740a58673ac6 +eᵏʳᵢⱼ = [[[0.25 0.5 0.25; 0.4 0.4 0.2], [0.7 0.2 0.1; 0.3 0.3 0.4], [0.4 0.2 0.4; 0.1 0.1 0.8]], [[0.3 0.2 0.5; 0.2 0.4 0.4], [0.5 0.25 0.25; 0.5 0.4 0.1], [0.25 0.25 0.5; 0.2 0.6 0.2]]] + +# ╔═╡ e7acfa37-371b-437c-aa78-47e37500d414 +begin + nstates = 2 + ncontrol = 2 + nobservations = 3 + N = 20 + true_control_seq = rand([1, 2], N) + true_state_seq = Int64[] + push!(true_state_seq, rand(Categorical(uᵏᵢ[true_control_seq[1]]))) + for i ∈ 2:N + push!(true_state_seq, rand(Categorical(tᵏᵢⱼ[true_control_seq[i]][true_state_seq[i-1],:]))) + end + true_obs_seq = Int64[] + push!(true_obs_seq, rand(Categorical(eᵏʳᵢⱼ[true_control_seq[1]][rand([1,2,3])][true_state_seq[1],:]))) + for i ∈ 2:N + push!(true_obs_seq, rand(Categorical(eᵏʳᵢⱼ[true_control_seq[i]][true_obs_seq[i-1]][true_state_seq[i],:]))) + end +end + +# ╔═╡ 4d76c441-429f-4500-9c26-6247abad17ed +begin + + const AbstractVectorOrNTuple{T} = Union{AbstractVector{T},NTuple{N,T}} where {N} + + sum_to_one!(x) = ldiv!(sum(x), x) + + mynonzeros(x::AbstractArray) = x + mynonzeros(x::AbstractSparseArray) = nonzeros(x) + + mynnz(x::AbstractArray) = length(mynonzeros(x)) + + elementwise_log(x::AbstractArray) = log.(x) + + function elementwise_log(A::SparseMatrixCSC) + return SparseMatrixCSC(A.m, A.n, A.colptr, A.rowval, log.(A.nzval)) + end + + """ + mul_rows_cols!(B, l, A, r) + + Perform the in-place operation `B .= l .* A .* transpose(r)`. + """ + function mul_rows_cols!( + B::AbstractMatrix, l::AbstractVector, A::AbstractMatrix, r::AbstractVector + ) + B .= l .* A .* transpose(r) + return B + end + + function mul_rows_cols!( + B::SparseMatrixCSC, l::AbstractVector, A::SparseMatrixCSC, r::AbstractVector + ) + @argcheck axes(A, 1) == eachindex(r) + @argcheck axes(A, 2) == eachindex(l) + @argcheck size(A) == size(B) + @argcheck nnz(B) == nnz(A) + Brv = rowvals(B) + Bnz = nonzeros(B) + Anz = nonzeros(A) + @simd for j in axes(B, 2) + @argcheck nzrange(B, j) == nzrange(A, j) + @simd for k in nzrange(B, j) + i = Brv[k] + Bnz[k] = l[i] * Anz[k] * r[j] + end + end + return B + end + + """ + argmaxplus_transmul!(y, ind, A, x) + + Perform the in-place multiplication `transpose(A) * x` _in the sense of max-plus algebra_, store the result in `y`, and store the index of the maximum for each component of `y` in `ind`. + """ + function argmaxplus_transmul!( + y::AbstractVector{R}, + ind::AbstractVector{<:Integer}, + A::AbstractMatrix, + x::AbstractVector, + ) where {R} + @argcheck axes(A, 1) == eachindex(x) + @argcheck axes(A, 2) == eachindex(y) + fill!(y, typemin(R)) + fill!(ind, 0) + @simd for j in axes(A, 2) + @simd for i in axes(A, 1) + z = A[i, j] + x[i] + if z > y[j] + y[j] = z + ind[j] = i + end + end + end + return y + end + + function argmaxplus_transmul!( + y::AbstractVector{R}, + ind::AbstractVector{<:Integer}, + A::SparseMatrixCSC, + x::AbstractVector, + ) where {R} + @argcheck axes(A, 1) == eachindex(x) + @argcheck axes(A, 2) == eachindex(y) + Anz = nonzeros(A) + Arv = rowvals(A) + fill!(y, typemin(R)) + fill!(ind, 0) + @simd for j in axes(A, 2) + @simd for k in nzrange(A, j) + i = Arv[k] + z = Anz[k] + x[i] + if z > y[j] + y[j] = z + ind[j] = i + end + end + end + return y + end + +end + +# ╔═╡ 6444a4ab-5992-4a39-a2e9-8301f9a6ba91 +begin + ############################################################################ + # 1. TYPE # + ############################################################################ + """ + AbstractHMM + + Abstract supertype for an HMM amenable to simulation, inference and learning. + + # Interface + + To create your own subtype of `AbstractHMM`, you need to implement the following methods: + + - [`initialization`](@ref) + - [`transition_matrix`](@ref) + - [`obs_distributions`](@ref) + - [`fit!`](@ref) (for learning) + + # Applicable functions + + Any `AbstractHMM` which satisfies the interface can be given to the following functions: + + - [`rand`](@ref) + - [`logdensityof`](@ref) + - [`forward`](@ref) + - [`viterbi`](@ref) + - [`forward_backward`](@ref) + - [`baum_welch`](@ref) (if `[fit!](@ref)` is implemented) + """ + abstract type AbstractHMM{T} end + + @inline DensityInterface.DensityKind(::AbstractHMM) = HasDensity() + + + ############################################################################ + # 2. INTERFACE # + ############################################################################ + + + #------------------------------ 2.1. length -------------------------------# + """ + length(hmm) + + Return the number of states of `hmm`. + """ + Base.length(hmm::AbstractHMM) = length(initialization(hmm)) + + Base.length(hmm::AbstractHMM, control) = length(initialization(hmm, control)) + + Base.length(hmm::AbstractHMM, ::Nothing) = length(initialization(hmm)) + + + #------------------------------ 2.2. eltype -------------------------------# + """ + eltype(hmm, obs, control) + + Return a type that can accommodate forward-backward computations for `hmm` on observations similar to `obs`. + + It is typically a promotion between the element type of the initialization, the element type of the transition matrix, and the type of an observation logdensity evaluated at `obs`. + """ + function Base.eltype(hmm::AbstractHMM, obs, control) + init_type = eltype(initialization(hmm, control)) + trans_type = eltype(transition_matrix(hmm, control)) + dist = obs_distributions(hmm, control, obs)[1] + logdensity_type = typeof(logdensityof(dist, obs)) + return promote_type(init_type, trans_type, logdensity_type) + end + + + #--------------------------- 2.3. initialization --------------------------# + """ + initialization(hmm) + + Return the vector of initial state probabilities for `hmm`. + """ + function initialization end + + initialization(hmm::AbstractHMM) = hmm.init + + initialization(hmm::AbstractHMM, control) = hmm.init + + initialization(hmm::AbstractHMM, ::Nothing) = hmm.init + + + #------------------------ 2.4. log_initialization -------------------------# + """ + log_initialization(hmm) + + Return the vector of initial state log-probabilities for `hmm`. + + Falls back on `initialization`. + """ + log_initialization(hmm::AbstractHMM) = elementwise_log(initialization(hmm)) + + log_initialization(hmm::AbstractHMM, control) = elementwise_log(initialization(hmm, control)) + + log_initialization(hmm::AbstractHMM, ::Nothing) = elementwise_log(initialization(hmm)) + + + #------------------------- 2.5. transition_matrix -------------------------# + """ + transition_matrix(hmm) + transition_matrix(hmm, control) + + Return the matrix of state transition probabilities for `hmm` (possibly when `control` is applied). + + !!! note + When processing sequences, the control at time `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`). + """ + function transition_matrix end + + transition_matrix(hmm::AbstractHMM) = hmm.trans + + transition_matrix(hmm::AbstractHMM, control) = transition_matrix(hmm) + + transition_matrix(hmm::AbstractHMM, ::Nothing) = transition_matrix(hmm) + + + #----------------------- 2.6. log_transition_matrix -----------------------# + """ + log_transition_matrix(hmm) + log_transition_matrix(hmm, control) + + Return the matrix of state transition log-probabilities for `hmm` (possibly when `control` is applied). + + Falls back on `transition_matrix`. + + !!! note + When processing sequences, the control at time `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`). + """ + log_transition_matrix(hmm::AbstractHMM) = elementwise_log(transition_matrix(hmm)) + + log_transition_matrix(hmm::AbstractHMM, control) = elementwise_log(transition_matrix(hmm, control)) + + log_transition_matrix(hmm::AbstractHMM, ::Nothing) = elementwise_log(transition_matrix(hmm)) + + + + #------------------------- 2.7. obs_distributions -------------------------# + """ + obs_distributions(hmm) + obs_distributions(hmm, control) + + Return a vector of observation distributions, one for each state of `hmm` (possibly when `control` is applied). + + These distribution objects should implement + + - `Random.rand(rng, dist)` for sampling + - `DensityInterface.logdensityof(dist, obs)` for inference + - `StatsAPI.fit!(dist, obs_seq, weight_seq)` for learning + """ + function obs_distributions end + + obs_distributions(hmm::AbstractHMM) = [hmm.dists[i] for i ∈ 1:length(hmm)] + + obs_distributions(hmm::AbstractHMM, control, prev_obs) = obs_distributions(hmm, control, prev_obs) + + ### Fallback when it is not autoregressive and there is no control + obs_distributions(hmm::AbstractHMM, ::Nothing) = obs_distributions(hmm) + + ### Fallback when it is autoregressive, but previous observation is Missing + obs_distributions(hmm::AbstractHMM, control, ::Missing) = obs_distributions(hmm, control) + + ### Fallback when it is autoregressive and there is no control, but observation is missing + obs_distributions(hmm::AbstractHMM, ::Nothing, ::Missing) = obs_distributions(hmm) + + + #---------------------------- 2.8. previous_obs ---------------------------# + previous_obs(::AbstractHMM{false}, obs_seq::AbstractVector, t::Integer) = nothing + + previous_obs(::AbstractHMM{true}, obs_seq::AbstractVector, t::Integer) = obs_seq[t - 1] + + + #--------------------------- 2.9. StatsAPI.fit! ---------------------------# + """ + StatsAPI.fit!( + hmm, fb_storage::ForwardBackwardStorage, + obs_seq, [control_seq]; seq_ends, + ) + + Update `hmm` in-place based on information generated during forward-backward. + + This function is allowed to reuse `fb_storage` as a scratch space, so its contents should not be trusted afterwards. + """ + StatsAPI.fit! + + + #------------------------- 2.10. obs_logdensities! ------------------------# + function obs_logdensities!( + logb::AbstractVector{T}, hmm::AbstractHMM, obs, control, prev_obs + ) where {T} + dists = obs_distributions(hmm, control, prev_obs) + @simd for i in eachindex(logb, dists) + logb[i] = logdensityof(dists[i], obs) + end + @argcheck maximum(logb) < typemax(T) + return nothing + end + + + ############################################################################ + # 3. SAMPLING # + ############################################################################ + + # <------------- Didn't touch it yet! + """ + rand([rng,] hmm, T) + rand([rng,] hmm, control_seq) + + Simulate `hmm` for `T` time steps, or when the sequence `control_seq` is applied. + + Return a named tuple `(; state_seq, obs_seq)`. + """ + function Random.rand(rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVector) + T = length(control_seq) + dummy_log_probas = fill(-Inf, length(hmm)) + + init = initialization(hmm, control) + state_seq = Vector{Int}(undef, T) + state1 = rand(rng, LightCategorical(init, dummy_log_probas)) + state_seq[1] = state1 + + @views for t in 1:(T - 1) + trans = transition_matrix(hmm, control_seq[t]) + state_seq[t + 1] = rand( + rng, LightCategorical(trans[state_seq[t], :], dummy_log_probas) + ) + end + + dists1 = obs_distributions(hmm, control_seq[1], missing) + obs1 = rand(rng, dists1[state1]) + obs_seq = Vector{typeof(obs1)}(undef, T) + obs_seq[1] = obs1 + + for t in 2:T + dists = obs_distributions(hmm, control_seq[t], previous_obs(hmm, obs_seq, t)) + obs_seq[t] = rand(rng, dists[state_seq[t]]) + end + return (; state_seq=state_seq, obs_seq=obs_seq) + end + + function Random.rand(hmm::AbstractHMM, control_seq::AbstractVector) + return rand(default_rng(), hmm, control_seq) + end + + function Random.rand(rng::AbstractRNG, hmm::AbstractHMM, T::Integer) + return rand(rng, hmm, Fill(nothing, T)) + end + + function Random.rand(hmm::AbstractHMM, T::Integer) + return rand(hmm, Fill(nothing, T)) + end + + + ############################################################################ + # 4. PRIOR # + ############################################################################ + """ + logdensityof(hmm) + + Return the prior loglikelihood associated with the parameters of `hmm`. + """ + DensityInterface.logdensityof(hmm::AbstractHMM) = false + + + ############################################################################ + # 5. ARHMM EXAMPLE # + ############################################################################ + ## Test scruct for Discrete ARHMM with control + struct DiscreteCARHMM{T<:Number} <: AbstractHMM{true} + # Initial distribution P(X_{1}|U_{1}), one vector for each control + init::Vector{Vector{T}} + # Transition matrix P(X_{t}|X_{t-1}, U_{t}), one matrix for each control + trans::Vector{Matrix{T}} + # Emission matriz P(Y_{t}|X_{t}, U_{t}), one matriz for each control and each possible observation + dists::Vector{Vector{Matrix{T}}} + # Prior Distribution for P(Y_{1}|X_{1}, U_{1}), one matriz for each control + prior::Vector{Matrix{T}} + end + + initialization(hmm::DiscreteCARHMM, control) = hmm.init[control] + + transition_matrix(hmm::DiscreteCARHMM, control) = hmm.trans[control] + + obs_distributions(hmm::DiscreteCARHMM, control, prev_obs) = [Categorical(hmm.dists[control][prev_obs][i,:]) for i in 1:length(hmm, control)] + + obs_distributions(hmm::DiscreteCARHMM, control, ::Missing) = [Categorical(hmm.prior[control][i,:]) for i in 1:length(hmm, control)] + + + ## Test scruct for Discrete ARHMM without control + struct DiscreteARHMM{T<:Number} <: AbstractHMM{true} + # Initial distribution P(X_{1}) + init::Vector{T} + # Transition matrix P(X_{t}|X_{t-1}) + trans::Matrix{T} + # Emission matriz P(Y_{t}|X_{t}) + dists::Vector{Matrix{T}} + # Prior Distribution for P(Y_{1}|X_{1}) + prior::Matrix{T} + end + + obs_distributions(hmm::DiscreteARHMM, ::Nothing, prev_obs) = [Categorical(hmm.dists[prev_obs][i,:]) for i in 1:length(hmm)] + + obs_distributions(hmm::DiscreteARHMM, ::Nothing, ::Missing) = [Categorical(hmm.prior[i,:]) for i in 1:length(hmm)] +end + +# ╔═╡ c90ce230-7640-4337-b6ae-e9369564e8e7 +begin + function predict_next_state!( + next_state_marginals::AbstractVector{<:Real}, + hmm::AbstractHMM, + current_state_marginals::AbstractVector{<:Real}, + control=nothing, + ) + trans = transition_matrix(hmm, control) + mul!(next_state_marginals, transpose(trans), current_state_marginals) + return next_state_marginals + end + + function predict_previous_state!( + previous_state_marginals::AbstractVector{<:Real}, + hmm::AbstractHMM, + current_state_marginals::AbstractVector{<:Real}, + control=nothing, + ) + trans = transition_matrix(hmm, control) + mul!(previous_state_marginals, trans, current_state_marginals) + return previous_state_marginals + end +end + +# ╔═╡ 5e9a1a49-0356-48ec-ad85-ed5c49653d2d +hmm₀ = DiscreteARHMM(uᵢ, tᵢⱼ, eʳᵢⱼ, pᵢⱼ) + +# ╔═╡ dc54a9a0-4fd6-4f64-9067-775d7674ebc1 +""" +$(SIGNATURES) + +Return a tuple `(t1, t2)` giving the begin and end indices of subsequence `k` within a set of sequences ending at `seq_ends`. +""" +function seq_limits(seq_ends::AbstractVectorOrNTuple{Int}, k::Integer) + if k == 1 + return 1, seq_ends[k] + else + return seq_ends[k - 1] + 1, seq_ends[k] + end +end + +# ╔═╡ f4fe5934-b4a3-412c-89e2-688bdffd7dd2 +begin + struct ForwardStorage{R} + "posterior last state marginals `α[i] = ℙ(X[T]=i | Y[1:T])`" + α::Matrix{R} + "one loglikelihood per observation sequence" + logL::Vector{R} + B::Matrix{R} + c::Vector{R} + end + + struct ForwardBackwardStorage{R,M<:AbstractMatrix{R}} + "posterior state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`" + γ::Matrix{R} + "posterior transition marginals `ξ[t][i,j] = ℙ(X[t]=i, X[t+1]=j | Y[1:T])`" + ξ::Vector{M} + "one loglikelihood per observation sequence" + logL::Vector{R} + B::Matrix{R} + α::Matrix{R} + c::Vector{R} + β::Matrix{R} + Bβ::Matrix{R} + end + + Base.eltype(::ForwardBackwardStorage{R}) where {R} = R + + const ForwardOrForwardBackwardStorage{R} = Union{ + ForwardStorage{R},ForwardBackwardStorage{R} + } + + """ + $(SIGNATURES) + """ + function initialize_forward( + hmm::AbstractHMM, + obs_seq::AbstractVector, + control_seq::AbstractVector; + seq_ends::AbstractVectorOrNTuple{Int}, + ) + N, T, K = length(hmm, control_seq[1]), length(obs_seq), length(seq_ends) + R = eltype(hmm, obs_seq[1], control_seq[1]) + α = Matrix{R}(undef, N, T) + logL = Vector{R}(undef, K) + B = Matrix{R}(undef, N, T) + c = Vector{R}(undef, T) + return ForwardStorage(α, logL, B, c) + end + + function _forward_digest_observation!( + current_state_marginals::AbstractVector{<:Real}, + current_obs_likelihoods::AbstractVector{<:Real}, + hmm::AbstractHMM, + obs, + control, + prev_obs, + ) + a, b = current_state_marginals, current_obs_likelihoods + + obs_logdensities!(b, hmm, obs, control, prev_obs) + logm = maximum(b) + b .= exp.(b .- logm) + + a .*= b + c = inv(sum(a)) + lmul!(c, a) + + logL = -log(c) + logm + return c, logL + end + + function _forward!( + storage::ForwardOrForwardBackwardStorage, + hmm::AbstractHMM, + obs_seq::AbstractVector, + control_seq::AbstractVector, + seq_ends::AbstractVectorOrNTuple{Int}, + k::Integer, + ) + (; α, B, c, logL) = storage + t1, t2 = seq_limits(seq_ends, k) + logL[k] = zero(eltype(logL)) + for t in t1:t2 + αₜ = view(α, :, t) + Bₜ = view(B, :, t) + if t == t1 + copyto!(αₜ, initialization(hmm, control_seq[t])) + else + αₜ₋₁ = view(α, :, t - 1) + predict_next_state!(αₜ, hmm, αₜ₋₁, control_seq[t - 1]) + end + prev_obs = t == t1 ? missing : previous_obs(hmm, obs_seq, t) + cₜ, logLₜ = _forward_digest_observation!( + αₜ, Bₜ, hmm, obs_seq[t], control_seq[t], prev_obs + ) + c[t] = cₜ + logL[k] += logLₜ + end + + @argcheck isfinite(logL[k]) + return nothing + end + + """ + $(SIGNATURES) + """ + function forward!( + storage::ForwardOrForwardBackwardStorage, + hmm::AbstractHMM, + obs_seq::AbstractVector, + control_seq::AbstractVector; + seq_ends::AbstractVectorOrNTuple{Int}, + ) + if seq_ends isa NTuple{1} + for k in eachindex(seq_ends) + _forward!(storage, hmm, obs_seq, control_seq, seq_ends, k) + end + else + @threads for k in eachindex(seq_ends) + _forward!(storage, hmm, obs_seq, control_seq, seq_ends, k) + end + end + return nothing + end + + """ + $(SIGNATURES) + + Apply the forward algorithm to infer the current state after sequence `obs_seq` for `hmm`. + + Return a tuple `(storage.α, storage.logL)` where `storage` is of type [`ForwardStorage`](@ref). + """ + function forward( + hmm::AbstractHMM, + obs_seq::AbstractVector, + control_seq::AbstractVector=Fill(nothing, length(obs_seq)); + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), + ) + storage = initialize_forward(hmm, obs_seq, control_seq; seq_ends) + forward!(storage, hmm, obs_seq, control_seq; seq_ends) + return storage.α, storage.logL + end + + nothing +end + +# ╔═╡ 4bcf1079-13c6-418e-a37e-ce71e5cc74f1 +begin + """ + $(SIGNATURES) + + Run the forward algorithm to compute the loglikelihood of `obs_seq` for `hmm`, integrating over all possible state sequences. + """ + function DensityInterface.logdensityof( + hmm::AbstractHMM, + obs_seq::AbstractVector, + control_seq::AbstractVector=Fill(nothing, length(obs_seq)); + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), + ) + _, logL = forward(hmm, obs_seq, control_seq; seq_ends) + return sum(logL) + end + + """ + $(SIGNATURES) + + Run the forward algorithm to compute the the joint loglikelihood of `obs_seq` and `state_seq` for `hmm`. + """ + function joint_logdensityof( + hmm::AbstractHMM, + obs_seq::AbstractVector, + state_seq::AbstractVector, + control_seq::AbstractVector=Fill(nothing, length(obs_seq)); + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), + ) + R = eltype(hmm, obs_seq[1], control_seq[1]) + logL = zero(R) + for k in eachindex(seq_ends) + t1, t2 = seq_limits(seq_ends, k) + # Initialization + init = initialization(hmm, control_seq[t1]) + logL += log(init[state_seq[t1]]) + # Transitions + for t in t1:(t2 - 1) + trans = transition_matrix(hmm, control_seq[t]) + logL += log(trans[state_seq[t], state_seq[t + 1]]) + end + # Observations + for t in t1:t2 + dists = obs_distributions(hmm, control_seq[t], previous_obs(hmm, obs_seq, t)) + logL += logdensityof(dists[state_seq[t]], obs_seq[t]) + end + end + return logL + end + + nothing +end + +# ╔═╡ 0a903f68-79dc-4a45-b0fe-e5e17eff3b6f +begin + """ + $(SIGNATURES) + """ + function initialize_forward_backward( + hmm::AbstractHMM, + obs_seq::AbstractVector, + control_seq::AbstractVector; + seq_ends::AbstractVectorOrNTuple{Int}, + transition_marginals=true, + ) + N, T, K = length(hmm, control_seq[1]), length(obs_seq), length(seq_ends) + R = eltype(hmm, obs_seq[1], control_seq[1]) + trans = transition_matrix(hmm, control_seq[1]) + M = typeof(similar(trans, R)) + + γ = Matrix{R}(undef, N, T) + ξ = Vector{M}(undef, T) + if transition_marginals + for t in 1:T + ξ[t] = similar(transition_matrix(hmm, control_seq[t]), R) + end + end + logL = Vector{R}(undef, K) + B = Matrix{R}(undef, N, T) + α = Matrix{R}(undef, N, T) + c = Vector{R}(undef, T) + β = Matrix{R}(undef, N, T) + Bβ = Matrix{R}(undef, N, T) + return ForwardBackwardStorage{R,M}(γ, ξ, logL, B, α, c, β, Bβ) + end + + function _forward_backward!( + storage::ForwardBackwardStorage{R}, + hmm::AbstractHMM, + obs_seq::AbstractVector, + control_seq::AbstractVector, + seq_ends::AbstractVectorOrNTuple{Int}, + k::Integer; + transition_marginals::Bool=true, + ) where {R} + (; α, β, c, γ, ξ, B, Bβ) = storage + t1, t2 = seq_limits(seq_ends, k) + + # Forward (fill B, α, c and logL) + _forward!(storage, hmm, obs_seq, control_seq, seq_ends, k) + + # Backward + β[:, t2] .= c[t2] + for t in (t2 - 1):-1:t1 + Bβ[:, t + 1] .= view(B, :, t + 1) .* view(β, :, t + 1) + βₜ = view(β, :, t) + Bβₜ₊₁ = view(Bβ, :, t + 1) + predict_previous_state!(βₜ, hmm, Bβₜ₊₁, control_seq[t]) + lmul!(c[t], βₜ) + end + Bβ[:, t1] .= view(B, :, t1) .* view(β, :, t1) + + # State marginals + γ[:, t1:t2] .= view(α, :, t1:t2) .* view(β, :, t1:t2) ./ view(c, t1:t2)' + + # Transition marginals + if transition_marginals + for t in t1:(t2 - 1) + trans = transition_matrix(hmm, control_seq[t]) + mul_rows_cols!(ξ[t], view(α, :, t), trans, view(Bβ, :, t + 1)) + end + ξ[t2] .= zero(R) + end + + return nothing + end + + """ + $(SIGNATURES) + """ + function forward_backward!( + storage::ForwardBackwardStorage, + hmm::AbstractHMM, + obs_seq::AbstractVector, + control_seq::AbstractVector; + seq_ends::AbstractVectorOrNTuple{Int}, + transition_marginals::Bool=true, + ) + if seq_ends isa NTuple{1} + for k in eachindex(seq_ends) + _forward_backward!( + storage, hmm, obs_seq, control_seq, seq_ends, k; transition_marginals + ) + end + else + @threads for k in eachindex(seq_ends) + _forward_backward!( + storage, hmm, obs_seq, control_seq, seq_ends, k; transition_marginals + ) + end + end + return nothing + end + + """ + $(SIGNATURES) + + Apply the forward-backward algorithm to infer the posterior state and transition marginals during sequence `obs_seq` for `hmm`. + + Return a tuple `(storage.γ, storage.logL)` where `storage` is of type [`ForwardBackwardStorage`](@ref). + """ + function forward_backward( + hmm::AbstractHMM, + obs_seq::AbstractVector, + control_seq::AbstractVector=Fill(nothing, length(obs_seq)); + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), + ) + transition_marginals = false + storage = initialize_forward_backward( + hmm, obs_seq, control_seq; seq_ends, transition_marginals + ) + forward_backward!(storage, hmm, obs_seq, control_seq; seq_ends, transition_marginals) + return storage.γ, storage.logL + end + + nothing +end + +# ╔═╡ 7b2b0560-726d-4bc0-acfa-16361da4f0b2 +begin + """ + $(TYPEDEF) + + # Fields + + Only the fields with a description are part of the public API. + + $(TYPEDFIELDS) + """ + struct ViterbiStorage{R} + "most likely state sequence `q[t] = argmaxᵢ ℙ(X[t]=i | Y[1:T])`" + q::Vector{Int} + "one joint loglikelihood per pair of observation sequence and most likely state sequence" + logL::Vector{R} + logB::Matrix{R} + ϕ::Matrix{R} + ψ::Matrix{Int} + end + + """ + $(SIGNATURES) + """ + function initialize_viterbi( + hmm::AbstractHMM, + obs_seq::AbstractVector, + control_seq::AbstractVector; + seq_ends::AbstractVectorOrNTuple{Int}, + ) + N, T, K = length(hmm, control_seq[1]), length(obs_seq), length(seq_ends) + R = eltype(hmm, obs_seq[1], control_seq[1]) + q = Vector{Int}(undef, T) + logL = Vector{R}(undef, K) + logB = Matrix{R}(undef, N, T) + ϕ = Matrix{R}(undef, N, T) + ψ = Matrix{Int}(undef, N, T) + return ViterbiStorage(q, logL, logB, ϕ, ψ) + end + + function _viterbi!( + storage::ViterbiStorage{R}, + hmm::AbstractHMM, + obs_seq::AbstractVector, + control_seq::AbstractVector, + seq_ends::AbstractVectorOrNTuple{Int}, + k::Integer, + ) where {R} + (; q, logB, ϕ, ψ, logL) = storage + t1, t2 = seq_limits(seq_ends, k) + + logBₜ₁ = view(logB, :, t1) + obs_logdensities!(logBₜ₁, hmm, obs_seq[t1], control_seq[t1], missing) + loginit = log_initialization(hmm, control_seq[t1]) + ϕ[:, t1] .= loginit .+ logBₜ₁ + + for t in (t1 + 1):t2 + logBₜ = view(logB, :, t) + obs_logdensities!( + logBₜ, hmm, obs_seq[t], control_seq[t], previous_obs(hmm, obs_seq, t) + ) + logtrans = log_transition_matrix(hmm, control_seq[t - 1]) + ϕₜ, ϕₜ₋₁ = view(ϕ, :, t), view(ϕ, :, t - 1) + ψₜ = view(ψ, :, t) + argmaxplus_transmul!(ϕₜ, ψₜ, logtrans, ϕₜ₋₁) + ϕₜ .+= logBₜ + end + + ϕₜ₂ = view(ϕ, :, t2) + q[t2] = argmax(ϕₜ₂) + logL[k] = ϕ[q[t2], t2] + for t in (t2 - 1):-1:t1 + q[t] = ψ[q[t + 1], t + 1] + end + + @argcheck isfinite(logL[k]) + return nothing + end + + """ + $(SIGNATURES) + """ + function viterbi!( + storage::ViterbiStorage{R}, + hmm::AbstractHMM, + obs_seq::AbstractVector, + control_seq::AbstractVector; + seq_ends::AbstractVectorOrNTuple{Int}, + ) where {R} + if seq_ends isa NTuple{1} + for k in eachindex(seq_ends) + _viterbi!(storage, hmm, obs_seq, control_seq, seq_ends, k) + end + else + @threads for k in eachindex(seq_ends) + _viterbi!(storage, hmm, obs_seq, control_seq, seq_ends, k) + end + end + return nothing + end + + """ + $(SIGNATURES) + + Apply the Viterbi algorithm to infer the most likely state sequence corresponding to `obs_seq` for `hmm`. + + Return a tuple `(storage.q, storage.logL)` where `storage` is of type [`ViterbiStorage`](@ref). + """ + function viterbi( + hmm::AbstractHMM, + obs_seq::AbstractVector, + control_seq::AbstractVector=Fill(nothing, length(obs_seq)); + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), + ) + storage = initialize_viterbi(hmm, obs_seq, control_seq; seq_ends) + viterbi!(storage, hmm, obs_seq, control_seq; seq_ends) + return storage.q, storage.logL + end + + nothing +end + +# ╔═╡ 519901db-5fb9-4323-822a-ff04e28285c9 +forward(hmm₀, true_obs_seq) + +# ╔═╡ 24b2383c-cf58-46a8-84f7-d00ddc1bbbe4 +forward_backward(hmm₀, true_obs_seq) + +# ╔═╡ be40ed94-e6e2-4838-98db-83d1de9eb830 +viterbi(hmm₀, true_obs_seq) + +# ╔═╡ 19b64180-184c-435f-9be6-40dd9f482f78 +pᵏᵢⱼ = [[0.4 0.4 0.2; 0.3 0.3 0.4], [0.25 0.5 0.25; 0.7 0.15 0.15]] + +# ╔═╡ 1bb32d8a-1961-4709-93ff-aa6f988f0ce5 +hmm = DiscreteCARHMM(uᵏᵢ, tᵏᵢⱼ, eᵏʳᵢⱼ, pᵏᵢⱼ) + +# ╔═╡ 3968aa53-4022-4299-b3ea-5d4086872935 +forward(hmm, true_obs_seq, true_control_seq) + +# ╔═╡ 48a686a4-a996-4c1b-8c48-f72d2b8aff4f +forward_backward(hmm, true_obs_seq, true_control_seq) + +# ╔═╡ 7d6012de-53b6-491d-95e4-eff55c2b80b5 +viterbi(hmm, true_obs_seq, true_control_seq) + +# ╔═╡ 3b034316-b7c8-48cc-911a-fdcea177cf2e +md""" +## 9. Comparison +""" + +# ╔═╡ 05e3d2f4-f793-44ce-a9f8-23e02b6dc303 +#https://cran.r-project.org/web/packages/LMest/vignettes/vignetteLMest.html + +# ╔═╡ f3bc9ea7-9e41-47f4-9d3d-fbc7d35bf5a2 +#https://github.com/TheoMichelot/hmmTMB + +# ╔═╡ 78c2591b-063a-4917-936d-28f17686eb7b +#https://github.com/hmmlearn/hmmlearn + +# ╔═╡ 00000000-0000-0000-0000-000000000001 +PLUTO_PROJECT_TOML_CONTENTS = """ +[deps] +ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" + +[compat] +ArgCheck = "~2.3.0" +ChainRulesCore = "~1.25.0" +DensityInterface = "~0.4.0" +Distributions = "~0.25.113" +DocStringExtensions = "~0.9.3" +FillArrays = "~1.13.0" +PlutoUI = "~0.7.60" +StatsAPI = "~1.7.0" +StatsFuns = "~1.3.2" +""" + +# ╔═╡ 00000000-0000-0000-0000-000000000002 +PLUTO_MANIFEST_TOML_CONTENTS = """ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.11.1" +manifest_format = "2.0" +project_hash = "cb45b4653fe301fab6134ac7f13a025a6c4ad1f3" + +[[deps.AbstractPlutoDingetjes]] +deps = ["Pkg"] +git-tree-sha1 = "6e1d2a35f2f90a4bc7c2ed98079b2ba09c35b83a" +uuid = "6e696c72-6542-2067-7265-42206c756150" +version = "1.3.2" + +[[deps.AliasTables]] +deps = ["PtrArrays", "Random"] +git-tree-sha1 = "9876e1e164b144ca45e9e3198d0b689cadfed9ff" +uuid = "66dad0bd-aa9a-41b7-9441-69ab47430ed8" +version = "1.1.3" + +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.2" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +version = "1.11.0" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +version = "1.11.0" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "3e4b134270b372f2ed4d4d0e936aabaefc1802bc" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.25.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" + +[[deps.ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.11.5" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.16.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.1.1+0" + +[[deps.DataAPI]] +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.16.0" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.20" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +version = "1.11.0" + +[[deps.DensityInterface]] +deps = ["InverseFunctions", "Test"] +git-tree-sha1 = "80c3e8639e3353e5d2912fb3a1916b8455e2494b" +uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d" +version = "0.4.0" + +[[deps.Distributions]] +deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] +git-tree-sha1 = "3101c32aab536e7a27b1763c0797dba151b899ad" +uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" +version = "0.25.113" +weakdeps = ["ChainRulesCore", "DensityInterface", "Test"] + + [deps.Distributions.extensions] + DistributionsChainRulesCoreExt = "ChainRulesCore" + DistributionsDensityInterfaceExt = "DensityInterface" + DistributionsTestExt = "Test" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" +version = "1.11.0" + +[[deps.FillArrays]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "6a70198746448456524cb442b8af316927ff3e1a" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "1.13.0" +weakdeps = ["PDMats", "SparseArrays", "Statistics"] + + [deps.FillArrays.extensions] + FillArraysPDMatsExt = "PDMats" + FillArraysSparseArraysExt = "SparseArrays" + FillArraysStatisticsExt = "Statistics" + +[[deps.FixedPointNumbers]] +deps = ["Statistics"] +git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.8.5" + +[[deps.HypergeometricFunctions]] +deps = ["LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] +git-tree-sha1 = "b1c2585431c382e3fe5805874bda6aea90a95de9" +uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" +version = "0.3.25" + +[[deps.Hyperscript]] +deps = ["Test"] +git-tree-sha1 = "179267cfa5e712760cd43dcae385d7ea90cc25a4" +uuid = "47d2ed2b-36de-50cf-bf87-49c2cf4b8b91" +version = "0.0.5" + +[[deps.HypertextLiteral]] +deps = ["Tricks"] +git-tree-sha1 = "7134810b1afce04bbc1045ca1985fbe81ce17653" +uuid = "ac1192a8-f4b3-4bfe-ba22-af5b92cd3ab2" +version = "0.9.5" + +[[deps.IOCapture]] +deps = ["Logging", "Random"] +git-tree-sha1 = "b6d6bfdd7ce25b0f9b2f6b3dd56b2673a66c8770" +uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" +version = "0.2.5" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +version = "1.11.0" + +[[deps.InverseFunctions]] +git-tree-sha1 = "a779299d77cd080bf77b97535acecd73e1c5e5cb" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.17" +weakdeps = ["Dates", "Test"] + + [deps.InverseFunctions.extensions] + InverseFunctionsDatesExt = "Dates" + InverseFunctionsTestExt = "Test" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "be3dc50a92e5a386872a493a10050136d4703f9b" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.6.1" + +[[deps.JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.4" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.6.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +version = "1.11.0" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.7.2+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.0+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +version = "1.11.0" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +version = "1.11.0" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.28" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +version = "1.11.0" + +[[deps.MIMEs]] +git-tree-sha1 = "65f28ad4b594aebe22157d6fac869786a255b7eb" +uuid = "6c6e2e6c-3030-632d-7369-2d6c69616d65" +version = "0.1.4" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +version = "1.11.0" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.6+0" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.2.0" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" +version = "1.11.0" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2023.12.12" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.27+1" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+2" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.OrderedCollections]] +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.6.3" + +[[deps.PDMats]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65" +uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" +version = "0.11.31" + +[[deps.Parsers]] +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.8.1" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.11.0" + + [deps.Pkg.extensions] + REPLExt = "REPL" + + [deps.Pkg.weakdeps] + REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.PlutoUI]] +deps = ["AbstractPlutoDingetjes", "Base64", "ColorTypes", "Dates", "FixedPointNumbers", "Hyperscript", "HypertextLiteral", "IOCapture", "InteractiveUtils", "JSON", "Logging", "MIMEs", "Markdown", "Random", "Reexport", "URIs", "UUIDs"] +git-tree-sha1 = "eba4810d5e6a01f612b948c9fa94f905b49087b0" +uuid = "7f904dfe-b85e-4ff6-b463-dae2292396a8" +version = "0.7.60" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.1" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.3" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +version = "1.11.0" + +[[deps.PtrArrays]] +git-tree-sha1 = "77a42d78b6a92df47ab37e177b2deac405e1c88f" +uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" +version = "1.2.1" + +[[deps.QuadGK]] +deps = ["DataStructures", "LinearAlgebra"] +git-tree-sha1 = "cda3b045cf9ef07a08ad46731f5a3165e56cf3da" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.11.1" + + [deps.QuadGK.extensions] + QuadGKEnzymeExt = "Enzyme" + + [deps.QuadGK.weakdeps] + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +version = "1.11.0" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.Rmath]] +deps = ["Random", "Rmath_jll"] +git-tree-sha1 = "852bd0f55565a9e973fcfee83a84413270224dc4" +uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" +version = "0.8.0" + +[[deps.Rmath_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "58cdd8fb2201a6267e1db87ff148dd6c1dbd8ad8" +uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" +version = "0.5.1+0" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +version = "1.11.0" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.2.1" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.11.0" + +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.4.0" +weakdeps = ["ChainRulesCore"] + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + +[[deps.Statistics]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0" +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.11.1" +weakdeps = ["SparseArrays"] + + [deps.Statistics.extensions] + SparseArraysExt = ["SparseArrays"] + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.7.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.34.3" + +[[deps.StatsFuns]] +deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] +git-tree-sha1 = "b423576adc27097764a90e163157bcfc9acf0f46" +uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +version = "1.3.2" +weakdeps = ["ChainRulesCore", "InverseFunctions"] + + [deps.StatsFuns.extensions] + StatsFunsChainRulesCoreExt = "ChainRulesCore" + StatsFunsInverseFunctionsExt = "InverseFunctions" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.7.0+0" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +version = "1.11.0" + +[[deps.Tricks]] +git-tree-sha1 = "7822b97e99a1672bfb1b49b668a6d46d58d8cbcb" +uuid = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" +version = "0.1.9" + +[[deps.URIs]] +git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" +uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" +version = "1.5.1" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +version = "1.11.0" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +version = "1.11.0" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+1" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.11.0+0" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.59.0+0" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+2" +""" + +# ╔═╡ Cell order: +# ╠═455866a4-a5ef-11ef-04ff-f347c85da0d8 +# ╠═b8b4c6ec-a8f3-4609-a0c0-15652c0d15be +# ╟─a5c89117-9642-4c28-bf5e-7c93184934b4 +# ╠═6444a4ab-5992-4a39-a2e9-8301f9a6ba91 +# ╟─5d513c86-0595-4fee-8312-9927ea731fde +# ╠═4d76c441-429f-4500-9c26-6247abad17ed +# ╟─b6abd844-345b-4c36-b6c0-699fa5c45fd9 +# ╠═dc54a9a0-4fd6-4f64-9067-775d7674ebc1 +# ╟─ebfa27fc-bdb0-4aaa-8930-85025df33b7c +# ╠═f4fe5934-b4a3-412c-89e2-688bdffd7dd2 +# ╠═c90ce230-7640-4337-b6ae-e9369564e8e7 +# ╟─f2599655-855a-4c41-a6ed-c22c16d792a1 +# ╠═4bcf1079-13c6-418e-a37e-ce71e5cc74f1 +# ╟─e35f3103-c5ec-4a6f-bfc0-76326f8eed54 +# ╠═0a903f68-79dc-4a45-b0fe-e5e17eff3b6f +# ╟─71193717-7f4e-4db6-9669-149e7d6405af +# ╠═7b2b0560-726d-4bc0-acfa-16361da4f0b2 +# ╟─ffb943ac-b403-4f71-a19b-0875748674ed +# ╠═3a00c710-e6e6-4b5e-875e-7bbcb173d8a3 +# ╠═e7acfa37-371b-437c-aa78-47e37500d414 +# ╠═8d99e582-fed1-42bd-8196-5f73cffe5a05 +# ╠═c3ed40cc-0921-411e-aa8b-547d35fa1842 +# ╠═c24abed8-65bc-4bb8-9444-64377c90be21 +# ╠═4b6ecf0d-6330-4fc5-b31a-f49bbc03d397 +# ╠═1c825864-2332-4e99-8d3a-4583f4fb56e2 +# ╠═5e9a1a49-0356-48ec-ad85-ed5c49653d2d +# ╠═519901db-5fb9-4323-822a-ff04e28285c9 +# ╠═24b2383c-cf58-46a8-84f7-d00ddc1bbbe4 +# ╠═be40ed94-e6e2-4838-98db-83d1de9eb830 +# ╠═3b669bd8-21e1-42d5-b2d6-45340c57e1e2 +# ╠═dfa9f848-b928-4299-8347-2ac492958df7 +# ╠═b454b317-29a9-41ec-82a7-9c48a88f2576 +# ╠═b1377b4b-ce15-4eb8-83cb-740a58673ac6 +# ╠═19b64180-184c-435f-9be6-40dd9f482f78 +# ╠═1bb32d8a-1961-4709-93ff-aa6f988f0ce5 +# ╠═3968aa53-4022-4299-b3ea-5d4086872935 +# ╠═48a686a4-a996-4c1b-8c48-f72d2b8aff4f +# ╠═7d6012de-53b6-491d-95e4-eff55c2b80b5 +# ╟─3b034316-b7c8-48cc-911a-fdcea177cf2e +# ╠═05e3d2f4-f793-44ce-a9f8-23e02b6dc303 +# ╠═f3bc9ea7-9e41-47f4-9d3d-fbc7d35bf5a2 +# ╠═78c2591b-063a-4917-936d-28f17686eb7b +# ╟─00000000-0000-0000-0000-000000000001 +# ╟─00000000-0000-0000-0000-000000000002