diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index e59b552c..db6dd8a4 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -8,7 +8,7 @@ on: jobs: Benchmark: runs-on: ubuntu-latest - if: contains(github.event.pull_request.labels.*.name, 'run-benchmark') + if: contains(github.event.pull_request.labels.*.name, 'run benchmark') steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@latest diff --git a/benchmark/HMMBenchmark/src/HMMBenchmark.jl b/benchmark/HMMBenchmark/src/HMMBenchmark.jl index 9be1bd64..441ef4e2 100644 --- a/benchmark/HMMBenchmark/src/HMMBenchmark.jl +++ b/benchmark/HMMBenchmark/src/HMMBenchmark.jl @@ -7,6 +7,8 @@ using Distributions: Normal, DiagNormal, PDiagMat using HiddenMarkovModels using HiddenMarkovModels: LightDiagNormal, + rand_prob_vec, + rand_trans_mat, initialize_viterbi, viterbi!, initialize_forward, diff --git a/benchmark/Manifest.toml b/benchmark/Manifest.toml index b9bd4788..bc06d1a8 100644 --- a/benchmark/Manifest.toml +++ b/benchmark/Manifest.toml @@ -4,11 +4,6 @@ julia_version = "1.9.4" manifest_format = "2.0" project_hash = "30d194898a132aa612114340054b3f9c8df83b1c" -[[deps.ArgCheck]] -git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" -uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.3.0" - [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" version = "1.1.1" @@ -37,11 +32,15 @@ git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" version = "0.5.1" -[[deps.Clustering]] -deps = ["Distances", "LinearAlgebra", "NearestNeighbors", "Printf", "Random", "SparseArrays", "Statistics", "StatsBase"] -git-tree-sha1 = "7ebbd653f74504447f1c33b91cd706a69a1b189f" -uuid = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" -version = "0.14.4" +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "e0af648f0692ec1691b5d094b8724ba1346281cf" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.18.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] @@ -101,36 +100,18 @@ git-tree-sha1 = "80c3e8639e3353e5d2912fb3a1916b8455e2494b" uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d" version = "0.4.0" -[[deps.Distances]] -deps = ["LinearAlgebra", "Statistics", "StatsAPI"] -git-tree-sha1 = "5225c965635d8c21168e32a12954675e7bea1151" -uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.10" - - [deps.Distances.extensions] - DistancesChainRulesCoreExt = "ChainRulesCore" - DistancesSparseArraysExt = "SparseArrays" - - [deps.Distances.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - [[deps.Distributions]] deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] git-tree-sha1 = "a6c00f894f24460379cb7136633cef54ac9f6f4a" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" version = "0.25.103" +weakdeps = ["ChainRulesCore", "DensityInterface", "Test"] [deps.Distributions.extensions] DistributionsChainRulesCoreExt = "ChainRulesCore" DistributionsDensityInterfaceExt = "DensityInterface" DistributionsTestExt = "Test" - [deps.Distributions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" - Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - [[deps.DocStringExtensions]] deps = ["LibGit2"] git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" @@ -172,12 +153,6 @@ weakdeps = ["SparseArrays", "Statistics"] deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" -[[deps.HMMBase]] -deps = ["ArgCheck", "Clustering", "Distributions", "Hungarian", "LinearAlgebra", "Random"] -git-tree-sha1 = "47d95dcc06cafd4a1c100bfad64da3ab06ad38c7" -uuid = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7" -version = "1.0.7" - [[deps.HMMBenchmark]] deps = ["BenchmarkTools", "CSV", "DataFrames", "Distributions", "HiddenMarkovModels", "LinearAlgebra", "Pkg", "Random", "SimpleUnPack", "SparseArrays"] path = "HMMBenchmark" @@ -194,27 +169,19 @@ version = "0.1.0" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" [[deps.HiddenMarkovModels]] -deps = ["DensityInterface", "DocStringExtensions", "LinearAlgebra", "PrecompileTools", "Random", "Requires", "SimpleUnPack", "SparseArrays", "StatsAPI"] +deps = ["ChainRulesCore", "DensityInterface", "DocStringExtensions", "LinearAlgebra", "PrecompileTools", "Random", "Requires", "SimpleUnPack", "SparseArrays", "StatsAPI"] path = ".." uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47" version = "0.4.0" [deps.HiddenMarkovModels.extensions] - HiddenMarkovModelsChainRulesCoreExt = "ChainRulesCore" HiddenMarkovModelsDistributionsExt = "Distributions" HiddenMarkovModelsHMMBaseExt = "HMMBase" [deps.HiddenMarkovModels.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7" -[[deps.Hungarian]] -deps = ["LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "371a7df7a6cce5909d6c576f234a2da2e3fa0c98" -uuid = "e91730f6-4275-51fb-a7a0-7064cfbd3b39" -version = "0.6.0" - [[deps.HypergeometricFunctions]] deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" @@ -342,12 +309,6 @@ git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" version = "1.0.2" -[[deps.NearestNeighbors]] -deps = ["Distances", "StaticArrays"] -git-tree-sha1 = "2c3726ceb3388917602169bed973dbc97f1b51a8" -uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.13" - [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" version = "1.2.0" @@ -495,28 +456,11 @@ deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_j git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" version = "2.3.1" +weakdeps = ["ChainRulesCore"] [deps.SpecialFunctions.extensions] SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - [deps.SpecialFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "StaticArraysCore"] -git-tree-sha1 = "0adf069a2a490c47273727e029371b31d44b72b2" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.6.5" -weakdeps = ["Statistics"] - - [deps.StaticArrays.extensions] - StaticArraysStatisticsExt = "Statistics" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.2" - [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -539,15 +483,12 @@ deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Re git-tree-sha1 = "f625d686d5a88bcd2b15cd81f18f98186fdc0c9a" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" version = "1.3.0" +weakdeps = ["ChainRulesCore", "InverseFunctions"] [deps.StatsFuns.extensions] StatsFunsChainRulesCoreExt = "ChainRulesCore" StatsFunsInverseFunctionsExt = "InverseFunctions" - [deps.StatsFuns.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - [[deps.StringManipulation]] deps = ["PrecompileTools"] git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 8a6d8262..e8aacf36 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -4,11 +4,11 @@ using BenchmarkTools implems = ("HiddenMarkovModels.jl",) algos = ("rand", "logdensity", "viterbi", "forward_backward", "baum_welch") configurations = [] -for sparse in (false, true), nb_states in (4, 16) +for sparse in (false, true), nb_states in (4, 16, 64) push!( configurations, Configuration(; - sparse, nb_states, obs_dim=1, seq_length=1000, nb_seqs=100, bw_iter=10 + sparse, nb_states, obs_dim=1, seq_length=100, nb_seqs=100, bw_iter=1 ), ) end diff --git a/benchmark/run.jl b/benchmark/run.jl index ff408b30..b687dcd1 100644 --- a/benchmark/run.jl +++ b/benchmark/run.jl @@ -1,4 +1,4 @@ include("benchmarks.jl") -results = run(SUITE; verbose=true, samples=10) +results = run(SUITE; verbose=true, samples=5) data = parse_results(minimum(results); path=joinpath(@__DIR__, "results.csv")) diff --git a/benchmark/tune.json b/benchmark/tune.json index b57a5637..1b55dce1 100644 --- a/benchmark/tune.json +++ b/benchmark/tune.json @@ -1 +1 @@ -[{"Julia":"1.10.0-rc1","BenchmarkTools":"1.0.0"},[["BenchmarkGroup",{"data":{"HiddenMarkovModels.jl":["BenchmarkGroup",{"data":{"(1, 16, 1, 1000, 100, 10)":["BenchmarkGroup",{"data":{"viterbi_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"logdensity":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"rand":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"viterbi!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}]},"tags":[]}],"(0, 16, 1, 1000, 100, 10)":["BenchmarkGroup",{"data":{"viterbi_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"logdensity":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"rand":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"viterbi!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}]},"tags":[]}],"(1, 4, 1, 1000, 100, 10)":["BenchmarkGroup",{"data":{"viterbi_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"logdensity":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"rand":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"viterbi!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}]},"tags":[]}],"(0, 4, 1, 1000, 100, 10)":["BenchmarkGroup",{"data":{"viterbi_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"logdensity":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"rand":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"viterbi!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}]},"tags":[]}]},"tags":[]}]},"tags":[]}]]] \ No newline at end of file +[{"Julia":"1.9.4","BenchmarkTools":"1.0.0"},[["BenchmarkGroup",{"data":{"HiddenMarkovModels.jl":["BenchmarkGroup",{"data":{"(1, 16, 1, 100, 100, 1)":["BenchmarkGroup",{"data":{"viterbi_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"logdensity":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"rand":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"viterbi!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}]},"tags":[]}],"(0, 16, 1, 100, 100, 1)":["BenchmarkGroup",{"data":{"viterbi_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"logdensity":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"rand":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"viterbi!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}]},"tags":[]}],"(1, 4, 1, 100, 100, 1)":["BenchmarkGroup",{"data":{"viterbi_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"logdensity":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"rand":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"viterbi!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}]},"tags":[]}],"(0, 4, 1, 100, 100, 1)":["BenchmarkGroup",{"data":{"viterbi_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"logdensity":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"rand":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"viterbi!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}]},"tags":[]}],"(1, 64, 1, 100, 100, 1)":["BenchmarkGroup",{"data":{"viterbi_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"logdensity":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"rand":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"viterbi!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}]},"tags":[]}],"(0, 64, 1, 100, 100, 1)":["BenchmarkGroup",{"data":{"viterbi_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"logdensity":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"baum_welch!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"forward_backward_init":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"rand":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"viterbi!":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}]},"tags":[]}]},"tags":[]}]},"tags":[]}]]] \ No newline at end of file diff --git a/docs/src/api.md b/docs/src/api.md index ad92d52d..2c0f0200 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -21,6 +21,7 @@ eltype initialization transition_matrix obs_distributions +fit! ``` ## Inference @@ -31,20 +32,13 @@ forward viterbi forward_backward baum_welch -fit! +MultiSeq ``` -## Misc - -```@docs -rand_prob_vec -rand_trans_mat -HiddenMarkovModels.fit_element_from_sequence! -HiddenMarkovModels.LightDiagNormal -HiddenMarkovModels.LightCategorical -``` +## Internals -## In-place algorithms (internals) +These objects are not yet stabilized and may change at any time. +Do not consider them to be part of the API subject to semantic versioning. ### Storage types @@ -73,6 +67,18 @@ HiddenMarkovModels.forward_backward! HiddenMarkovModels.baum_welch! ``` +## Misc + +```@docs +HiddenMarkovModels.rand_prob_vec +HiddenMarkovModels.rand_trans_mat +HiddenMarkovModels.project_prob_vec +HiddenMarkovModels.project_trans_mat +HiddenMarkovModels.fit_element_from_sequence! +HiddenMarkovModels.LightDiagNormal +HiddenMarkovModels.LightCategorical +``` + ## Notations ### Integers diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index 6fcf848d..344591ea 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -19,7 +19,6 @@ using SparseArrays: AbstractSparseArray, SparseMatrixCSC, nnz, nonzeros, nzrange using StatsAPI: StatsAPI, fit, fit! export AbstractHMM, HMM, PermutedHMM -export rand_prob_vec, rand_trans_mat export initialization, transition_matrix, obs_distributions export logdensityof, viterbi, forward, forward_backward, baum_welch export fit! diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index cd30bcbd..404dcafa 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -15,6 +15,9 @@ struct BaumWelchStorage{O,R,M} <: AbstractBaumWelchStorage logL_evolution::Vector{R} end +""" + initialize_baum_welch(hmm, MultiSeq(obs_seqs); max_iterations) +""" function initialize_baum_welch(hmm::AbstractHMM, obs_seqs::MultiSeq; max_iterations=0) O = typeof(obs_seqs[1][1]) R = eltype(hmm, obs_seqs[1][1]) diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index a8055d6e..b09d53f8 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -60,20 +60,26 @@ Return the vector of initial state probabilities for `hmm`. function initialization end """ + transition_matrix(hmm) transition_matrix(hmm, t) -Return the matrix of state transition probabilities for `hmm` at time `t`. +Return the matrix of state transition probabilities for `hmm` (at time `t`). """ -function transition_matrix end +transition_matrix(hmm::AbstractHMM, t::Integer) = transition_matrix(hmm) """ + obs_distributions(hmm) obs_distributions(hmm, t) -Return a vector of observation distributions, one for each state of `hmm` at time `t`. +Return a vector of observation distributions, one for each state of `hmm` (at time `t`). -There objects should support `rand(rng, dist)` and `DensityInterface.logdensityof(dist, obs)`. +There objects should support + +- `rand(rng, dist)` +- `DensityInterface.logdensityof(dist, obs)` +- `StatsAPI.fit!(dist, obs_seq, weight_seq)` """ -function obs_distributions end +obs_distributions(hmm::AbstractHMM, t::Integer) = obs_distributions(hmm) function obs_logdensities!(logb::AbstractVector, hmm::AbstractHMM, t::Integer, obs) dists = obs_distributions(hmm, t) diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 00a3a78a..68ebe7b7 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -22,8 +22,8 @@ end Base.length(hmm::HMM) = length(hmm.init) initialization(hmm::HMM) = hmm.init -transition_matrix(hmm::HMM, ::Integer) = hmm.trans -obs_distributions(hmm::HMM, ::Integer) = hmm.dists +transition_matrix(hmm::HMM) = hmm.trans +obs_distributions(hmm::HMM) = hmm.dists ## Fitting diff --git a/test/autodiff.jl b/test/autodiff.jl index 93514200..8ab7cd99 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -2,6 +2,7 @@ using Distributions using FiniteDifferences: FiniteDifferences, central_fdm using ForwardDiff: ForwardDiff using HiddenMarkovModels +using HiddenMarkovModels: rand_prob_vec, rand_trans_mat using SimpleUnPack using Test using Zygote: Zygote diff --git a/test/types_allocations.jl b/test/types_allocations.jl index db28526e..b467f845 100644 --- a/test/types_allocations.jl +++ b/test/types_allocations.jl @@ -1,6 +1,6 @@ using Distributions using HiddenMarkovModels -using HiddenMarkovModels: LightDiagNormal, LightCategorical +using HiddenMarkovModels: LightDiagNormal, LightCategorical, rand_prob_vec, rand_trans_mat import HiddenMarkovModels as HMMs using JET using LinearAlgebra