Skip to content

Commit

Permalink
Check dimensions in sparse gemm and gemv (#83)
Browse files Browse the repository at this point in the history
* Check dimensions.

* Switch to DimensionMismatch

* Add some tests.

---------

Co-authored-by: Ashley Milsted <[email protected]>
  • Loading branch information
amilsted and Ashley Milsted authored Mar 21, 2023
1 parent f6b3987 commit 3ed0e84
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ end


function gemm!(alpha, M::SparseMatrixCSC, B::AbstractMatrix, beta, result::AbstractMatrix)
size(M, 2) == size(B, 1) || throw(DimensionMismatch())
size(M, 1) == size(result, 1) || throw(DimensionMismatch())
size(B, 2) == size(result, 2) || throw(DimensionMismatch())
if iszero(beta)
fill!(result, beta)
elseif !isone(beta)
Expand All @@ -112,6 +115,9 @@ function gemm!(alpha, M::SparseMatrixCSC, B::AbstractMatrix, beta, result::Abstr
end

function gemm!(alpha, B::AbstractMatrix, M::SparseMatrixCSC, beta, result::AbstractMatrix)
size(M, 1) == size(B, 2) || throw(DimensionMismatch())
size(M, 2) == size(result,2) || throw(DimensionMismatch())
size(B, 1) == size(result,1) || throw(DimensionMismatch())
if iszero(beta)
fill!(result, beta)
elseif !isone(beta)
Expand Down Expand Up @@ -146,6 +152,9 @@ function gemm!(alpha, M_::Adjoint{T,<:SparseMatrixCSC{T}}, B::AbstractMatrix, be
if nnz(M) > 550
LinearAlgebra.mul!(result, M_, B, alpha, beta)
else
size(M_, 2) == size(B, 1) || throw(DimensionMismatch())
size(M_, 1) == size(result, 1) || throw(DimensionMismatch())
size(B, 2) == size(result, 2) || throw(DimensionMismatch())
if iszero(beta)
fill!(result, beta)
elseif !isone(beta)
Expand All @@ -156,6 +165,9 @@ function gemm!(alpha, M_::Adjoint{T,<:SparseMatrixCSC{T}}, B::AbstractMatrix, be
end

function gemm!(alpha, B::AbstractMatrix, M::Adjoint{T,<:SparseMatrixCSC{T}}, beta, result::AbstractMatrix) where T
size(M, 1) == size(B, 2) || throw(DimensionMismatch())
size(M, 2) == size(result,2) || throw(DimensionMismatch())
size(B, 1) == size(result,1) || throw(DimensionMismatch())
if iszero(beta)
fill!(result, beta)
elseif !isone(beta)
Expand All @@ -178,6 +190,9 @@ function gemm!(alpha, A::SparseMatrixCSC, B::SparseMatrixCSC, beta, result::Abst
end

function gemv!(alpha, M::SparseMatrixCSC, v::AbstractVector, beta, result::AbstractVector)
size(M, 2) == size(v, 1) || throw(DimensionMismatch())
size(M, 1) == size(result, 1) || throw(DimensionMismatch())

if iszero(beta)
fill!(result, beta)
elseif !isone(beta)
Expand All @@ -201,6 +216,9 @@ function gemv!(alpha, M::SparseMatrixCSC, v::AbstractVector, beta, result::Abstr
end

function gemv!(alpha, v::AbstractVector, M::SparseMatrixCSC, beta, result::AbstractVector)
size(M, 1) == size(v, 1) || throw(DimensionMismatch())
size(M, 2) == size(result, 1) || throw(DimensionMismatch())

if iszero(beta)
fill!(result, beta)
elseif !isone(beta)
Expand Down
7 changes: 7 additions & 0 deletions test/test_operators_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,13 @@ op3 = randoperator(bf)
@test_throws QuantumOpticsBase.IncompatibleBases op1 .+ op3
@test_throws ErrorException cos.(op1)

# Dimension mismatches
b1, b2, b3 = NLevelBasis.((2,3,4)) # N is not a type parameter
@test_throws DimensionMismatch mul!(randstate(b1), randoperator(b2), randstate(b3))
@test_throws DimensionMismatch mul!(randstate(b1)', randstate(b3)', randoperator(b2))
@test_throws DimensionMismatch mul!(randoperator(b1), randoperator(b2), randoperator(b3))
@test_throws DimensionMismatch mul!(randoperator(b1), randoperator(b3)', randoperator(b2))

end # testset

@testset "State-operator tensor products" begin
Expand Down
7 changes: 7 additions & 0 deletions test/test_operators_sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,13 @@ op_ .+= op1
@test op_ == 2*op1
@test_throws ErrorException cos.(op_)

# Dimension mismatches
b1, b2, b3 = NLevelBasis.((2,3,4)) # N is not a type parameter
@test_throws DimensionMismatch mul!(randstate(b1), sparse(randoperator(b2)), randstate(b3))
@test_throws DimensionMismatch mul!(randstate(b1)', randstate(b3)', sparse(randoperator(b2)))
@test_throws DimensionMismatch mul!(randoperator(b1), sparse(randoperator(b2)), randoperator(b3))
@test_throws DimensionMismatch mul!(randoperator(b1), randoperator(b3)', sparse(randoperator(b2)))

end # testset

@testset "State-operator tensor products, sparse" begin
Expand Down

0 comments on commit 3ed0e84

Please sign in to comment.