Skip to content

Commit

Permalink
Correct benchmarks (#84)
Browse files Browse the repository at this point in the history
* Correct benchmarks

* Plots
  • Loading branch information
gdalle authored Feb 3, 2024
1 parent a86c3cd commit 0b7f2ac
Show file tree
Hide file tree
Showing 26 changed files with 324 additions and 192 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,7 @@ scratchpad.jl
/benchmark/*.json

/docs/src/index.md
/docs/src/examples/*.md
/docs/src/examples/*.md

*.pdf
*.png
2 changes: 1 addition & 1 deletion benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using StableRNGs

rng = StableRNG(63)

algos = ["rand", "logdensity", "forward", "viterbi", "forward_backward", "baum_welch"]
algos = ["forward", "viterbi", "forward_backward", "baum_welch"]
instances = [
# compare state numbers
Instance(; nb_states=4, obs_dim=1),
Expand Down
8 changes: 4 additions & 4 deletions examples/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ In practical applications, the state sequence is not known, which is why we need
# ## Inference

#=
The Viterbi algorithm ([`viterbi`](@ref)) returns the most likely state sequence $\hat{X}_{1:T} = \underset{X_{1:T}}{\mathrm{argmax}}~\mathbb{P}(X_{1:T} \vert Y_{1:T})$, along with the joint loglikelihood $\mathbb{P}(\hat{X}_{1:T}, Y_{1:T})$.
The Viterbi algorithm ([`viterbi`](@ref)) returns the most likely state sequence $\hat{X}_{1:T} = \underset{X_{1:T}}{\mathrm{argmax}}~\mathbb{P}(X_{1:T} \vert Y_{1:T})$, along with the joint loglikelihood $\mathbb{P}(\hat{X}_{1:T}, Y_{1:T})$ (in a vector of size 1).
=#

best_state_seq, best_joint_loglikelihood = viterbi(hmm, obs_seq);
Expand All @@ -66,7 +66,7 @@ As we can see, it is very close to the true state sequence, but not necessarily
vcat(state_seq', best_state_seq')

#=
The forward algorithm ([`forward`](@ref)) returns a matrix of filtered state marginals $\alpha[i, t] = \mathbb{P}(X_t = i | Y_{1:t})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence.
The forward algorithm ([`forward`](@ref)) returns a matrix of filtered state marginals $\alpha[i, t] = \mathbb{P}(X_t = i | Y_{1:t})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence (in a vector of size 1).
=#

filtered_state_marginals, obs_seq_loglikelihood1 = forward(hmm, obs_seq);
Expand All @@ -79,7 +79,7 @@ This is particularly useful to infer the marginal distribution of the last state
filtered_state_marginals[:, end]

#=
Conversely, the forward-backward algorithm ([`forward_backward`](@ref)) returns a matrix of smoothed state marginals $\gamma[i, t] = \mathbb{P}(X_t = i | Y_{1:T})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence.
Conversely, the forward-backward algorithm ([`forward_backward`](@ref)) returns a matrix of smoothed state marginals $\gamma[i, t] = \mathbb{P}(X_t = i | Y_{1:T})$, along with the loglikelihood $\mathbb{P}(Y_{1:T})$ of the observation sequence (in a vector of size 1).
=#

smoothed_state_marginals, obs_seq_loglikelihood2 = forward_backward(hmm, obs_seq);
Expand Down Expand Up @@ -179,7 +179,7 @@ long_obs_seq_concat = reduce(vcat, long_obs_seqs)
seq_ends = cumsum(length.(long_obs_seqs))

#=
The outputs of inference algorithms are then concatenated, and the associated loglikelihoods are summed over all sequences.
The outputs of inference algorithms are then concatenated, and the associated loglikelihoods are split by sequence (in a vector of size `length(seq_ends)`).
=#

best_state_seq_concat, _ = viterbi(hmm, long_obs_seq_concat; seq_ends);
Expand Down
8 changes: 4 additions & 4 deletions libs/HMMBenchmark/src/HMMBenchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ using LinearAlgebra: BLAS, Diagonal, SymTridiagonal
using Pkg: Pkg
using Random: AbstractRNG
using SparseArrays: spdiagm
using Statistics: mean, median, std
using Statistics: mean, median, std, quantile

export AbstractImplementation, Instance
export define_suite, parse_results, print_julia_setup
export Implementation, Instance, Params, HiddenMarkovModelsImplem
export build_params, build_data, build_model, build_benchmarkables
export define_suite, parse_results, read_results, print_julia_setup

abstract type Implementation end
Base.string(implem::Implementation) = string(typeof(implem))[begin:(end - length("Implem"))]

include("instance.jl")
include("params.jl")
Expand Down
30 changes: 9 additions & 21 deletions libs/HMMBenchmark/src/hiddenmarkovmodels.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
struct HiddenMarkovModelsImplem <: Implementation end
Base.string(::HiddenMarkovModelsImplem) = "HiddenMarkovModels.jl"

function build_model(rng::AbstractRNG, ::HiddenMarkovModelsImplem; instance::Instance)
function build_model(::HiddenMarkovModelsImplem, instance::Instance, params::Params)
(; custom_dist, nb_states, obs_dim) = instance
(; init, trans, means, stds) = build_params(rng; instance)
(; init, trans, means, stds) = params

if custom_dist
dists = [LightDiagNormal(means[:, i], stds[:, i]) for i in 1:nb_states]
Expand All @@ -19,15 +20,14 @@ function build_model(rng::AbstractRNG, ::HiddenMarkovModelsImplem; instance::Ins
end

function build_benchmarkables(
rng::AbstractRNG,
implem::HiddenMarkovModelsImplem;
implem::HiddenMarkovModelsImplem,
instance::Instance,
algos::Vector{String},
params::Params,
data::AbstractArray{<:Real,3},
algos::Vector{String};
)
(; obs_dim, seq_length, nb_seqs, bw_iter) = instance

hmm = build_model(rng, implem; instance)
data = randn(rng, nb_seqs, seq_length, obs_dim)
hmm = build_model(implem, instance, params)

if obs_dim == 1
obs_seqs = [[data[k, t, 1] for t in 1:seq_length] for k in 1:nb_seqs]
Expand All @@ -40,18 +40,6 @@ function build_benchmarkables(

benchs = Dict()

if "rand" in algos
benchs["rand"] = @benchmarkable begin
[rand($rng, $hmm, $seq_length).obs_seq for _ in 1:($nb_seqs)]
end evals = 1 samples = 100
end

if "logdensity" in algos
benchs["logdensity"] = @benchmarkable begin
logdensityof($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100
end

if "forward" in algos
benchs["forward"] = @benchmarkable begin
forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
Expand Down Expand Up @@ -112,7 +100,7 @@ function build_benchmarkables(
loglikelihood_increasing=false,
)
end evals = 1 samples = 100 setup = (
hmm_guess = build_model($rng, $implem; instance=$instance);
hmm_guess = build_model($implem, $instance, $params);
fb_storage = initialize_forward_backward(
hmm_guess, $obs_seq, $control_seq; seq_ends=$seq_ends
);
Expand Down
5 changes: 5 additions & 0 deletions libs/HMMBenchmark/src/instance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@ function Instance(s::String)
vals = parse.(Int, split(s, " ")[2:2:end])
return Instance(vals...)
end

function build_data(rng::AbstractRNG, instance::Instance)
(; nb_seqs, seq_length, obs_dim) = instance
return randn(rng, nb_seqs, seq_length, obs_dim)
end
34 changes: 9 additions & 25 deletions libs/HMMBenchmark/src/params.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
function build_initialization(rng::AbstractRNG; instance::Instance)
(; nb_states) = instance
init = ones(nb_states) / nb_states
return init
Base.@kwdef struct Params{T,M<:AbstractMatrix{T}}
init::Vector{T}
trans::M
means::Matrix{T}
stds::Matrix{T}
end

function build_transition_matrix(rng::AbstractRNG; instance::Instance)
(; sparse, nb_states) = instance
function build_params(rng::AbstractRNG, instance::Instance)
(; sparse, nb_states, obs_dim) = instance
init = ones(nb_states) / nb_states
if sparse
trans = spdiagm(
0 => rand(rng, nb_states) / 2,
Expand All @@ -18,25 +20,7 @@ function build_transition_matrix(rng::AbstractRNG; instance::Instance)
for row in eachrow(trans)
row ./= sum(row)
end
return trans
end

function build_means(rng::AbstractRNG; instance::Instance)
(; obs_dim, nb_states) = instance
means = randn(rng, obs_dim, nb_states)
return means
end

function build_stds(rng::AbstractRNG; instance::Instance)
(; obs_dim, nb_states) = instance
stds = ones(obs_dim, nb_states)
return stds
end

function build_params(rng::AbstractRNG; instance::Instance)
init = build_initialization(rng; instance)
trans = build_transition_matrix(rng; instance)
means = build_means(rng; instance)
stds = build_stds(rng; instance)
return (; init, trans, means, stds)
return Params(; init, trans, means, stds)
end
15 changes: 13 additions & 2 deletions libs/HMMBenchmark/src/suite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ function define_suite(
SUITE = BenchmarkGroup()
for implem in implems
for instance in instances
bench_tup = build_benchmarkables(rng, implem; instance, algos)
params = build_params(rng, instance)
data = build_data(rng, instance)
bench_tup = build_benchmarkables(implem, instance, params, data, algos)
for (algo, bench) in pairs(bench_tup)
SUITE[string(implem)][string(instance)][algo] = bench
end
Expand All @@ -20,8 +22,13 @@ function define_suite(
return SUITE
end

quantile75(x) = quantile(x, 0.75)
quantile25(x) = quantile(x, 0.25)

function parse_results(
results; path=nothing, aggregators=[minimum, median, maximum, mean, std]
results;
path=nothing,
aggregators=[minimum, median, maximum, mean, std, quantile25, quantile75],
)
data = DataFrame()
for implem_str in identity.(keys(results))
Expand All @@ -48,3 +55,7 @@ function parse_results(
end
return data
end

function read_results(path)
return CSV.read(path, DataFrame)
end
2 changes: 2 additions & 0 deletions libs/HMMComparison/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7"
HMMBenchmark = "557005d5-2e4a-43f9-8aa7-ba8df2d03179"
HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
20 changes: 8 additions & 12 deletions libs/HMMComparison/src/HMMComparison.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,24 @@ module HMMComparison
using BenchmarkTools: BenchmarkGroup, @benchmarkable
using CondaPkg: CondaPkg
using Distributions: Normal, MvNormal
using HiddenMarkovModels: HiddenMarkovModels
using HMMBase: HMMBase
using HMMBenchmark:
HMMBenchmark,
Instance,
Implementation,
HiddenMarkovModelsImplem,
build_params,
build_model,
build_benchmarkables
using LinearAlgebra: Diagonal
using PythonCall: Py, pyimport, pybuiltins, pylist
using LogExpFunctions: logsumexp
using PythonCall: Py, PyArray, pyimport, pyconvert, pybuiltins, pylist
using Random: AbstractRNG
using Reexport: @reexport
using SparseArrays: spdiagm
@reexport using HMMBenchmark

export HiddenMarkovModelsImplem,
HMMBaseImplem, hmmlearnImplem, pomegranateImplem, dynamaxImplem
export build_model, build_benchmarkables
export HMMBaseImplem, hmmlearnImplem, pomegranateImplem, dynamaxImplem
export print_python_setup, compare_loglikelihoods

include("hmmbase.jl")
include("hmmlearn.jl")
include("pomegranate.jl")
include("dynamax.jl")
include("setup.jl")
include("correctness.jl")

end # module HMMComparison
78 changes: 78 additions & 0 deletions libs/HMMComparison/src/correctness.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
function compare_loglikelihoods(
instance::Instance, params::Params, data::AbstractArray{<:Real,3}
)
torch = pyimport("torch")
jax = pyimport("jax")
jnp = pyimport("jax.numpy")

torch.set_default_dtype(torch.float64)
jax.config.update("jax_enable_x64", true)

(; obs_dim, seq_length, nb_seqs) = instance

results = Dict{String,Any}()

## Data formats

if obs_dim == 1
obs_seqs = [[data[k, t, 1] for t in 1:seq_length] for k in 1:nb_seqs]
else
obs_seqs = [[data[k, t, :] for t in 1:seq_length] for k in 1:nb_seqs]
end
obs_seq = reduce(vcat, obs_seqs)
control_seq = fill(nothing, length(obs_seq))
seq_ends = cumsum(length.(obs_seqs))

if obs_dim == 1
obs_mat = reduce(vcat, data[k, :, 1] for k in 1:nb_seqs)
else
obs_mat = reduce(vcat, data[k, :, :] for k in 1:nb_seqs)
end

obs_mat_concat = reduce(vcat, data[k, :, :] for k in 1:nb_seqs)
obs_mat_concat_py = Py(obs_mat_concat).to_numpy()
obs_mat_len_py = Py(fill(seq_length, nb_seqs)).to_numpy()

obs_tens_torch_py = torch.tensor(Py(data).to_numpy())
obs_tens_jax_py = jnp.array(Py(data).to_numpy())

# HiddenMarkovModels.jl

implem1 = HiddenMarkovModelsImplem()
hmm1 = build_model(implem1, instance, params)
_, logLs1 = HiddenMarkovModels.forward(hmm1, obs_seq, control_seq; seq_ends)
results[string(implem1)] = logLs1

## HMMBase.jl

implem2 = HMMBaseImplem()
hmm2 = build_model(implem2, instance, params)
_, logL2 = HMMBase.forward(hmm2, obs_mat)
results[string(implem2)] = logL2

## hmmlearn

implem3 = hmmlearnImplem()
hmm3 = build_model(implem3, instance, params)
logL3 = hmm3.score(obs_mat_concat_py, obs_mat_len_py)
results[string(implem3)] = pyconvert(Number, logL3)

## pomegranate

implem4 = pomegranateImplem()
hmm4 = build_model(implem4, instance, params)
logαs4 = PyArray(hmm4.forward(obs_tens_torch_py))
logLs4 = [logsumexp(logαs4[k, end, :]) for k in 1:nb_seqs]
results[string(implem4)] = logLs4

## dynamax

implem5 = dynamaxImplem()
hmm5, dyn_params5 = build_model(implem5, instance, params)
filter_vmap = jax.jit(jax.vmap(hmm5.filter; in_axes=pylist((pybuiltins.None, 0))))
posterior5 = filter_vmap(dyn_params5, obs_tens_jax_py)
logLs5 = PyArray(posterior5.marginal_loglik)
results[string(implem5)] = logLs5

return results
end
Loading

0 comments on commit 0b7f2ac

Please sign in to comment.