From 1ed97cf25a803020641cdee7f31a1440d1c57332 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 3 Feb 2024 13:15:11 +0100 Subject: [PATCH] Fairer Python benchmarks * Fairer Python benchmarks --- libs/HMMComparison/src/HMMComparison.jl | 2 +- libs/HMMComparison/src/dynamax.jl | 39 +++++++++++++------------ libs/HMMComparison/src/hmmlearn.jl | 10 +++---- libs/HMMComparison/src/pomegranate.jl | 11 ++++--- libs/HMMComparison/test/comparison.jl | 8 +++-- 5 files changed, 36 insertions(+), 34 deletions(-) diff --git a/libs/HMMComparison/src/HMMComparison.jl b/libs/HMMComparison/src/HMMComparison.jl index 5f24837f..e8da5901 100644 --- a/libs/HMMComparison/src/HMMComparison.jl +++ b/libs/HMMComparison/src/HMMComparison.jl @@ -13,7 +13,7 @@ using HMMBenchmark: build_model, build_benchmarkables using LinearAlgebra: Diagonal -using PythonCall: pyimport, pybuiltins +using PythonCall: Py, pyimport, pybuiltins, pylist using Random: AbstractRNG using SparseArrays: spdiagm diff --git a/libs/HMMComparison/src/dynamax.jl b/libs/HMMComparison/src/dynamax.jl index daffd7ca..946811bf 100644 --- a/libs/HMMComparison/src/dynamax.jl +++ b/libs/HMMComparison/src/dynamax.jl @@ -10,10 +10,10 @@ function HMMBenchmark.build_model( (; nb_states, obs_dim) = instance (; init, trans, means, stds) = build_params(rng; instance) - initial_probs = jnp.array(np.array(init)) - transition_matrix = jnp.array(np.array(trans)) - emission_means = jnp.array(np.array(transpose(means))) - emission_scale_diags = jnp.array(np.array(transpose(stds))) + initial_probs = jnp.array(Py(init).to_numpy()) + transition_matrix = jnp.array(Py(trans).to_numpy()) + emission_means = jnp.array(Py(transpose(means)).to_numpy()) + emission_scale_diags = jnp.array(Py(transpose(stds)).to_numpy()) hmm = dynamax_hmm.DiagonalGaussianHMM(nb_states, obs_dim) params, props = hmm.initialize(; @@ -37,51 +37,52 @@ function HMMBenchmark.build_benchmarkables( hmm, params, _ = build_model(rng, implem; instance) data = randn(rng, nb_seqs, seq_length, obs_dim) - obs_tens_py = jnp.array(np.array(data)) + obs_tens_py = jnp.array(Py(data).to_numpy()) benchs = Dict() if "logdensity" in algos - filter_vmap = jax.vmap(hmm.filter; in_axes=pylist([pybuiltins.None, 0])) + filter_vmap = jax.jit(jax.vmap(hmm.filter; in_axes=pylist((pybuiltins.None, 0)))) benchs["logdensity"] = @benchmarkable begin $(filter_vmap)($params, $obs_tens_py) - end evals = 1 samples = 100 setup = ($(filter_vmap)($params, $obs_tens_py)) + end evals = 1 samples = 100 end if "forward" in algos - filter_vmap = jax.vmap(hmm.filter; in_axes=pylist([pybuiltins.None, 0])) + filter_vmap = jax.jit(jax.vmap(hmm.filter; in_axes=pylist((pybuiltins.None, 0)))) benchs["forward"] = @benchmarkable begin $(filter_vmap)($params, $obs_tens_py) - end evals = 1 samples = 100 setup = ($(filter_vmap)($params, $obs_tens_py)) + end evals = 1 samples = 100 end if "viterbi" in algos - most_likely_states_vmap = jax.vmap( - hmm.most_likely_states; in_axes=pylist([pybuiltins.None, 0]) + most_likely_states_vmap = jax.jit( + jax.vmap(hmm.most_likely_states; in_axes=pylist((pybuiltins.None, 0))) ) benchs["viterbi"] = @benchmarkable begin $(most_likely_states_vmap)($params, $obs_tens_py) - end evals = 1 samples = 100 setup = ($(most_likely_states_vmap)( - $params, $obs_tens_py - )) + end evals = 1 samples = 100 end if "forward_backward" in algos - smoother_vmap = jax.vmap(hmm.smoother; in_axes=pylist([pybuiltins.None, 0])) + smoother_vmap = jax.jit( + jax.vmap(hmm.smoother; in_axes=pylist((pybuiltins.None, 0))) + ) benchs["forward_backward"] = @benchmarkable begin $(smoother_vmap)($params, $obs_tens_py) - end evals = 1 samples = 100 setup = ($(smoother_vmap)($params, $obs_tens_py)) + end evals = 1 samples = 100 end if "baum_welch" in algos benchs["baum_welch"] = @benchmarkable begin - hmm_guess.fit_em(params_guess, props_guess, $obs_tens_py; num_iters=$bw_iter) + hmm_guess.fit_em( + params_guess, props_guess, $obs_tens_py; num_iters=$bw_iter, verbose=false + ) end evals = 1 samples = 100 setup = ( tup = build_model($rng, $implem; instance=$instance); hmm_guess = tup[1]; params_guess = tup[2]; - props_guess = tup[3]; - hmm_guess.fit_em(params_guess, props_guess, $obs_tens_py; num_iters=$bw_iter) + props_guess = tup[3] ) end diff --git a/libs/HMMComparison/src/hmmlearn.jl b/libs/HMMComparison/src/hmmlearn.jl index e2f3fd91..1b3aeecf 100644 --- a/libs/HMMComparison/src/hmmlearn.jl +++ b/libs/HMMComparison/src/hmmlearn.jl @@ -18,10 +18,10 @@ function HMMBenchmark.build_model( init_params="", ) - hmm.startprob_ = np.array(init) - hmm.transmat_ = np.array(trans) - hmm.means_ = np.array(transpose(means)) - hmm.covars_ = np.array(transpose(stds .^ 2)) + hmm.startprob_ = Py(init).to_numpy() + hmm.transmat_ = Py(trans).to_numpy() + hmm.means_ = Py(transpose(means)).to_numpy() + hmm.covars_ = Py(transpose(stds .^ 2)).to_numpy() return hmm end @@ -35,7 +35,7 @@ function HMMBenchmark.build_benchmarkables( data = randn(rng, nb_seqs, seq_length, obs_dim) obs_mat_concat = reduce(vcat, data[k, :, :] for k in 1:nb_seqs) - obs_mat_concat_py = np.array(obs_mat_concat) + obs_mat_concat_py = Py(obs_mat_concat).to_numpy() obs_mat_len_py = np.full(nb_seqs, seq_length) benchs = Dict() diff --git a/libs/HMMComparison/src/pomegranate.jl b/libs/HMMComparison/src/pomegranate.jl index c21bc393..6f984746 100644 --- a/libs/HMMComparison/src/pomegranate.jl +++ b/libs/HMMComparison/src/pomegranate.jl @@ -3,7 +3,6 @@ struct pomegranateImplem <: Implementation end function HMMBenchmark.build_model( rng::AbstractRNG, implem::pomegranateImplem; instance::Instance ) - np = pyimport("numpy") torch = pyimport("torch") torch.set_default_dtype(torch.float64) pomegranate_distributions = pyimport("pomegranate.distributions") @@ -12,14 +11,14 @@ function HMMBenchmark.build_model( (; nb_states, bw_iter) = instance (; init, trans, means, stds) = build_params(rng; instance) - starts = torch.tensor(np.array(init)) + starts = torch.tensor(Py(init).to_numpy()) ends = torch.ones(nb_states) * 1e-10 - edges = torch.tensor(np.array(trans)) + edges = torch.tensor(Py(trans).to_numpy()) distributions = pylist([ pomegranate_distributions.Normal(; - means=torch.tensor(np.array(means[:, i])), - covs=torch.square(torch.tensor(np.array(stds[:, i] .^ 2))), + means=torch.tensor(Py(means[:, i]).to_numpy()), + covs=torch.square(torch.tensor(Py(stds[:, i] .^ 2).to_numpy())), covariance_type="diag", ) for i in 1:nb_states ]) @@ -48,7 +47,7 @@ function HMMBenchmark.build_benchmarkables( hmm = build_model(rng, implem; instance) data = randn(rng, nb_seqs, seq_length, obs_dim) - obs_tens_py = torch.tensor(np.array(data)) + obs_tens_py = torch.tensor(Py(data).to_numpy()) benchs = Dict() diff --git a/libs/HMMComparison/test/comparison.jl b/libs/HMMComparison/test/comparison.jl index 0e373468..f6e0e1f9 100644 --- a/libs/HMMComparison/test/comparison.jl +++ b/libs/HMMComparison/test/comparison.jl @@ -1,6 +1,10 @@ using BenchmarkTools using HMMComparison using HMMBenchmark +using LinearAlgebra +using StableRNGs + +BLAS.set_num_threads(1) rng = StableRNG(63) @@ -13,9 +17,7 @@ implems = [ ] algos = ["logdensity", "forward", "viterbi", "forward_backward", "baum_welch"] instances = [ - Instance(; - sparse=false, nb_states=5, obs_dim=10, seq_length=100, nb_seqs=50, bw_iter=10 - ), + Instance(; sparse=false, nb_states=5, obs_dim=10, seq_length=1000, nb_seqs=10, bw_iter=10) ] SUITE = define_suite(rng, implems; instances, algos)