diff --git a/benchmark/Manifest.toml b/benchmark/Manifest.toml index ca459dfe..3801b946 100644 --- a/benchmark/Manifest.toml +++ b/benchmark/Manifest.toml @@ -155,7 +155,7 @@ deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[deps.HMMBenchmark]] -deps = ["BenchmarkTools", "CSV", "DataFrames", "Distributions", "HiddenMarkovModels", "LinearAlgebra", "Pkg", "Random", "SparseArrays", "StableRNGs", "Statistics"] +deps = ["BenchmarkTools", "CSV", "DataFrames", "Distributions", "FillArrays", "HiddenMarkovModels", "LinearAlgebra", "Pkg", "Random", "SparseArrays", "StableRNGs", "Statistics"] path = "../libs/HMMBenchmark" uuid = "557005d5-2e4a-43f9-8aa7-ba8df2d03179" version = "0.1.0" diff --git a/examples/basics.jl b/examples/basics.jl index 23734d14..8e0a390d 100644 --- a/examples/basics.jl +++ b/examples/basics.jl @@ -219,6 +219,6 @@ control_seqs = [fill(nothing, rand(rng, 100:200)) for k in 1:500]; #src control_seq = reduce(vcat, control_seqs); #src seq_ends = cumsum(length.(control_seqs)); #src -test_identical_hmmbase(rng, hmm, hmm_guess; T=100) #src -test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05) #src -test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) #src +test_identical_hmmbase(rng, hmm; hmm_guess, T=100) #src +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.05) #src +test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/examples/controlled.jl b/examples/controlled.jl index cdffafdd..dfb70660 100644 --- a/examples/controlled.jl +++ b/examples/controlled.jl @@ -148,5 +148,5 @@ hcat(hmm_est.dist_coeffs[2], hmm.dist_coeffs[2]) # ## Tests #src -test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.08, init=false) #src -test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) #src +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.08, init=false) #src +test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/examples/interfaces.jl b/examples/interfaces.jl index 3df60d58..f9d781dc 100644 --- a/examples/interfaces.jl +++ b/examples/interfaces.jl @@ -242,6 +242,6 @@ control_seqs = [fill(nothing, rand(rng, 100:200)) for k in 1:100]; #src control_seq = reduce(vcat, control_seqs); #src seq_ends = cumsum(length.(control_seqs)); #src -test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05, init=false) #src -test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) #src -test_allocations(rng, hmm, hmm_guess; control_seq, seq_ends) #src +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.05, init=false) #src +test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src +test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/examples/temporal.jl b/examples/temporal.jl index 29252730..132abc66 100644 --- a/examples/temporal.jl +++ b/examples/temporal.jl @@ -181,5 +181,5 @@ hcat(obs_distributions(hmm_est, 2), obs_distributions(hmm, 2)) # ## Tests #src -test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.1, init=false) #src -test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) #src +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.1, init=false) #src +test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/examples/types.jl b/examples/types.jl index 6634a150..42ebb4dc 100644 --- a/examples/types.jl +++ b/examples/types.jl @@ -94,8 +94,8 @@ control_seqs = [fill(nothing, rand(rng, 100:200)) for k in 1:100]; #src control_seq = reduce(vcat, control_seqs); #src seq_ends = cumsum(length.(control_seqs)); #src -test_identical_hmmbase(rng, hmm, hmm_guess; T=100) #src -test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05, init=false) #src -test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) #src +test_identical_hmmbase(rng, hmm; hmm_guess, T=100) #src +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.05, init=false) #src +test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src # https://github.com/JuliaSparse/SparseArrays.jl/issues/469 #src -@test_skip test_allocations(rng, hmm, hmm_guess; control_seq, seq_ends) #src +@test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/ext/HiddenMarkovModelsDistributionsExt.jl b/ext/HiddenMarkovModelsDistributionsExt.jl index d0ebe78e..2debfb15 100644 --- a/ext/HiddenMarkovModelsDistributionsExt.jl +++ b/ext/HiddenMarkovModelsDistributionsExt.jl @@ -1,6 +1,6 @@ module HiddenMarkovModelsDistributionsExt -using HiddenMarkovModels: HiddenMarkovModels +using HiddenMarkovModels: HiddenMarkovModels, dcat using Distributions: Distributions, Distribution, @@ -15,6 +15,12 @@ function HiddenMarkovModels.fit_in_sequence!( return dists[i] = fit(D, x_nums, w) end +function HiddenMarkovModels.fit_in_sequence!( + dists::AbstractVector{D}, i::Integer, x_mat::AbstractMatrix, w::AbstractVector +) where {D<:MultivariateDistribution} + return dists[i] = fit(D, x_mat, w) +end + function HiddenMarkovModels.fit_in_sequence!( dists::AbstractVector{D}, i::Integer, @@ -24,6 +30,12 @@ function HiddenMarkovModels.fit_in_sequence!( return dists[i] = fit(D, reduce(hcat, x_vecs), w) end +function HiddenMarkovModels.fit_in_sequence!( + dists::AbstractVector{D}, i::Integer, x_tens::AbstractArray{Any,3}, w::AbstractVector +) where {D<:MatrixDistribution} + return dists[i] = fit(D, x_tens, w) +end + function HiddenMarkovModels.fit_in_sequence!( dists::AbstractVector{D}, i::Integer, @@ -33,6 +45,4 @@ function HiddenMarkovModels.fit_in_sequence!( return dists[i] = fit(D, reduce(dcat, x_mats), w) end -dcat(M1, M2) = cat(M1, M2; dims=3) - end diff --git a/libs/HMMBenchmark/Project.toml b/libs/HMMBenchmark/Project.toml index c9a9a128..6d9e7b6e 100644 --- a/libs/HMMBenchmark/Project.toml +++ b/libs/HMMBenchmark/Project.toml @@ -8,6 +8,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/libs/HMMBenchmark/src/HMMBenchmark.jl b/libs/HMMBenchmark/src/HMMBenchmark.jl index b2b543fa..b1193aec 100644 --- a/libs/HMMBenchmark/src/HMMBenchmark.jl +++ b/libs/HMMBenchmark/src/HMMBenchmark.jl @@ -5,6 +5,7 @@ using BenchmarkTools: @benchmarkable, BenchmarkGroup using CSV: CSV using DataFrames: DataFrame using Distributions: Normal, MvNormal +using FillArrays: Fill using HiddenMarkovModels using HiddenMarkovModels: LightDiagNormal, @@ -16,7 +17,8 @@ using HiddenMarkovModels: forward!, initialize_forward_backward, forward_backward!, - baum_welch! + baum_welch!, + duration using LinearAlgebra: BLAS, Diagonal, SymTridiagonal using Pkg: Pkg using Random: AbstractRNG diff --git a/libs/HMMBenchmark/src/hiddenmarkovmodels.jl b/libs/HMMBenchmark/src/hiddenmarkovmodels.jl index 346d384d..5be91bb2 100644 --- a/libs/HMMBenchmark/src/hiddenmarkovmodels.jl +++ b/libs/HMMBenchmark/src/hiddenmarkovmodels.jl @@ -24,19 +24,27 @@ function build_benchmarkables( instance::Instance, algos::Vector{String}, ) - (; obs_dim, seq_length, nb_seqs, bw_iter) = instance + (; custom_dist, obs_dim, seq_length, nb_seqs, bw_iter) = instance hmm = build_model(rng, implem; instance) data = randn(rng, nb_seqs, seq_length, obs_dim) - if obs_dim == 1 - obs_seqs = [[data[k, t, 1] for t in 1:seq_length] for k in 1:nb_seqs] - else + if custom_dist obs_seqs = [[data[k, t, :] for t in 1:seq_length] for k in 1:nb_seqs] + else + if obs_dim == 1 + obs_seqs = [[data[k, t, 1] for t in 1:seq_length] for k in 1:nb_seqs] + else + obs_seqs = [collect(data[k, :, :]') for k in 1:nb_seqs] + end + end + if first(obs_seqs) isa AbstractVector + obs_seq = reduce(vcat, obs_seqs) + else + obs_seq = reduce(hcat, obs_seqs) end - obs_seq = reduce(vcat, obs_seqs) - control_seq = fill(nothing, length(obs_seq)) - seq_ends = cumsum(length.(obs_seqs)) + control_seq = Fill(nothing, duration(obs_seq)) + seq_ends = cumsum(duration.(obs_seqs)) benchs = Dict() diff --git a/libs/HMMTest/src/allocations.jl b/libs/HMMTest/src/allocations.jl index 0dac7ce7..f427526c 100644 --- a/libs/HMMTest/src/allocations.jl +++ b/libs/HMMTest/src/allocations.jl @@ -1,10 +1,51 @@ +function test_allocations_aux( + hmm::AbstractHMM, + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat; + seq_ends::AbstractVector{Int}, + hmm_guess::Union{Nothing,AbstractHMM}=nothing, +) + t1, t2 = 1, seq_ends[1] + + ## Forward + f_storage = HMMs.initialize_forward(hmm, obs_seq, control_seq; seq_ends) + allocs_f = @ballocated HMMs.forward!($f_storage, $hmm, $obs_seq, $control_seq, $t1, $t2) evals = + 1 samples = 1 + @test allocs_f == 0 + + ## Viterbi + v_storage = HMMs.initialize_viterbi(hmm, obs_seq, control_seq; seq_ends) + allocs_v = @ballocated HMMs.viterbi!($v_storage, $hmm, $obs_seq, $control_seq, $t1, $t2) evals = + 1 samples = 1 + @test allocs_v == 0 + + ## Forward-backward + fb_storage = HMMs.initialize_forward_backward(hmm, obs_seq, control_seq; seq_ends) + allocs_fb = @ballocated HMMs.forward_backward!( + $fb_storage, $hmm, $obs_seq, $control_seq, $t1, $t2 + ) evals = 1 samples = 1 + @test allocs_fb == 0 + + ## Baum-Welch + if !isnothing(hmm_guess) + fb_storage = HMMs.initialize_forward_backward( + hmm_guess, obs_seq, control_seq; seq_ends + ) + HMMs.forward_backward!(fb_storage, hmm, obs_seq, control_seq; seq_ends) + allocs_bw = @ballocated fit!( + $hmm_guess, $fb_storage, $obs_seq, $control_seq; seq_ends=$seq_ends + ) evals = 1 samples = 1 setup = (hmm_guess = deepcopy($hmm)) + @test_broken allocs_bw == 0 + end +end + function test_allocations( rng::AbstractRNG, hmm::AbstractHMM, - hmm_guess::Union{Nothing,AbstractHMM}=nothing; - control_seq::AbstractVector, + control_seq::AbstractVecOrMat; seq_ends::AbstractVector{Int}, + hmm_guess::Union{Nothing,AbstractHMM}=nothing, ) @testset "Allocations" begin obs_seq = mapreduce(vcat, eachindex(seq_ends)) do k @@ -12,39 +53,14 @@ function test_allocations( rand(rng, hmm, control_seq[t1:t2]).obs_seq end - t1, t2 = 1, seq_ends[1] - - ## Forward - f_storage = HMMs.initialize_forward(hmm, obs_seq, control_seq; seq_ends) - allocs_f = @ballocated HMMs.forward!( - $f_storage, $hmm, $obs_seq, $control_seq, $t1, $t2 - ) evals = 1 samples = 1 - @test allocs_f == 0 - - ## Viterbi - v_storage = HMMs.initialize_viterbi(hmm, obs_seq, control_seq; seq_ends) - allocs_v = @ballocated HMMs.viterbi!( - $v_storage, $hmm, $obs_seq, $control_seq, $t1, $t2 - ) evals = 1 samples = 1 - @test allocs_v == 0 - - ## Forward-backward - fb_storage = HMMs.initialize_forward_backward(hmm, obs_seq, control_seq; seq_ends) - allocs_fb = @ballocated HMMs.forward_backward!( - $fb_storage, $hmm, $obs_seq, $control_seq, $t1, $t2 - ) evals = 1 samples = 1 - @test allocs_fb == 0 - - ## Baum-Welch - if !isnothing(hmm_guess) - fb_storage = HMMs.initialize_forward_backward( - hmm_guess, obs_seq, control_seq; seq_ends - ) - HMMs.forward_backward!(fb_storage, hmm, obs_seq, control_seq; seq_ends) - allocs_bw = @ballocated fit!( - $hmm_guess, $fb_storage, $obs_seq, $control_seq; seq_ends=$seq_ends - ) evals = 1 samples = 1 setup = (hmm_guess = deepcopy($hmm)) - @test_broken allocs_bw == 0 + @testset "Sequence" begin + test_allocations_aux(hmm, obs_seq, control_seq; seq_ends, hmm_guess) + end + if first(obs_seq) isa AbstractVector + obs_mat = reduce(hcat, obs_seq) + @testset "Matrix" begin + test_allocations_aux(hmm, obs_mat, control_seq; seq_ends, hmm_guess) + end end end end diff --git a/libs/HMMTest/src/coherence.jl b/libs/HMMTest/src/coherence.jl index f0646e24..1b614ed7 100644 --- a/libs/HMMTest/src/coherence.jl +++ b/libs/HMMTest/src/coherence.jl @@ -2,8 +2,8 @@ infnorm(x) = maximum(abs, x) function check_equal_hmms( hmm1::AbstractHMM, - hmm2::AbstractHMM; - control_seq=[nothing], + hmm2::AbstractHMM, + control_seq=[nothing]; atol::Real=0.1, init::Bool=true, test::Bool=true, @@ -45,21 +45,53 @@ end function test_equal_hmms( hmm1::AbstractHMM, - hmm2::AbstractHMM; - control_seq=[nothing], + hmm2::AbstractHMM, + control_seq=[nothing]; atol::Real=0.1, init::Bool=true, ) - check_equal_hmms(hmm1, hmm2; control_seq, atol, init, test=true) + check_equal_hmms(hmm1, hmm2, control_seq; atol, init, test=true) return nothing end +function test_coherent_algorithms_aux( + hmm::AbstractHMM, + obs_seq::AbstractVecOrMat, + state_seq::AbstractVector{<:Integer}, + control_seq::AbstractVecOrMat; + seq_ends::AbstractVector{Int}, + hmm_guess::Union{Nothing,AbstractHMM}, + atol::Real, + init::Bool, +) + logL = logdensityof(hmm, obs_seq, control_seq; seq_ends) + logL_joint = joint_logdensityof(hmm, obs_seq, state_seq, control_seq; seq_ends) + + q, logL_viterbi = viterbi(hmm, obs_seq, control_seq; seq_ends) + @test logL_viterbi > logL_joint + @test logL_viterbi ≈ joint_logdensityof(hmm, obs_seq, q, control_seq; seq_ends) + + α, logL_forward = forward(hmm, obs_seq, control_seq; seq_ends) + @test logL_forward ≈ logL + + γ, logL_forward_backward = forward_backward(hmm, obs_seq, control_seq; seq_ends) + @test logL_forward_backward ≈ logL + @test all(α[:, seq_ends[k]] ≈ γ[:, seq_ends[k]] for k in eachindex(seq_ends)) + + if !isnothing(hmm_guess) + hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends) + @test all(>=(0), diff(logL_evolution)) + @test !check_equal_hmms(hmm, hmm_guess, control_seq[1:2]; atol, test=false) + test_equal_hmms(hmm, hmm_est, control_seq[1:2]; atol, init) + end +end + function test_coherent_algorithms( rng::AbstractRNG, hmm::AbstractHMM, - hmm_guess::Union{Nothing,AbstractHMM}=nothing; - control_seq::AbstractVector, + control_seq::AbstractVecOrMat; seq_ends::AbstractVector{Int}, + hmm_guess::Union{Nothing,AbstractHMM}=nothing, atol::Real=0.1, init::Bool=true, ) @@ -75,27 +107,18 @@ function test_coherent_algorithms( state_seq = reduce(vcat, state_seqs) obs_seq = reduce(vcat, obs_seqs) - logL = logdensityof(hmm, obs_seq, control_seq; seq_ends) - logL_joint = joint_logdensityof(hmm, obs_seq, state_seq, control_seq; seq_ends) - - q, logL_viterbi = viterbi(hmm, obs_seq, control_seq; seq_ends) - @test logL_viterbi > logL_joint - @test logL_viterbi ≈ joint_logdensityof(hmm, obs_seq, q, control_seq; seq_ends) - - α, logL_forward = forward(hmm, obs_seq, control_seq; seq_ends) - @test logL_forward ≈ logL - - γ, logL_forward_backward = forward_backward(hmm, obs_seq, control_seq; seq_ends) - @test logL_forward_backward ≈ logL - @test all(α[:, seq_ends[k]] ≈ γ[:, seq_ends[k]] for k in eachindex(seq_ends)) - - if !isnothing(hmm_guess) - hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends) - @test all(>=(0), diff(logL_evolution)) - @test !check_equal_hmms( - hmm, hmm_guess; control_seq=control_seq[1:2], atol, test=false + @testset "Sequence" begin + test_coherent_algorithms_aux( + hmm, obs_seq, state_seq, control_seq; seq_ends, hmm_guess, atol, init ) - test_equal_hmms(hmm, hmm_est; control_seq=control_seq[1:2], atol, init) + end + if first(obs_seq) isa AbstractVector + obs_mat = reduce(hcat, obs_seq) + @testset "Matrix" begin + test_coherent_algorithms_aux( + hmm, obs_mat, state_seq, control_seq; seq_ends, hmm_guess, atol, init + ) + end end end end diff --git a/libs/HMMTest/src/hmmbase.jl b/libs/HMMTest/src/hmmbase.jl index 76ba0731..3014ec4d 100644 --- a/libs/HMMTest/src/hmmbase.jl +++ b/libs/HMMTest/src/hmmbase.jl @@ -1,8 +1,8 @@ function test_identical_hmmbase( rng::AbstractRNG, - hmm::AbstractHMM, - hmm_guess::Union{Nothing,AbstractHMM}=nothing; + hmm::AbstractHMM; + hmm_guess::Union{Nothing,AbstractHMM}=nothing, T::Integer, atol::Real=1e-5, ) diff --git a/libs/HMMTest/src/jet.jl b/libs/HMMTest/src/jet.jl index 95d82789..97e70da8 100644 --- a/libs/HMMTest/src/jet.jl +++ b/libs/HMMTest/src/jet.jl @@ -2,9 +2,9 @@ function test_type_stability( rng::AbstractRNG, hmm::AbstractHMM, - hmm_guess::Union{Nothing,AbstractHMM}=nothing; - control_seq::AbstractVector, + control_seq::AbstractVecOrMat; seq_ends::AbstractVector{Int}, + hmm_guess::Union{Nothing,AbstractHMM}=nothing, ) @testset "Type stability" begin state_seq, obs_seq = rand(rng, hmm, control_seq) diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index e12bcd6e..55e34e76 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -29,6 +29,7 @@ export seq_limits include("types/abstract_hmm.jl") +include("utils/defaults.jl") include("utils/linalg.jl") include("utils/check.jl") include("utils/probvec_transmat.jl") diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 924e2bf1..35b5b97b 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -20,14 +20,14 @@ function baum_welch!( fb_storage::ForwardBackwardStorage, logL_evolution::Vector, hmm::AbstractHMM, - obs_seq::AbstractVector, - control_seq::AbstractVector; + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat; seq_ends::AbstractVector{Int}, atol::Real, max_iterations::Integer, loglikelihood_increasing::Bool, ) - for iteration in 1:max_iterations + for _ in 1:max_iterations forward_backward!(fb_storage, hmm, obs_seq, control_seq; seq_ends) push!(logL_evolution, logdensityof(hmm) + sum(fb_storage.logL)) fit!(hmm, fb_storage, obs_seq, control_seq; seq_ends) @@ -53,9 +53,9 @@ Return a tuple `(hmm_est, loglikelihood_evolution)` where `hmm_est` is the estim """ function baum_welch( hmm_guess::AbstractHMM, - obs_seq::AbstractVector, - control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat=Fill(nothing, duration(obs_seq)); + seq_ends::AbstractVector{Int}=Fill(duration(obs_seq), 1), atol=1e-5, max_iterations=100, loglikelihood_increasing=true, diff --git a/src/inference/chainrules.jl b/src/inference/chainrules.jl index 424236e1..9bdd30fe 100644 --- a/src/inference/chainrules.jl +++ b/src/inference/chainrules.jl @@ -1,17 +1,15 @@ -_dcat(M1, M2) = cat(M1, M2; dims=3) - function _params_and_loglikelihoods( hmm::AbstractHMM, obs_seq::Vector, - control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + control_seq::AbstractVecOrMat=Fill(nothing, duration(obs_seq)); + seq_ends::AbstractVector{Int}=Fill(duration(obs_seq), 1), ) init = initialization(hmm) - trans_by_time = mapreduce(_dcat, eachindex(control_seq)) do t - transition_matrix(hmm, control_seq[t]) + trans_by_time = mapreduce(dcat, 1:duration(obs_seq)) do t + transition_matrix(hmm, control_seq, t) end - logB = mapreduce(hcat, eachindex(obs_seq, control_seq)) do t - logdensityof.(obs_distributions(hmm, control_seq[t]), (obs_seq[t],)) + logB = mapreduce(hcat, 1:duration(obs_seq)) do t + logdensityof.(obs_distributions(hmm, control_seq, t), (at_time(obs_seq, t),)) end return init, trans_by_time, logB end @@ -20,9 +18,9 @@ function ChainRulesCore.rrule( rc::RuleConfig, ::typeof(logdensityof), hmm::AbstractHMM, - obs_seq::AbstractVector, - control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat=Fill(nothing, duration(obs_seq)); + seq_ends::AbstractVector{Int}=Fill(duration(obs_seq), 1), ) _, pullback = rrule_via_ad( rc, _params_and_loglikelihoods, hmm, obs_seq, control_seq; seq_ends @@ -30,7 +28,7 @@ function ChainRulesCore.rrule( fb_storage = initialize_forward_backward(hmm, obs_seq, control_seq; seq_ends) forward_backward!(fb_storage, hmm, obs_seq, control_seq; seq_ends) (; logL, α, γ, Bβ) = fb_storage - N, T = length(hmm), length(obs_seq) + N, T = length(hmm), duration(obs_seq) R = eltype(α) Δinit = zeros(R, N) diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 82522bfb..93ab882c 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -23,12 +23,12 @@ $(SIGNATURES) """ function initialize_forward( hmm::AbstractHMM, - obs_seq::AbstractVector, - control_seq::AbstractVector; + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat; seq_ends::AbstractVector{Int}, ) - N, T, K = length(hmm), length(obs_seq), length(seq_ends) - R = eltype(hmm, obs_seq[1], control_seq[1]) + N, T, K = length(hmm), duration(obs_seq), length(seq_ends) + R = eltype(hmm, obs_seq, control_seq, 1) α = Matrix{R}(undef, N, T) logL = Vector{R}(undef, K) B = Matrix{R}(undef, N, T) @@ -42,8 +42,8 @@ $(SIGNATURES) function forward!( storage, hmm::AbstractHMM, - obs_seq::AbstractVector, - control_seq::AbstractVector, + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat, t1::Integer, t2::Integer; ) @@ -51,7 +51,7 @@ function forward!( # Initialization Bₜ₁ = view(B, :, t1) - obs_logdensities!(Bₜ₁, hmm, obs_seq[t1], control_seq[t1]) + obs_logdensities!(Bₜ₁, hmm, obs_seq, control_seq, t1) logm = maximum(Bₜ₁) Bₜ₁ .= exp.(Bₜ₁ .- logm) @@ -66,11 +66,11 @@ function forward!( # Loop for t in t1:(t2 - 1) Bₜ₊₁ = view(B, :, t + 1) - obs_logdensities!(Bₜ₊₁, hmm, obs_seq[t + 1], control_seq[t + 1]) + obs_logdensities!(Bₜ₊₁, hmm, obs_seq, control_seq, t + 1) logm = maximum(Bₜ₊₁) Bₜ₊₁ .= exp.(Bₜ₊₁ .- logm) - trans = transition_matrix(hmm, control_seq[t]) + trans = transition_matrix(hmm, control_seq, t) αₜ₊₁ = view(α, :, t + 1) mul!(αₜ₊₁, trans', view(α, :, t)) αₜ₊₁ .*= Bₜ₊₁ @@ -89,8 +89,8 @@ $(SIGNATURES) function forward!( storage, hmm::AbstractHMM, - obs_seq::AbstractVector, - control_seq::AbstractVector; + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat; seq_ends::AbstractVector{Int}, ) (; α, logL) = storage @@ -111,9 +111,9 @@ Return a tuple `(storage.α, sum(storage.logL))` where `storage` is of type [`Fo """ function forward( hmm::AbstractHMM, - obs_seq::AbstractVector, - control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat=Fill(nothing, duration(obs_seq)); + seq_ends::AbstractVector{Int}=Fill(duration(obs_seq), 1), ) storage = initialize_forward(hmm, obs_seq, control_seq; seq_ends) forward!(storage, hmm, obs_seq, control_seq; seq_ends) diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index a01ebc82..a919f9ed 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -28,21 +28,21 @@ $(SIGNATURES) """ function initialize_forward_backward( hmm::AbstractHMM, - obs_seq::AbstractVector, - control_seq::AbstractVector; + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat; seq_ends::AbstractVector{Int}, transition_marginals=true, ) - N, T, K = length(hmm), length(obs_seq), length(seq_ends) - R = eltype(hmm, obs_seq[1], control_seq[1]) - trans = transition_matrix(hmm, control_seq[1]) + N, T, K = length(hmm), duration(obs_seq), length(seq_ends) + R = eltype(hmm, obs_seq, control_seq, 1) + trans = transition_matrix(hmm, control_seq, 1) M = typeof(mysimilar_mutable(trans, R)) γ = Matrix{R}(undef, N, T) ξ = Vector{M}(undef, T) if transition_marginals for t in 1:T - ξ[t] = mysimilar_mutable(transition_matrix(hmm, control_seq[t]), R) + ξ[t] = mysimilar_mutable(transition_matrix(hmm, control_seq, t), R) end end logL = Vector{R}(undef, K) @@ -60,8 +60,8 @@ $(SIGNATURES) function forward_backward!( storage::ForwardBackwardStorage{R}, hmm::AbstractHMM, - obs_seq::AbstractVector, - control_seq::AbstractVector, + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat, t1::Integer, t2::Integer; transition_marginals::Bool=true, @@ -74,7 +74,7 @@ function forward_backward!( # Backward β[:, t2] .= c[t2] for t in (t2 - 1):-1:t1 - trans = transition_matrix(hmm, control_seq[t]) + trans = transition_matrix(hmm, control_seq, t) Bβ[:, t + 1] .= view(B, :, t + 1) .* view(β, :, t + 1) mul!(view(β, :, t), trans, view(Bβ, :, t + 1)) lmul!(c[t], view(β, :, t)) @@ -87,7 +87,7 @@ function forward_backward!( # Transition marginals if transition_marginals for t in t1:(t2 - 1) - trans = transition_matrix(hmm, control_seq[t]) + trans = transition_matrix(hmm, control_seq, t) mul_rows_cols!(ξ[t], view(α, :, t), trans, view(Bβ, :, t + 1)) end ξ[t2] .= zero(R) @@ -102,8 +102,8 @@ $(SIGNATURES) function forward_backward!( storage::ForwardBackwardStorage{R}, hmm::AbstractHMM, - obs_seq::AbstractVector, - control_seq::AbstractVector; + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat; seq_ends::AbstractVector{Int}, transition_marginals::Bool=true, ) where {R} @@ -127,9 +127,9 @@ Return a tuple `(storage.γ, sum(storage.logL))` where `storage` is of type [`Fo """ function forward_backward( hmm::AbstractHMM, - obs_seq::AbstractVector, - control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat=Fill(nothing, duration(obs_seq)); + seq_ends::AbstractVector{Int}=Fill(duration(obs_seq), 1), ) transition_marginals = false storage = initialize_forward_backward( diff --git a/src/inference/logdensity.jl b/src/inference/logdensity.jl index 4d82e152..f58ab421 100644 --- a/src/inference/logdensity.jl +++ b/src/inference/logdensity.jl @@ -5,9 +5,9 @@ Run the forward algorithm to compute the loglikelihood of `obs_seq` for `hmm`, i """ function DensityInterface.logdensityof( hmm::AbstractHMM, - obs_seq::AbstractVector, - control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat=Fill(nothing, duration(obs_seq)); + seq_ends::AbstractVector{Int}=Fill(duration(obs_seq), 1), ) _, logL = forward(hmm, obs_seq, control_seq; seq_ends) return logL @@ -20,12 +20,12 @@ Run the forward algorithm to compute the the joint loglikelihood of `obs_seq` an """ function joint_logdensityof( hmm::AbstractHMM, - obs_seq::AbstractVector, + obs_seq::AbstractVecOrMat, state_seq::AbstractVector, - control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + control_seq::AbstractVecOrMat=Fill(nothing, duration(obs_seq)); + seq_ends::AbstractVector{Int}=Fill(duration(obs_seq), 1), ) - R = eltype(hmm, obs_seq[1], control_seq[1]) + R = eltype(hmm, obs_seq, control_seq, 1) logL = zero(R) for k in eachindex(seq_ends) t1, t2 = seq_limits(seq_ends, k) @@ -34,13 +34,13 @@ function joint_logdensityof( logL += log(init[state_seq[t1]]) # Transitions for t in t1:(t2 - 1) - trans = transition_matrix(hmm, control_seq[t]) + trans = transition_matrix(hmm, control_seq, t) logL += log(trans[state_seq[t], state_seq[t + 1]]) end # Observations for t in t1:t2 - dists = obs_distributions(hmm, control_seq[t]) - logL += logdensityof(dists[state_seq[t]], obs_seq[t]) + dists = obs_distributions(hmm, control_seq, t) + logL += logdensityof(dists[state_seq[t]], at_time(obs_seq, t)) end end return logL diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index 2afc983f..749ceaa2 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -24,12 +24,12 @@ $(SIGNATURES) """ function initialize_viterbi( hmm::AbstractHMM, - obs_seq::AbstractVector, - control_seq::AbstractVector; + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat; seq_ends::AbstractVector{Int}, ) - N, T, K = length(hmm), length(obs_seq), length(seq_ends) - R = eltype(hmm, obs_seq[1], control_seq[1]) + N, T, K = length(hmm), duration(obs_seq), length(seq_ends) + R = eltype(hmm, obs_seq, control_seq, 1) q = Vector{Int}(undef, T) logL = Vector{R}(undef, K) logB = Matrix{R}(undef, N, T) @@ -44,20 +44,20 @@ $(SIGNATURES) function viterbi!( storage::ViterbiStorage{R}, hmm::AbstractHMM, - obs_seq::AbstractVector, - control_seq::AbstractVector, + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat, t1::Integer, t2::Integer; ) where {R} (; q, logB, ϕ, ψ) = storage - obs_logdensities!(view(logB, :, t1), hmm, obs_seq[t1], control_seq[t1]) + obs_logdensities!(view(logB, :, t1), hmm, obs_seq, control_seq, t1) init = initialization(hmm) ϕ[:, t1] .= log.(init) .+ view(logB, :, t1) for t in (t1 + 1):t2 - obs_logdensities!(view(logB, :, t), hmm, obs_seq[t], control_seq[t]) - trans = transition_matrix(hmm, control_seq[t - 1]) + obs_logdensities!(view(logB, :, t), hmm, obs_seq, control_seq, t) + trans = transition_matrix(hmm, control_seq, t - 1) for j in 1:length(hmm) i_max = 1 score_max = ϕ[i_max, t - 1] + log(trans[i_max, j]) @@ -86,8 +86,8 @@ $(SIGNATURES) function viterbi!( storage::ViterbiStorage{R}, hmm::AbstractHMM, - obs_seq::AbstractVector, - control_seq::AbstractVector; + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat; seq_ends::AbstractVector{Int}, ) where {R} (; logL, ϕ) = storage @@ -108,9 +108,9 @@ Return a tuple `(storage.q, sum(storage.logL))` where `storage` is of type [`Vit """ function viterbi( hmm::AbstractHMM, - obs_seq::AbstractVector, - control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat=Fill(nothing, duration(obs_seq)); + seq_ends::AbstractVector{Int}=Fill(duration(obs_seq), 1), ) storage = initialize_viterbi(hmm, obs_seq, control_seq; seq_ends) viterbi!(storage, hmm, obs_seq, control_seq; seq_ends) diff --git a/src/precompile.jl b/src/precompile.jl index 535deb05..0bd28387 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -5,11 +5,14 @@ dists = [LightDiagNormal(randn(D), ones(D)) for i in 1:N] hmm = HMM(init, trans, dists) state_seq, obs_seq = rand(hmm, T) + obs_mat = reduce(hcat, obs_seq) - logdensityof(hmm, obs_seq, state_seq) - logdensityof(hmm, obs_seq) - forward(hmm, obs_seq) - viterbi(hmm, obs_seq) - forward_backward(hmm, obs_seq) - baum_welch(hmm, obs_seq; max_iterations=1) + for obs_seq_or_mat in (obs_seq, obs_mat) + logdensityof(hmm, obs_seq_or_mat, state_seq) + logdensityof(hmm, obs_seq_or_mat) + forward(hmm, obs_seq_or_mat) + viterbi(hmm, obs_seq_or_mat) + forward_backward(hmm, obs_seq_or_mat) + baum_welch(hmm, obs_seq_or_mat; max_iterations=1) + end end diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index f02d4ca5..5fad47d3 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -51,6 +51,12 @@ function Base.eltype(hmm::AbstractHMM, obs, control) return promote_type(init_type, trans_type, logdensity_type) end +function Base.eltype( + hmm::AbstractHMM, obs_seq::AbstractVecOrMat, control_seq::AbstractVecOrMat, t::Integer +) + return eltype(hmm, at_time(obs_seq, t), at_time(control_seq, t)) +end + """ initialization(hmm) @@ -66,6 +72,10 @@ Return the matrix of state transition probabilities for `hmm` (possibly when `co """ transition_matrix(hmm::AbstractHMM, control) = transition_matrix(hmm) +function transition_matrix(hmm::AbstractHMM, control_seq::AbstractVecOrMat, t::Integer) + return transition_matrix(hmm, at_time(control_seq, t)) +end + """ obs_distributions(hmm) obs_distributions(hmm, control) @@ -80,7 +90,18 @@ These distribution objects should implement """ obs_distributions(hmm::AbstractHMM, control) = obs_distributions(hmm) -function obs_logdensities!(logb::AbstractVector, hmm::AbstractHMM, obs, control) +function obs_distributions(hmm::AbstractHMM, control_seq::AbstractVecOrMat, t::Integer) + return obs_distributions(hmm, at_time(control_seq, t)) +end + +function obs_logdensities!( + logb::AbstractVector, + hmm::AbstractHMM, + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat, + t::Integer, +) + obs, control = at_time(obs_seq, t), at_time(control_seq, t) dists = obs_distributions(hmm, control) @inbounds for i in eachindex(logb, dists) logb[i] = logdensityof(dists[i], obs) @@ -93,8 +114,8 @@ end fit!( hmm::AbstractHMM, fb_storage::ForwardBackwardStorage, - obs_seq::AbstractVector; - control_seq::AbstractVector, + obs_seq::AbstractVecOrMat; + control_seq::AbstractVecOrMat, seq_ends::AbstractVector{Int}, ) @@ -114,7 +135,7 @@ Simulate `hmm` for `T` time steps, or when the sequence `control_seq` is applied Return a named tuple `(; state_seq, obs_seq)`. """ -function Random.rand(rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVector) +function Random.rand(rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVecOrMat) T = length(control_seq) dummy_log_probas = fill(-Inf, length(hmm)) @@ -142,7 +163,7 @@ function Random.rand(rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVe return (; state_seq=state_seq, obs_seq=obs_seq) end -function Random.rand(hmm::AbstractHMM, control_seq::AbstractVector) +function Random.rand(hmm::AbstractHMM, control_seq::AbstractVecOrMat) return rand(default_rng(), hmm, control_seq) end @@ -161,4 +182,4 @@ end Return the prior loglikelihood associated with the parameters of `hmm`. """ -DensityInterface.logdensityof(hmm::AbstractHMM) = 0 +DensityInterface.logdensityof(hmm::AbstractHMM) = false diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 5dcd6ec4..dc08f479 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -35,8 +35,8 @@ obs_distributions(hmm::HMM) = hmm.dists function StatsAPI.fit!( hmm::HMM, fb_storage::ForwardBackwardStorage, - obs_seq::AbstractVector, - control_seq::AbstractVector; + obs_seq::AbstractVecOrMat, + control_seq::AbstractVecOrMat; seq_ends::AbstractVector{Int}, ) (; γ, ξ) = fb_storage diff --git a/src/utils/defaults.jl b/src/utils/defaults.jl new file mode 100644 index 00000000..0270ca7f --- /dev/null +++ b/src/utils/defaults.jl @@ -0,0 +1,5 @@ +duration(seq::AbstractVector) = length(seq) +duration(seq::AbstractMatrix) = size(seq, 2) + +at_time(seq::AbstractVector, t::Integer) = seq[t] +at_time(seq::AbstractMatrix, t::Integer) = view(seq, :, t) diff --git a/src/utils/fit.jl b/src/utils/fit.jl index 0e0d7711..eab219a2 100644 --- a/src/utils/fit.jl +++ b/src/utils/fit.jl @@ -11,7 +11,9 @@ Override for Distributions.jl (in the package extension) dists[i] = fit(eltype(dists), x, w) """ -function fit_in_sequence!(dists::AbstractVector, i::Integer, x, w) +function fit_in_sequence!( + dists::AbstractVector, i::Integer, x::AbstractVecOrMat, w::AbstractVector +) fit!(dists[i], x, w) return nothing end diff --git a/src/utils/linalg.jl b/src/utils/linalg.jl index 72a7abeb..2af2a37d 100644 --- a/src/utils/linalg.jl +++ b/src/utils/linalg.jl @@ -27,3 +27,5 @@ function mul_rows_cols!( end return nothing end + +dcat(M1::AbstractArray, M2::AbstractArray) = cat(M1, M2; dims=3) diff --git a/test/correctness.jl b/test/correctness.jl index c5f72d5a..e57fe3e8 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -42,12 +42,12 @@ seq_ends = cumsum(length.(control_seqs)); hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_identical_hmmbase(rng, hmm, hmm_guess; T) + test_identical_hmmbase(rng, hmm; hmm_guess, T) test_coherent_algorithms( - rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05, init=false + rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.05, init=false ) - test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) - test_allocations(rng, hmm, hmm_guess; control_seq, seq_ends) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) end @testset "DiagNormal" begin @@ -59,11 +59,11 @@ end hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_identical_hmmbase(rng, hmm, hmm_guess; T) + test_identical_hmmbase(rng, hmm; hmm_guess, T) test_coherent_algorithms( - rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05, init=false + rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.05, init=false ) - test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) end @testset "LightCategorical" begin @@ -74,10 +74,10 @@ end hmm_guess = HMM(init_guess, trans_guess, dists_guess) test_coherent_algorithms( - rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05, init=false + rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.05, init=false ) - test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) - test_allocations(rng, hmm, hmm_guess; control_seq, seq_ends) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) end @test_skip @testset "LightDiagNormal" begin @@ -88,10 +88,10 @@ end hmm_guess = HMM(init_guess, trans_guess, dists_guess) test_coherent_algorithms( - rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05, init=false + rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.05, init=false ) - test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) - test_allocations(rng, hmm, hmm_guess; control_seq, seq_ends) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) end # Controlled @@ -123,6 +123,6 @@ end control_seq = reduce(vcat, control_seqs) seq_ends = cumsum(length.(control_seqs)) - test_coherent_algorithms(rng, hmm; control_seq, seq_ends, atol=0.05, init=false) - test_type_stability(rng, hmm; control_seq, seq_ends) + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, atol=0.05, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends) end