Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
Merge #692
Browse files Browse the repository at this point in the history
692: Fix sparse mul! r=amontoison a=amontoison

close #629 
#630 , #637 
@haampie 

Co-authored-by: Alexis Montoison <[email protected]>
  • Loading branch information
bors[bot] and amontoison authored Apr 27, 2020
2 parents ee16e77 + 2ec414f commit 16c0080
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/sparse/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Base.:(\)(adjA::Adjoint{T, LowerTriangular{T, S}},B::CuMatrix{T}) where {T<:Blas

LinearAlgebra.mul!(C::CuVector{T},A::CuSparseMatrix,B::CuVector) where {T} = mv!('N',one(T),A,B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any,<:CuSparseMatrix},B::CuVector) where {T} = mv!('T',one(T),parent(transA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any,<:CuSparseMatrix},B::CuVector) where {T} = mv!('C',one(T),parent(transA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any,<:CuSparseMatrix},B::CuVector) where {T} = mv!('C',one(T),parent(adjA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuVector{T},A::HermOrSym{T,<:CuSparseMatrix{T}},B::CuVector{T}) where T = mv!('N',one(T),A,B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::CuVector{T}) where {T} = mv!('T',one(T),parent(transA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::CuVector{T}) where {T} = mv!('C',one(T),parent(adjA),B,zero(T),C,'O')
Expand Down
16 changes: 10 additions & 6 deletions src/sparse/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -567,17 +567,21 @@ for (fname,elty) in ((:cusparseScsrmv, :Float32),
ctransa = 'T'
end
cutransa = cusparseop(ctransa)
cuind = cusparseindex(index)
cudesc = getDescr(A,index)
n,m = Mat.dims
if ctransa == 'N'
chkmvdims(X,n,Y,m)
chkmvdims(X, n, Y, m)
end
if ctransa == 'T' || ctransa == 'C'
chkmvdims(X,m,Y,n)
chkmvdims(X, m, Y, n)
end
cudesc = getDescr(A,index)
nzVal = Mat.nzVal
if transa == 'C' && $elty <: Complex
nzVal = conj(Mat.nzVal)
end
$fname(handle(), cutransa, m, n, Mat.nnz, [alpha], Ref(cudesc),
Mat.nzVal, Mat.colPtr, Mat.rowVal, X, [beta], Y)
$fname(handle(),
cutransa, m, n, Mat.nnz, [alpha], Ref(cudesc), nzVal,
Mat.colPtr, Mat.rowVal, X, [beta], Y)
Y
end
end
Expand Down
46 changes: 46 additions & 0 deletions test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2006,4 +2006,50 @@ end
end
end

@testset "mul!" begin
for elty in [Float32,Float64,ComplexF32,ComplexF64]
A = sparse(rand(elty,m,m))
x = rand(elty,m)
y = rand(elty,m)
@testset "csr -- $elty" begin
d_x = CuArray(x)
d_y = CuArray(y)
d_A = CuSparseMatrixCSR(A)
d_Aᵀ = transpose(d_A)
d_Aᴴ = adjoint(d_A)
CUSPARSE.mul!(d_y, d_A, d_x)
h_y = collect(d_y)
z = A * x
@test z h_y
CUSPARSE.mul!(d_y, d_Aᵀ, d_x)
h_y = collect(d_y)
z = transpose(A) * x
@test z h_y
CUSPARSE.mul!(d_y, d_Aᴴ, d_x)
h_y = collect(d_y)
z = adjoint(A) * x
@test z h_y
end
@testset "csc -- $elty" begin
d_x = CuArray(x)
d_y = CuArray(y)
d_A = CuSparseMatrixCSC(A)
d_Aᵀ = transpose(d_A)
d_Aᴴ = adjoint(d_A)
CUSPARSE.mul!(d_y, d_A, d_x)
h_y = collect(d_y)
z = A * x
@test z h_y
CUSPARSE.mul!(d_y, d_Aᵀ, d_x)
h_y = collect(d_y)
z = transpose(A) * x
@test z h_y
CUSPARSE.mul!(d_y, d_Aᴴ, d_x)
h_y = collect(d_y)
z = adjoint(A) * x
@test z h_y
end
end
end

end

0 comments on commit 16c0080

Please sign in to comment.