Skip to content

Commit

Permalink
Special case for sparse matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Feb 24, 2024
1 parent 2746f09 commit 083bbf9
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 12 deletions.
2 changes: 1 addition & 1 deletion examples/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ Another useful array type is [StaticArrays.jl](https://github.com/JuliaArrays/St

# ## Tests #src

@test_broken nnz(log_transition_matrix(hmm)) == nnz(transition_matrix(hmm)) #src
@test nnz(log_transition_matrix(hmm)) == nnz(transition_matrix(hmm)) #src

seq_ends = cumsum(rand(rng, 100:200, 100)); #src
control_seq = fill(nothing, last(seq_ends)); #src
Expand Down
2 changes: 1 addition & 1 deletion src/inference/viterbi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ function viterbi!(
logtrans = log_transition_matrix(hmm, control_seq[t - 1])
ϕₜ, ϕₜ₋₁ = view(ϕ, :, t), view(ϕ, :, t - 1)
ψₜ = view(ψ, :, t)
argmaxplus_mul!(ϕₜ, ψₜ, transpose(logtrans), ϕₜ₋₁)
argmaxplus_transmul!(ϕₜ, ψₜ, logtrans, ϕₜ₋₁)
ϕₜ .+= logBₜ
end

Expand Down
2 changes: 1 addition & 1 deletion src/types/hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ function StatsAPI.fit!(
end
# Update logs
hmm.loginit .= log.(hmm.init)
hmm.logtrans .= log.(hmm.trans)
mynonzeros(hmm.logtrans) .= log.(mynonzeros(hmm.trans))
# Safety check
@argcheck valid_hmm(hmm)
return nothing
Expand Down
47 changes: 38 additions & 9 deletions src/utils/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ mynnz(x::AbstractArray) = length(mynonzeros(x))

elementwise_log(x::AbstractArray) = log.(x)

function elementwise_log(A::SparseMatrixCSC)
return SparseMatrixCSC(A.m, A.n, A.colptr, A.rowval, log.(A.nzval))
end

"""
mul_rows_cols!(B, l, A, r)
Expand Down Expand Up @@ -40,26 +44,51 @@ function mul_rows_cols!(
end

"""
argmaxplus_mul!(y, ind, A, x)
argmaxplus_transmul!(y, ind, A, x)
Perform the in-place multiplication `A * x` _in the sense of max-plus algebra_, store the result in `y`, and store the index of the maximum for each row in `ind`.
Perform the in-place multiplication `transpose(A) * x` _in the sense of max-plus algebra_, store the result in `y`, and store the index of the maximum for each component of `y` in `ind`.
"""
function argmaxplus_mul!(
function argmaxplus_transmul!(
y::AbstractVector{R},
ind::AbstractVector{<:Integer},
A::AbstractMatrix,
x::AbstractVector,
) where {R}
@argcheck axes(A, 1) == eachindex(y)
@argcheck axes(A, 2) == eachindex(x)
@argcheck axes(A, 1) == eachindex(x)
@argcheck axes(A, 2) == eachindex(y)
y .= typemin(R)
ind .= 0
for j in axes(A, 2)
for i in axes(A, 1)
z = A[i, j] + x[j]
if z > y[i]
y[i] = z
ind[i] = j
z = A[i, j] + x[i]
if z > y[j]
y[j] = z
ind[j] = i
end
end
end
return y
end

function argmaxplus_transmul!(
y::AbstractVector{R},
ind::AbstractVector{<:Integer},
A::SparseMatrixCSC,
x::AbstractVector,
) where {R}
@argcheck axes(A, 1) == eachindex(x)
@argcheck axes(A, 2) == eachindex(y)
Anz = nonzeros(A)
Arv = rowvals(A)
y .= typemin(R)
ind .= 0
for j in axes(A, 2)
for k in nzrange(A, j)
i = Arv[k]
z = Anz[k] + x[i]
if z > y[j]
y[j] = z
ind[j] = i
end
end
end
Expand Down

0 comments on commit 083bbf9

Please sign in to comment.