Skip to content

Commit

Permalink
Minor optimizations and simd
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Sep 29, 2024
1 parent a8b048a commit 9567c8c
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 16 deletions.
6 changes: 3 additions & 3 deletions src/types/hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ function StatsAPI.fit!(
t1, t2 = seq_limits(seq_ends, k)
# use ξ[t2] as scratch space since it is zero anyway
scratch = ξ[t2]
scratch .= zero(eltype(scratch))
fill!(scratch, zero(eltype(scratch)))
for t in t1:(t2 - 1)
scratch .+= ξ[t]
end
end
hmm.init .= zero(eltype(hmm.init))
hmm.trans .= zero(eltype(hmm.trans))
fill!(hmm.init, zero(eltype(hmm.init)))
fill!(hmm.trans, zero(eltype(hmm.trans)))
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
hmm.init .+= view(γ, :, t1)
Expand Down
2 changes: 1 addition & 1 deletion src/utils/lightcategorical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ function StatsAPI.fit!(
) where {T1}
@argcheck 1 <= minimum(x) <= maximum(x) <= length(dist.p)
w_tot = sum(w)
dist.p .= zero(T1)
fill!(dist.p, zero(T1))
@inbounds @simd for i in eachindex(x, w)
dist.p[x[i]] += w[i]
end
Expand Down
4 changes: 2 additions & 2 deletions src/utils/lightdiagnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ function StatsAPI.fit!(
dist::LightDiagNormal{T1,T2}, x::AbstractVector{<:AbstractVector}, w::AbstractVector
) where {T1,T2}
w_tot = sum(w)
dist.μ .= zero(T1)
dist.σ .= zero(T2)
fill!(dist.μ, zero(T1))
fill!(dist.σ, zero(T2))
@inbounds @simd for i in eachindex(x, w)
dist.μ .+= x[i] .* w[i]
end
Expand Down
20 changes: 10 additions & 10 deletions src/utils/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ function mul_rows_cols!(
Brv = rowvals(B)
Bnz = nonzeros(B)
Anz = nonzeros(A)
for j in axes(B, 2)
@simd for j in axes(B, 2)
@argcheck nzrange(B, j) == nzrange(A, j)
for k in nzrange(B, j)
@simd for k in nzrange(B, j)
i = Brv[k]
Bnz[k] = l[i] * Anz[k] * r[j]
end
Expand All @@ -56,10 +56,10 @@ function argmaxplus_transmul!(
) where {R}
@argcheck axes(A, 1) == eachindex(x)
@argcheck axes(A, 2) == eachindex(y)
y .= typemin(R)
ind .= 0
for j in axes(A, 2)
for i in axes(A, 1)
fill!(y, typemin(R))
fill!(ind, 0)
@simd for j in axes(A, 2)
@simd for i in axes(A, 1)
z = A[i, j] + x[i]
if z > y[j]
y[j] = z
Expand All @@ -80,10 +80,10 @@ function argmaxplus_transmul!(
@argcheck axes(A, 2) == eachindex(y)
Anz = nonzeros(A)
Arv = rowvals(A)
y .= typemin(R)
ind .= 0
for j in axes(A, 2)
for k in nzrange(A, j)
fill!(y, typemin(R))
fill!(ind, 0)
@simd for j in axes(A, 2)
@simd for k in nzrange(A, j)
i = Arv[k]
z = Anz[k] + x[i]
if z > y[j]
Expand Down

0 comments on commit 9567c8c

Please sign in to comment.