From 36b3313c97db601c191084d3fb188f7748a0de17 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 12 Jan 2024 10:59:39 +0100 Subject: [PATCH] Remove deps and fix Enzyme (#73) * Remove deps and fix Enzyme * Skip allocation test on sparse matrices * Move control_seq to positional argument --- .gitignore | 3 +- Project.toml | 5 --- benchmark/Manifest.toml | 46 ++++++++++++------------ docs/Project.toml | 1 - docs/src/api.md | 1 + examples/autodiff.jl | 30 ++++++++-------- examples/basics.jl | 6 ++-- examples/controlled.jl | 11 +++--- examples/interfaces.jl | 4 +-- examples/temporal.jl | 11 +++--- examples/types.jl | 3 +- ext/HiddenMarkovModelsSparseArraysExt.jl | 22 ------------ libs/HMMBenchmark/src/algos.jl | 40 ++++++++------------- libs/HMMTest/src/allocations.jl | 26 +++++++------- libs/HMMTest/src/coherence.jl | 14 ++++---- libs/HMMTest/src/jet.jl | 20 +++++------ src/HiddenMarkovModels.jl | 8 +++-- src/inference/baum_welch.jl | 18 +++++----- src/inference/chainrules.jl | 20 +++++------ src/inference/forward.jl | 24 ++++++------- src/inference/forward_backward.jl | 26 +++++++------- src/inference/logdensity.jl | 12 +++---- src/inference/viterbi.jl | 24 ++++++------- src/types/hmm.jl | 6 ++-- src/utils/linalg.jl | 16 +++++++++ test/Project.toml | 1 - test/correctness.jl | 1 - 27 files changed, 190 insertions(+), 209 deletions(-) delete mode 100644 ext/HiddenMarkovModelsSparseArraysExt.jl diff --git a/.gitignore b/.gitignore index 51b491e1..9624171c 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,5 @@ scratchpad.jl /docs/src/index.md /docs/src/examples/*.md -.vscode/ \ No newline at end of file +.vscode/ +*.ipynb \ No newline at end of file diff --git a/Project.toml b/Project.toml index 00efafdc..7689d2bf 100644 --- a/Project.toml +++ b/Project.toml @@ -6,23 +6,19 @@ version = "0.4.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" [weakdeps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [extensions] HiddenMarkovModelsDistributionsExt = "Distributions" -HiddenMarkovModelsSparseArraysExt = "SparseArrays" [compat] ChainRulesCore = "1.16" @@ -33,7 +29,6 @@ FillArrays = "1" LinearAlgebra = "1" PrecompileTools = "1.1" Random = "1" -SimpleUnPack = "1.1" SparseArrays = "1" StatsAPI = "1.6" julia = "1.9" diff --git a/benchmark/Manifest.toml b/benchmark/Manifest.toml index 9d476843..d17f442a 100644 --- a/benchmark/Manifest.toml +++ b/benchmark/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.9.4" +julia_version = "1.10.0" manifest_format = "2.0" project_hash = "a1b4318401476bf26277ef3565ca4043fa58d314" @@ -61,7 +61,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.5+0" +version = "1.0.5+1" [[deps.Crayons]] git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" @@ -161,15 +161,14 @@ uuid = "557005d5-2e4a-43f9-8aa7-ba8df2d03179" version = "0.1.0" [[deps.HiddenMarkovModels]] -deps = ["ChainRulesCore", "DensityInterface", "DocStringExtensions", "FillArrays", "LinearAlgebra", "PrecompileTools", "Random", "SimpleUnPack", "StatsAPI"] +deps = ["ChainRulesCore", "DensityInterface", "DocStringExtensions", "FillArrays", "LinearAlgebra", "PrecompileTools", "Random", "SparseArrays", "StatsAPI"] path = ".." uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47" version = "0.4.0" -weakdeps = ["Distributions", "SparseArrays"] +weakdeps = ["Distributions"] [deps.HiddenMarkovModels.extensions] HiddenMarkovModelsDistributionsExt = "Distributions" - HiddenMarkovModelsSparseArraysExt = "SparseArrays" [[deps.HypergeometricFunctions]] deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] @@ -236,9 +235,14 @@ uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" version = "8.4.0+0" [[deps.LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" @@ -277,7 +281,7 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+0" +version = "2.28.2+1" [[deps.Missings]] deps = ["DataAPI"] @@ -290,7 +294,7 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.10.11" +version = "2023.1.10" [[deps.NaNMath]] deps = ["OpenLibm_jll"] @@ -305,12 +309,12 @@ version = "1.2.0" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.21+4" +version = "0.3.23+2" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+0" +version = "0.8.1+2" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -338,7 +342,7 @@ version = "2.8.1" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.9.2" +version = "1.10.0" [[deps.PooledArrays]] deps = ["DataAPI", "Future"] @@ -383,7 +387,7 @@ deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [[deps.Random]] -deps = ["SHA", "Serialization"] +deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [[deps.Reexport]] @@ -416,11 +420,6 @@ version = "1.4.1" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" -[[deps.SimpleUnPack]] -git-tree-sha1 = "58e6353e72cde29b90a69527e56df1b5c3d8c437" -uuid = "ce78b400-467f-4804-87d8-8f486da07d0a" -version = "1.1.0" - [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" @@ -433,6 +432,7 @@ version = "1.2.0" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] @@ -447,7 +447,7 @@ weakdeps = ["ChainRulesCore"] [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.9.0" +version = "1.10.0" [[deps.StatsAPI]] deps = ["LinearAlgebra"] @@ -483,9 +483,9 @@ deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "5.10.1+6" +version = "7.2.1+1" [[deps.TOML]] deps = ["Dates"] @@ -543,12 +543,12 @@ version = "1.6.1" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+0" +version = "1.2.13+1" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+0" +version = "5.8.0+1" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] @@ -558,4 +558,4 @@ version = "1.52.0+1" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+0" +version = "17.4.0+2" diff --git a/docs/Project.toml b/docs/Project.toml index ad75722e..73bd16cd 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -11,7 +11,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" diff --git a/docs/src/api.md b/docs/src/api.md index 8720fa3c..4e93cef2 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -45,6 +45,7 @@ seq_limits ```@docs logdensityof +joint_logdensityof forward viterbi forward_backward diff --git a/examples/autodiff.jl b/examples/autodiff.jl index aabd0690..d652c667 100644 --- a/examples/autodiff.jl +++ b/examples/autodiff.jl @@ -72,36 +72,36 @@ grad_z = Zygote.gradient(f, params)[1] grad_f ≈ grad_z #= -For increased efficiency, one can use Enzyme.jl and provide temporary storage. -This requires going one level deeper, to the mutating [`HiddenMarkovModels.forward!`](@ref) function. +Enzyme.jl also works natively but we have to avoid the type instability of global variables by providing more information. =# -control_seq = fill(nothing, length(obs_seq)) - -function f!(storage::HMMs.ForwardStorage, params::ComponentVector) +function f_extended(params::ComponentVector, obs_seq, seq_ends) new_hmm = HMM(params.init, params.trans, Normal.(params.means)) - HMMs.forward!(storage, new_hmm, obs_seq; control_seq, seq_ends) - return sum(storage.logL) -end + return logdensityof(new_hmm, obs_seq; seq_ends) +end; -storage = HMMs.initialize_forward(hmm, obs_seq; control_seq, seq_ends) -storage_shadow = HMMs.initialize_forward(hmm, obs_seq; control_seq, seq_ends) -params_shadow = zero(params) +shadow_params = Enzyme.make_zero(params) Enzyme.autodiff( Enzyme.Reverse, - f!, + f_extended, Enzyme.Active, - Enzyme.Duplicated(storage, storage_shadow), - Enzyme.Duplicated(params, params_shadow), + Enzyme.Duplicated(params, shadow_params), + Enzyme.Const(obs_seq), + Enzyme.Const(seq_ends), ) -grad_e = params_shadow +grad_e = shadow_params #- grad_e ≈ grad_f +#= +For increased efficiency, one can provide temporary storage to Enzyme.jl in order to avoid allocations. +This requires going one level deeper, by leveraging the in-place [`HiddenMarkovModels.forward!`](@ref) function. +=# + # ## Gradient methods #= diff --git a/examples/basics.jl b/examples/basics.jl index 112a1d86..6af235ee 100644 --- a/examples/basics.jl +++ b/examples/basics.jl @@ -99,16 +99,16 @@ Finally, we provide a thin wrapper ([`logdensityof`](@ref)) around the forward a logdensityof(hmm, obs_seq) #= -The same function can also compute joint loglikelihoods $\mathbb{P}(X_{1:T}, Y_{1:T})$ that take the states into account. +Another function can compute joint loglikelihoods $\mathbb{P}(X_{1:T}, Y_{1:T})$ which take the states into account. =# -logdensityof(hmm, obs_seq, state_seq) +joint_logdensityof(hmm, obs_seq, state_seq) #= For instance, we can check that the output of Viterbi is at least as likely as the true state sequence. =# -logdensityof(hmm, obs_seq, best_state_seq) +joint_logdensityof(hmm, obs_seq, best_state_seq) # ## Learning diff --git a/examples/controlled.jl b/examples/controlled.jl index 53d1fae7..a547f483 100644 --- a/examples/controlled.jl +++ b/examples/controlled.jl @@ -10,7 +10,6 @@ import HiddenMarkovModels as HMMs using HMMTest #src using LinearAlgebra using Random -using SimpleUnPack using StatsAPI using Test #src @@ -79,7 +78,7 @@ seq_ends = cumsum(length.(obs_seqs)); Not much changes from the case with simple time dependency. =# -best_state_seq, _ = viterbi(hmm, obs_seq; control_seq, seq_ends) +best_state_seq, _ = viterbi(hmm, obs_seq, control_seq; seq_ends) # ## Learning @@ -92,11 +91,11 @@ Meanwhile, the observation coefficients are given by the formula for [weighted l function StatsAPI.fit!( hmm::ControlledGaussianHMM{T}, fb_storage::HMMs.ForwardBackwardStorage, - obs_seq::AbstractVector; - control_seq::AbstractVector, + obs_seq::AbstractVector, + control_seq::AbstractVector; seq_ends::AbstractVector{Int}, ) where {T} - @unpack γ, ξ = fb_storage + (; γ, ξ) = fb_storage N = length(hmm) hmm.init .= 0 @@ -130,7 +129,7 @@ hmm_guess = ControlledGaussianHMM(init_guess, trans_guess, dist_coeffs_guess); #- -hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq; control_seq, seq_ends) +hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends) first(loglikelihood_evolution), last(loglikelihood_evolution) #= diff --git a/examples/interfaces.jl b/examples/interfaces.jl index f292fe8d..e20889b7 100644 --- a/examples/interfaces.jl +++ b/examples/interfaces.jl @@ -179,8 +179,8 @@ We will make use of the fields `fb_storage.γ` and `fb_storage.ξ`, which contai function StatsAPI.fit!( hmm::PriorHMM, fb_storage::HiddenMarkovModels.ForwardBackwardStorage, - obs_seq::AbstractVector; - control_seq::AbstractVector, + obs_seq::AbstractVector, + control_seq::AbstractVector; seq_ends::AbstractVector{Int}, ) ## initialize to defaults without observations diff --git a/examples/temporal.jl b/examples/temporal.jl index c40b232f..de46a0d9 100644 --- a/examples/temporal.jl +++ b/examples/temporal.jl @@ -10,7 +10,6 @@ using HiddenMarkovModels import HiddenMarkovModels as HMMs using HMMTest #src using Random -using SimpleUnPack using StatsAPI using Test #src @@ -88,7 +87,7 @@ seq_ends = cumsum(length.(obs_seqs)); All three inference algorithms work in the same way, except that we need to provide the control sequence as a keyword argument. =# -best_state_seq, _ = viterbi(hmm, obs_seq; control_seq, seq_ends) +best_state_seq, _ = viterbi(hmm, obs_seq, control_seq; seq_ends) #= For Viterbi, unsurprisingly, the most likely state sequence aligns with the sign of the observations. @@ -106,11 +105,11 @@ The key is to split the observations according to which periodic parameter they function StatsAPI.fit!( hmm::PeriodicHMM{T}, fb_storage::HMMs.ForwardBackwardStorage, - obs_seq::AbstractVector; - control_seq::AbstractVector, + obs_seq::AbstractVector, + control_seq::AbstractVector; seq_ends::AbstractVector{Int}, ) where {T} - @unpack γ, ξ = fb_storage + (; γ, ξ) = fb_storage L, N = period(hmm), length(hmm) hmm.init .= zero(T) @@ -159,7 +158,7 @@ hmm_guess = PeriodicHMM(init_guess, trans_per_guess, dists_per_guess); Naturally, Baum-Welch also requires knowing `control_seq`. =# -hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq; control_seq, seq_ends); +hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends); first(loglikelihood_evolution), last(loglikelihood_evolution) #= diff --git a/examples/types.jl b/examples/types.jl index 7dd8865b..e001ae3a 100644 --- a/examples/types.jl +++ b/examples/types.jl @@ -97,4 +97,5 @@ seq_ends = cumsum(length.(control_seqs)); #src test_identical_hmmbase(rng, hmm, hmm_guess; T=100) #src test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05, init=false) #src test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) #src -test_allocations(rng, hmm, hmm_guess; control_seq, seq_ends) #src +# https://github.com/JuliaSparse/SparseArrays.jl/issues/469 #src +@test_skip test_allocations(rng, hmm, hmm_guess; control_seq, seq_ends) #src diff --git a/ext/HiddenMarkovModelsSparseArraysExt.jl b/ext/HiddenMarkovModelsSparseArraysExt.jl deleted file mode 100644 index 0c9890d3..00000000 --- a/ext/HiddenMarkovModelsSparseArraysExt.jl +++ /dev/null @@ -1,22 +0,0 @@ -module HiddenMarkovModelsSparseArraysExt - -using HiddenMarkovModels: HiddenMarkovModels -using SparseArrays - -HiddenMarkovModels.mynonzeros(x::AbstractSparseArray) = nonzeros(x) - -function HiddenMarkovModels.mul_rows_cols!( - B::SparseMatrixCSC, l::AbstractVector, A::SparseMatrixCSC, r::AbstractVector -) - @assert size(B) == size(A) == (length(l), length(r)) - @assert nnz(B) == nnz(A) - for j in axes(B, 2) - for k in nzrange(B, j) - i = B.rowval[k] - B.nzval[k] = l[i] * A.nzval[k] * r[j] - end - end - return nothing -end - -end diff --git a/libs/HMMBenchmark/src/algos.jl b/libs/HMMBenchmark/src/algos.jl index e155ceed..0cf85fa4 100644 --- a/libs/HMMBenchmark/src/algos.jl +++ b/libs/HMMBenchmark/src/algos.jl @@ -44,51 +44,41 @@ function benchmarkables_hiddenmarkovmodels(rng::AbstractRNG; configuration, algo if "logdensity" in algos benchs["logdensity"] = @benchmarkable begin - logdensityof($hmm, $obs_seq; control_seq=$control_seq, seq_ends=$seq_ends) + logdensityof($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) end evals = 1 samples = 100 end if "forward" in algos benchs["forward"] = @benchmarkable begin - forward($hmm, $obs_seq; control_seq=$control_seq, seq_ends=$seq_ends) + forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) end evals = 1 samples = 100 benchs["forward!"] = @benchmarkable begin - forward!( - f_storage, $hmm, $obs_seq; control_seq=$control_seq, seq_ends=$seq_ends - ) + forward!(f_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) end evals = 1 samples = 100 setup = ( - f_storage = initialize_forward( - $hmm, $obs_seq; control_seq=$control_seq, seq_ends=$seq_ends - ) + f_storage = initialize_forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) ) end if "viterbi" in algos benchs["viterbi"] = @benchmarkable begin - viterbi($hmm, $obs_seq; control_seq=$control_seq, seq_ends=$seq_ends) + viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) end evals = 1 samples = 100 benchs["viterbi!"] = @benchmarkable begin - viterbi!( - v_storage, $hmm, $obs_seq; control_seq=$control_seq, seq_ends=$seq_ends - ) + viterbi!(v_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) end evals = 1 samples = 100 setup = ( - v_storage = initialize_viterbi( - $hmm, $obs_seq; control_seq=$control_seq, seq_ends=$seq_ends - ) + v_storage = initialize_viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) ) end if "forward_backward" in algos benchs["forward_backward"] = @benchmarkable begin - forward_backward($hmm, $obs_seq; control_seq=$control_seq, seq_ends=$seq_ends) + forward_backward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) end evals = 1 samples = 100 benchs["forward_backward!"] = @benchmarkable begin - forward_backward!( - fb_storage, $hmm, $obs_seq; control_seq=$control_seq, seq_ends=$seq_ends - ) + forward_backward!(fb_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) end evals = 1 samples = 100 setup = ( fb_storage = initialize_forward_backward( - $hmm, $obs_seq; control_seq=$control_seq, seq_ends=$seq_ends + $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends ) ) end @@ -97,8 +87,8 @@ function benchmarkables_hiddenmarkovmodels(rng::AbstractRNG; configuration, algo benchs["baum_welch"] = @benchmarkable begin baum_welch( $hmm, - $obs_seq; - control_seq=$control_seq, + $obs_seq, + $control_seq; seq_ends=$seq_ends, max_iterations=$bw_iter, atol=-Inf, @@ -110,8 +100,8 @@ function benchmarkables_hiddenmarkovmodels(rng::AbstractRNG; configuration, algo fb_storage, logL_evolution, $hmm, - $obs_seq; - control_seq=$control_seq, + $obs_seq, + $control_seq; seq_ends=$seq_ends, max_iterations=$bw_iter, atol=-Inf, @@ -119,7 +109,7 @@ function benchmarkables_hiddenmarkovmodels(rng::AbstractRNG; configuration, algo ) end evals = 1 samples = 100 setup = ( fb_storage = initialize_forward_backward( - $hmm, $obs_seq; control_seq=$control_seq, seq_ends=$seq_ends + $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends ); logL_evolution = Float64[]; sizehint!(logL_evolution, $bw_iter) diff --git a/libs/HMMTest/src/allocations.jl b/libs/HMMTest/src/allocations.jl index 2d0ff79c..15f53c17 100644 --- a/libs/HMMTest/src/allocations.jl +++ b/libs/HMMTest/src/allocations.jl @@ -13,30 +13,30 @@ function test_allocations( end ## Forward - forward(hmm, obs_seq; control_seq, seq_ends) # compile - f_storage = HMMs.initialize_forward(hmm, obs_seq; control_seq, seq_ends) - allocs = @allocated HMMs.forward!(f_storage, hmm, obs_seq; control_seq, seq_ends) + forward(hmm, obs_seq, control_seq; seq_ends) # compile + f_storage = HMMs.initialize_forward(hmm, obs_seq, control_seq; seq_ends) + allocs = @allocated HMMs.forward!(f_storage, hmm, obs_seq, control_seq; seq_ends) @test allocs == 0 ## Viterbi - viterbi(hmm, obs_seq; control_seq, seq_ends) # compile - v_storage = HMMs.initialize_viterbi(hmm, obs_seq; control_seq, seq_ends) - allocs = @allocated HMMs.viterbi!(v_storage, hmm, obs_seq; control_seq, seq_ends) + viterbi(hmm, obs_seq, control_seq; seq_ends) # compile + v_storage = HMMs.initialize_viterbi(hmm, obs_seq, control_seq; seq_ends) + allocs = @allocated HMMs.viterbi!(v_storage, hmm, obs_seq, control_seq; seq_ends) @test allocs == 0 ## Forward-backward - forward_backward(hmm, obs_seq; control_seq, seq_ends) # compile - fb_storage = HMMs.initialize_forward_backward(hmm, obs_seq; control_seq, seq_ends) + forward_backward(hmm, obs_seq, control_seq; seq_ends) # compile + fb_storage = HMMs.initialize_forward_backward(hmm, obs_seq, control_seq; seq_ends) allocs = @allocated HMMs.forward_backward!( - fb_storage, hmm, obs_seq; control_seq, seq_ends + fb_storage, hmm, obs_seq, control_seq; seq_ends ) @test allocs == 0 if !isnothing(hmm_guess) ## Baum-Welch - baum_welch(hmm_guess, obs_seq; control_seq, seq_ends, max_iterations=1) # compile + baum_welch(hmm_guess, obs_seq, control_seq; seq_ends, max_iterations=1) # compile fb_storage = HMMs.initialize_forward_backward( - hmm_guess, obs_seq; control_seq, seq_ends + hmm_guess, obs_seq, control_seq; seq_ends ) logL_evolution = Float64[] sizehint!(logL_evolution, 1) @@ -45,8 +45,8 @@ function test_allocations( fb_storage, logL_evolution, hmm_guess, - obs_seq; - control_seq, + obs_seq, + control_seq; seq_ends, atol=-Inf, max_iterations=1, diff --git a/libs/HMMTest/src/coherence.jl b/libs/HMMTest/src/coherence.jl index 85fe4688..f0646e24 100644 --- a/libs/HMMTest/src/coherence.jl +++ b/libs/HMMTest/src/coherence.jl @@ -75,22 +75,22 @@ function test_coherent_algorithms( state_seq = reduce(vcat, state_seqs) obs_seq = reduce(vcat, obs_seqs) - logL = logdensityof(hmm, obs_seq; control_seq, seq_ends) - logL_joint = logdensityof(hmm, obs_seq, state_seq; control_seq, seq_ends) + logL = logdensityof(hmm, obs_seq, control_seq; seq_ends) + logL_joint = joint_logdensityof(hmm, obs_seq, state_seq, control_seq; seq_ends) - q, logL_viterbi = viterbi(hmm, obs_seq; control_seq, seq_ends) + q, logL_viterbi = viterbi(hmm, obs_seq, control_seq; seq_ends) @test logL_viterbi > logL_joint - @test logL_viterbi ≈ logdensityof(hmm, obs_seq, q; control_seq, seq_ends) + @test logL_viterbi ≈ joint_logdensityof(hmm, obs_seq, q, control_seq; seq_ends) - α, logL_forward = forward(hmm, obs_seq; control_seq, seq_ends) + α, logL_forward = forward(hmm, obs_seq, control_seq; seq_ends) @test logL_forward ≈ logL - γ, logL_forward_backward = forward_backward(hmm, obs_seq; control_seq, seq_ends) + γ, logL_forward_backward = forward_backward(hmm, obs_seq, control_seq; seq_ends) @test logL_forward_backward ≈ logL @test all(α[:, seq_ends[k]] ≈ γ[:, seq_ends[k]] for k in eachindex(seq_ends)) if !isnothing(hmm_guess) - hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seq; control_seq, seq_ends) + hmm_est, logL_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends) @test all(>=(0), diff(logL_evolution)) @test !check_equal_hmms( hmm, hmm_guess; control_seq=control_seq[1:2], atol, test=false diff --git a/libs/HMMTest/src/jet.jl b/libs/HMMTest/src/jet.jl index a92c5fcc..95d82789 100644 --- a/libs/HMMTest/src/jet.jl +++ b/libs/HMMTest/src/jet.jl @@ -12,9 +12,9 @@ function test_type_stability( @test_opt target_modules = (HMMs,) rand(hmm, control_seq) @test_call target_modules = (HMMs,) rand(hmm, control_seq) - @test_opt target_modules = (HMMs,) logdensityof(hmm, obs_seq; control_seq, seq_ends) + @test_opt target_modules = (HMMs,) logdensityof(hmm, obs_seq, control_seq; seq_ends) @test_call target_modules = (HMMs,) logdensityof( - hmm, obs_seq; control_seq, seq_ends + hmm, obs_seq, control_seq; seq_ends ) @test_opt target_modules = (HMMs,) logdensityof( hmm, obs_seq, state_seq; control_seq, seq_ends @@ -23,25 +23,25 @@ function test_type_stability( hmm, obs_seq, state_seq; control_seq, seq_ends ) - @test_opt target_modules = (HMMs,) forward(hmm, obs_seq; control_seq, seq_ends) - @test_call target_modules = (HMMs,) forward(hmm, obs_seq; control_seq, seq_ends) + @test_opt target_modules = (HMMs,) forward(hmm, obs_seq, control_seq; seq_ends) + @test_call target_modules = (HMMs,) forward(hmm, obs_seq, control_seq; seq_ends) - @test_opt target_modules = (HMMs,) viterbi(hmm, obs_seq; control_seq, seq_ends) - @test_call target_modules = (HMMs,) viterbi(hmm, obs_seq; control_seq, seq_ends) + @test_opt target_modules = (HMMs,) viterbi(hmm, obs_seq, control_seq; seq_ends) + @test_call target_modules = (HMMs,) viterbi(hmm, obs_seq, control_seq; seq_ends) @test_opt target_modules = (HMMs,) forward_backward( - hmm, obs_seq; control_seq, seq_ends + hmm, obs_seq, control_seq; seq_ends ) @test_call target_modules = (HMMs,) forward_backward( - hmm, obs_seq; control_seq, seq_ends + hmm, obs_seq, control_seq; seq_ends ) if !isnothing(hmm_guess) @test_opt target_modules = (HMMs,) baum_welch( - hmm, obs_seq; control_seq, seq_ends, max_iterations=1 + hmm, obs_seq, control_seq; seq_ends, max_iterations=1 ) @test_call target_modules = (HMMs,) baum_welch( - hmm, obs_seq; control_seq, seq_ends, max_iterations=1 + hmm, obs_seq, control_seq; seq_ends, max_iterations=1 ) end end diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index b7c061e8..f2062bfb 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -2,6 +2,10 @@ HiddenMarkovModels A Julia package for HMM modeling, simulation, inference and learning. + +# Exports + +$(EXPORTS) """ module HiddenMarkovModels @@ -14,12 +18,12 @@ using FillArrays: Fill using LinearAlgebra: dot, ldiv!, lmul!, mul! using PrecompileTools: @compile_workload using Random: Random, AbstractRNG, default_rng -using SimpleUnPack: @unpack +using SparseArrays: AbstractSparseArray, SparseMatrixCSC, nonzeros, nnz, nzrange using StatsAPI: StatsAPI, fit, fit! export AbstractHMM, HMM export initialization, transition_matrix, obs_distributions -export fit!, logdensityof +export fit!, logdensityof, joint_logdensityof export viterbi, forward, forward_backward, baum_welch export seq_limits diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 21c024ea..924e2bf1 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -20,17 +20,17 @@ function baum_welch!( fb_storage::ForwardBackwardStorage, logL_evolution::Vector, hmm::AbstractHMM, - obs_seq::AbstractVector; - control_seq::AbstractVector, + obs_seq::AbstractVector, + control_seq::AbstractVector; seq_ends::AbstractVector{Int}, atol::Real, max_iterations::Integer, loglikelihood_increasing::Bool, ) for iteration in 1:max_iterations - forward_backward!(fb_storage, hmm, obs_seq; control_seq, seq_ends) + forward_backward!(fb_storage, hmm, obs_seq, control_seq; seq_ends) push!(logL_evolution, logdensityof(hmm) + sum(fb_storage.logL)) - fit!(hmm, fb_storage, obs_seq; control_seq, seq_ends) + fit!(hmm, fb_storage, obs_seq, control_seq; seq_ends) if baum_welch_has_converged(logL_evolution; atol, loglikelihood_increasing) break end @@ -53,23 +53,23 @@ Return a tuple `(hmm_est, loglikelihood_evolution)` where `hmm_est` is the estim """ function baum_welch( hmm_guess::AbstractHMM, - obs_seq::AbstractVector; - control_seq::AbstractVector=Fill(nothing, length(obs_seq)), + obs_seq::AbstractVector, + control_seq::AbstractVector=Fill(nothing, length(obs_seq)); seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), atol=1e-5, max_iterations=100, loglikelihood_increasing=true, ) hmm = deepcopy(hmm_guess) - fb_storage = initialize_forward_backward(hmm, obs_seq; control_seq, seq_ends) + fb_storage = initialize_forward_backward(hmm, obs_seq, control_seq; seq_ends) logL_evolution = eltype(fb_storage)[] sizehint!(logL_evolution, max_iterations) baum_welch!( fb_storage, logL_evolution, hmm, - obs_seq; - control_seq, + obs_seq, + control_seq; seq_ends, atol, max_iterations, diff --git a/src/inference/chainrules.jl b/src/inference/chainrules.jl index db3cc380..424236e1 100644 --- a/src/inference/chainrules.jl +++ b/src/inference/chainrules.jl @@ -2,8 +2,8 @@ _dcat(M1, M2) = cat(M1, M2; dims=3) function _params_and_loglikelihoods( hmm::AbstractHMM, - obs_seq::Vector; - control_seq::AbstractVector=Fill(nothing, length(obs_seq)), + obs_seq::Vector, + control_seq::AbstractVector=Fill(nothing, length(obs_seq)); seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), ) init = initialization(hmm) @@ -20,16 +20,16 @@ function ChainRulesCore.rrule( rc::RuleConfig, ::typeof(logdensityof), hmm::AbstractHMM, - obs_seq::AbstractVector; - control_seq::AbstractVector=Fill(nothing, length(obs_seq)), + obs_seq::AbstractVector, + control_seq::AbstractVector=Fill(nothing, length(obs_seq)); seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), ) _, pullback = rrule_via_ad( - rc, _params_and_loglikelihoods, hmm, obs_seq; control_seq, seq_ends + rc, _params_and_loglikelihoods, hmm, obs_seq, control_seq; seq_ends ) - fb_storage = initialize_forward_backward(hmm, obs_seq; control_seq, seq_ends) - forward_backward!(fb_storage, hmm, obs_seq; control_seq, seq_ends) - @unpack logL, α, β, γ, Bβ = fb_storage + fb_storage = initialize_forward_backward(hmm, obs_seq, control_seq; seq_ends) + forward_backward!(fb_storage, hmm, obs_seq, control_seq; seq_ends) + (; logL, α, γ, Bβ) = fb_storage N, T = length(hmm), length(obs_seq) R = eltype(α) @@ -45,11 +45,11 @@ function ChainRulesCore.rrule( ΔlogB = γ function logdensityof_hmm_pullback(ΔlogL) - _, Δhmm, Δobs_seq = pullback(( + _, Δhmm, Δobs_seq, Δcontrol_seq = pullback(( ΔlogL .* Δinit, ΔlogL .* Δtrans_by_time, ΔlogL .* ΔlogB )) Δlogdensityof = NoTangent() - return Δlogdensityof, Δhmm, Δobs_seq + return Δlogdensityof, Δhmm, Δobs_seq, Δcontrol_seq end return sum(logL), logdensityof_hmm_pullback diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 850231ef..b18a5301 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -23,8 +23,8 @@ $(SIGNATURES) """ function initialize_forward( hmm::AbstractHMM, - obs_seq::AbstractVector; - control_seq::AbstractVector, + obs_seq::AbstractVector, + control_seq::AbstractVector; seq_ends::AbstractVector{Int}, ) N, T, K = length(hmm), length(obs_seq), length(seq_ends) @@ -43,11 +43,11 @@ function forward!( storage, hmm::AbstractHMM, obs_seq::AbstractVector, + control_seq::AbstractVector, t1::Integer, t2::Integer; - control_seq::AbstractVector, ) - @unpack α, B, c = storage + (; α, B, c) = storage # Initialization Bₜ₁ = view(B, :, t1) @@ -89,14 +89,14 @@ $(SIGNATURES) function forward!( storage, hmm::AbstractHMM, - obs_seq::AbstractVector; - control_seq::AbstractVector, + obs_seq::AbstractVector, + control_seq::AbstractVector; seq_ends::AbstractVector{Int}, ) - @unpack α, logL, B, c = storage + (; α, logL) = storage for k in eachindex(seq_ends) t1, t2 = seq_limits(seq_ends, k) - logL[k] = forward!(storage, hmm, obs_seq, t1, t2; control_seq) + logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;) end check_finite(α) return nothing @@ -111,11 +111,11 @@ Return a tuple `(storage.α, sum(storage.logL))` where `storage` is of type [`Fo """ function forward( hmm::AbstractHMM, - obs_seq::AbstractVector; - control_seq::AbstractVector=Fill(nothing, length(obs_seq)), + obs_seq::AbstractVector, + control_seq::AbstractVector=Fill(nothing, length(obs_seq)); seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), ) - storage = initialize_forward(hmm, obs_seq; control_seq, seq_ends) - forward!(storage, hmm, obs_seq; control_seq, seq_ends) + storage = initialize_forward(hmm, obs_seq, control_seq; seq_ends) + forward!(storage, hmm, obs_seq, control_seq; seq_ends) return storage.α, sum(storage.logL) end diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index 4da63a40..3b70b4ec 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -28,8 +28,8 @@ $(SIGNATURES) """ function initialize_forward_backward( hmm::AbstractHMM, - obs_seq::AbstractVector; - control_seq::AbstractVector, + obs_seq::AbstractVector, + control_seq::AbstractVector; seq_ends::AbstractVector{Int}, transition_marginals=true, ) @@ -61,15 +61,15 @@ function forward_backward!( storage::ForwardBackwardStorage{R}, hmm::AbstractHMM, obs_seq::AbstractVector, + control_seq::AbstractVector, t1::Integer, t2::Integer; - control_seq::AbstractVector, transition_marginals::Bool=true, ) where {R} - @unpack α, β, c, γ, ξ, B, Bβ = storage + (; α, β, c, γ, ξ, B, Bβ) = storage # Forward (fill B, α, c and logL) - logL = forward!(storage, hmm, obs_seq, t1, t2; control_seq) + logL = forward!(storage, hmm, obs_seq, control_seq, t1, t2) # Backward β[:, t2] .= c[t2] @@ -102,16 +102,16 @@ $(SIGNATURES) function forward_backward!( storage::ForwardBackwardStorage{R}, hmm::AbstractHMM, - obs_seq::AbstractVector; - control_seq::AbstractVector, + obs_seq::AbstractVector, + control_seq::AbstractVector; seq_ends::AbstractVector{Int}, transition_marginals::Bool=true, ) where {R} - @unpack logL, α, β, c, γ, ξ, B, Bβ = storage + (; logL, γ) = storage for k in eachindex(seq_ends) t1, t2 = seq_limits(seq_ends, k) logL[k] = forward_backward!( - storage, hmm, obs_seq, t1, t2; control_seq, transition_marginals + storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals ) end check_finite(γ) @@ -127,14 +127,14 @@ Return a tuple `(storage.γ, sum(storage.logL))` where `storage` is of type [`Fo """ function forward_backward( hmm::AbstractHMM, - obs_seq::AbstractVector; - control_seq::AbstractVector=Fill(nothing, length(obs_seq)), + obs_seq::AbstractVector, + control_seq::AbstractVector=Fill(nothing, length(obs_seq)); seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), ) transition_marginals = false storage = initialize_forward_backward( - hmm, obs_seq; control_seq, seq_ends, transition_marginals + hmm, obs_seq, control_seq; seq_ends, transition_marginals ) - forward_backward!(storage, hmm, obs_seq; control_seq, seq_ends, transition_marginals) + forward_backward!(storage, hmm, obs_seq, control_seq; seq_ends, transition_marginals) return storage.γ, sum(storage.logL) end diff --git a/src/inference/logdensity.jl b/src/inference/logdensity.jl index d8974cd3..4d82e152 100644 --- a/src/inference/logdensity.jl +++ b/src/inference/logdensity.jl @@ -5,11 +5,11 @@ Run the forward algorithm to compute the loglikelihood of `obs_seq` for `hmm`, i """ function DensityInterface.logdensityof( hmm::AbstractHMM, - obs_seq::AbstractVector; - control_seq::AbstractVector=Fill(nothing, length(obs_seq)), + obs_seq::AbstractVector, + control_seq::AbstractVector=Fill(nothing, length(obs_seq)); seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), ) - _, logL = forward(hmm, obs_seq; control_seq, seq_ends) + _, logL = forward(hmm, obs_seq, control_seq; seq_ends) return logL end @@ -18,11 +18,11 @@ $(SIGNATURES) Run the forward algorithm to compute the the joint loglikelihood of `obs_seq` and `state_seq` for `hmm`. """ -function DensityInterface.logdensityof( +function joint_logdensityof( hmm::AbstractHMM, obs_seq::AbstractVector, - state_seq::AbstractVector; - control_seq::AbstractVector=Fill(nothing, length(obs_seq)), + state_seq::AbstractVector, + control_seq::AbstractVector=Fill(nothing, length(obs_seq)); seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), ) R = eltype(hmm, obs_seq[1], control_seq[1]) diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index 4a23645c..9e4b8ca0 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -24,8 +24,8 @@ $(SIGNATURES) """ function initialize_viterbi( hmm::AbstractHMM, - obs_seq::AbstractVector; - control_seq::AbstractVector, + obs_seq::AbstractVector, + control_seq::AbstractVector; seq_ends::AbstractVector{Int}, ) N, T, K = length(hmm), length(obs_seq), length(seq_ends) @@ -45,11 +45,11 @@ function viterbi!( storage::ViterbiStorage{R}, hmm::AbstractHMM, obs_seq::AbstractVector, + control_seq::AbstractVector, t1::Integer, t2::Integer; - control_seq::AbstractVector, ) where {R} - @unpack q, logB, ϕ, ψ = storage + (; q, logB, ϕ, ψ) = storage obs_logdensities!(view(logB, :, t1), hmm, obs_seq[t1], control_seq[t1]) init = initialization(hmm) @@ -86,14 +86,14 @@ $(SIGNATURES) function viterbi!( storage::ViterbiStorage{R}, hmm::AbstractHMM, - obs_seq::AbstractVector; - control_seq::AbstractVector, + obs_seq::AbstractVector, + control_seq::AbstractVector; seq_ends::AbstractVector{Int}, ) where {R} - @unpack q, logL, logB, ϕ, ψ = storage + (; logL, ϕ) = storage for k in eachindex(seq_ends) t1, t2 = seq_limits(seq_ends, k) - logL[k] = viterbi!(storage, hmm, obs_seq, t1, t2; control_seq) + logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;) end check_right_finite(ϕ) return nothing @@ -108,11 +108,11 @@ Return a tuple `(storage.q, sum(storage.logL))` where `storage` is of type [`Vit """ function viterbi( hmm::AbstractHMM, - obs_seq::AbstractVector; - control_seq::AbstractVector=Fill(nothing, length(obs_seq)), + obs_seq::AbstractVector, + control_seq::AbstractVector=Fill(nothing, length(obs_seq)); seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), ) - storage = initialize_viterbi(hmm, obs_seq; control_seq, seq_ends) - viterbi!(storage, hmm, obs_seq; control_seq, seq_ends) + storage = initialize_viterbi(hmm, obs_seq, control_seq; seq_ends) + viterbi!(storage, hmm, obs_seq, control_seq; seq_ends) return storage.q, sum(storage.logL) end diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 2d269253..d3aa828d 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -35,11 +35,11 @@ obs_distributions(hmm::HMM) = hmm.dists function StatsAPI.fit!( hmm::HMM, fb_storage::ForwardBackwardStorage, - obs_seq::AbstractVector; - control_seq::AbstractVector, + obs_seq::AbstractVector, + control_seq::AbstractVector; seq_ends::AbstractVector{Int}, ) - @unpack γ, ξ = fb_storage + (; γ, ξ) = fb_storage # Fit states hmm.init .= zero(eltype(hmm.init)) hmm.trans .= zero(eltype(hmm.trans)) diff --git a/src/utils/linalg.jl b/src/utils/linalg.jl index 4e4f1a7f..72a7abeb 100644 --- a/src/utils/linalg.jl +++ b/src/utils/linalg.jl @@ -3,6 +3,8 @@ sum_to_one!(x) = ldiv!(sum(x), x) mysimilar_mutable(x::AbstractArray, ::Type{R}) where {R} = similar(x, R) mynonzeros(x::AbstractArray) = x +mynonzeros(x::AbstractSparseArray) = nonzeros(x) + mynnz(x) = length(mynonzeros(x)) function mul_rows_cols!( @@ -11,3 +13,17 @@ function mul_rows_cols!( B .= l .* A .* r' return nothing end + +function mul_rows_cols!( + B::SparseMatrixCSC, l::AbstractVector, A::SparseMatrixCSC, r::AbstractVector +) + @assert size(B) == size(A) == (length(l), length(r)) + @assert nnz(B) == nnz(A) + for j in axes(B, 2) + for k in nzrange(B, j) + i = B.rowval[k] + B.nzval[k] = l[i] * A.nzval[k] * r[j] + end + end + return nothing +end diff --git a/test/Project.toml b/test/Project.toml index 25c7c270..8b0091d1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,7 +12,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" diff --git a/test/correctness.jl b/test/correctness.jl index 8860d698..c88acfa5 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -5,7 +5,6 @@ using HiddenMarkovModels: LightDiagNormal, LightCategorical using HMMTest using LinearAlgebra using Random: Random, AbstractRNG, default_rng, seed! -using SimpleUnPack using SparseArrays using Test