From d9e4a5d5cd776724ee1c40c825f88d3e555adfca Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 29 Sep 2024 10:47:50 +0200 Subject: [PATCH] Minor optimizations and simd (#108) * Minor optimizations and simd * Fix JET --- src/types/hmm.jl | 6 +++--- src/utils/lightcategorical.jl | 2 +- src/utils/lightdiagnormal.jl | 4 ++-- src/utils/linalg.jl | 20 ++++++++++---------- test/runtests.jl | 4 +++- 5 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 9de3cf2e..ca6d33c3 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -69,13 +69,13 @@ function StatsAPI.fit!( t1, t2 = seq_limits(seq_ends, k) # use ξ[t2] as scratch space since it is zero anyway scratch = ξ[t2] - scratch .= zero(eltype(scratch)) + fill!(scratch, zero(eltype(scratch))) for t in t1:(t2 - 1) scratch .+= ξ[t] end end - hmm.init .= zero(eltype(hmm.init)) - hmm.trans .= zero(eltype(hmm.trans)) + fill!(hmm.init, zero(eltype(hmm.init))) + fill!(hmm.trans, zero(eltype(hmm.trans))) for k in eachindex(seq_ends) t1, t2 = seq_limits(seq_ends, k) hmm.init .+= view(γ, :, t1) diff --git a/src/utils/lightcategorical.jl b/src/utils/lightcategorical.jl index f16f46dc..fd96dd29 100644 --- a/src/utils/lightcategorical.jl +++ b/src/utils/lightcategorical.jl @@ -52,7 +52,7 @@ function StatsAPI.fit!( ) where {T1} @argcheck 1 <= minimum(x) <= maximum(x) <= length(dist.p) w_tot = sum(w) - dist.p .= zero(T1) + fill!(dist.p, zero(T1)) @inbounds @simd for i in eachindex(x, w) dist.p[x[i]] += w[i] end diff --git a/src/utils/lightdiagnormal.jl b/src/utils/lightdiagnormal.jl index 8b0748d6..6c84ac44 100644 --- a/src/utils/lightdiagnormal.jl +++ b/src/utils/lightdiagnormal.jl @@ -56,8 +56,8 @@ function StatsAPI.fit!( dist::LightDiagNormal{T1,T2}, x::AbstractVector{<:AbstractVector}, w::AbstractVector ) where {T1,T2} w_tot = sum(w) - dist.μ .= zero(T1) - dist.σ .= zero(T2) + fill!(dist.μ, zero(T1)) + fill!(dist.σ, zero(T2)) @inbounds @simd for i in eachindex(x, w) dist.μ .+= x[i] .* w[i] end diff --git a/src/utils/linalg.jl b/src/utils/linalg.jl index 506b29e4..0c18596d 100644 --- a/src/utils/linalg.jl +++ b/src/utils/linalg.jl @@ -33,9 +33,9 @@ function mul_rows_cols!( Brv = rowvals(B) Bnz = nonzeros(B) Anz = nonzeros(A) - for j in axes(B, 2) + @simd for j in axes(B, 2) @argcheck nzrange(B, j) == nzrange(A, j) - for k in nzrange(B, j) + @simd for k in nzrange(B, j) i = Brv[k] Bnz[k] = l[i] * Anz[k] * r[j] end @@ -56,10 +56,10 @@ function argmaxplus_transmul!( ) where {R} @argcheck axes(A, 1) == eachindex(x) @argcheck axes(A, 2) == eachindex(y) - y .= typemin(R) - ind .= 0 - for j in axes(A, 2) - for i in axes(A, 1) + fill!(y, typemin(R)) + fill!(ind, 0) + @simd for j in axes(A, 2) + @simd for i in axes(A, 1) z = A[i, j] + x[i] if z > y[j] y[j] = z @@ -80,10 +80,10 @@ function argmaxplus_transmul!( @argcheck axes(A, 2) == eachindex(y) Anz = nonzeros(A) Arv = rowvals(A) - y .= typemin(R) - ind .= 0 - for j in axes(A, 2) - for k in nzrange(A, j) + fill!(y, typemin(R)) + fill!(ind, 0) + @simd for j in axes(A, 2) + @simd for k in nzrange(A, j) i = Arv[k] z = Anz[k] + x[i] if z > y[j] diff --git a/test/runtests.jl b/test/runtests.jl index 96b2f0fa..0d2bac03 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,7 +22,9 @@ Pkg.develop(; path=joinpath(dirname(@__DIR__), "libs", "HMMTest")) @testset "Code linting" begin using Distributions using Zygote - JET.test_package(HiddenMarkovModels; target_defined_modules=true) + if VERSION >= v"1.10" + JET.test_package(HiddenMarkovModels; target_defined_modules=true) + end end @testset "Distributions" begin