From c7134e1b9a80e166d2d3231c48db121feac5a6ef Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 2 Feb 2024 18:30:48 +0100 Subject: [PATCH] Clean up benchmarks (#77) * Rephrase benchmarks with dispatch * Dynamax benchmarks * Update Manifest --- .gitignore | 6 +- benchmark/Manifest.toml | 52 ++++++----- benchmark/Project.toml | 1 + benchmark/benchmarks.jl | 26 +++--- benchmark/run.jl | 2 +- libs/HMMBenchmark/Project.toml | 2 + libs/HMMBenchmark/src/HMMBenchmark.jl | 15 +++- .../src/{algos.jl => hiddenmarkovmodels.jl} | 55 ++++++------ .../src/{configuration.jl => instance.jl} | 12 +-- libs/HMMBenchmark/src/params.jl | 42 +++++++++ libs/HMMBenchmark/src/setup.jl | 12 +++ libs/HMMBenchmark/src/suite.jl | 62 ++++++------- libs/HMMComparison/Project.toml | 5 ++ libs/HMMComparison/src/HMMComparison.jl | 69 ++++---------- libs/HMMComparison/src/dynamax.jl | 89 ++++++++++++++++++- libs/HMMComparison/src/hmmbase.jl | 49 ++++++---- libs/HMMComparison/src/hmmlearn.jl | 66 ++++++++------ libs/HMMComparison/src/pomegranate.jl | 67 +++++++++----- libs/HMMComparison/src/setup.jl | 12 +++ libs/HMMComparison/test/comparison.jl | 26 +++--- libs/HMMComparison/test/plots.jl | 11 +++ 21 files changed, 448 insertions(+), 233 deletions(-) rename libs/HMMBenchmark/src/{algos.jl => hiddenmarkovmodels.jl} (71%) rename libs/HMMBenchmark/src/{configuration.jl => instance.jl} (54%) create mode 100644 libs/HMMBenchmark/src/params.jl create mode 100644 libs/HMMBenchmark/src/setup.jl create mode 100644 libs/HMMComparison/src/setup.jl create mode 100644 libs/HMMComparison/test/plots.jl diff --git a/.gitignore b/.gitignore index 9624171c..6156fbfb 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ scratchpad.jl **/.CondaPkg **/__pycache__ *.ipynb +*.ipynb_checkpoints *.csv *.txt @@ -20,7 +21,4 @@ scratchpad.jl /benchmark/*.json /docs/src/index.md -/docs/src/examples/*.md - -.vscode/ -*.ipynb \ No newline at end of file +/docs/src/examples/*.md \ No newline at end of file diff --git a/benchmark/Manifest.toml b/benchmark/Manifest.toml index d17f442a..ca459dfe 100644 --- a/benchmark/Manifest.toml +++ b/benchmark/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.0" manifest_format = "2.0" -project_hash = "a1b4318401476bf26277ef3565ca4043fa58d314" +project_hash = "e05ed926575e94b72904ad898b09f017dc14d96a" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" @@ -22,9 +22,9 @@ version = "1.4.0" [[deps.CSV]] deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] -git-tree-sha1 = "44dbf560808d49041989b8a96cae4cffbeb7966a" +git-tree-sha1 = "679e69c611fff422038e9e21e270c4197d49d918" uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -version = "0.10.11" +version = "0.10.12" [[deps.Calculus]] deps = ["LinearAlgebra"] @@ -34,9 +34,9 @@ version = "0.5.1" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "2118cb2765f8197b08e5958cdd17c165427425ee" +git-tree-sha1 = "1287e3872d646eed95198457873249bd9f0caed2" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.19.0" +version = "1.20.1" weakdeps = ["SparseArrays"] [deps.ChainRulesCore.extensions] @@ -44,15 +44,15 @@ weakdeps = ["SparseArrays"] [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "cd67fc487743b2f0fd4380d4cbd3a24660d0eec8" +git-tree-sha1 = "59939d8a997469ee05c4b4944560a820f9ba0d73" uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.3" +version = "0.7.4" [[deps.Compat]] -deps = ["UUIDs"] -git-tree-sha1 = "886826d76ea9e72b35fcd000e535588f7b60f21d" +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "75bd5b6fc5089df449b5d35fa501c846c9b6549b" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.10.1" +version = "4.12.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -69,9 +69,9 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" version = "4.1.1" [[deps.DataAPI]] -git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.15.0" +version = "1.16.0" [[deps.DataFrames]] deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] @@ -81,9 +81,9 @@ version = "1.6.1" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d" +git-tree-sha1 = "ac67408d9ddf207de5cfa9a97e114352430f01ed" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.15" +version = "0.18.16" [[deps.DataValueInterfaces]] git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" @@ -102,9 +102,9 @@ version = "0.4.0" [[deps.Distributions]] deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] -git-tree-sha1 = "9242eec9b7e2e14f9952e8ea1c7e31a50501d587" +git-tree-sha1 = "7c302d7a5fec5214eb8a5a4c466dcf7a51fcf169" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.104" +version = "0.25.107" weakdeps = ["ChainRulesCore", "DensityInterface", "Test"] [deps.Distributions.extensions] @@ -155,7 +155,7 @@ deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[deps.HMMBenchmark]] -deps = ["BenchmarkTools", "CSV", "DataFrames", "Distributions", "HiddenMarkovModels", "LinearAlgebra", "Pkg", "Random", "SparseArrays"] +deps = ["BenchmarkTools", "CSV", "DataFrames", "Distributions", "HiddenMarkovModels", "LinearAlgebra", "Pkg", "Random", "SparseArrays", "StableRNGs", "Statistics"] path = "../libs/HMMBenchmark" uuid = "557005d5-2e4a-43f9-8aa7-ba8df2d03179" version = "0.1.0" @@ -378,9 +378,9 @@ uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" [[deps.QuadGK]] deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "9ebcd48c498668c7fa0e97a9cae873fbee7bfee1" +git-tree-sha1 = "9b23c31e76e333e6fb4c1595ae6afa74966a729e" uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.9.1" +version = "2.9.4" [[deps.REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] @@ -425,9 +425,9 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc" [[deps.SortingAlgorithms]] deps = ["DataStructures"] -git-tree-sha1 = "5165dfb9fd131cf0c6957a3a7605dede376e7b63" +git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.2.0" +version = "1.2.1" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] @@ -444,6 +444,12 @@ weakdeps = ["ChainRulesCore"] [deps.SpecialFunctions.extensions] SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" +[[deps.StableRNGs]] +deps = ["Random", "Test"] +git-tree-sha1 = "ddc1a7b85e760b5285b50b882fa91e40c603be47" +uuid = "860ef19b-820b-49d6-a774-d7a799459cd3" +version = "1.0.1" + [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -514,9 +520,9 @@ deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[deps.TranscodingStreams]] -git-tree-sha1 = "1fbeaaca45801b4ba17c251dd8603ef24801dd84" +git-tree-sha1 = "54194d92959d8ebaa8e26227dbe3cdefcdcd594f" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.10.2" +version = "0.10.3" weakdeps = ["Random", "Test"] [deps.TranscodingStreams.extensions] diff --git a/benchmark/Project.toml b/benchmark/Project.toml index 390c4830..e3641866 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -4,3 +4,4 @@ HMMBenchmark = "557005d5-2e4a-43f9-8aa7-ba8df2d03179" HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 0a472de0..8239c69a 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -1,23 +1,23 @@ using BenchmarkTools using HMMBenchmark using Random +using StableRNGs -rng = Random.default_rng() -Random.seed!(rng, 63) +rng = StableRNG(63) -algos = ("rand", "logdensity", "forward", "viterbi", "forward_backward", "baum_welch") -configurations = [ +algos = ["rand", "logdensity", "forward", "viterbi", "forward_backward", "baum_welch"] +instances = [ # compare state numbers - Configuration(; nb_states=4, obs_dim=1), - Configuration(; nb_states=8, obs_dim=1), - Configuration(; nb_states=16, obs_dim=1), - Configuration(; nb_states=32, obs_dim=1), + Instance(; nb_states=4, obs_dim=1), + Instance(; nb_states=8, obs_dim=1), + Instance(; nb_states=16, obs_dim=1), + Instance(; nb_states=32, obs_dim=1), # compare sparse - Configuration(; nb_states=64, obs_dim=1, sparse=false), - Configuration(; nb_states=64, obs_dim=1, sparse=true), + Instance(; nb_states=64, obs_dim=1, sparse=false), + Instance(; nb_states=64, obs_dim=1, sparse=true), # compare dists - Configuration(; nb_states=4, obs_dim=10, custom_dist=true), - Configuration(; nb_states=4, obs_dim=10, custom_dist=false), + Instance(; nb_states=4, obs_dim=10, custom_dist=true), + Instance(; nb_states=4, obs_dim=10, custom_dist=false), ] -SUITE = define_suite(rng; configurations, algos) +SUITE = define_suite(rng; instances, algos) diff --git a/benchmark/run.jl b/benchmark/run.jl index 114f13e6..8f3a8c4f 100644 --- a/benchmark/run.jl +++ b/benchmark/run.jl @@ -1,4 +1,4 @@ include("benchmarks.jl") results = run(SUITE; verbose=true) -data = parse_results(minimum(results); path=joinpath(@__DIR__, "results.csv")) +data = parse_results(results; path=joinpath(@__DIR__, "results.csv")) diff --git a/libs/HMMBenchmark/Project.toml b/libs/HMMBenchmark/Project.toml index 876babae..c9a9a128 100644 --- a/libs/HMMBenchmark/Project.toml +++ b/libs/HMMBenchmark/Project.toml @@ -13,6 +13,8 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] julia = "1.9" diff --git a/libs/HMMBenchmark/src/HMMBenchmark.jl b/libs/HMMBenchmark/src/HMMBenchmark.jl index b0e6b609..b2b543fa 100644 --- a/libs/HMMBenchmark/src/HMMBenchmark.jl +++ b/libs/HMMBenchmark/src/HMMBenchmark.jl @@ -1,5 +1,6 @@ module HMMBenchmark +using Base.Threads using BenchmarkTools: @benchmarkable, BenchmarkGroup using CSV: CSV using DataFrames: DataFrame @@ -16,15 +17,21 @@ using HiddenMarkovModels: initialize_forward_backward, forward_backward!, baum_welch! -using LinearAlgebra: Diagonal, SymTridiagonal +using LinearAlgebra: BLAS, Diagonal, SymTridiagonal using Pkg: Pkg using Random: AbstractRNG using SparseArrays: spdiagm +using Statistics: mean, median, std -export Configuration, define_suite, parse_results, print_julia_setup +export AbstractImplementation, Instance +export define_suite, parse_results, print_julia_setup -include("configuration.jl") -include("algos.jl") +abstract type Implementation end +Base.string(implem::Implementation) = string(typeof(implem))[begin:(end - length("Implem"))] + +include("instance.jl") +include("params.jl") +include("hiddenmarkovmodels.jl") include("suite.jl") end diff --git a/libs/HMMBenchmark/src/algos.jl b/libs/HMMBenchmark/src/hiddenmarkovmodels.jl similarity index 71% rename from libs/HMMBenchmark/src/algos.jl rename to libs/HMMBenchmark/src/hiddenmarkovmodels.jl index 0cf85fa4..346d384d 100644 --- a/libs/HMMBenchmark/src/algos.jl +++ b/libs/HMMBenchmark/src/hiddenmarkovmodels.jl @@ -1,44 +1,48 @@ -function benchmarkables_hiddenmarkovmodels(rng::AbstractRNG; configuration, algos) - (; sparse, custom_dist, nb_states, obs_dim, seq_length, nb_seqs, bw_iter) = - configuration +struct HiddenMarkovModelsImplem <: Implementation end - # Model - init = ones(nb_states) / nb_states - if sparse - trans = spdiagm( - 0 => ones(nb_states) / 2, - +1 => ones(nb_states - 1) / 2, - -(nb_states - 1) => ones(1) / 2, - ) - else - trans = ones(nb_states, nb_states) / nb_states - end +function build_model(rng::AbstractRNG, ::HiddenMarkovModelsImplem; instance::Instance) + (; custom_dist, nb_states, obs_dim) = instance + (; init, trans, means, stds) = build_params(rng; instance) if custom_dist - dists = [LightDiagNormal(i .* ones(obs_dim), ones(obs_dim)) for i in 1:nb_states] + dists = [LightDiagNormal(means[:, i], stds[:, i]) for i in 1:nb_states] else if obs_dim == 1 - dists = [Normal(i, 1.0) for i in 1:nb_states] + dists = [Normal(means[1, i], stds[1, i]) for i in 1:nb_states] else - dists = [ - MvNormal(i .* ones(obs_dim), Diagonal(ones(obs_dim))) for i in 1:nb_states - ] + dists = [MvNormal(means[:, i], Diagonal(stds[:, i])) for i in 1:nb_states] end end + hmm = HiddenMarkovModels.HMM(init, trans, dists) + return hmm +end + +function build_benchmarkables( + rng::AbstractRNG, + implem::HiddenMarkovModelsImplem; + instance::Instance, + algos::Vector{String}, +) + (; obs_dim, seq_length, nb_seqs, bw_iter) = instance - # Data - obs_seqs = [rand(rng, hmm, seq_length).obs_seq for _ in 1:nb_seqs] + hmm = build_model(rng, implem; instance) + data = randn(rng, nb_seqs, seq_length, obs_dim) + + if obs_dim == 1 + obs_seqs = [[data[k, t, 1] for t in 1:seq_length] for k in 1:nb_seqs] + else + obs_seqs = [[data[k, t, :] for t in 1:seq_length] for k in 1:nb_seqs] + end obs_seq = reduce(vcat, obs_seqs) control_seq = fill(nothing, length(obs_seq)) seq_ends = cumsum(length.(obs_seqs)) - # Benchmarks benchs = Dict() if "rand" in algos benchs["rand"] = @benchmarkable begin - [rand($hmm, $seq_length).obs_seq for _ in 1:($nb_seqs)] + [rand($rng, $hmm, $seq_length).obs_seq for _ in 1:($nb_seqs)] end evals = 1 samples = 100 end @@ -99,7 +103,7 @@ function benchmarkables_hiddenmarkovmodels(rng::AbstractRNG; configuration, algo baum_welch!( fb_storage, logL_evolution, - $hmm, + hmm_guess, $obs_seq, $control_seq; seq_ends=$seq_ends, @@ -108,8 +112,9 @@ function benchmarkables_hiddenmarkovmodels(rng::AbstractRNG; configuration, algo loglikelihood_increasing=false, ) end evals = 1 samples = 100 setup = ( + hmm_guess = build_model($rng, $implem; instance=$instance); fb_storage = initialize_forward_backward( - $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends + hmm_guess, $obs_seq, $control_seq; seq_ends=$seq_ends ); logL_evolution = Float64[]; sizehint!(logL_evolution, $bw_iter) diff --git a/libs/HMMBenchmark/src/configuration.jl b/libs/HMMBenchmark/src/instance.jl similarity index 54% rename from libs/HMMBenchmark/src/configuration.jl rename to libs/HMMBenchmark/src/instance.jl index ae36198c..c49596d9 100644 --- a/libs/HMMBenchmark/src/configuration.jl +++ b/libs/HMMBenchmark/src/instance.jl @@ -1,4 +1,4 @@ -Base.@kwdef struct Configuration +Base.@kwdef struct Instance nb_states::Int obs_dim::Int seq_length::Int = 100 @@ -8,15 +8,11 @@ Base.@kwdef struct Configuration custom_dist::Bool = false end -function Base.string(c::Configuration) +function Base.string(c::Instance) return reduce(*, "$n $(Int(getfield(c, n))) " for n in fieldnames(typeof(c)))[1:(end - 1)] end -function Configuration(s::String) +function Instance(s::String) vals = parse.(Int, split(s, " ")[2:2:end]) - return Configuration(vals...) -end - -function to_namedtuple(c::Configuration) - return NamedTuple(n => getfield(c, n) for n in fieldnames(typeof(c))) + return Instance(vals...) end diff --git a/libs/HMMBenchmark/src/params.jl b/libs/HMMBenchmark/src/params.jl new file mode 100644 index 00000000..21f59703 --- /dev/null +++ b/libs/HMMBenchmark/src/params.jl @@ -0,0 +1,42 @@ +function build_initialization(rng::AbstractRNG; instance::Instance) + (; nb_states) = instance + init = ones(nb_states) / nb_states + return init +end + +function build_transition_matrix(rng::AbstractRNG; instance::Instance) + (; sparse, nb_states) = instance + if sparse + trans = spdiagm( + 0 => rand(rng, nb_states) / 2, + +1 => rand(rng, nb_states - 1) / 2, + -(nb_states - 1) => rand(rng, 1) / 2, + ) + else + trans = rand(rng, nb_states, nb_states) + end + for row in eachrow(trans) + row ./= sum(row) + end + return trans +end + +function build_means(rng::AbstractRNG; instance::Instance) + (; obs_dim, nb_states) = instance + means = randn(rng, obs_dim, nb_states) + return means +end + +function build_stds(rng::AbstractRNG; instance::Instance) + (; obs_dim, nb_states) = instance + stds = ones(obs_dim, nb_states) + return stds +end + +function build_params(rng::AbstractRNG; instance::Instance) + init = build_initialization(rng; instance) + trans = build_transition_matrix(rng; instance) + means = build_means(rng; instance) + stds = build_stds(rng; instance) + return (; init, trans, means, stds) +end diff --git a/libs/HMMBenchmark/src/setup.jl b/libs/HMMBenchmark/src/setup.jl new file mode 100644 index 00000000..83d76fe4 --- /dev/null +++ b/libs/HMMBenchmark/src/setup.jl @@ -0,0 +1,12 @@ +function print_julia_setup(; path) + open(path, "w") do file + redirect_stdout(file) do + versioninfo() + println("\n# Multithreading\n") + println("Julia threads = $(Threads.nthreads())") + println("OpenBLAS threads = $(BLAS.get_num_threads())") + println("\n# Julia packages\n") + Pkg.status() + end + end +end diff --git a/libs/HMMBenchmark/src/suite.jl b/libs/HMMBenchmark/src/suite.jl index 1308d96a..4adc3879 100644 --- a/libs/HMMBenchmark/src/suite.jl +++ b/libs/HMMBenchmark/src/suite.jl @@ -1,28 +1,41 @@ -function define_suite(rng::AbstractRNG; configurations, algos) +function to_namedtuple(x) + return NamedTuple(n => getfield(x, n) for n in fieldnames(typeof(x))) +end + +function define_suite( + rng::AbstractRNG, + implems::Vector{<:Implementation}=[HiddenMarkovModelsImplem()]; + instances::Vector{<:Instance}, + algos::Vector{String}, +) SUITE = BenchmarkGroup() - implem = "HiddenMarkovModels.jl" - for configuration in configurations - bench_tup = benchmarkables_hiddenmarkovmodels(rng; configuration, algos) - for (algo, bench) in pairs(bench_tup) - SUITE[implem][string(configuration)][algo] = bench + for implem in implems + for instance in instances + bench_tup = build_benchmarkables(rng, implem; instance, algos) + for (algo, bench) in pairs(bench_tup) + SUITE[string(implem)][string(instance)][algo] = bench + end end end return SUITE end -function parse_results(results; path=nothing) +function parse_results( + results; path=nothing, aggregators=[minimum, median, maximum, mean, std] +) data = DataFrame() - for implem in identity.(keys(results)) - for configuration_str in identity.(keys(results[implem])) - configuration = Configuration(configuration_str) - for algo in identity.(keys(results[implem][configuration_str])) - perf = results[implem][configuration_str][algo] - (; time, gctime, memory, allocs) = perf - row = merge( - (; implem, algo), - to_namedtuple(configuration), - (; time, gctime, memory, allocs), - ) + for implem_str in identity.(keys(results)) + for instance_str in identity.(keys(results[implem_str])) + instance = Instance(instance_str) + for algo in identity.(keys(results[implem_str][instance_str])) + perf = results[implem_str][instance_str][algo] + perf_dict = Dict{Symbol,Number}() + perf_dict[:samples] = length(perf.times) + for agg in aggregators + perf_dict[Symbol("time_$agg")] = agg(perf.times) + end + row = merge((; implem=implem_str, algo), to_namedtuple(instance)) + row = merge(row, perf_dict) push!(data, row) end end @@ -35,16 +48,3 @@ function parse_results(results; path=nothing) end return data end - -function print_julia_setup(; path) - open(path, "w") do file - redirect_stdout(file) do - versioninfo() - println("\n# Multithreading\n") - println("Julia threads = $(Threads.nthreads())") - println("OpenBLAS threads = $(BLAS.get_num_threads())") - println("\n# Julia packages\n") - Pkg.status() - end - end -end diff --git a/libs/HMMComparison/Project.toml b/libs/HMMComparison/Project.toml index 714d5c7c..587be330 100644 --- a/libs/HMMComparison/Project.toml +++ b/libs/HMMComparison/Project.toml @@ -5,10 +5,15 @@ version = "0.1.0" [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7" HMMBenchmark = "557005d5-2e4a-43f9-8aa7-ba8df2d03179" HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/libs/HMMComparison/src/HMMComparison.jl b/libs/HMMComparison/src/HMMComparison.jl index 01df0a44..5f24837f 100644 --- a/libs/HMMComparison/src/HMMComparison.jl +++ b/libs/HMMComparison/src/HMMComparison.jl @@ -1,61 +1,30 @@ module HMMComparison -using BenchmarkTools -using Distributions -using HMMBase -using HMMBenchmark: benchmarkables_hiddenmarkovmodels, to_tuple -using PythonCall +using BenchmarkTools: BenchmarkGroup, @benchmarkable +using CondaPkg: CondaPkg +using Distributions: Normal, MvNormal +using HMMBase: HMMBase +using HMMBenchmark: + HMMBenchmark, + Instance, + Implementation, + HiddenMarkovModelsImplem, + build_params, + build_model, + build_benchmarkables +using LinearAlgebra: Diagonal +using PythonCall: pyimport, pybuiltins using Random: AbstractRNG -using SparseArrays +using SparseArrays: spdiagm -export define_full_suite - -function print_python_setup(; path) - open(path, "w") do file - redirect_stdout(file) do - println("Pytorch threads = $(torch.get_num_threads())") - println("\n# Python packages\n") - end - redirect_stderr(file) do - PythonCall.CondaPkg.status() - end - end -end - -function define_full_suite(rng::AbstractRNG; implems, configurations, algos) - SUITE = BenchmarkGroup() - for implem in implems - SUITE[implem] = BenchmarkGroup() - for configuration in configurations - SUITE[implem][to_tuple(configuration)] = BenchmarkGroup() - bench_tup = benchmarkables_by_implem(rng; implem, configuration, algos) - for (algo, bench) in pairs(bench_tup) - SUITE[implem][to_tuple(configuration)][algo] = bench - end - end - end - return SUITE -end - -function benchmarkables_by_implem(rng::AbstractRNG; implem, configuration, algos) - if implem == "HiddenMarkovModels.jl" - return benchmarkables_hiddenmarkovmodels(rng; configuration, algos) - elseif implem == "HMMBase.jl" - return benchmarkables_hmmbase(rng; configuration, algos) - elseif implem == "hmmlearn" - return benchmarkables_hmmlearn(rng; configuration, algos) - elseif implem == "pomegranate" - return benchmarkables_pomegranate(rng; configuration, algos) - elseif implem == "dynamax" - return benchmarkables_dynamax(rng; configuration, algos) - else - throw(ArgumentError("Unknown implementation")) - end -end +export HiddenMarkovModelsImplem, + HMMBaseImplem, hmmlearnImplem, pomegranateImplem, dynamaxImplem +export build_model, build_benchmarkables include("hmmbase.jl") include("hmmlearn.jl") include("pomegranate.jl") include("dynamax.jl") +include("setup.jl") end # module HMMComparison diff --git a/libs/HMMComparison/src/dynamax.jl b/libs/HMMComparison/src/dynamax.jl index 2b728060..daffd7ca 100644 --- a/libs/HMMComparison/src/dynamax.jl +++ b/libs/HMMComparison/src/dynamax.jl @@ -1,4 +1,89 @@ -function benchmarkables_dynamax(rng::AbstractRNG; configuration, algos) - benchs = BenchmarkGroup() +struct dynamaxImplem <: Implementation end + +function HMMBenchmark.build_model( + rng::AbstractRNG, implem::dynamaxImplem; instance::Instance +) + np = pyimport("numpy") + jnp = pyimport("jax.numpy") + dynamax_hmm = pyimport("dynamax.hidden_markov_model") + + (; nb_states, obs_dim) = instance + (; init, trans, means, stds) = build_params(rng; instance) + + initial_probs = jnp.array(np.array(init)) + transition_matrix = jnp.array(np.array(trans)) + emission_means = jnp.array(np.array(transpose(means))) + emission_scale_diags = jnp.array(np.array(transpose(stds))) + + hmm = dynamax_hmm.DiagonalGaussianHMM(nb_states, obs_dim) + params, props = hmm.initialize(; + initial_probs=initial_probs, + transition_matrix=transition_matrix, + emission_means=emission_means, + emission_scale_diags=emission_scale_diags, + ) + + return hmm, params, props +end + +function HMMBenchmark.build_benchmarkables( + rng::AbstractRNG, implem::dynamaxImplem; instance::Instance, algos::Vector{String} +) + np = pyimport("numpy") + jax = pyimport("jax") + jnp = pyimport("jax.numpy") + (; obs_dim, seq_length, nb_seqs, bw_iter) = instance + + hmm, params, _ = build_model(rng, implem; instance) + data = randn(rng, nb_seqs, seq_length, obs_dim) + + obs_tens_py = jnp.array(np.array(data)) + + benchs = Dict() + + if "logdensity" in algos + filter_vmap = jax.vmap(hmm.filter; in_axes=pylist([pybuiltins.None, 0])) + benchs["logdensity"] = @benchmarkable begin + $(filter_vmap)($params, $obs_tens_py) + end evals = 1 samples = 100 setup = ($(filter_vmap)($params, $obs_tens_py)) + end + + if "forward" in algos + filter_vmap = jax.vmap(hmm.filter; in_axes=pylist([pybuiltins.None, 0])) + benchs["forward"] = @benchmarkable begin + $(filter_vmap)($params, $obs_tens_py) + end evals = 1 samples = 100 setup = ($(filter_vmap)($params, $obs_tens_py)) + end + + if "viterbi" in algos + most_likely_states_vmap = jax.vmap( + hmm.most_likely_states; in_axes=pylist([pybuiltins.None, 0]) + ) + benchs["viterbi"] = @benchmarkable begin + $(most_likely_states_vmap)($params, $obs_tens_py) + end evals = 1 samples = 100 setup = ($(most_likely_states_vmap)( + $params, $obs_tens_py + )) + end + + if "forward_backward" in algos + smoother_vmap = jax.vmap(hmm.smoother; in_axes=pylist([pybuiltins.None, 0])) + benchs["forward_backward"] = @benchmarkable begin + $(smoother_vmap)($params, $obs_tens_py) + end evals = 1 samples = 100 setup = ($(smoother_vmap)($params, $obs_tens_py)) + end + + if "baum_welch" in algos + benchs["baum_welch"] = @benchmarkable begin + hmm_guess.fit_em(params_guess, props_guess, $obs_tens_py; num_iters=$bw_iter) + end evals = 1 samples = 100 setup = ( + tup = build_model($rng, $implem; instance=$instance); + hmm_guess = tup[1]; + params_guess = tup[2]; + props_guess = tup[3]; + hmm_guess.fit_em(params_guess, props_guess, $obs_tens_py; num_iters=$bw_iter) + ) + end + return benchs end diff --git a/libs/HMMComparison/src/hmmbase.jl b/libs/HMMComparison/src/hmmbase.jl index c8f883fd..d41582e0 100644 --- a/libs/HMMComparison/src/hmmbase.jl +++ b/libs/HMMComparison/src/hmmbase.jl @@ -1,50 +1,67 @@ -function benchmarkables_hmmbase(rng::AbstractRNG; configuration, algos) - (; sparse, nb_states, obs_dim, seq_length, nb_seqs, bw_iter) = configuration +struct HMMBaseImplem <: Implementation end - # Model - a = ones(nb_states) / nb_states - A = ones(nb_states, nb_states) / nb_states +function HMMBenchmark.build_model( + rng::AbstractRNG, implem::HMMBaseImplem; instance::Instance +) + (; nb_states, obs_dim) = instance + (; init, trans, means, stds) = build_params(rng; instance) + + a = init + A = trans if obs_dim == 1 - B = [Normal(i, 1.0) for i in 1:nb_states] + B = [Normal(means[1, i], stds[1, i]) for i in 1:nb_states] else - B = [MvNormal(i .* ones(obs_dim), Diagonal(ones(obs_dim))) for i in 1:nb_states] + B = [MvNormal(means[:, i], Diagonal(stds[:, i])) for i in 1:nb_states] end + hmm = HMMBase.HMM(a, A, B) + return hmm +end + +function HMMBenchmark.build_benchmarkables( + rng::AbstractRNG, implem::HMMBaseImplem; instance::Instance, algos::Vector{String} +) + (; obs_dim, seq_length, nb_seqs, bw_iter) = instance - # Data - obs_mat = rand(rng, hmm, seq_length * nb_seqs) # concat insted of multiple sequences + hmm = build_model(rng, implem; instance) + data = randn(rng, nb_seqs, seq_length, obs_dim) + + if obs_dim == 1 + obs_mat = reduce(vcat, data[k, :, 1] for k in 1:nb_seqs) + else + obs_mat = reduce(vcat, data[k, :, :] for k in 1:nb_seqs) + end - # Benchmarks - benchs = Dict() + benchs = BenchmarkGroup() if "logdensity" in algos benchs["logdensity"] = @benchmarkable begin HMMBase.forward($hmm, $obs_mat) - end + end evals = 1 samples = 100 end if "forward" in algos benchs["forward"] = @benchmarkable begin HMMBase.forward($hmm, $obs_mat) - end + end evals = 1 samples = 100 end if "viterbi" in algos benchs["viterbi"] = @benchmarkable begin HMMBase.viterbi($hmm, $obs_mat) - end + end evals = 1 samples = 100 end if "forward_backward" in algos benchs["forward_backward"] = @benchmarkable begin HMMBase.posteriors($hmm, $obs_mat) - end + end evals = 1 samples = 100 end if "baum_welch" in algos benchs["baum_welch"] = @benchmarkable begin HMMBase.fit_mle($hmm, $obs_mat; maxiter=$bw_iter, tol=-Inf) - end + end evals = 1 samples = 100 end return benchs diff --git a/libs/HMMComparison/src/hmmlearn.jl b/libs/HMMComparison/src/hmmlearn.jl index 01deaf27..e2f3fd91 100644 --- a/libs/HMMComparison/src/hmmlearn.jl +++ b/libs/HMMComparison/src/hmmlearn.jl @@ -1,9 +1,15 @@ -function benchmarkables_hmmlearn(rng::AbstractRNG; configuration, algos) +struct hmmlearnImplem <: Implementation end + +function HMMBenchmark.build_model( + rng::AbstractRNG, implem::hmmlearnImplem; instance::Instance +) np = pyimport("numpy") - (; sparse, nb_states, obs_dim, seq_length, nb_seqs, bw_iter) = configuration + hmmlearn_hmm = pyimport("hmmlearn.hmm") + + (; bw_iter, nb_states) = instance + (; init, trans, means, stds) = build_params(rng; instance) - # Model - hmm = pyimport("hmmlearn.hmm").GaussianHMM(; + hmm = hmmlearn_hmm.GaussianHMM(; n_components=nb_states, covariance_type="diag", n_iter=bw_iter, @@ -11,49 +17,59 @@ function benchmarkables_hmmlearn(rng::AbstractRNG; configuration, algos) implementation="scaling", init_params="", ) - hmm.startprob_ = np.ones(nb_states) / nb_states - hmm.transmat_ = np.ones((nb_states, nb_states)) / nb_states - hmm.means_ = - np.ones((nb_states, obs_dim)) * - np.arange(1, nb_states + 1)[0:(nb_states - 1), np.newaxis] - hmm.covars_ = np.ones((nb_states, obs_dim)) - - # Data - obs_mats_list_py = pylist([hmm.sample(seq_length)[0] for _ in 1:nb_seqs]) - obs_mat_concat_py = np.concatenate(obs_mats_list_py) + + hmm.startprob_ = np.array(init) + hmm.transmat_ = np.array(trans) + hmm.means_ = np.array(transpose(means)) + hmm.covars_ = np.array(transpose(stds .^ 2)) + return hmm +end + +function HMMBenchmark.build_benchmarkables( + rng::AbstractRNG, implem::hmmlearnImplem; instance::Instance, algos::Vector{String} +) + np = pyimport("numpy") + (; obs_dim, seq_length, nb_seqs) = instance + + hmm = build_model(rng, implem; instance) + data = randn(rng, nb_seqs, seq_length, obs_dim) + + obs_mat_concat = reduce(vcat, data[k, :, :] for k in 1:nb_seqs) + obs_mat_concat_py = np.array(obs_mat_concat) obs_mat_len_py = np.full(nb_seqs, seq_length) - # Benchmarks benchs = Dict() if "logdensity" in algos benchs["logdensity"] = @benchmarkable begin - pycall($(hmm.score), $obs_mat_concat_py, $obs_mat_len_py) - end + $(hmm.score)($obs_mat_concat_py, $obs_mat_len_py) + end evals = 1 samples = 100 end if "forward" in algos benchs["forward"] = @benchmarkable begin - pycall($(hmm.score), $obs_mat_concat_py, $obs_mat_len_py) - end + $(hmm.score)($obs_mat_concat_py, $obs_mat_len_py) + end evals = 1 samples = 100 end if "viterbi" in algos benchs["viterbi"] = @benchmarkable begin - pycall($(hmm.decode), $obs_mat_concat_py, $obs_mat_len_py) - end + $(hmm.decode)($obs_mat_concat_py, $obs_mat_len_py) + end evals = 1 samples = 100 end if "forward_backward" in algos benchs["forward_backward"] = @benchmarkable begin - pycall($(hmm.predict_proba), $obs_mat_concat_py, $obs_mat_len_py) - end + $(hmm.predict_proba)($obs_mat_concat_py, $obs_mat_len_py) + end evals = 1 samples = 100 end if "baum_welch" in algos benchs["baum_welch"] = @benchmarkable begin - pycall($(hmm.fit), $obs_mat_concat_py, $obs_mat_len_py) - end + hmm_guess.fit($obs_mat_concat_py, $obs_mat_len_py) + end evals = 1 samples = 100 setup = ( + hmm_guess = build_model($rng, $implem; instance=$instance) + ) end return benchs diff --git a/libs/HMMComparison/src/pomegranate.jl b/libs/HMMComparison/src/pomegranate.jl index 0d3c2fbd..c21bc393 100644 --- a/libs/HMMComparison/src/pomegranate.jl +++ b/libs/HMMComparison/src/pomegranate.jl @@ -1,54 +1,81 @@ -function benchmarkables_pomegranate(rng::AbstractRNG; configuration, algos) +struct pomegranateImplem <: Implementation end + +function HMMBenchmark.build_model( + rng::AbstractRNG, implem::pomegranateImplem; instance::Instance +) + np = pyimport("numpy") torch = pyimport("torch") - (; sparse, nb_states, obs_dim, seq_length, nb_seqs, bw_iter) = configuration + torch.set_default_dtype(torch.float64) + pomegranate_distributions = pyimport("pomegranate.distributions") + pomegranate_hmm = pyimport("pomegranate.hmm") + + (; nb_states, bw_iter) = instance + (; init, trans, means, stds) = build_params(rng; instance) + + starts = torch.tensor(np.array(init)) + ends = torch.ones(nb_states) * 1e-10 + edges = torch.tensor(np.array(trans)) - # Model - starts = torch.ones(nb_states) / nb_states - edges = torch.ones(nb_states, nb_states) / nb_states distributions = pylist([ - pyimport("pomegranate.distributions").Normal(; - means=i * torch.ones(obs_dim), - covs=torch.square(torch.ones(obs_dim)), + pomegranate_distributions.Normal(; + means=torch.tensor(np.array(means[:, i])), + covs=torch.square(torch.tensor(np.array(stds[:, i] .^ 2))), covariance_type="diag", ) for i in 1:nb_states ]) - hmm = pyimport("pomegranate.hmm").DenseHMM(; + + hmm = pomegranate_hmm.DenseHMM(; distributions=distributions, edges=edges, starts=starts, + ends=ends, max_iter=bw_iter, tol=1e-10, verbose=false, ) - # Data - obs_tens_py = torch.randn(nb_seqs, seq_length, obs_dim) + return hmm +end + +function HMMBenchmark.build_benchmarkables( + rng::AbstractRNG, implem::pomegranateImplem; instance::Instance, algos::Vector{String} +) + np = pyimport("numpy") + torch = pyimport("torch") + torch.set_default_dtype(torch.float64) + (; obs_dim, seq_length, nb_seqs) = instance + + hmm = build_model(rng, implem; instance) + data = randn(rng, nb_seqs, seq_length, obs_dim) + + obs_tens_py = torch.tensor(np.array(data)) - # Benchmarks benchs = Dict() if "logdensity" in algos benchs["logdensity"] = @benchmarkable begin - pycall($(hmm.forward), $obs_tens_py) - end + $(hmm.forward)($obs_tens_py) + end evals = 1 samples = 100 end if "forward" in algos benchs["forward"] = @benchmarkable begin - pycall($(hmm.forward), $obs_tens_py) - end + $(hmm.forward)($obs_tens_py) + end evals = 1 samples = 100 end if "forward_backward" in algos benchs["forward_backward"] = @benchmarkable begin - pycall($(hmm.forward_backward), $obs_tens_py) - end + $(hmm.forward_backward)($obs_tens_py) + end evals = 1 samples = 100 end if "baum_welch" in algos benchs["baum_welch"] = @benchmarkable begin - pycall($(hmm.fit), $obs_tens_py) - end + hmm_guess.fit($obs_tens_py) + end evals = 1 samples = 100 setup = ( + hmm_guess = build_model($rng, $implem; instance=$instance) + ) end return benchs diff --git a/libs/HMMComparison/src/setup.jl b/libs/HMMComparison/src/setup.jl new file mode 100644 index 00000000..c311e101 --- /dev/null +++ b/libs/HMMComparison/src/setup.jl @@ -0,0 +1,12 @@ + +function print_python_setup(; path) + open(path, "w") do file + redirect_stdout(file) do + println("Pytorch threads = $(torch.get_num_threads())") + println("\n# Python packages\n") + end + redirect_stderr(file) do + CondaPkg.status() + end + end +end diff --git a/libs/HMMComparison/test/comparison.jl b/libs/HMMComparison/test/comparison.jl index 7e279280..0e373468 100644 --- a/libs/HMMComparison/test/comparison.jl +++ b/libs/HMMComparison/test/comparison.jl @@ -1,20 +1,24 @@ using BenchmarkTools using HMMComparison using HMMBenchmark -using Random -rng = Random.default_rng() -Random.seed!(rng, 63) +rng = StableRNG(63) -implems = ("HiddenMarkovModels.jl", "HMMBase.jl", "dynamax", "hmmlearn", "pomegranate") -algos = ("logdensity", "baum_welch") -configurations = [ - Configuration(; - sparse=false, nb_states=4, obs_dim=1, seq_length=100, nb_seqs=100, bw_iter=10 +implems = [ + HiddenMarkovModelsImplem(), # + HMMBaseImplem(), # + hmmlearnImplem(), # + pomegranateImplem(), # + dynamaxImplem(), # +] +algos = ["logdensity", "forward", "viterbi", "forward_backward", "baum_welch"] +instances = [ + Instance(; + sparse=false, nb_states=5, obs_dim=10, seq_length=100, nb_seqs=50, bw_iter=10 ), ] -SUITE = define_full_suite(rng; implems, configurations, algos) -# BenchmarkTools.save(joinpath(@__DIR__, "tune.json"), BenchmarkTools.params(SUITE)); +SUITE = define_suite(rng, implems; instances, algos) + results = BenchmarkTools.run(SUITE; verbose=true) -data = parse_results(minimum(results); path=joinpath(@__DIR__, "results.csv")) +data = parse_results(results) diff --git a/libs/HMMComparison/test/plots.jl b/libs/HMMComparison/test/plots.jl new file mode 100644 index 00000000..a218e4c8 --- /dev/null +++ b/libs/HMMComparison/test/plots.jl @@ -0,0 +1,11 @@ +using DataFrames +using Plots + +include("comparison.jl") + +data + +algo = "forward" +metric = :time_minimum +data_algo = data[data[!, :algo] .== algo, :] +bar(data_algo[!, :implem], data_algo[!, metric]; title=algo, label=string(metric))