Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with matrix observations #81

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"

[[deps.HMMBenchmark]]
deps = ["BenchmarkTools", "CSV", "DataFrames", "Distributions", "HiddenMarkovModels", "LinearAlgebra", "Pkg", "Random", "SparseArrays", "StableRNGs", "Statistics"]
deps = ["BenchmarkTools", "CSV", "DataFrames", "Distributions", "FillArrays", "HiddenMarkovModels", "LinearAlgebra", "Pkg", "Random", "SparseArrays", "StableRNGs", "Statistics"]
path = "../libs/HMMBenchmark"
uuid = "557005d5-2e4a-43f9-8aa7-ba8df2d03179"
version = "0.1.0"
Expand Down
6 changes: 3 additions & 3 deletions examples/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,6 @@ control_seqs = [fill(nothing, rand(rng, 100:200)) for k in 1:500]; #src
control_seq = reduce(vcat, control_seqs); #src
seq_ends = cumsum(length.(control_seqs)); #src

test_identical_hmmbase(rng, hmm, hmm_guess; T=100) #src
test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05) #src
test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) #src
test_identical_hmmbase(rng, hmm; hmm_guess, T=100) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.05) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
4 changes: 2 additions & 2 deletions examples/controlled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,5 @@ hcat(hmm_est.dist_coeffs[2], hmm.dist_coeffs[2])

# ## Tests #src

test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.08, init=false) #src
test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.08, init=false) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
6 changes: 3 additions & 3 deletions examples/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,6 @@ control_seqs = [fill(nothing, rand(rng, 100:200)) for k in 1:100]; #src
control_seq = reduce(vcat, control_seqs); #src
seq_ends = cumsum(length.(control_seqs)); #src

test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05, init=false) #src
test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) #src
test_allocations(rng, hmm, hmm_guess; control_seq, seq_ends) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.05, init=false) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) #src
4 changes: 2 additions & 2 deletions examples/temporal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,5 +181,5 @@ hcat(obs_distributions(hmm_est, 2), obs_distributions(hmm, 2))

# ## Tests #src

test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.1, init=false) #src
test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.1, init=false) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
8 changes: 4 additions & 4 deletions examples/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ control_seqs = [fill(nothing, rand(rng, 100:200)) for k in 1:100]; #src
control_seq = reduce(vcat, control_seqs); #src
seq_ends = cumsum(length.(control_seqs)); #src

test_identical_hmmbase(rng, hmm, hmm_guess; T=100) #src
test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05, init=false) #src
test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) #src
test_identical_hmmbase(rng, hmm; hmm_guess, T=100) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.05, init=false) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
# https://github.com/JuliaSparse/SparseArrays.jl/issues/469 #src
@test_skip test_allocations(rng, hmm, hmm_guess; control_seq, seq_ends) #src
@test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) #src
16 changes: 13 additions & 3 deletions ext/HiddenMarkovModelsDistributionsExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module HiddenMarkovModelsDistributionsExt

using HiddenMarkovModels: HiddenMarkovModels
using HiddenMarkovModels: HiddenMarkovModels, dcat
using Distributions:
Distributions,
Distribution,
Expand All @@ -15,6 +15,12 @@
return dists[i] = fit(D, x_nums, w)
end

function HiddenMarkovModels.fit_in_sequence!(
dists::AbstractVector{D}, i::Integer, x_mat::AbstractMatrix, w::AbstractVector
) where {D<:MultivariateDistribution}
return dists[i] = fit(D, x_mat, w)
end

