Skip to content

Commit

Permalink
Clean up benchmarks (gdalle#77)
Browse files Browse the repository at this point in the history
* Rephrase benchmarks with dispatch

* Dynamax benchmarks

* Update Manifest
  • Loading branch information
gdalle authored Feb 2, 2024
1 parent e71a056 commit c7134e1
Show file tree
Hide file tree
Showing 21 changed files with 448 additions and 233 deletions.
6 changes: 2 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@ scratchpad.jl
**/.CondaPkg
**/__pycache__
*.ipynb
*.ipynb_checkpoints

*.csv
*.txt
.benchmarkci/
/benchmark/*.json

/docs/src/index.md
/docs/src/examples/*.md

.vscode/
*.ipynb
/docs/src/examples/*.md
52 changes: 29 additions & 23 deletions benchmark/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.0"
manifest_format = "2.0"
project_hash = "a1b4318401476bf26277ef3565ca4043fa58d314"
project_hash = "e05ed926575e94b72904ad898b09f017dc14d96a"

[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
Expand All @@ -22,9 +22,9 @@ version = "1.4.0"

[[deps.CSV]]
deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"]
git-tree-sha1 = "44dbf560808d49041989b8a96cae4cffbeb7966a"
git-tree-sha1 = "679e69c611fff422038e9e21e270c4197d49d918"
uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
version = "0.10.11"
version = "0.10.12"

[[deps.Calculus]]
deps = ["LinearAlgebra"]
Expand All @@ -34,25 +34,25 @@ version = "0.5.1"

[[deps.ChainRulesCore]]
deps = ["Compat", "LinearAlgebra"]
git-tree-sha1 = "2118cb2765f8197b08e5958cdd17c165427425ee"
git-tree-sha1 = "1287e3872d646eed95198457873249bd9f0caed2"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.19.0"
version = "1.20.1"
weakdeps = ["SparseArrays"]

[deps.ChainRulesCore.extensions]
ChainRulesCoreSparseArraysExt = "SparseArrays"

[[deps.CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
git-tree-sha1 = "cd67fc487743b2f0fd4380d4cbd3a24660d0eec8"
git-tree-sha1 = "59939d8a997469ee05c4b4944560a820f9ba0d73"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.7.3"
version = "0.7.4"

[[deps.Compat]]
deps = ["UUIDs"]
git-tree-sha1 = "886826d76ea9e72b35fcd000e535588f7b60f21d"
deps = ["TOML", "UUIDs"]
git-tree-sha1 = "75bd5b6fc5089df449b5d35fa501c846c9b6549b"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "4.10.1"
version = "4.12.0"
weakdeps = ["Dates", "LinearAlgebra"]

[deps.Compat.extensions]
Expand All @@ -69,9 +69,9 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.1.1"

[[deps.DataAPI]]
git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c"
git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.15.0"
version = "1.16.0"

[[deps.DataFrames]]
deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
Expand All @@ -81,9 +81,9 @@ version = "1.6.1"

[[deps.DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d"
git-tree-sha1 = "ac67408d9ddf207de5cfa9a97e114352430f01ed"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.15"
version = "0.18.16"

[[deps.DataValueInterfaces]]
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
Expand All @@ -102,9 +102,9 @@ version = "0.4.0"

[[deps.Distributions]]
deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"]
git-tree-sha1 = "9242eec9b7e2e14f9952e8ea1c7e31a50501d587"
git-tree-sha1 = "7c302d7a5fec5214eb8a5a4c466dcf7a51fcf169"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.104"
version = "0.25.107"
weakdeps = ["ChainRulesCore", "DensityInterface", "Test"]

[deps.Distributions.extensions]
Expand Down Expand Up @@ -155,7 +155,7 @@ deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"

[[deps.HMMBenchmark]]
deps = ["BenchmarkTools", "CSV", "DataFrames", "Distributions", "HiddenMarkovModels", "LinearAlgebra", "Pkg", "Random", "SparseArrays"]
deps = ["BenchmarkTools", "CSV", "DataFrames", "Distributions", "HiddenMarkovModels", "LinearAlgebra", "Pkg", "Random", "SparseArrays", "StableRNGs", "Statistics"]
path = "../libs/HMMBenchmark"
uuid = "557005d5-2e4a-43f9-8aa7-ba8df2d03179"
version = "0.1.0"
Expand Down Expand Up @@ -378,9 +378,9 @@ uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"

[[deps.QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
git-tree-sha1 = "9ebcd48c498668c7fa0e97a9cae873fbee7bfee1"
git-tree-sha1 = "9b23c31e76e333e6fb4c1595ae6afa74966a729e"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
version = "2.9.1"
version = "2.9.4"

[[deps.REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
Expand Down Expand Up @@ -425,9 +425,9 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[deps.SortingAlgorithms]]
deps = ["DataStructures"]
git-tree-sha1 = "5165dfb9fd131cf0c6957a3a7605dede376e7b63"
git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "1.2.0"
version = "1.2.1"

[[deps.SparseArrays]]
deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
Expand All @@ -444,6 +444,12 @@ weakdeps = ["ChainRulesCore"]
[deps.SpecialFunctions.extensions]
SpecialFunctionsChainRulesCoreExt = "ChainRulesCore"

[[deps.StableRNGs]]
deps = ["Random", "Test"]
git-tree-sha1 = "ddc1a7b85e760b5285b50b882fa91e40c603be47"
uuid = "860ef19b-820b-49d6-a774-d7a799459cd3"
version = "1.0.1"

[[deps.Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down Expand Up @@ -514,9 +520,9 @@ deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[deps.TranscodingStreams]]
git-tree-sha1 = "1fbeaaca45801b4ba17c251dd8603ef24801dd84"
git-tree-sha1 = "54194d92959d8ebaa8e26227dbe3cdefcdcd594f"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.10.2"
version = "0.10.3"
weakdeps = ["Random", "Test"]

[deps.TranscodingStreams.extensions]
Expand Down
1 change: 1 addition & 0 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ HMMBenchmark = "557005d5-2e4a-43f9-8aa7-ba8df2d03179"
HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
26 changes: 13 additions & 13 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
using BenchmarkTools
using HMMBenchmark
using Random
using StableRNGs

rng = Random.default_rng()
Random.seed!(rng, 63)
rng = StableRNG(63)

algos = ("rand", "logdensity", "forward", "viterbi", "forward_backward", "baum_welch")
configurations = [
algos = ["rand", "logdensity", "forward", "viterbi", "forward_backward", "baum_welch"]
instances = [
# compare state numbers
Configuration(; nb_states=4, obs_dim=1),
Configuration(; nb_states=8, obs_dim=1),
Configuration(; nb_states=16, obs_dim=1),
Configuration(; nb_states=32, obs_dim=1),
Instance(; nb_states=4, obs_dim=1),
Instance(; nb_states=8, obs_dim=1),
Instance(; nb_states=16, obs_dim=1),
Instance(; nb_states=32, obs_dim=1),
# compare sparse
Configuration(; nb_states=64, obs_dim=1, sparse=false),
Configuration(; nb_states=64, obs_dim=1, sparse=true),
Instance(; nb_states=64, obs_dim=1, sparse=false),
Instance(; nb_states=64, obs_dim=1, sparse=true),
# compare dists
Configuration(; nb_states=4, obs_dim=10, custom_dist=true),
Configuration(; nb_states=4, obs_dim=10, custom_dist=false),
Instance(; nb_states=4, obs_dim=10, custom_dist=true),
Instance(; nb_states=4, obs_dim=10, custom_dist=false),
]

SUITE = define_suite(rng; configurations, algos)
SUITE = define_suite(rng; instances, algos)
2 changes: 1 addition & 1 deletion benchmark/run.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
include("benchmarks.jl")

results = run(SUITE; verbose=true)
data = parse_results(minimum(results); path=joinpath(@__DIR__, "results.csv"))
data = parse_results(results; path=joinpath(@__DIR__, "results.csv"))
2 changes: 2 additions & 0 deletions libs/HMMBenchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
julia = "1.9"
15 changes: 11 additions & 4 deletions libs/HMMBenchmark/src/HMMBenchmark.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module HMMBenchmark

using Base.Threads
using BenchmarkTools: @benchmarkable, BenchmarkGroup
using CSV: CSV
using DataFrames: DataFrame
Expand All @@ -16,15 +17,21 @@ using HiddenMarkovModels:
initialize_forward_backward,
forward_backward!,
baum_welch!
using LinearAlgebra: Diagonal, SymTridiagonal
using LinearAlgebra: BLAS, Diagonal, SymTridiagonal
using Pkg: Pkg
using Random: AbstractRNG
using SparseArrays: spdiagm
using Statistics: mean, median, std

export Configuration, define_suite, parse_results, print_julia_setup
export AbstractImplementation, Instance
export define_suite, parse_results, print_julia_setup

include("configuration.jl")
include("algos.jl")
abstract type Implementation end
Base.string(implem::Implementation) = string(typeof(implem))[begin:(end - length("Implem"))]

include("instance.jl")
include("params.jl")
include("hiddenmarkovmodels.jl")
include("suite.jl")

end
Original file line number Diff line number Diff line change
@@ -1,44 +1,48 @@
function benchmarkables_hiddenmarkovmodels(rng::AbstractRNG; configuration, algos)
(; sparse, custom_dist, nb_states, obs_dim, seq_length, nb_seqs, bw_iter) =
configuration
struct HiddenMarkovModelsImplem <: Implementation end

# Model
init = ones(nb_states) / nb_states
if sparse
trans = spdiagm(
0 => ones(nb_states) / 2,
+1 => ones(nb_states - 1) / 2,
-(nb_states - 1) => ones(1) / 2,
)
else
trans = ones(nb_states, nb_states) / nb_states
end
function build_model(rng::AbstractRNG, ::HiddenMarkovModelsImplem; instance::Instance)
(; custom_dist, nb_states, obs_dim) = instance
(; init, trans, means, stds) = build_params(rng; instance)

if custom_dist
dists = [LightDiagNormal(i .* ones(obs_dim), ones(obs_dim)) for i in 1:nb_states]
dists = [LightDiagNormal(means[:, i], stds[:, i]) for i in 1:nb_states]
else
if obs_dim == 1
dists = [Normal(i, 1.0) for i in 1:nb_states]
dists = [Normal(means[1, i], stds[1, i]) for i in 1:nb_states]
else
dists = [
MvNormal(i .* ones(obs_dim), Diagonal(ones(obs_dim))) for i in 1:nb_states
]
dists = [MvNormal(means[:, i], Diagonal(stds[:, i])) for i in 1:nb_states]
end
end

hmm = HiddenMarkovModels.HMM(init, trans, dists)
return hmm
end

function build_benchmarkables(
rng::AbstractRNG,
implem::HiddenMarkovModelsImplem;
instance::Instance,
algos::Vector{String},
)
(; obs_dim, seq_length, nb_seqs, bw_iter) = instance

# Data
obs_seqs = [rand(rng, hmm, seq_length).obs_seq for _ in 1:nb_seqs]
hmm = build_model(rng, implem; instance)
data = randn(rng, nb_seqs, seq_length, obs_dim)

if obs_dim == 1
obs_seqs = [[data[k, t, 1] for t in 1:seq_length] for k in 1:nb_seqs]
else
obs_seqs = [[data[k, t, :] for t in 1:seq_length] for k in 1:nb_seqs]
end
obs_seq = reduce(vcat, obs_seqs)
control_seq = fill(nothing, length(obs_seq))
seq_ends = cumsum(length.(obs_seqs))

# Benchmarks
benchs = Dict()

if "rand" in algos
benchs["rand"] = @benchmarkable begin
[rand($hmm, $seq_length).obs_seq for _ in 1:($nb_seqs)]
[rand($rng, $hmm, $seq_length).obs_seq for _ in 1:($nb_seqs)]
end evals = 1 samples = 100
end

Expand Down Expand Up @@ -99,7 +103,7 @@ function benchmarkables_hiddenmarkovmodels(rng::AbstractRNG; configuration, algo
baum_welch!(
fb_storage,
logL_evolution,
$hmm,
hmm_guess,
$obs_seq,
$control_seq;
seq_ends=$seq_ends,
Expand All @@ -108,8 +112,9 @@ function benchmarkables_hiddenmarkovmodels(rng::AbstractRNG; configuration, algo
loglikelihood_increasing=false,
)
end evals = 1 samples = 100 setup = (
hmm_guess = build_model($rng, $implem; instance=$instance);
fb_storage = initialize_forward_backward(
$hmm, $obs_seq, $control_seq; seq_ends=$seq_ends
hmm_guess, $obs_seq, $control_seq; seq_ends=$seq_ends
);
logL_evolution = Float64[];
sizehint!(logL_evolution, $bw_iter)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Base.@kwdef struct Configuration
Base.@kwdef struct Instance
nb_states::Int
obs_dim::Int
seq_length::Int = 100
Expand All @@ -8,15 +8,11 @@ Base.@kwdef struct Configuration
custom_dist::Bool = false
end

function Base.string(c::Configuration)
function Base.string(c::Instance)
return reduce(*, "$n $(Int(getfield(c, n))) " for n in fieldnames(typeof(c)))[1:(end - 1)]
end

function Configuration(s::String)
function Instance(s::String)
vals = parse.(Int, split(s, " ")[2:2:end])
return Configuration(vals...)
end

function to_namedtuple(c::Configuration)
return NamedTuple(n => getfield(c, n) for n in fieldnames(typeof(c)))
return Instance(vals...)
end
Loading

0 comments on commit c7134e1

Please sign in to comment.