From 083bbf98ed9df6af759f3031b54f0d0a9de2f72c Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 24 Feb 2024 20:45:33 +0100 Subject: [PATCH] Special case for sparse matrices --- examples/types.jl | 2 +- src/inference/viterbi.jl | 2 +- src/types/hmm.jl | 2 +- src/utils/linalg.jl | 47 ++++++++++++++++++++++++++++++++-------- 4 files changed, 41 insertions(+), 12 deletions(-) diff --git a/examples/types.jl b/examples/types.jl index d4966a5f..0945be63 100644 --- a/examples/types.jl +++ b/examples/types.jl @@ -154,7 +154,7 @@ Another useful array type is [StaticArrays.jl](https://github.com/JuliaArrays/St # ## Tests #src -@test_broken nnz(log_transition_matrix(hmm)) == nnz(transition_matrix(hmm)) #src +@test nnz(log_transition_matrix(hmm)) == nnz(transition_matrix(hmm)) #src seq_ends = cumsum(rand(rng, 100:200, 100)); #src control_seq = fill(nothing, last(seq_ends)); #src diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index abf254e8..09e18a26 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -62,7 +62,7 @@ function viterbi!( logtrans = log_transition_matrix(hmm, control_seq[t - 1]) ϕₜ, ϕₜ₋₁ = view(ϕ, :, t), view(ϕ, :, t - 1) ψₜ = view(ψ, :, t) - argmaxplus_mul!(ϕₜ, ψₜ, transpose(logtrans), ϕₜ₋₁) + argmaxplus_transmul!(ϕₜ, ψₜ, logtrans, ϕₜ₋₁) ϕₜ .+= logBₜ end diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 476eaeae..a9310798 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -79,7 +79,7 @@ function StatsAPI.fit!( end # Update logs hmm.loginit .= log.(hmm.init) - hmm.logtrans .= log.(hmm.trans) + mynonzeros(hmm.logtrans) .= log.(mynonzeros(hmm.trans)) # Safety check @argcheck valid_hmm(hmm) return nothing diff --git a/src/utils/linalg.jl b/src/utils/linalg.jl index 07f8f31a..506b29e4 100644 --- a/src/utils/linalg.jl +++ b/src/utils/linalg.jl @@ -7,6 +7,10 @@ mynnz(x::AbstractArray) = length(mynonzeros(x)) elementwise_log(x::AbstractArray) = log.(x) +function elementwise_log(A::SparseMatrixCSC) + return SparseMatrixCSC(A.m, A.n, A.colptr, A.rowval, log.(A.nzval)) +end + """ mul_rows_cols!(B, l, A, r) @@ -40,26 +44,51 @@ function mul_rows_cols!( end """ - argmaxplus_mul!(y, ind, A, x) + argmaxplus_transmul!(y, ind, A, x) -Perform the in-place multiplication `A * x` _in the sense of max-plus algebra_, store the result in `y`, and store the index of the maximum for each row in `ind`. +Perform the in-place multiplication `transpose(A) * x` _in the sense of max-plus algebra_, store the result in `y`, and store the index of the maximum for each component of `y` in `ind`. """ -function argmaxplus_mul!( +function argmaxplus_transmul!( y::AbstractVector{R}, ind::AbstractVector{<:Integer}, A::AbstractMatrix, x::AbstractVector, ) where {R} - @argcheck axes(A, 1) == eachindex(y) - @argcheck axes(A, 2) == eachindex(x) + @argcheck axes(A, 1) == eachindex(x) + @argcheck axes(A, 2) == eachindex(y) y .= typemin(R) ind .= 0 for j in axes(A, 2) for i in axes(A, 1) - z = A[i, j] + x[j] - if z > y[i] - y[i] = z - ind[i] = j + z = A[i, j] + x[i] + if z > y[j] + y[j] = z + ind[j] = i + end + end + end + return y +end + +function argmaxplus_transmul!( + y::AbstractVector{R}, + ind::AbstractVector{<:Integer}, + A::SparseMatrixCSC, + x::AbstractVector, +) where {R} + @argcheck axes(A, 1) == eachindex(x) + @argcheck axes(A, 2) == eachindex(y) + Anz = nonzeros(A) + Arv = rowvals(A) + y .= typemin(R) + ind .= 0 + for j in axes(A, 2) + for k in nzrange(A, j) + i = Arv[k] + z = Anz[k] + x[i] + if z > y[j] + y[j] = z + ind[j] = i end end end