Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix benchmarks #90

Merged
merged 6 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions .github/workflows/draft-pdf.yml

This file was deleted.

3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,4 @@ scratchpad.jl
/docs/src/index.md
/docs/src/examples/*.md

*.pdf
*.png
*.pdf
35 changes: 20 additions & 15 deletions benchmark/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.10.0"
julia_version = "1.10.1"
manifest_format = "2.0"
project_hash = "e05ed926575e94b72904ad898b09f017dc14d96a"

[[deps.ArgCheck]]
git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4"
uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197"
version = "2.3.0"

[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
version = "1.1.1"
Expand Down Expand Up @@ -34,9 +39,9 @@ version = "0.5.1"

[[deps.ChainRulesCore]]
deps = ["Compat", "LinearAlgebra"]
git-tree-sha1 = "1287e3872d646eed95198457873249bd9f0caed2"
git-tree-sha1 = "892b245fdec1c511906671b6a5e1bafa38a727c1"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.20.1"
version = "1.22.0"
weakdeps = ["SparseArrays"]

[deps.ChainRulesCore.extensions]
Expand All @@ -50,9 +55,9 @@ version = "0.7.4"

[[deps.Compat]]
deps = ["TOML", "UUIDs"]
git-tree-sha1 = "75bd5b6fc5089df449b5d35fa501c846c9b6549b"
git-tree-sha1 = "d2c021fbdde94f6cdaa799639adfeeaa17fd67f5"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "4.12.0"
version = "4.13.0"
weakdeps = ["Dates", "LinearAlgebra"]

[deps.Compat.extensions]
Expand All @@ -61,7 +66,7 @@ weakdeps = ["Dates", "LinearAlgebra"]
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.0.5+1"
version = "1.1.0+0"

[[deps.Crayons]]
git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
Expand All @@ -81,9 +86,9 @@ version = "1.6.1"

[[deps.DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "ac67408d9ddf207de5cfa9a97e114352430f01ed"
git-tree-sha1 = "1fb174f0d48fe7d142e1109a10636bc1d14f5ac2"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.16"
version = "0.18.17"

[[deps.DataValueInterfaces]]
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
Expand Down Expand Up @@ -155,13 +160,13 @@ deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"

[[deps.HMMBenchmark]]
deps = ["BenchmarkTools", "CSV", "DataFrames", "Distributions", "HiddenMarkovModels", "LinearAlgebra", "Pkg", "Random", "SparseArrays", "StableRNGs", "Statistics"]
deps = ["BenchmarkTools", "CSV", "DataFrames", "Distributions", "HiddenMarkovModels", "InteractiveUtils", "LinearAlgebra", "Pkg", "Random", "SparseArrays", "StableRNGs", "Statistics"]
path = "../libs/HMMBenchmark"
uuid = "557005d5-2e4a-43f9-8aa7-ba8df2d03179"
version = "0.1.0"

[[deps.HiddenMarkovModels]]
deps = ["ChainRulesCore", "DensityInterface", "DocStringExtensions", "FillArrays", "LinearAlgebra", "PrecompileTools", "Random", "SparseArrays", "StatsAPI"]
deps = ["ArgCheck", "ChainRulesCore", "DensityInterface", "DocStringExtensions", "FillArrays", "LinearAlgebra", "PrecompileTools", "Random", "SparseArrays", "StatsAPI", "StatsFuns"]
path = ".."
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
version = "0.4.0"
Expand Down Expand Up @@ -257,9 +262,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[deps.LogExpFunctions]]
deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa"
git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.26"
version = "0.3.27"

[deps.LogExpFunctions.extensions]
LogExpFunctionsChainRulesCoreExt = "ChainRulesCore"
Expand Down Expand Up @@ -309,7 +314,7 @@ version = "1.2.0"
[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.23+2"
version = "0.3.23+4"

[[deps.OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
Expand Down Expand Up @@ -469,9 +474,9 @@ version = "0.34.2"

[[deps.StatsFuns]]
deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
git-tree-sha1 = "f625d686d5a88bcd2b15cd81f18f98186fdc0c9a"
git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "1.3.0"
version = "1.3.1"
weakdeps = ["ChainRulesCore", "InverseFunctions"]

[deps.StatsFuns.extensions]
Expand Down
1 change: 1 addition & 0 deletions libs/HMMBenchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
4 changes: 3 additions & 1 deletion libs/HMMBenchmark/src/HMMBenchmark.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module HMMBenchmark

using Base.Threads
using Base.Threads: Threads
using InteractiveUtils: InteractiveUtils
using BenchmarkTools: @benchmarkable, BenchmarkGroup
using CSV: CSV
using DataFrames: DataFrame
Expand Down Expand Up @@ -33,5 +34,6 @@ include("instance.jl")
include("params.jl")
include("hiddenmarkovmodels.jl")
include("suite.jl")
include("setup.jl")

end
32 changes: 20 additions & 12 deletions libs/HMMBenchmark/src/hiddenmarkovmodels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ function build_model(::HiddenMarkovModelsImplem, instance::Instance, params::Par
(; custom_dist, nb_states, obs_dim) = instance
(; init, trans, means, stds) = params

if custom_dist
dists = [LightDiagNormal(means[:, i], stds[:, i]) for i in 1:nb_states]
if obs_dim == 1
dists = [Normal(means[1, i], stds[1, i]) for i in 1:nb_states]
else
if obs_dim == 1
dists = [Normal(means[1, i], stds[1, i]) for i in 1:nb_states]
if custom_dist
dists = [LightDiagNormal(means[:, i], stds[:, i]) for i in 1:nb_states]
else
dists = [MvNormal(means[:, i], Diagonal(stds[:, i])) for i in 1:nb_states]
end
Expand Down Expand Up @@ -43,32 +43,38 @@ function build_benchmarkables(
if "forward" in algos
benchs["forward"] = @benchmarkable begin
forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100
end evals = 1 samples = 10
end
if "forward!" in algos
benchs["forward!"] = @benchmarkable begin
forward!(f_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100 setup = (
end evals = 1 samples = 10 setup = (
f_storage = initialize_forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
)
end

if "viterbi" in algos
benchs["viterbi"] = @benchmarkable begin
viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100
end evals = 1 samples = 10
end
if "viterbi!" in algos
benchs["viterbi!"] = @benchmarkable begin
viterbi!(v_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100 setup = (
end evals = 1 samples = 10 setup = (
v_storage = initialize_viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
)
end

if "forward_backward" in algos
benchs["forward_backward"] = @benchmarkable begin
forward_backward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100
end evals = 1 samples = 10
end
if "forward_backward!" in algos
benchs["forward_backward!"] = @benchmarkable begin
forward_backward!(fb_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100 setup = (
end evals = 1 samples = 10 setup = (
fb_storage = initialize_forward_backward(
$hmm, $obs_seq, $control_seq; seq_ends=$seq_ends
)
Expand All @@ -86,7 +92,9 @@ function build_benchmarkables(
atol=-Inf,
loglikelihood_increasing=false,
)
end evals = 1 samples = 100
end evals = 1 samples = 10
end
if "baum_welch!" in algos
benchs["baum_welch!"] = @benchmarkable begin
baum_welch!(
fb_storage,
Expand All @@ -99,7 +107,7 @@ function build_benchmarkables(
atol=-Inf,
loglikelihood_increasing=false,
)
end evals = 1 samples = 100 setup = (
end evals = 1 samples = 10 setup = (
hmm_guess = build_model($implem, $instance, $params);
fb_storage = initialize_forward_backward(
hmm_guess, $obs_seq, $control_seq; seq_ends=$seq_ends
Expand Down
4 changes: 2 additions & 2 deletions libs/HMMBenchmark/src/setup.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function print_julia_setup(; path)
function print_julia_setup(path)
open(path, "w") do file
redirect_stdout(file) do
versioninfo()
InteractiveUtils.versioninfo()
println("\n# Multithreading\n")
println("Julia threads = $(Threads.nthreads())")
println("OpenBLAS threads = $(BLAS.get_num_threads())")
Expand Down
62 changes: 62 additions & 0 deletions libs/HMMComparison/experiments/performance.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using Pkg
Pkg.activate(joinpath(@__DIR__, ".."))
Pkg.develop(; path=joinpath(@__DIR__, "..", "..", ".."))
Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "HMMBenchmark"))

@assert Base.Threads.nthreads() == 1

# see https://superfastpython.com/numpy-number-blas-threads/
ENV["MKL_NUM_THREADS"] = 1
ENV["NUMEXPR_NUM_THREADS"] = 1
ENV["OMP_NUM_THREADS"] = 1
ENV["OPENBLAS_NUM_THREADS"] = 1
ENV["VECLIB_MAXIMUM_THREADS"] = 1

# see https://github.com/google/jax/issues/743
ENV["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"

using BenchmarkTools
using LinearAlgebra
using PythonCall # Python process starts now
using StableRNGs
using HMMComparison

# see https://pytorch.org/docs/stable/generated/torch.set_num_threads.html
pyimport("torch").set_num_threads(1)

rng = StableRNG(63)

print_julia_setup(joinpath(@__DIR__, "results", "julia_setup.txt"))
print_python_setup(joinpath(@__DIR__, "results", "python_setup.txt"))

implems = [
HiddenMarkovModelsImplem(), #
HMMBaseImplem(), #
hmmlearnImplem(), #
pomegranateImplem(), #
dynamaxImplem(), #
]

algos = ["forward", "viterbi", "forward_backward", "baum_welch"]

instances = Instance[]

for nb_states in 2:2:16
push!(
instances,
Instance(;
custom_dist=true,
sparse=false,
nb_states=nb_states,
obs_dim=1,
seq_length=200,
nb_seqs=100,
bw_iter=5,
),
)
end

SUITE = define_suite(rng, implems; instances, algos)

results = BenchmarkTools.run(SUITE; verbose=true)
data = parse_results(results; path=joinpath(@__DIR__, "results", "results.csv"))
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using DataFrames
using Plots
using HMMComparison

data = read_results(joinpath(@__DIR__, "results.csv"))
data = read_results(joinpath(@__DIR__, "results", "results.csv"))

sort!(data, [:algo, :implem, :nb_states])

Expand All @@ -13,7 +13,7 @@ implems = [
"pomegranate", #
"dynamax", #
]
algos = ["forward", "baum_welch"]
algos = ["viterbi", "forward", "forward_backward", "baum_welch"]

markershapes = [:star5, :circle, :diamond, :hexagon, :pentagon, :utriangle]

Expand All @@ -24,6 +24,7 @@ for algo in algos
yscale=:log,
xlabel="nb states",
ylabel="runtime (s)",
xticks=unique(data[!, :nb_states]),
legend=:outerright,
margin=5Plots.mm,
)
Expand All @@ -33,10 +34,6 @@ for algo in algos
pl,
subdata[!, :nb_states],
subdata[!, :time_median] ./ 1e9;
yerror=(
(subdata[!, :time_median] .- subdata[!, :time_quantile25]) ./ 1e9,
(subdata[!, :time_quantile75] .- subdata[!, :time_median]) ./ 1e9,
),
label=implem,
markershape=markershapes[i],
markerstrokecolor=:auto,
Expand All @@ -46,5 +43,5 @@ for algo in algos
)
end
display(pl)
savefig(pl, joinpath(@__DIR__, "$(algo).png"))
savefig(pl, joinpath(@__DIR__, "results", "$(algo).pdf"))
end
Empty file.
Loading
Loading