diff --git a/.gitignore b/.gitignore index 6156fbfb..f8e41336 100644 --- a/.gitignore +++ b/.gitignore @@ -21,4 +21,7 @@ scratchpad.jl /benchmark/*.json /docs/src/index.md -/docs/src/examples/*.md \ No newline at end of file +/docs/src/examples/*.md + +*.pdf +*.png \ No newline at end of file diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 8239c69a..07923e7f 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -5,7 +5,7 @@ using StableRNGs rng = StableRNG(63) -algos = ["rand", "logdensity", "forward", "viterbi", "forward_backward", "baum_welch"] +algos = ["forward", "viterbi", "forward_backward", "baum_welch"] instances = [ # compare state numbers Instance(; nb_states=4, obs_dim=1), diff --git a/examples/basics.jl b/examples/basics.jl index 23734d14..70f80c7e 100644 --- a/examples/basics.jl +++ b/examples/basics.jl @@ -54,7 +54,7 @@ In practical applications, the state sequence is not known, which is why we need # ## Inference #= -The Viterbi algorithm ([`viterbi`](@ref)) returns the most likely state sequence $\hat{X}_{1:T} = \underset{X_{1:T}}{\mathrm{argmax}}~\mathbb{P}(X_{1:T} \vert Y_{1:T})$, along with the joint loglikelihood $\mathbb{P}(\hat{X}_{1:T}, Y_{1:T})$. +The Viterbi algorithm ([`viterbi`](@ref)) returns the most likely state sequence $\hat{X}_{1:T} = \underset{X_{1:T}}{\mathrm{argmax}}~\mathbb{P}(X_{1:T} \vert Y_{1:T})$, along with the joint loglikelihood $\mathbb{P}(\hat{X}_{1:T}, Y_{1:T})$ (in a vector of size 1). =# best_state_seq, best_joint_loglikelihood = viterbi(hmm, obs_seq); @@ -66,7 +66,7 @@ As we can see, it is very close to the true state sequence, but not necessarily vcat(state_seq', best_state_seq') #= -The forward algorithm ([`forward`](@ref)) returns a matrix of filtered state marginals $\alpha[i, t] = \mathbb{P}(X_t = i | Y_{1:t})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence. +The forward algorithm ([`forward`](@ref)) returns a matrix of filtered state marginals $\alpha[i, t] = \mathbb{P}(X_t = i | Y_{1:t})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence (in a vector of size 1). =# filtered_state_marginals, obs_seq_loglikelihood1 = forward(hmm, obs_seq); @@ -79,7 +79,7 @@ This is particularly useful to infer the marginal distribution of the last state filtered_state_marginals[:, end] #= -Conversely, the forward-backward algorithm ([`forward_backward`](@ref)) returns a matrix of smoothed state marginals $\gamma[i, t] = \mathbb{P}(X_t = i | Y_{1:T})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence. +Conversely, the forward-backward algorithm ([`forward_backward`](@ref)) returns a matrix of smoothed state marginals $\gamma[i, t] = \mathbb{P}(X_t = i | Y_{1:T})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence (in a vector of size 1). =# smoothed_state_marginals, obs_seq_loglikelihood2 = forward_backward(hmm, obs_seq); @@ -179,7 +179,7 @@ long_obs_seq_concat = reduce(vcat, long_obs_seqs) seq_ends = cumsum(length.(long_obs_seqs)) #= -The outputs of inference algorithms are then concatenated, and the associated loglikelihoods are summed over all sequences. +The outputs of inference algorithms are then concatenated, and the associated loglikelihoods are split by sequence (in a vector of size `length(seq_ends)`). =# best_state_seq_concat, _ = viterbi(hmm, long_obs_seq_concat; seq_ends); diff --git a/libs/HMMBenchmark/src/HMMBenchmark.jl b/libs/HMMBenchmark/src/HMMBenchmark.jl index b2b543fa..6bf86a9c 100644 --- a/libs/HMMBenchmark/src/HMMBenchmark.jl +++ b/libs/HMMBenchmark/src/HMMBenchmark.jl @@ -21,13 +21,13 @@ using LinearAlgebra: BLAS, Diagonal, SymTridiagonal using Pkg: Pkg using Random: AbstractRNG using SparseArrays: spdiagm -using Statistics: mean, median, std +using Statistics: mean, median, std, quantile -export AbstractImplementation, Instance -export define_suite, parse_results, print_julia_setup +export Implementation, Instance, Params, HiddenMarkovModelsImplem +export build_params, build_data, build_model, build_benchmarkables +export define_suite, parse_results, read_results, print_julia_setup abstract type Implementation end -Base.string(implem::Implementation) = string(typeof(implem))[begin:(end - length("Implem"))] include("instance.jl") include("params.jl") diff --git a/libs/HMMBenchmark/src/hiddenmarkovmodels.jl b/libs/HMMBenchmark/src/hiddenmarkovmodels.jl index 346d384d..76fab3fd 100644 --- a/libs/HMMBenchmark/src/hiddenmarkovmodels.jl +++ b/libs/HMMBenchmark/src/hiddenmarkovmodels.jl @@ -1,8 +1,9 @@ struct HiddenMarkovModelsImplem <: Implementation end +Base.string(::HiddenMarkovModelsImplem) = "HiddenMarkovModels.jl" -function build_model(rng::AbstractRNG, ::HiddenMarkovModelsImplem; instance::Instance) +function build_model(::HiddenMarkovModelsImplem, instance::Instance, params::Params) (; custom_dist, nb_states, obs_dim) = instance - (; init, trans, means, stds) = build_params(rng; instance) + (; init, trans, means, stds) = params if custom_dist dists = [LightDiagNormal(means[:, i], stds[:, i]) for i in 1:nb_states] @@ -19,15 +20,14 @@ function build_model(rng::AbstractRNG, ::HiddenMarkovModelsImplem; instance::Ins end function build_benchmarkables( - rng::AbstractRNG, - implem::HiddenMarkovModelsImplem; + implem::HiddenMarkovModelsImplem, instance::Instance, - algos::Vector{String}, + params::Params, + data::AbstractArray{<:Real,3}, + algos::Vector{String}; ) (; obs_dim, seq_length, nb_seqs, bw_iter) = instance - - hmm = build_model(rng, implem; instance) - data = randn(rng, nb_seqs, seq_length, obs_dim) + hmm = build_model(implem, instance, params) if obs_dim == 1 obs_seqs = [[data[k, t, 1] for t in 1:seq_length] for k in 1:nb_seqs] @@ -40,18 +40,6 @@ function build_benchmarkables( benchs = Dict() - if "rand" in algos - benchs["rand"] = @benchmarkable begin - [rand($rng, $hmm, $seq_length).obs_seq for _ in 1:($nb_seqs)] - end evals = 1 samples = 100 - end - - if "logdensity" in algos - benchs["logdensity"] = @benchmarkable begin - logdensityof($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 100 - end - if "forward" in algos benchs["forward"] = @benchmarkable begin forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) @@ -112,7 +100,7 @@ function build_benchmarkables( loglikelihood_increasing=false, ) end evals = 1 samples = 100 setup = ( - hmm_guess = build_model($rng, $implem; instance=$instance); + hmm_guess = build_model($implem, $instance, $params); fb_storage = initialize_forward_backward( hmm_guess, $obs_seq, $control_seq; seq_ends=$seq_ends ); diff --git a/libs/HMMBenchmark/src/instance.jl b/libs/HMMBenchmark/src/instance.jl index c49596d9..9086722b 100644 --- a/libs/HMMBenchmark/src/instance.jl +++ b/libs/HMMBenchmark/src/instance.jl @@ -16,3 +16,8 @@ function Instance(s::String) vals = parse.(Int, split(s, " ")[2:2:end]) return Instance(vals...) end + +function build_data(rng::AbstractRNG, instance::Instance) + (; nb_seqs, seq_length, obs_dim) = instance + return randn(rng, nb_seqs, seq_length, obs_dim) +end diff --git a/libs/HMMBenchmark/src/params.jl b/libs/HMMBenchmark/src/params.jl index 21f59703..246b609d 100644 --- a/libs/HMMBenchmark/src/params.jl +++ b/libs/HMMBenchmark/src/params.jl @@ -1,11 +1,13 @@ -function build_initialization(rng::AbstractRNG; instance::Instance) - (; nb_states) = instance - init = ones(nb_states) / nb_states - return init +Base.@kwdef struct Params{T,M<:AbstractMatrix{T}} + init::Vector{T} + trans::M + means::Matrix{T} + stds::Matrix{T} end -function build_transition_matrix(rng::AbstractRNG; instance::Instance) - (; sparse, nb_states) = instance +function build_params(rng::AbstractRNG, instance::Instance) + (; sparse, nb_states, obs_dim) = instance + init = ones(nb_states) / nb_states if sparse trans = spdiagm( 0 => rand(rng, nb_states) / 2, @@ -18,25 +20,7 @@ function build_transition_matrix(rng::AbstractRNG; instance::Instance) for row in eachrow(trans) row ./= sum(row) end - return trans -end - -function build_means(rng::AbstractRNG; instance::Instance) - (; obs_dim, nb_states) = instance means = randn(rng, obs_dim, nb_states) - return means -end - -function build_stds(rng::AbstractRNG; instance::Instance) - (; obs_dim, nb_states) = instance stds = ones(obs_dim, nb_states) - return stds -end - -function build_params(rng::AbstractRNG; instance::Instance) - init = build_initialization(rng; instance) - trans = build_transition_matrix(rng; instance) - means = build_means(rng; instance) - stds = build_stds(rng; instance) - return (; init, trans, means, stds) + return Params(; init, trans, means, stds) end diff --git a/libs/HMMBenchmark/src/suite.jl b/libs/HMMBenchmark/src/suite.jl index 4adc3879..dd7d969d 100644 --- a/libs/HMMBenchmark/src/suite.jl +++ b/libs/HMMBenchmark/src/suite.jl @@ -11,7 +11,9 @@ function define_suite( SUITE = BenchmarkGroup() for implem in implems for instance in instances - bench_tup = build_benchmarkables(rng, implem; instance, algos) + params = build_params(rng, instance) + data = build_data(rng, instance) + bench_tup = build_benchmarkables(implem, instance, params, data, algos) for (algo, bench) in pairs(bench_tup) SUITE[string(implem)][string(instance)][algo] = bench end @@ -20,8 +22,13 @@ function define_suite( return SUITE end +quantile75(x) = quantile(x, 0.75) +quantile25(x) = quantile(x, 0.25) + function parse_results( - results; path=nothing, aggregators=[minimum, median, maximum, mean, std] + results; + path=nothing, + aggregators=[minimum, median, maximum, mean, std, quantile25, quantile75], ) data = DataFrame() for implem_str in identity.(keys(results)) @@ -48,3 +55,7 @@ function parse_results( end return data end + +function read_results(path) + return CSV.read(path, DataFrame) +end diff --git a/libs/HMMComparison/Project.toml b/libs/HMMComparison/Project.toml index 587be330..8e8e3d4f 100644 --- a/libs/HMMComparison/Project.toml +++ b/libs/HMMComparison/Project.toml @@ -12,8 +12,10 @@ HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7" HMMBenchmark = "557005d5-2e4a-43f9-8aa7-ba8df2d03179" HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/libs/HMMComparison/src/HMMComparison.jl b/libs/HMMComparison/src/HMMComparison.jl index e8da5901..4b956de8 100644 --- a/libs/HMMComparison/src/HMMComparison.jl +++ b/libs/HMMComparison/src/HMMComparison.jl @@ -3,28 +3,24 @@ module HMMComparison using BenchmarkTools: BenchmarkGroup, @benchmarkable using CondaPkg: CondaPkg using Distributions: Normal, MvNormal +using HiddenMarkovModels: HiddenMarkovModels using HMMBase: HMMBase -using HMMBenchmark: - HMMBenchmark, - Instance, - Implementation, - HiddenMarkovModelsImplem, - build_params, - build_model, - build_benchmarkables using LinearAlgebra: Diagonal -using PythonCall: Py, pyimport, pybuiltins, pylist +using LogExpFunctions: logsumexp +using PythonCall: Py, PyArray, pyimport, pyconvert, pybuiltins, pylist using Random: AbstractRNG +using Reexport: @reexport using SparseArrays: spdiagm +@reexport using HMMBenchmark -export HiddenMarkovModelsImplem, - HMMBaseImplem, hmmlearnImplem, pomegranateImplem, dynamaxImplem -export build_model, build_benchmarkables +export HMMBaseImplem, hmmlearnImplem, pomegranateImplem, dynamaxImplem +export print_python_setup, compare_loglikelihoods include("hmmbase.jl") include("hmmlearn.jl") include("pomegranate.jl") include("dynamax.jl") include("setup.jl") +include("correctness.jl") end # module HMMComparison diff --git a/libs/HMMComparison/src/correctness.jl b/libs/HMMComparison/src/correctness.jl new file mode 100644 index 00000000..69132582 --- /dev/null +++ b/libs/HMMComparison/src/correctness.jl @@ -0,0 +1,78 @@ +function compare_loglikelihoods( + instance::Instance, params::Params, data::AbstractArray{<:Real,3} +) + torch = pyimport("torch") + jax = pyimport("jax") + jnp = pyimport("jax.numpy") + + torch.set_default_dtype(torch.float64) + jax.config.update("jax_enable_x64", true) + + (; obs_dim, seq_length, nb_seqs) = instance + + results = Dict{String,Any}() + + ## Data formats + + if obs_dim == 1 + obs_seqs = [[data[k, t, 1] for t in 1:seq_length] for k in 1:nb_seqs] + else + obs_seqs = [[data[k, t, :] for t in 1:seq_length] for k in 1:nb_seqs] + end + obs_seq = reduce(vcat, obs_seqs) + control_seq = fill(nothing, length(obs_seq)) + seq_ends = cumsum(length.(obs_seqs)) + + if obs_dim == 1 + obs_mat = reduce(vcat, data[k, :, 1] for k in 1:nb_seqs) + else + obs_mat = reduce(vcat, data[k, :, :] for k in 1:nb_seqs) + end + + obs_mat_concat = reduce(vcat, data[k, :, :] for k in 1:nb_seqs) + obs_mat_concat_py = Py(obs_mat_concat).to_numpy() + obs_mat_len_py = Py(fill(seq_length, nb_seqs)).to_numpy() + + obs_tens_torch_py = torch.tensor(Py(data).to_numpy()) + obs_tens_jax_py = jnp.array(Py(data).to_numpy()) + + # HiddenMarkovModels.jl + + implem1 = HiddenMarkovModelsImplem() + hmm1 = build_model(implem1, instance, params) + _, logLs1 = HiddenMarkovModels.forward(hmm1, obs_seq, control_seq; seq_ends) + results[string(implem1)] = logLs1 + + ## HMMBase.jl + + implem2 = HMMBaseImplem() + hmm2 = build_model(implem2, instance, params) + _, logL2 = HMMBase.forward(hmm2, obs_mat) + results[string(implem2)] = logL2 + + ## hmmlearn + + implem3 = hmmlearnImplem() + hmm3 = build_model(implem3, instance, params) + logL3 = hmm3.score(obs_mat_concat_py, obs_mat_len_py) + results[string(implem3)] = pyconvert(Number, logL3) + + ## pomegranate + + implem4 = pomegranateImplem() + hmm4 = build_model(implem4, instance, params) + logαs4 = PyArray(hmm4.forward(obs_tens_torch_py)) + logLs4 = [logsumexp(logαs4[k, end, :]) for k in 1:nb_seqs] + results[string(implem4)] = logLs4 + + ## dynamax + + implem5 = dynamaxImplem() + hmm5, dyn_params5 = build_model(implem5, instance, params) + filter_vmap = jax.jit(jax.vmap(hmm5.filter; in_axes=pylist((pybuiltins.None, 0)))) + posterior5 = filter_vmap(dyn_params5, obs_tens_jax_py) + logLs5 = PyArray(posterior5.marginal_loglik) + results[string(implem5)] = logLs5 + + return results +end diff --git a/libs/HMMComparison/src/dynamax.jl b/libs/HMMComparison/src/dynamax.jl index 946811bf..4a845911 100644 --- a/libs/HMMComparison/src/dynamax.jl +++ b/libs/HMMComparison/src/dynamax.jl @@ -1,14 +1,14 @@ struct dynamaxImplem <: Implementation end +Base.string(::dynamaxImplem) = "dynamax" -function HMMBenchmark.build_model( - rng::AbstractRNG, implem::dynamaxImplem; instance::Instance -) - np = pyimport("numpy") +function HMMBenchmark.build_model(implem::dynamaxImplem, instance::Instance, params::Params) + jax = pyimport("jax") jnp = pyimport("jax.numpy") dynamax_hmm = pyimport("dynamax.hidden_markov_model") + jax.config.update("jax_enable_x64", true) (; nb_states, obs_dim) = instance - (; init, trans, means, stds) = build_params(rng; instance) + (; init, trans, means, stds) = params initial_probs = jnp.array(Py(init).to_numpy()) transition_matrix = jnp.array(Py(trans).to_numpy()) @@ -16,42 +16,38 @@ function HMMBenchmark.build_model( emission_scale_diags = jnp.array(Py(transpose(stds)).to_numpy()) hmm = dynamax_hmm.DiagonalGaussianHMM(nb_states, obs_dim) - params, props = hmm.initialize(; + dyn_params, dyn_props = hmm.initialize(; initial_probs=initial_probs, transition_matrix=transition_matrix, emission_means=emission_means, emission_scale_diags=emission_scale_diags, ) - return hmm, params, props + return hmm, dyn_params, dyn_props end function HMMBenchmark.build_benchmarkables( - rng::AbstractRNG, implem::dynamaxImplem; instance::Instance, algos::Vector{String} + implem::dynamaxImplem, + instance::Instance, + params::Params, + data::AbstractArray{<:Real,3}, + algos::Vector{String}, ) - np = pyimport("numpy") jax = pyimport("jax") jnp = pyimport("jax.numpy") - (; obs_dim, seq_length, nb_seqs, bw_iter) = instance + jax.config.update("jax_enable_x64", true) - hmm, params, _ = build_model(rng, implem; instance) - data = randn(rng, nb_seqs, seq_length, obs_dim) + (; bw_iter) = instance + hmm, dyn_params, _ = build_model(implem, instance, params) - obs_tens_py = jnp.array(Py(data).to_numpy()) + obs_tens_jax_py = jnp.array(Py(data).to_numpy()) benchs = Dict() - if "logdensity" in algos - filter_vmap = jax.jit(jax.vmap(hmm.filter; in_axes=pylist((pybuiltins.None, 0)))) - benchs["logdensity"] = @benchmarkable begin - $(filter_vmap)($params, $obs_tens_py) - end evals = 1 samples = 100 - end - if "forward" in algos filter_vmap = jax.jit(jax.vmap(hmm.filter; in_axes=pylist((pybuiltins.None, 0)))) benchs["forward"] = @benchmarkable begin - $(filter_vmap)($params, $obs_tens_py) + $(filter_vmap)($dyn_params, $obs_tens_jax_py) end evals = 1 samples = 100 end @@ -60,7 +56,7 @@ function HMMBenchmark.build_benchmarkables( jax.vmap(hmm.most_likely_states; in_axes=pylist((pybuiltins.None, 0))) ) benchs["viterbi"] = @benchmarkable begin - $(most_likely_states_vmap)($params, $obs_tens_py) + $(most_likely_states_vmap)($dyn_params, $obs_tens_jax_py) end evals = 1 samples = 100 end @@ -69,20 +65,24 @@ function HMMBenchmark.build_benchmarkables( jax.vmap(hmm.smoother; in_axes=pylist((pybuiltins.None, 0))) ) benchs["forward_backward"] = @benchmarkable begin - $(smoother_vmap)($params, $obs_tens_py) + $(smoother_vmap)($dyn_params, $obs_tens_jax_py) end evals = 1 samples = 100 end if "baum_welch" in algos benchs["baum_welch"] = @benchmarkable begin hmm_guess.fit_em( - params_guess, props_guess, $obs_tens_py; num_iters=$bw_iter, verbose=false + dyn_params_guess, + dyn_props_guess, + $obs_tens_jax_py; + num_iters=$bw_iter, + verbose=false, ) end evals = 1 samples = 100 setup = ( - tup = build_model($rng, $implem; instance=$instance); + tup = build_model($implem, $instance, $params); hmm_guess = tup[1]; - params_guess = tup[2]; - props_guess = tup[3] + dyn_params_guess = tup[2]; + dyn_props_guess = tup[3] ) end diff --git a/libs/HMMComparison/src/hmmbase.jl b/libs/HMMComparison/src/hmmbase.jl index d41582e0..12b629a7 100644 --- a/libs/HMMComparison/src/hmmbase.jl +++ b/libs/HMMComparison/src/hmmbase.jl @@ -1,10 +1,9 @@ struct HMMBaseImplem <: Implementation end +Base.string(::HMMBaseImplem) = "HMMBase.jl" -function HMMBenchmark.build_model( - rng::AbstractRNG, implem::HMMBaseImplem; instance::Instance -) +function HMMBenchmark.build_model(implem::HMMBaseImplem, instance::Instance, params::Params) (; nb_states, obs_dim) = instance - (; init, trans, means, stds) = build_params(rng; instance) + (; init, trans, means, stds) = params a = init A = trans @@ -19,12 +18,14 @@ function HMMBenchmark.build_model( end function HMMBenchmark.build_benchmarkables( - rng::AbstractRNG, implem::HMMBaseImplem; instance::Instance, algos::Vector{String} + implem::HMMBaseImplem, + instance::Instance, + params::Params, + data::AbstractArray{<:Real,3}, + algos::Vector{String}, ) (; obs_dim, seq_length, nb_seqs, bw_iter) = instance - - hmm = build_model(rng, implem; instance) - data = randn(rng, nb_seqs, seq_length, obs_dim) + hmm = build_model(implem, instance, params) if obs_dim == 1 obs_mat = reduce(vcat, data[k, :, 1] for k in 1:nb_seqs) @@ -34,12 +35,6 @@ function HMMBenchmark.build_benchmarkables( benchs = BenchmarkGroup() - if "logdensity" in algos - benchs["logdensity"] = @benchmarkable begin - HMMBase.forward($hmm, $obs_mat) - end evals = 1 samples = 100 - end - if "forward" in algos benchs["forward"] = @benchmarkable begin HMMBase.forward($hmm, $obs_mat) diff --git a/libs/HMMComparison/src/hmmlearn.jl b/libs/HMMComparison/src/hmmlearn.jl index 1b3aeecf..610c9129 100644 --- a/libs/HMMComparison/src/hmmlearn.jl +++ b/libs/HMMComparison/src/hmmlearn.jl @@ -1,13 +1,14 @@ struct hmmlearnImplem <: Implementation end +Base.string(::hmmlearnImplem) = "hmmlearn" function HMMBenchmark.build_model( - rng::AbstractRNG, implem::hmmlearnImplem; instance::Instance + implem::hmmlearnImplem, instance::Instance, params::Params ) np = pyimport("numpy") hmmlearn_hmm = pyimport("hmmlearn.hmm") (; bw_iter, nb_states) = instance - (; init, trans, means, stds) = build_params(rng; instance) + (; init, trans, means, stds) = params hmm = hmmlearn_hmm.GaussianHMM(; n_components=nb_states, @@ -26,26 +27,21 @@ function HMMBenchmark.build_model( end function HMMBenchmark.build_benchmarkables( - rng::AbstractRNG, implem::hmmlearnImplem; instance::Instance, algos::Vector{String} + implem::hmmlearnImplem, + instance::Instance, + params::Params, + data::AbstractArray{<:Real,3}, + algos::Vector{String}, ) - np = pyimport("numpy") (; obs_dim, seq_length, nb_seqs) = instance - - hmm = build_model(rng, implem; instance) - data = randn(rng, nb_seqs, seq_length, obs_dim) + hmm = build_model(implem, instance, params) obs_mat_concat = reduce(vcat, data[k, :, :] for k in 1:nb_seqs) obs_mat_concat_py = Py(obs_mat_concat).to_numpy() - obs_mat_len_py = np.full(nb_seqs, seq_length) + obs_mat_len_py = Py(fill(seq_length, nb_seqs)).to_numpy() benchs = Dict() - if "logdensity" in algos - benchs["logdensity"] = @benchmarkable begin - $(hmm.score)($obs_mat_concat_py, $obs_mat_len_py) - end evals = 1 samples = 100 - end - if "forward" in algos benchs["forward"] = @benchmarkable begin $(hmm.score)($obs_mat_concat_py, $obs_mat_len_py) @@ -68,7 +64,7 @@ function HMMBenchmark.build_benchmarkables( benchs["baum_welch"] = @benchmarkable begin hmm_guess.fit($obs_mat_concat_py, $obs_mat_len_py) end evals = 1 samples = 100 setup = ( - hmm_guess = build_model($rng, $implem; instance=$instance) + hmm_guess = build_model($implem, $instance, $params) ) end diff --git a/libs/HMMComparison/src/pomegranate.jl b/libs/HMMComparison/src/pomegranate.jl index 6f984746..d5939919 100644 --- a/libs/HMMComparison/src/pomegranate.jl +++ b/libs/HMMComparison/src/pomegranate.jl @@ -1,15 +1,16 @@ struct pomegranateImplem <: Implementation end +Base.string(::pomegranateImplem) = "pomegranate" function HMMBenchmark.build_model( - rng::AbstractRNG, implem::pomegranateImplem; instance::Instance + implem::pomegranateImplem, instance::Instance, params::Params ) torch = pyimport("torch") - torch.set_default_dtype(torch.float64) pomegranate_distributions = pyimport("pomegranate.distributions") pomegranate_hmm = pyimport("pomegranate.hmm") + torch.set_default_dtype(torch.float64) (; nb_states, bw_iter) = instance - (; init, trans, means, stds) = build_params(rng; instance) + (; init, trans, means, stds) = params starts = torch.tensor(Py(init).to_numpy()) ends = torch.ones(nb_states) * 1e-10 @@ -37,43 +38,39 @@ function HMMBenchmark.build_model( end function HMMBenchmark.build_benchmarkables( - rng::AbstractRNG, implem::pomegranateImplem; instance::Instance, algos::Vector{String} + implem::pomegranateImplem, + instance::Instance, + params::Params, + data::AbstractArray{<:Real,3}, + algos::Vector{String}, ) - np = pyimport("numpy") torch = pyimport("torch") torch.set_default_dtype(torch.float64) - (; obs_dim, seq_length, nb_seqs) = instance - hmm = build_model(rng, implem; instance) - data = randn(rng, nb_seqs, seq_length, obs_dim) + (; obs_dim, seq_length, nb_seqs) = instance + hmm = build_model(implem, instance, params) - obs_tens_py = torch.tensor(Py(data).to_numpy()) + obs_tens_torch_py = torch.tensor(Py(data).to_numpy()) benchs = Dict() - if "logdensity" in algos - benchs["logdensity"] = @benchmarkable begin - $(hmm.forward)($obs_tens_py) - end evals = 1 samples = 100 - end - if "forward" in algos benchs["forward"] = @benchmarkable begin - $(hmm.forward)($obs_tens_py) + $(hmm.forward)($obs_tens_torch_py) end evals = 1 samples = 100 end if "forward_backward" in algos benchs["forward_backward"] = @benchmarkable begin - $(hmm.forward_backward)($obs_tens_py) + $(hmm.forward_backward)($obs_tens_torch_py) end evals = 1 samples = 100 end if "baum_welch" in algos benchs["baum_welch"] = @benchmarkable begin - hmm_guess.fit($obs_tens_py) + hmm_guess.fit($obs_tens_torch_py) end evals = 1 samples = 100 setup = ( - hmm_guess = build_model($rng, $implem; instance=$instance) + hmm_guess = build_model($implem, $instance, $params) ) end diff --git a/libs/HMMComparison/test/Project.toml b/libs/HMMComparison/test/Project.toml new file mode 100644 index 00000000..7a21f898 --- /dev/null +++ b/libs/HMMComparison/test/Project.toml @@ -0,0 +1,3 @@ +[deps] +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/libs/HMMComparison/test/comparison.jl b/libs/HMMComparison/test/comparison.jl deleted file mode 100644 index e7f63fe3..00000000 --- a/libs/HMMComparison/test/comparison.jl +++ /dev/null @@ -1,28 +0,0 @@ -using BenchmarkTools -using HMMComparison -using HMMBenchmark -using LinearAlgebra -using StableRNGs - -BLAS.set_num_threads(1) - -rng = StableRNG(63) - -implems = [ - HiddenMarkovModelsImplem(), # - HMMBaseImplem(), # - hmmlearnImplem(), # - pomegranateImplem(), # - dynamaxImplem(), # -] -algos = ["logdensity", "forward", "viterbi", "forward_backward", "baum_welch"] -instances = [ - Instance(; - sparse=false, nb_states=5, obs_dim=10, seq_length=1000, nb_seqs=10, bw_iter=10 - ), -] - -SUITE = define_suite(rng, implems; instances, algos) - -results = BenchmarkTools.run(SUITE; verbose=true) -data = parse_results(results) diff --git a/libs/HMMComparison/test/performance.jl b/libs/HMMComparison/test/performance.jl new file mode 100644 index 00000000..cd6643cf --- /dev/null +++ b/libs/HMMComparison/test/performance.jl @@ -0,0 +1,40 @@ +using BenchmarkTools +using HMMComparison +using LinearAlgebra +using StableRNGs + +BLAS.set_num_threads(1) + +rng = StableRNG(63) + +implems = [ + HiddenMarkovModelsImplem(), # + HMMBaseImplem(), # + hmmlearnImplem(), # + pomegranateImplem(), # + dynamaxImplem(), # +] + +algos = ["forward", "viterbi", "forward_backward", "baum_welch"] + +instances = Instance[] + +for nb_states in 2:3:24 + push!( + instances, + Instance(; + custom_dist=true, + sparse=false, + nb_states=nb_states, + obs_dim=5, + seq_length=100, + nb_seqs=10, + bw_iter=10, + ), + ) +end + +SUITE = define_suite(rng, implems; instances, algos) + +results = BenchmarkTools.run(SUITE; verbose=true) +data = parse_results(results; path=joinpath(@__DIR__, "results.csv")) diff --git a/libs/HMMComparison/test/plots.jl b/libs/HMMComparison/test/plots.jl index a218e4c8..58371848 100644 --- a/libs/HMMComparison/test/plots.jl +++ b/libs/HMMComparison/test/plots.jl @@ -1,11 +1,50 @@ using DataFrames using Plots +using HMMComparison -include("comparison.jl") +data = read_results(joinpath(@__DIR__, "results.csv")) -data +sort!(data, [:algo, :implem, :nb_states]) -algo = "forward" -metric = :time_minimum -data_algo = data[data[!, :algo] .== algo, :] -bar(data_algo[!, :implem], data_algo[!, metric]; title=algo, label=string(metric)) +implems = [ + "HiddenMarkovModels.jl", # + "HMMBase.jl", # + "hmmlearn", # + "pomegranate", # + "dynamax", # +] +algos = ["forward", "baum_welch"] + +markershapes = [:star5, :circle, :diamond, :hexagon, :pentagon, :utriangle] + +for algo in algos + pl = plot(; + title=algo, + size=(800, 400), + yscale=:log, + xlabel="nb states", + ylabel="runtime (s)", + legend=:outerright, + margin=5Plots.mm, + ) + for (i, implem) in enumerate(implems) + subdata = data[(data[!, :algo] .== algo) .& (data[!, :implem] .== implem), :] + plot!( + pl, + subdata[!, :nb_states], + subdata[!, :time_median] ./ 1e9; + yerror=( + (subdata[!, :time_median] .- subdata[!, :time_quantile25]) ./ 1e9, + (subdata[!, :time_quantile75] .- subdata[!, :time_median]) ./ 1e9, + ), + label=implem, + markershape=markershapes[i], + markerstrokecolor=:auto, + markersize=5, + linestyle=:auto, + linewidth=2, + ) + end + display(pl) + savefig(pl, joinpath(@__DIR__, "$(algo).png")) +end diff --git a/libs/HMMComparison/test/runtests.jl b/libs/HMMComparison/test/runtests.jl new file mode 100644 index 00000000..6799f710 --- /dev/null +++ b/libs/HMMComparison/test/runtests.jl @@ -0,0 +1,23 @@ +using HMMComparison +using Random +using Test + +rng = Random.default_rng() + +@testset "HMMComparison" begin + instance = Instance(; + custom_dist=false, sparse=false, nb_states=5, obs_dim=10, seq_length=25, nb_seqs=10 + ) + params = build_params(rng, instance) + data = build_data(rng, instance) + logLs = compare_loglikelihoods(instance, params, data) + for (key, val) in pairs(logLs) + if key != "HMMBase.jl" + if val isa AbstractVector + @test all(val .≈ logLs["HiddenMarkovModels.jl"]) + elseif val isa Number + @test val ≈ sum(logLs["HiddenMarkovModels.jl"]) + end + end + end +end diff --git a/libs/HMMTest/src/coherence.jl b/libs/HMMTest/src/coherence.jl index f0646e24..f95004dd 100644 --- a/libs/HMMTest/src/coherence.jl +++ b/libs/HMMTest/src/coherence.jl @@ -79,14 +79,14 @@ function test_coherent_algorithms( logL_joint = joint_logdensityof(hmm, obs_seq, state_seq, control_seq; seq_ends) q, logL_viterbi = viterbi(hmm, obs_seq, control_seq; seq_ends) - @test logL_viterbi > logL_joint - @test logL_viterbi ≈ joint_logdensityof(hmm, obs_seq, q, control_seq; seq_ends) + @test sum(logL_viterbi) > logL_joint + @test sum(logL_viterbi) ≈ joint_logdensityof(hmm, obs_seq, q, control_seq; seq_ends) α, logL_forward = forward(hmm, obs_seq, control_seq; seq_ends) - @test logL_forward ≈ logL + @test sum(logL_forward) .≈ logL γ, logL_forward_backward = forward_backward(hmm, obs_seq, control_seq; seq_ends) - @test logL_forward_backward ≈ logL + @test sum(logL_forward_backward) ≈ logL @test all(α[:, seq_ends[k]] ≈ γ[:, seq_ends[k]] for k in eachindex(seq_ends)) if !isnothing(hmm_guess) diff --git a/libs/HMMTest/src/hmmbase.jl b/libs/HMMTest/src/hmmbase.jl index 76ba0731..687b72fb 100644 --- a/libs/HMMTest/src/hmmbase.jl +++ b/libs/HMMTest/src/hmmbase.jl @@ -17,12 +17,12 @@ function test_identical_hmmbase( logL_base = HMMBase.forward(hmm_base, obs_mat)[2] logL = logdensityof(hmm, obs_seq; seq_ends) - @test logL ≈ 2logL_base + @test sum(logL) ≈ 2logL_base α_base, logL_forward_base = HMMBase.forward(hmm_base, obs_mat) α, logL_forward = forward(hmm, obs_seq; seq_ends) @test isapprox(α[:, 1:T], α_base') && isapprox(α[:, (T + 1):(2T)], α_base') - @test logL_forward ≈ 2logL_forward_base + @test sum(logL_forward) ≈ 2logL_forward_base q_base = HMMBase.viterbi(hmm_base, obs_mat) q, logL_viterbi = viterbi(hmm, obs_seq; seq_ends) diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 82522bfb..0a49d745 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -107,7 +107,7 @@ $(SIGNATURES) Apply the forward algorithm to infer the current state after sequence `obs_seq` for `hmm`. -Return a tuple `(storage.α, sum(storage.logL))` where `storage` is of type [`ForwardStorage`](@ref). +Return a tuple `(storage.α, storage.logL)` where `storage` is of type [`ForwardStorage`](@ref). """ function forward( hmm::AbstractHMM, @@ -117,5 +117,5 @@ function forward( ) storage = initialize_forward(hmm, obs_seq, control_seq; seq_ends) forward!(storage, hmm, obs_seq, control_seq; seq_ends) - return storage.α, sum(storage.logL) + return storage.α, storage.logL end diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index a01ebc82..eb6ec9d3 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -123,7 +123,7 @@ $(SIGNATURES) Apply the forward-backward algorithm to infer the posterior state and transition marginals during sequence `obs_seq` for `hmm`. -Return a tuple `(storage.γ, sum(storage.logL))` where `storage` is of type [`ForwardBackwardStorage`](@ref). +Return a tuple `(storage.γ, storage.logL)` where `storage` is of type [`ForwardBackwardStorage`](@ref). """ function forward_backward( hmm::AbstractHMM, @@ -136,5 +136,5 @@ function forward_backward( hmm, obs_seq, control_seq; seq_ends, transition_marginals ) forward_backward!(storage, hmm, obs_seq, control_seq; seq_ends, transition_marginals) - return storage.γ, sum(storage.logL) + return storage.γ, storage.logL end diff --git a/src/inference/logdensity.jl b/src/inference/logdensity.jl index 4d82e152..f43fb25c 100644 --- a/src/inference/logdensity.jl +++ b/src/inference/logdensity.jl @@ -10,7 +10,7 @@ function DensityInterface.logdensityof( seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), ) _, logL = forward(hmm, obs_seq, control_seq; seq_ends) - return logL + return sum(logL) end """ diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index 2afc983f..4948261b 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -104,7 +104,7 @@ $(SIGNATURES) Apply the Viterbi algorithm to infer the most likely state sequence corresponding to `obs_seq` for `hmm`. -Return a tuple `(storage.q, sum(storage.logL))` where `storage` is of type [`ViterbiStorage`](@ref). +Return a tuple `(storage.q, storage.logL)` where `storage` is of type [`ViterbiStorage`](@ref). """ function viterbi( hmm::AbstractHMM, @@ -114,5 +114,5 @@ function viterbi( ) storage = initialize_viterbi(hmm, obs_seq, control_seq; seq_ends) viterbi!(storage, hmm, obs_seq, control_seq; seq_ends) - return storage.q, sum(storage.logL) + return storage.q, storage.logL end