Skip to content

Commit

Permalink
Remove deps and fix Enzyme (gdalle#73)
Browse files Browse the repository at this point in the history
* Remove deps and fix Enzyme

* Skip allocation test on sparse matrices

* Move control_seq to positional argument
  • Loading branch information
gdalle authored Jan 12, 2024
1 parent ee33d20 commit 36b3313
Show file tree
Hide file tree
Showing 27 changed files with 190 additions and 209 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ scratchpad.jl
/docs/src/index.md
/docs/src/examples/*.md

.vscode/
.vscode/
*.ipynb
5 changes: 0 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,19 @@ version = "0.4.0"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"

[weakdeps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[extensions]
HiddenMarkovModelsDistributionsExt = "Distributions"
HiddenMarkovModelsSparseArraysExt = "SparseArrays"

[compat]
ChainRulesCore = "1.16"
Expand All @@ -33,7 +29,6 @@ FillArrays = "1"
LinearAlgebra = "1"
PrecompileTools = "1.1"
Random = "1"
SimpleUnPack = "1.1"
SparseArrays = "1"
StatsAPI = "1.6"
julia = "1.9"
46 changes: 23 additions & 23 deletions benchmark/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.9.4"
julia_version = "1.10.0"
manifest_format = "2.0"
project_hash = "a1b4318401476bf26277ef3565ca4043fa58d314"

Expand Down Expand Up @@ -61,7 +61,7 @@ weakdeps = ["Dates", "LinearAlgebra"]
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.0.5+0"
version = "1.0.5+1"

[[deps.Crayons]]
git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
Expand Down Expand Up @@ -161,15 +161,14 @@ uuid = "557005d5-2e4a-43f9-8aa7-ba8df2d03179"
version = "0.1.0"

[[deps.HiddenMarkovModels]]
deps = ["ChainRulesCore", "DensityInterface", "DocStringExtensions", "FillArrays", "LinearAlgebra", "PrecompileTools", "Random", "SimpleUnPack", "StatsAPI"]
deps = ["ChainRulesCore", "DensityInterface", "DocStringExtensions", "FillArrays", "LinearAlgebra", "PrecompileTools", "Random", "SparseArrays", "StatsAPI"]
path = ".."
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
version = "0.4.0"
weakdeps = ["Distributions", "SparseArrays"]
weakdeps = ["Distributions"]

[deps.HiddenMarkovModels.extensions]
HiddenMarkovModelsDistributionsExt = "Distributions"
HiddenMarkovModelsSparseArraysExt = "SparseArrays"

[[deps.HypergeometricFunctions]]
deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"]
Expand Down Expand Up @@ -236,9 +235,14 @@ uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
version = "8.4.0+0"

[[deps.LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[deps.LibGit2_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"]
uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5"
version = "1.6.4+0"

[[deps.LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
Expand Down Expand Up @@ -277,7 +281,7 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
version = "2.28.2+0"
version = "2.28.2+1"

[[deps.Missings]]
deps = ["DataAPI"]
Expand All @@ -290,7 +294,7 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[deps.MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
version = "2022.10.11"
version = "2023.1.10"

[[deps.NaNMath]]
deps = ["OpenLibm_jll"]
Expand All @@ -305,12 +309,12 @@ version = "1.2.0"
[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.21+4"
version = "0.3.23+2"

[[deps.OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
version = "0.8.1+0"
version = "0.8.1+2"

[[deps.OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
Expand Down Expand Up @@ -338,7 +342,7 @@ version = "2.8.1"
[[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.9.2"
version = "1.10.0"

[[deps.PooledArrays]]
deps = ["DataAPI", "Future"]
Expand Down Expand Up @@ -383,7 +387,7 @@ deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[deps.Random]]
deps = ["SHA", "Serialization"]
deps = ["SHA"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[deps.Reexport]]
Expand Down Expand Up @@ -416,11 +420,6 @@ version = "1.4.1"
[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[deps.SimpleUnPack]]
git-tree-sha1 = "58e6353e72cde29b90a69527e56df1b5c3d8c437"
uuid = "ce78b400-467f-4804-87d8-8f486da07d0a"
version = "1.1.0"

[[deps.Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

Expand All @@ -433,6 +432,7 @@ version = "1.2.0"
[[deps.SparseArrays]]
deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
version = "1.10.0"

[[deps.SpecialFunctions]]
deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
Expand All @@ -447,7 +447,7 @@ weakdeps = ["ChainRulesCore"]
[[deps.Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
version = "1.9.0"
version = "1.10.0"

[[deps.StatsAPI]]
deps = ["LinearAlgebra"]
Expand Down Expand Up @@ -483,9 +483,9 @@ deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[[deps.SuiteSparse_jll]]
deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"]
deps = ["Artifacts", "Libdl", "libblastrampoline_jll"]
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
version = "5.10.1+6"
version = "7.2.1+1"

[[deps.TOML]]
deps = ["Dates"]
Expand Down Expand Up @@ -543,12 +543,12 @@ version = "1.6.1"
[[deps.Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.13+0"
version = "1.2.13+1"

[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.8.0+0"
version = "5.8.0+1"

[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
Expand All @@ -558,4 +558,4 @@ version = "1.52.0+1"
[[deps.p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
version = "17.4.0+0"
version = "17.4.0+2"
1 change: 0 additions & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ seq_limits

```@docs
logdensityof
joint_logdensityof
forward
viterbi
forward_backward
Expand Down
30 changes: 15 additions & 15 deletions examples/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,36 +72,36 @@ grad_z = Zygote.gradient(f, params)[1]
grad_f grad_z

#=
For increased efficiency, one can use Enzyme.jl and provide temporary storage.
This requires going one level deeper, to the mutating [`HiddenMarkovModels.forward!`](@ref) function.
Enzyme.jl also works natively but we have to avoid the type instability of global variables by providing more information.
=#

control_seq = fill(nothing, length(obs_seq))

function f!(storage::HMMs.ForwardStorage, params::ComponentVector)
function f_extended(params::ComponentVector, obs_seq, seq_ends)
new_hmm = HMM(params.init, params.trans, Normal.(params.means))
HMMs.forward!(storage, new_hmm, obs_seq; control_seq, seq_ends)
return sum(storage.logL)
end
return logdensityof(new_hmm, obs_seq; seq_ends)
end;

storage = HMMs.initialize_forward(hmm, obs_seq; control_seq, seq_ends)
storage_shadow = HMMs.initialize_forward(hmm, obs_seq; control_seq, seq_ends)
params_shadow = zero(params)
shadow_params = Enzyme.make_zero(params)

Enzyme.autodiff(
Enzyme.Reverse,
f!,
f_extended,
Enzyme.Active,
Enzyme.Duplicated(storage, storage_shadow),
Enzyme.Duplicated(params, params_shadow),
Enzyme.Duplicated(params, shadow_params),
Enzyme.Const(obs_seq),
Enzyme.Const(seq_ends),
)

grad_e = params_shadow
grad_e = shadow_params

#-

grad_e grad_f

#=
For increased efficiency, one can provide temporary storage to Enzyme.jl in order to avoid allocations.
This requires going one level deeper, by leveraging the in-place [`HiddenMarkovModels.forward!`](@ref) function.
=#

# ## Gradient methods

#=
Expand Down
6 changes: 3 additions & 3 deletions examples/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,16 @@ Finally, we provide a thin wrapper ([`logdensityof`](@ref)) around the forward a
logdensityof(hmm, obs_seq)

#=
The same function can also compute joint loglikelihoods $\mathbb{P}(X_{1:T}, Y_{1:T})$ that take the states into account.
Another function can compute joint loglikelihoods $\mathbb{P}(X_{1:T}, Y_{1:T})$ which take the states into account.
=#

logdensityof(hmm, obs_seq, state_seq)
joint_logdensityof(hmm, obs_seq, state_seq)

#=
For instance, we can check that the output of Viterbi is at least as likely as the true state sequence.
=#

logdensityof(hmm, obs_seq, best_state_seq)
joint_logdensityof(hmm, obs_seq, best_state_seq)

# ## Learning

Expand Down
11 changes: 5 additions & 6 deletions examples/controlled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import HiddenMarkovModels as HMMs
using HMMTest #src
using LinearAlgebra
using Random
using SimpleUnPack
using StatsAPI
using Test #src

Expand Down Expand Up @@ -79,7 +78,7 @@ seq_ends = cumsum(length.(obs_seqs));
Not much changes from the case with simple time dependency.
=#

best_state_seq, _ = viterbi(hmm, obs_seq; control_seq, seq_ends)
best_state_seq, _ = viterbi(hmm, obs_seq, control_seq; seq_ends)

# ## Learning

Expand All @@ -92,11 +91,11 @@ Meanwhile, the observation coefficients are given by the formula for [weighted l
function StatsAPI.fit!(
hmm::ControlledGaussianHMM{T},
fb_storage::HMMs.ForwardBackwardStorage,
obs_seq::AbstractVector;
control_seq::AbstractVector,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
) where {T}
@unpack γ, ξ = fb_storage
(; γ, ξ) = fb_storage
N = length(hmm)

hmm.init .= 0
Expand Down Expand Up @@ -130,7 +129,7 @@ hmm_guess = ControlledGaussianHMM(init_guess, trans_guess, dist_coeffs_guess);

#-

hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq; control_seq, seq_ends)
hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends)
first(loglikelihood_evolution), last(loglikelihood_evolution)

#=
Expand Down
4 changes: 2 additions & 2 deletions examples/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ We will make use of the fields `fb_storage.γ` and `fb_storage.ξ`, which contai
function StatsAPI.fit!(
hmm::PriorHMM,
fb_storage::HiddenMarkovModels.ForwardBackwardStorage,
obs_seq::AbstractVector;
control_seq::AbstractVector,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
)
## initialize to defaults without observations
Expand Down
11 changes: 5 additions & 6 deletions examples/temporal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ using HiddenMarkovModels
import HiddenMarkovModels as HMMs
using HMMTest #src
using Random
using SimpleUnPack
using StatsAPI
using Test #src

Expand Down Expand Up @@ -88,7 +87,7 @@ seq_ends = cumsum(length.(obs_seqs));
All three inference algorithms work in the same way, except that we need to provide the control sequence as a keyword argument.
=#

best_state_seq, _ = viterbi(hmm, obs_seq; control_seq, seq_ends)
best_state_seq, _ = viterbi(hmm, obs_seq, control_seq; seq_ends)

#=
For Viterbi, unsurprisingly, the most likely state sequence aligns with the sign of the observations.
Expand All @@ -106,11 +105,11 @@ The key is to split the observations according to which periodic parameter they
function StatsAPI.fit!(
hmm::PeriodicHMM{T},
fb_storage::HMMs.ForwardBackwardStorage,
obs_seq::AbstractVector;
control_seq::AbstractVector,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
) where {T}
@unpack γ, ξ = fb_storage
(; γ, ξ) = fb_storage
L, N = period(hmm), length(hmm)

hmm.init .= zero(T)
Expand Down Expand Up @@ -159,7 +158,7 @@ hmm_guess = PeriodicHMM(init_guess, trans_per_guess, dists_per_guess);
Naturally, Baum-Welch also requires knowing `control_seq`.
=#

hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq; control_seq, seq_ends);
hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends);
first(loglikelihood_evolution), last(loglikelihood_evolution)

#=
Expand Down
3 changes: 2 additions & 1 deletion examples/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,5 @@ seq_ends = cumsum(length.(control_seqs)); #src
test_identical_hmmbase(rng, hmm, hmm_guess; T=100) #src
test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=0.05, init=false) #src
test_type_stability(rng, hmm, hmm_guess; control_seq, seq_ends) #src
test_allocations(rng, hmm, hmm_guess; control_seq, seq_ends) #src
# https://github.com/JuliaSparse/SparseArrays.jl/issues/469 #src
@test_skip test_allocations(rng, hmm, hmm_guess; control_seq, seq_ends) #src
Loading

0 comments on commit 36b3313

Please sign in to comment.