Skip to content

Commit

Permalink
Speed up viterbi (gdalle#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Feb 24, 2024
1 parent 5ad5d81 commit 2f975b6
Show file tree
Hide file tree
Showing 26 changed files with 258 additions and 110 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ on:
jobs:
Benchmark:
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
if: contains(github.event.pull_request.labels.*.name, 'run benchmark')
steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HiddenMarkovModels"
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
authors = ["Guillaume Dalle"]
version = "0.4.1"
version = "0.5.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ Then, you can create your first model as follows:

```julia
using Distributions, HiddenMarkovModels
init = [0.4, 0.6]
trans = [0.9 0.1; 0.2 0.8]
init = [0.6, 0.4]
trans = [0.7 0.3; 0.2 0.8]
dists = [Normal(-1.0), Normal(1.0)]
hmm = HMM(init, trans, dists)
```
Expand Down
2 changes: 1 addition & 1 deletion benchmark/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ version = "0.1.0"
deps = ["ArgCheck", "ChainRulesCore", "DensityInterface", "DocStringExtensions", "FillArrays", "LinearAlgebra", "PrecompileTools", "Random", "SparseArrays", "StatsAPI", "StatsFuns"]
path = ".."
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
version = "0.4.0"
version = "0.5.0"
weakdeps = ["Distributions"]

[deps.HiddenMarkovModels.extensions]
Expand Down
14 changes: 12 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,24 @@ HiddenMarkovModels.forward_backward!
HiddenMarkovModels.baum_welch!
```

## Misc
## Miscellaneous

```@docs
HiddenMarkovModels.valid_hmm
HiddenMarkovModels.rand_prob_vec
HiddenMarkovModels.rand_trans_mat
HiddenMarkovModels.fit_in_sequence!
```

## Internals

```@docs
HiddenMarkovModels.LightDiagNormal
HiddenMarkovModels.LightCategorical
HiddenMarkovModels.fit_in_sequence!
HiddenMarkovModels.log_initialization
HiddenMarkovModels.log_transition_matrix
HiddenMarkovModels.mul_rows_cols!
HiddenMarkovModels.argmaxplus_transmul!
```

## Index
Expand Down
12 changes: 6 additions & 6 deletions examples/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ rng = StableRNG(63);
# ## Model

#=
The package provides a versatile [`HMM`](@ref) type with three attributes:
- a vector of state initialization probabilities
- a matrix of state transition probabilities
- a vector of observation distributions, one for each state
The package provides a versatile [`HMM`](@ref) type with three main attributes:
- a vector `init` of state initialization probabilities
- a matrix `trans` of state transition probabilities
- a vector `dists` of observation distributions, one for each state
Any scalar- or vector-valued distribution from [Distributions.jl](https://github.com/JuliaStats/Distributions.jl) can be used for the last part, as well as [Custom distributions](@ref).
=#

init = [0.6, 0.4]
trans = [0.7 0.3; 0.3 0.7]
trans = [0.7 0.3; 0.2 0.8]
dists = [MvNormal([-0.5, -0.8], I), MvNormal([0.5, 0.8], I)]
hmm = HMM(init, trans, dists)

Expand Down Expand Up @@ -142,7 +142,7 @@ Since it is a local optimization procedure, it requires a starting point that is
=#

init_guess = [0.5, 0.5]
trans_guess = [0.6 0.4; 0.4 0.6]
trans_guess = [0.6 0.4; 0.3 0.7]
dists_guess = [MvNormal([-0.4, -0.7], I), MvNormal([0.4, 0.7], I)]
hmm_guess = HMM(init_guess, trans_guess, dists_guess);

Expand Down
17 changes: 9 additions & 8 deletions examples/controlled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ struct ControlledGaussianHMM{T} <: AbstractHMM
end

#=
In state $i$ with a vector of controls $u$, our observation is given by the linear model $y \sim \mathcal{N}(\beta_i^\top u, 1)$.
In state $i$ with a vector of controls $u$, our observation is given by the linear model $y \sim \mathcal{N}(\beta_i^\top u, 1)$.
Controls must be provided to both `transition_matrix` and `obs_distributions` even if they are only used by one.
=#

function HMMs.initialization(hmm::ControlledGaussianHMM)
return hmm.init
end

function HMMs.transition_matrix(hmm::ControlledGaussianHMM)
function HMMs.transition_matrix(hmm::ControlledGaussianHMM, control::AbstractVector)
return hmm.trans
end

Expand All @@ -54,8 +55,8 @@ In this case, the transition matrix does not depend on the control.
# ## Simulation

d = 3
init = [0.8, 0.2]
trans = [0.7 0.3; 0.3 0.7]
init = [0.6, 0.4]
trans = [0.7 0.3; 0.2 0.8]
dist_coeffs = [-ones(d), ones(d)]
hmm = ControlledGaussianHMM(init, trans, dist_coeffs);

Expand Down Expand Up @@ -122,9 +123,9 @@ end
Now we put it to the test.
=#

init_guess = [0.7, 0.3]
trans_guess = [0.6 0.4; 0.4 0.6]
dist_coeffs_guess = [-0.7 * ones(d), 0.7 * ones(d)]
init_guess = [0.5, 0.5]
trans_guess = [0.6 0.4; 0.3 0.7]
dist_coeffs_guess = [-1.1 * ones(d), 1.1 * ones(d)]
hmm_guess = ControlledGaussianHMM(init_guess, trans_guess, dist_coeffs_guess);

#-
Expand All @@ -136,7 +137,7 @@ first(loglikelihood_evolution), last(loglikelihood_evolution)
How did we perform?
=#

cat(transition_matrix(hmm_est), transition_matrix(hmm); dims=3)
cat(hmm_est.trans, hmm.trans; dims=3)

#-

Expand Down
6 changes: 3 additions & 3 deletions examples/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Let's put it to the test.
=#

init = [0.6, 0.4]
trans = [0.7 0.3; 0.3 0.7]
trans = [0.7 0.3; 0.2 0.8]
dists = [StuffDist(-1.0), StuffDist(+1.0)]
hmm = HMM(init, trans, dists);

Expand All @@ -104,8 +104,8 @@ If we implement `fit!`, Baum-Welch also works seamlessly.
=#

init_guess = [0.5, 0.5]
trans_guess = [0.6 0.4; 0.4 0.6]
dists_guess = [StuffDist(-0.7), StuffDist(+0.7)]
trans_guess = [0.6 0.4; 0.3 0.7]
dists_guess = [StuffDist(-1.1), StuffDist(+1.1)]
hmm_guess = HMM(init_guess, trans_guess, dists_guess);

#-
Expand Down
6 changes: 3 additions & 3 deletions examples/temporal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end
# ## Simulation

init = [0.6, 0.4]
trans_per = ([0.7 0.3; 0.3 0.7], [0.3 0.7; 0.7 0.3])
trans_per = ([0.7 0.3; 0.2 0.8], [0.3 0.7; 0.8 0.2])
dists_per = ([Normal(-1.0), Normal(-2.0)], [Normal(+1.0), Normal(+2.0)])
hmm = PeriodicHMM(init, trans_per, dists_per);

Expand Down Expand Up @@ -152,8 +152,8 @@ Now let's test our procedure with a reasonable guess.
=#

init_guess = [0.7, 0.3]
trans_per_guess = ([0.6 0.4; 0.4 0.6], [0.4 0.6; 0.6 0.4])
dists_per_guess = ([Normal(-0.7), Normal(-1.7)], [Normal(+0.7), Normal(+1.7)])
trans_per_guess = ([0.6 0.4; 0.3 0.7], [0.4 0.6; 0.7 0.3])
dists_per_guess = ([Normal(-1.1), Normal(-2.1)], [Normal(+1.1), Normal(+2.1)])
hmm_guess = PeriodicHMM(init_guess, trans_per_guess, dists_per_guess);

#=
Expand Down
15 changes: 11 additions & 4 deletions examples/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Here we explain why playing with different number and array types can be useful

using Distributions
using HiddenMarkovModels
using HiddenMarkovModels: log_transition_matrix #src
using HMMTest #src
using LinearAlgebra
using LogarithmicNumbers
Expand Down Expand Up @@ -40,7 +41,7 @@ To give an example, let us first generate some data from a vanilla HMM.
=#

init = [0.6, 0.4]
trans = [0.7 0.3; 0.3 0.7]
trans = [0.7 0.3; 0.2 0.8]
dists = [Normal(-1.0), Normal(1.0)]
hmm = HMM(init, trans, dists)
state_seq, obs_seq = rand(rng, hmm, 100);
Expand All @@ -57,6 +58,10 @@ hmm_uncertain = HMM(init, trans, dists_guess)
Every quantity we compute with this new HMM will have propagated uncertainties around it.
=#

logdensityof(hmm, obs_seq)

#-

logdensityof(hmm_uncertain, obs_seq)

#=
Expand Down Expand Up @@ -129,7 +134,7 @@ trans_guess = sparse([
0 0.6 0.4
0.4 0 0.6
])
dists_guess = [Normal(1.2), Normal(2.2), Normal(3.2)]
dists_guess = [Normal(1.1), Normal(2.1), Normal(3.1)]
hmm_guess = HMM(init_guess, trans_guess, dists_guess);

#-
Expand All @@ -149,10 +154,12 @@ Another useful array type is [StaticArrays.jl](https://github.com/JuliaArrays/St

# ## Tests #src

@test nnz(log_transition_matrix(hmm)) == nnz(transition_matrix(hmm)) #src

seq_ends = cumsum(rand(rng, 100:200, 100)); #src
control_seqs = fill(nothing, length(seq_ends)); #src
control_seq = fill(nothing, last(seq_ends)); #src
test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false, atol=0.08) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
# https://github.com/JuliaSparse/SparseArrays.jl/issues/469 #src
@test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) #src
16 changes: 8 additions & 8 deletions libs/HMMBenchmark/src/hiddenmarkovmodels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,38 +43,38 @@ function build_benchmarkables(
if "forward" in algos
benchs["forward"] = @benchmarkable begin
forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10
end evals = 1 samples = 100
end
if "forward!" in algos
benchs["forward!"] = @benchmarkable begin
forward!(f_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10 setup = (
end evals = 1 samples = 100 setup = (
f_storage = initialize_forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
)
end

if "viterbi" in algos
benchs["viterbi"] = @benchmarkable begin
viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10
end evals = 1 samples = 100
end
if "viterbi!" in algos
benchs["viterbi!"] = @benchmarkable begin
viterbi!(v_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10 setup = (
end evals = 1 samples = 100 setup = (
v_storage = initialize_viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
)
end

if "forward_backward" in algos
benchs["forward_backward"] = @benchmarkable begin
forward_backward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10
end evals = 1 samples = 100
end
if "forward_backward!" in algos
benchs["forward_backward!"] = @benchmarkable begin
forward_backward!(fb_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10 setup = (
end evals = 1 samples = 100 setup = (
fb_storage = initialize_forward_backward(
$hmm, $obs_seq, $control_seq; seq_ends=$seq_ends
)
Expand All @@ -92,7 +92,7 @@ function build_benchmarkables(
atol=-Inf,
loglikelihood_increasing=false,
)
end evals = 1 samples = 10
end evals = 1 samples = 100
end
if "baum_welch!" in algos
benchs["baum_welch!"] = @benchmarkable begin
Expand All @@ -107,7 +107,7 @@ function build_benchmarkables(
atol=-Inf,
loglikelihood_increasing=false,
)
end evals = 1 samples = 10 setup = (
end evals = 1 samples = 100 setup = (
hmm_guess = build_model($implem, $instance, $params);
fb_storage = initialize_forward_backward(
hmm_guess, $obs_seq, $control_seq; seq_ends=$seq_ends
Expand Down
8 changes: 4 additions & 4 deletions libs/HMMComparison/src/dynamax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function HMMBenchmark.build_benchmarkables(
filter_vmap = jax.jit(jax.vmap(hmm.filter; in_axes=pylist((pybuiltins.None, 0))))
benchs["forward"] = @benchmarkable begin
$(filter_vmap)($dyn_params, $obs_tens_jax_py)
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "viterbi" in algos
Expand All @@ -57,7 +57,7 @@ function HMMBenchmark.build_benchmarkables(
)
benchs["viterbi"] = @benchmarkable begin
$(most_likely_states_vmap)($dyn_params, $obs_tens_jax_py)
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "forward_backward" in algos
Expand All @@ -66,7 +66,7 @@ function HMMBenchmark.build_benchmarkables(
)
benchs["forward_backward"] = @benchmarkable begin
$(smoother_vmap)($dyn_params, $obs_tens_jax_py)
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "baum_welch" in algos
Expand All @@ -78,7 +78,7 @@ function HMMBenchmark.build_benchmarkables(
num_iters=$bw_iter,
verbose=false,
)
end evals = 1 samples = 10 setup = (
end evals = 1 samples = 100 setup = (
tup = build_model($implem, $instance, $params);
hmm_guess = tup[1];
dyn_params_guess = tup[2];
Expand Down
8 changes: 4 additions & 4 deletions libs/HMMComparison/src/hmmbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,29 +41,29 @@ function HMMBenchmark.build_benchmarkables(
@threads for k in eachindex($obs_mats)
HMMBase.forward($hmm, $obs_mats[k])
end
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "viterbi" in algos
benchs["viterbi"] = @benchmarkable begin
@threads for k in eachindex($obs_mats)
HMMBase.viterbi($hmm, $obs_mats[k])
end
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "forward_backward" in algos
benchs["forward_backward"] = @benchmarkable begin
@threads for k in eachindex($obs_mats)
HMMBase.posteriors($hmm, $obs_mats[k])
end
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "baum_welch" in algos
benchs["baum_welch"] = @benchmarkable begin
HMMBase.fit_mle($hmm, $obs_mat_concat; maxiter=$bw_iter, tol=-Inf)
end evals = 1 samples = 10
end evals = 1 samples = 100
end

return benchs
Expand Down
Loading

0 comments on commit 2f975b6

Please sign in to comment.