Skip to content

Commit

Permalink
Fairer Python benchmarks
Browse files Browse the repository at this point in the history
* Fairer Python benchmarks
  • Loading branch information
gdalle authored Feb 3, 2024
1 parent 8bf278a commit 1ed97cf
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 34 deletions.
2 changes: 1 addition & 1 deletion libs/HMMComparison/src/HMMComparison.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using HMMBenchmark:
build_model,
build_benchmarkables
using LinearAlgebra: Diagonal
using PythonCall: pyimport, pybuiltins
using PythonCall: Py, pyimport, pybuiltins, pylist
using Random: AbstractRNG
using SparseArrays: spdiagm

Expand Down
39 changes: 20 additions & 19 deletions libs/HMMComparison/src/dynamax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ function HMMBenchmark.build_model(
(; nb_states, obs_dim) = instance
(; init, trans, means, stds) = build_params(rng; instance)

initial_probs = jnp.array(np.array(init))
transition_matrix = jnp.array(np.array(trans))
emission_means = jnp.array(np.array(transpose(means)))
emission_scale_diags = jnp.array(np.array(transpose(stds)))
initial_probs = jnp.array(Py(init).to_numpy())
transition_matrix = jnp.array(Py(trans).to_numpy())
emission_means = jnp.array(Py(transpose(means)).to_numpy())
emission_scale_diags = jnp.array(Py(transpose(stds)).to_numpy())

hmm = dynamax_hmm.DiagonalGaussianHMM(nb_states, obs_dim)
params, props = hmm.initialize(;
Expand All @@ -37,51 +37,52 @@ function HMMBenchmark.build_benchmarkables(
hmm, params, _ = build_model(rng, implem; instance)
data = randn(rng, nb_seqs, seq_length, obs_dim)

obs_tens_py = jnp.array(np.array(data))
obs_tens_py = jnp.array(Py(data).to_numpy())

benchs = Dict()

if "logdensity" in algos
filter_vmap = jax.vmap(hmm.filter; in_axes=pylist([pybuiltins.None, 0]))
filter_vmap = jax.jit(jax.vmap(hmm.filter; in_axes=pylist((pybuiltins.None, 0))))
benchs["logdensity"] = @benchmarkable begin
$(filter_vmap)($params, $obs_tens_py)
end evals = 1 samples = 100 setup = ($(filter_vmap)($params, $obs_tens_py))
end evals = 1 samples = 100
end

if "forward" in algos
filter_vmap = jax.vmap(hmm.filter; in_axes=pylist([pybuiltins.None, 0]))
filter_vmap = jax.jit(jax.vmap(hmm.filter; in_axes=pylist((pybuiltins.None, 0))))
benchs["forward"] = @benchmarkable begin
$(filter_vmap)($params, $obs_tens_py)
end evals = 1 samples = 100 setup = ($(filter_vmap)($params, $obs_tens_py))
end evals = 1 samples = 100
end

if "viterbi" in algos
most_likely_states_vmap = jax.vmap(
hmm.most_likely_states; in_axes=pylist([pybuiltins.None, 0])
most_likely_states_vmap = jax.jit(
jax.vmap(hmm.most_likely_states; in_axes=pylist((pybuiltins.None, 0)))
)
benchs["viterbi"] = @benchmarkable begin
$(most_likely_states_vmap)($params, $obs_tens_py)
end evals = 1 samples = 100 setup = ($(most_likely_states_vmap)(
$params, $obs_tens_py
))
end evals = 1 samples = 100
end

if "forward_backward" in algos
smoother_vmap = jax.vmap(hmm.smoother; in_axes=pylist([pybuiltins.None, 0]))
smoother_vmap = jax.jit(
jax.vmap(hmm.smoother; in_axes=pylist((pybuiltins.None, 0)))
)
benchs["forward_backward"] = @benchmarkable begin
$(smoother_vmap)($params, $obs_tens_py)
end evals = 1 samples = 100 setup = ($(smoother_vmap)($params, $obs_tens_py))
end evals = 1 samples = 100
end

if "baum_welch" in algos
benchs["baum_welch"] = @benchmarkable begin
hmm_guess.fit_em(params_guess, props_guess, $obs_tens_py; num_iters=$bw_iter)
hmm_guess.fit_em(
params_guess, props_guess, $obs_tens_py; num_iters=$bw_iter, verbose=false
)
end evals = 1 samples = 100 setup = (
tup = build_model($rng, $implem; instance=$instance);
hmm_guess = tup[1];
params_guess = tup[2];
props_guess = tup[3];
hmm_guess.fit_em(params_guess, props_guess, $obs_tens_py; num_iters=$bw_iter)
props_guess = tup[3]
)
end

Expand Down
10 changes: 5 additions & 5 deletions libs/HMMComparison/src/hmmlearn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ function HMMBenchmark.build_model(
init_params="",
)

hmm.startprob_ = np.array(init)
hmm.transmat_ = np.array(trans)
hmm.means_ = np.array(transpose(means))
hmm.covars_ = np.array(transpose(stds .^ 2))
hmm.startprob_ = Py(init).to_numpy()
hmm.transmat_ = Py(trans).to_numpy()
hmm.means_ = Py(transpose(means)).to_numpy()
hmm.covars_ = Py(transpose(stds .^ 2)).to_numpy()
return hmm
end

Expand All @@ -35,7 +35,7 @@ function HMMBenchmark.build_benchmarkables(
data = randn(rng, nb_seqs, seq_length, obs_dim)

obs_mat_concat = reduce(vcat, data[k, :, :] for k in 1:nb_seqs)
obs_mat_concat_py = np.array(obs_mat_concat)
obs_mat_concat_py = Py(obs_mat_concat).to_numpy()
obs_mat_len_py = np.full(nb_seqs, seq_length)

benchs = Dict()
Expand Down
11 changes: 5 additions & 6 deletions libs/HMMComparison/src/pomegranate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ struct pomegranateImplem <: Implementation end
function HMMBenchmark.build_model(
rng::AbstractRNG, implem::pomegranateImplem; instance::Instance
)
np = pyimport("numpy")
torch = pyimport("torch")
torch.set_default_dtype(torch.float64)
pomegranate_distributions = pyimport("pomegranate.distributions")
Expand All @@ -12,14 +11,14 @@ function HMMBenchmark.build_model(
(; nb_states, bw_iter) = instance
(; init, trans, means, stds) = build_params(rng; instance)

starts = torch.tensor(np.array(init))
starts = torch.tensor(Py(init).to_numpy())
ends = torch.ones(nb_states) * 1e-10
edges = torch.tensor(np.array(trans))
edges = torch.tensor(Py(trans).to_numpy())

distributions = pylist([
pomegranate_distributions.Normal(;
means=torch.tensor(np.array(means[:, i])),
covs=torch.square(torch.tensor(np.array(stds[:, i] .^ 2))),
means=torch.tensor(Py(means[:, i]).to_numpy()),
covs=torch.square(torch.tensor(Py(stds[:, i] .^ 2).to_numpy())),
covariance_type="diag",
) for i in 1:nb_states
])
Expand Down Expand Up @@ -48,7 +47,7 @@ function HMMBenchmark.build_benchmarkables(
hmm = build_model(rng, implem; instance)
data = randn(rng, nb_seqs, seq_length, obs_dim)

obs_tens_py = torch.tensor(np.array(data))
obs_tens_py = torch.tensor(Py(data).to_numpy())

benchs = Dict()

Expand Down
8 changes: 5 additions & 3 deletions libs/HMMComparison/test/comparison.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
using BenchmarkTools
using HMMComparison
using HMMBenchmark
using LinearAlgebra
using StableRNGs

BLAS.set_num_threads(1)

rng = StableRNG(63)

Expand All @@ -13,9 +17,7 @@ implems = [
]
algos = ["logdensity", "forward", "viterbi", "forward_backward", "baum_welch"]
instances = [
Instance(;
sparse=false, nb_states=5, obs_dim=10, seq_length=100, nb_seqs=50, bw_iter=10
),
Instance(; sparse=false, nb_states=5, obs_dim=10, seq_length=1000, nb_seqs=10, bw_iter=10)
]

SUITE = define_suite(rng, implems; instances, algos)
Expand Down

0 comments on commit 1ed97cf

Please sign in to comment.