function HiddenMarkovModels.fit_in_sequence!(
dists::AbstractVector{D},
i::Integer,
Expand All @@ -24,6 +30,12 @@
return dists[i] = fit(D, reduce(hcat, x_vecs), w)
end

function HiddenMarkovModels.fit_in_sequence!(
dists::AbstractVector{D}, i::Integer, x_tens::AbstractArray{Any,3}, w::AbstractVector
) where {D<:MatrixDistribution}
return dists[i] = fit(D, x_tens, w)
end

function HiddenMarkovModels.fit_in_sequence!(
dists::AbstractVector{D},
i::Integer,
Expand All @@ -33,6 +45,4 @@
return dists[i] = fit(D, reduce(dcat, x_mats), w)
end

dcat(M1, M2) = cat(M1, M2; dims=3)

end
1 change: 1 addition & 0 deletions libs/HMMBenchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down
4 changes: 3 additions & 1 deletion libs/HMMBenchmark/src/HMMBenchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using BenchmarkTools: @benchmarkable, BenchmarkGroup
using CSV: CSV
using DataFrames: DataFrame
using Distributions: Normal, MvNormal
using FillArrays: Fill
using HiddenMarkovModels
using HiddenMarkovModels:
LightDiagNormal,
Expand All @@ -16,7 +17,8 @@ using HiddenMarkovModels:
forward!,
initialize_forward_backward,
forward_backward!,
baum_welch!
baum_welch!,
duration
using LinearAlgebra: BLAS, Diagonal, SymTridiagonal
using Pkg: Pkg
using Random: AbstractRNG
Expand Down
22 changes: 15 additions & 7 deletions libs/HMMBenchmark/src/hiddenmarkovmodels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,27 @@ function build_benchmarkables(
instance::Instance,
algos::Vector{String},
)
(; obs_dim, seq_length, nb_seqs, bw_iter) = instance
(; custom_dist, obs_dim, seq_length, nb_seqs, bw_iter) = instance

hmm = build_model(rng, implem; instance)
data = randn(rng, nb_seqs, seq_length, obs_dim)

if obs_dim == 1
obs_seqs = [[data[k, t, 1] for t in 1:seq_length] for k in 1:nb_seqs]
else
if custom_dist
obs_seqs = [[data[k, t, :] for t in 1:seq_length] for k in 1:nb_seqs]
else
if obs_dim == 1
obs_seqs = [[data[k, t, 1] for t in 1:seq_length] for k in 1:nb_seqs]
else
obs_seqs = [collect(data[k, :, :]') for k in 1:nb_seqs]
end
end
if first(obs_seqs) isa AbstractVector
obs_seq = reduce(vcat, obs_seqs)
else
obs_seq = reduce(hcat, obs_seqs)
end
obs_seq = reduce(vcat, obs_seqs)
control_seq = fill(nothing, length(obs_seq))
seq_ends = cumsum(length.(obs_seqs))
control_seq = Fill(nothing, duration(obs_seq))
seq_ends = cumsum(duration.(obs_seqs))

benchs = Dict()

Expand Down
86 changes: 51 additions & 35 deletions libs/HMMTest/src/allocations.jl
Original file line number Diff line number Diff line change
@@ -1,50 +1,66 @@

function test_allocations_aux(
hmm::AbstractHMM,
obs_seq::AbstractVecOrMat,
control_seq::AbstractVecOrMat;
seq_ends::AbstractVector{Int},
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
)
t1, t2 = 1, seq_ends[1]

## Forward
f_storage = HMMs.initialize_forward(hmm, obs_seq, control_seq; seq_ends)
allocs_f = @ballocated HMMs.forward!($f_storage, $hmm, $obs_seq, $control_seq, $t1, $t2) evals =
1 samples = 1
@test allocs_f == 0

## Viterbi
v_storage = HMMs.initialize_viterbi(hmm, obs_seq, control_seq; seq_ends)
allocs_v = @ballocated HMMs.viterbi!($v_storage, $hmm, $obs_seq, $control_seq, $t1, $t2) evals =
1 samples = 1
@test allocs_v == 0

## Forward-backward
fb_storage = HMMs.initialize_forward_backward(hmm, obs_seq, control_seq; seq_ends)
allocs_fb = @ballocated HMMs.forward_backward!(
$fb_storage, $hmm, $obs_seq, $control_seq, $t1, $t2
) evals = 1 samples = 1
@test allocs_fb == 0

## Baum-Welch
if !isnothing(hmm_guess)
fb_storage = HMMs.initialize_forward_backward(
hmm_guess, obs_seq, control_seq; seq_ends
)
HMMs.forward_backward!(fb_storage, hmm, obs_seq, control_seq; seq_ends)
allocs_bw = @ballocated fit!(
$hmm_guess, $fb_storage, $obs_seq, $control_seq; seq_ends=$seq_ends
) evals = 1 samples = 1 setup = (hmm_guess = deepcopy($hmm))
@test_broken allocs_bw == 0
end
end

function test_allocations(
rng::AbstractRNG,
hmm::AbstractHMM,
hmm_guess::Union{Nothing,AbstractHMM}=nothing;
control_seq::AbstractVector,
control_seq::AbstractVecOrMat;
seq_ends::AbstractVector{Int},
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
)
@testset "Allocations" begin
obs_seq = mapreduce(vcat, eachindex(seq_ends)) do k
t1, t2 = seq_limits(seq_ends, k)
rand(rng, hmm, control_seq[t1:t2]).obs_seq
end

t1, t2 = 1, seq_ends[1]

## Forward
f_storage = HMMs.initialize_forward(hmm, obs_seq, control_seq; seq_ends)
allocs_f = @ballocated HMMs.forward!(
$f_storage, $hmm, $obs_seq, $control_seq, $t1, $t2
) evals = 1 samples = 1
@test allocs_f == 0

## Viterbi
v_storage = HMMs.initialize_viterbi(hmm, obs_seq, control_seq; seq_ends)
allocs_v = @ballocated HMMs.viterbi!(
$v_storage, $hmm, $obs_seq, $control_seq, $t1, $t2
) evals = 1 samples = 1
@test allocs_v == 0

## Forward-backward
fb_storage = HMMs.initialize_forward_backward(hmm, obs_seq, control_seq; seq_ends)
allocs_fb = @ballocated HMMs.forward_backward!(
$fb_storage, $hmm, $obs_seq, $control_seq, $t1, $t2
) evals = 1 samples = 1
@test allocs_fb == 0

## Baum-Welch
if !isnothing(hmm_guess)
fb_storage = HMMs.initialize_forward_backward(
hmm_guess, obs_seq, control_seq; seq_ends
)
HMMs.forward_backward!(fb_storage, hmm, obs_seq, control_seq; seq_ends)
allocs_bw = @ballocated fit!(
$hmm_guess, $fb_storage, $obs_seq, $control_seq; seq_ends=$seq_ends
) evals = 1 samples = 1 setup = (hmm_guess = deepcopy($hmm))
@test_broken allocs_bw == 0
@testset "Sequence" begin
test_allocations_aux(hmm, obs_seq, control_seq; seq_ends, hmm_guess)
end
if first(obs_seq) isa AbstractVector
obs_mat = reduce(hcat, obs_seq)
@testset "Matrix" begin
test_allocations_aux(hmm, obs_mat, control_seq; seq_ends, hmm_guess)
end
end
end
end
77 changes: 50 additions & 27 deletions libs/HMMTest/src/coherence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ infnorm(x) = maximum(abs, x)

function check_equal_hmms(
hmm1::AbstractHMM,
hmm2::AbstractHMM;
control_seq=[nothing],
hmm2::AbstractHMM,
control_seq=[nothing];
atol::Real=0.1,
init::Bool=true,
test::Bool=true,
Expand Down Expand Up @@ -45,21 +45,53 @@ end

function test_equal_hmms(
hmm1::AbstractHMM,
hmm2::AbstractHMM;
control_seq=[nothing],
hmm2::AbstractHMM,
control_seq=[nothing];
atol::Real=0.1,
init::Bool=true,
)
check_equal_hmms(hmm1, hmm2; control_seq, atol, init, test=true)
check_equal_hmms(hmm1, hmm2, control_seq; atol, init, test=true)
return nothing
end

function test_coherent_algorithms_aux(
hmm::AbstractHMM,
obs_seq::AbstractVecOrMat,
state_seq::AbstractVector{<:Integer},
control_seq::AbstractVecOrMat;
seq_ends::AbstractVector{Int},
hmm_guess::Union{Nothing,AbstractHMM},
atol::Real,
init::Bool,
)
logL = logdensityof(hmm, obs_seq, control_seq; seq_ends)
logL_joint = joint_logdensityof(hmm, obs_seq, state_seq, control_seq; seq_ends)

q, logL_viterbi = viterbi(hmm, obs_seq, control_seq; seq_ends)
@test logL_viterbi > logL_joint
@test logL_viterbi ≈ joint_logdensityof(hmm, obs_seq, q, control_seq; seq_ends)

α, logL_forward = forward(hmm, obs_seq, control_seq; seq_ends)
@test logL_forward ≈ logL

γ, logL_forward_backward = forward_backward(hmm, obs_seq, control_seq; seq_ends)
@test logL_forward_backward ≈ logL
@test all(α[:, seq_ends[k]] ≈ γ[:, seq_ends[k]] for k in eachindex(seq_ends))

if !isnothing(hmm_guess)
hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends)
@test all(>=(0), diff(logL_evolution))
@test !check_equal_hmms(hmm, hmm_guess, control_seq[1:2]; atol, test=false)
test_equal_hmms(hmm, hmm_est, control_seq[1:2]; atol, init)
end
end

function test_coherent_algorithms(
rng::AbstractRNG,
hmm::AbstractHMM,
hmm_guess::Union{Nothing,AbstractHMM}=nothing;
control_seq::AbstractVector,
control_seq::AbstractVecOrMat;
seq_ends::AbstractVector{Int},
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
atol::Real=0.1,
init::Bool=true,
)
Expand All @@ -75,27 +107,18 @@ function test_coherent_algorithms(
state_seq = reduce(vcat, state_seqs)
obs_seq = reduce(vcat, obs_seqs)

logL = logdensityof(hmm, obs_seq, control_seq; seq_ends)
logL_joint = joint_logdensityof(hmm, obs_seq, state_seq, control_seq; seq_ends)

q, logL_viterbi = viterbi(hmm, obs_seq, control_seq; seq_ends)
@test logL_viterbi > logL_joint
@test logL_viterbi ≈ joint_logdensityof(hmm, obs_seq, q, control_seq; seq_ends)

α, logL_forward = forward(hmm, obs_seq, control_seq; seq_ends)
@test logL_forward ≈ logL

γ, logL_forward_backward = forward_backward(hmm, obs_seq, control_seq; seq_ends)
@test logL_forward_backward ≈ logL
@test all(α[:, seq_ends[k]] ≈ γ[:, seq_ends[k]] for k in eachindex(seq_ends))

if !isnothing(hmm_guess)
hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends)
@test all(>=(0), diff(logL_evolution))
@test !check_equal_hmms(
hmm, hmm_guess; control_seq=control_seq[1:2], atol, test=false
@testset "Sequence" begin
test_coherent_algorithms_aux(
hmm, obs_seq, state_seq, control_seq; seq_ends, hmm_guess, atol, init
)
test_equal_hmms(hmm, hmm_est; control_seq=control_seq[1:2], atol, init)
end
if first(obs_seq) isa AbstractVector
obs_mat = reduce(hcat, obs_seq)
@testset "Matrix" begin
test_coherent_algorithms_aux(
hmm, obs_mat, state_seq, control_seq; seq_ends, hmm_guess, atol, init
)
end
end
end
end
4 changes: 2 additions & 2 deletions libs/HMMTest/src/hmmbase.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@

function test_identical_hmmbase(
rng::AbstractRNG,
hmm::AbstractHMM,
hmm_guess::Union{Nothing,AbstractHMM}=nothing;
hmm::AbstractHMM;
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
T::Integer,
atol::Real=1e-5,
)
Expand Down
4 changes: 2 additions & 2 deletions libs/HMMTest/src/jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
function test_type_stability(
rng::AbstractRNG,
hmm::AbstractHMM,
hmm_guess::Union{Nothing,AbstractHMM}=nothing;
control_seq::AbstractVector,
control_seq::AbstractVecOrMat;
seq_ends::AbstractVector{Int},
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
)
@testset "Type stability" begin
state_seq, obs_seq = rand(rng, hmm, control_seq)
Expand Down
Loading