diff --git a/src/sparse/interfaces.jl b/src/sparse/interfaces.jl index 5cf3b1b7..352feb30 100644 --- a/src/sparse/interfaces.jl +++ b/src/sparse/interfaces.jl @@ -11,7 +11,7 @@ Base.:(\)(adjA::Adjoint{T, LowerTriangular{T, S}},B::CuMatrix{T}) where {T<:Blas LinearAlgebra.mul!(C::CuVector{T},A::CuSparseMatrix,B::CuVector) where {T} = mv!('N',one(T),A,B,zero(T),C,'O') LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any,<:CuSparseMatrix},B::CuVector) where {T} = mv!('T',one(T),parent(transA),B,zero(T),C,'O') -LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any,<:CuSparseMatrix},B::CuVector) where {T} = mv!('C',one(T),parent(transA),B,zero(T),C,'O') +LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any,<:CuSparseMatrix},B::CuVector) where {T} = mv!('C',one(T),parent(adjA),B,zero(T),C,'O') LinearAlgebra.mul!(C::CuVector{T},A::HermOrSym{T,<:CuSparseMatrix{T}},B::CuVector{T}) where T = mv!('N',one(T),A,B,zero(T),C,'O') LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::CuVector{T}) where {T} = mv!('T',one(T),parent(transA),B,zero(T),C,'O') LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::CuVector{T}) where {T} = mv!('C',one(T),parent(adjA),B,zero(T),C,'O') diff --git a/src/sparse/wrappers.jl b/src/sparse/wrappers.jl index 8769dc01..60218c21 100644 --- a/src/sparse/wrappers.jl +++ b/src/sparse/wrappers.jl @@ -567,17 +567,21 @@ for (fname,elty) in ((:cusparseScsrmv, :Float32), ctransa = 'T' end cutransa = cusparseop(ctransa) - cuind = cusparseindex(index) - cudesc = getDescr(A,index) n,m = Mat.dims if ctransa == 'N' - chkmvdims(X,n,Y,m) + chkmvdims(X, n, Y, m) end if ctransa == 'T' || ctransa == 'C' - chkmvdims(X,m,Y,n) + chkmvdims(X, m, Y, n) + end + cudesc = getDescr(A,index) + nzVal = Mat.nzVal + if transa == 'C' && $elty <: Complex + nzVal = conj(Mat.nzVal) end - $fname(handle(), cutransa, m, n, Mat.nnz, [alpha], Ref(cudesc), - Mat.nzVal, Mat.colPtr, Mat.rowVal, X, [beta], Y) + $fname(handle(), + cutransa, m, n, Mat.nnz, [alpha], Ref(cudesc), nzVal, + Mat.colPtr, Mat.rowVal, X, [beta], Y) Y end end diff --git a/test/sparse.jl b/test/sparse.jl index d4b946d8..a6128011 100644 --- a/test/sparse.jl +++ b/test/sparse.jl @@ -2006,4 +2006,50 @@ end end end +@testset "mul!" begin + for elty in [Float32,Float64,ComplexF32,ComplexF64] + A = sparse(rand(elty,m,m)) + x = rand(elty,m) + y = rand(elty,m) + @testset "csr -- $elty" begin + d_x = CuArray(x) + d_y = CuArray(y) + d_A = CuSparseMatrixCSR(A) + d_Aᵀ = transpose(d_A) + d_Aᴴ = adjoint(d_A) + CUSPARSE.mul!(d_y, d_A, d_x) + h_y = collect(d_y) + z = A * x + @test z ≈ h_y + CUSPARSE.mul!(d_y, d_Aᵀ, d_x) + h_y = collect(d_y) + z = transpose(A) * x + @test z ≈ h_y + CUSPARSE.mul!(d_y, d_Aᴴ, d_x) + h_y = collect(d_y) + z = adjoint(A) * x + @test z ≈ h_y + end + @testset "csc -- $elty" begin + d_x = CuArray(x) + d_y = CuArray(y) + d_A = CuSparseMatrixCSC(A) + d_Aᵀ = transpose(d_A) + d_Aᴴ = adjoint(d_A) + CUSPARSE.mul!(d_y, d_A, d_x) + h_y = collect(d_y) + z = A * x + @test z ≈ h_y + CUSPARSE.mul!(d_y, d_Aᵀ, d_x) + h_y = collect(d_y) + z = transpose(A) * x + @test z ≈ h_y + CUSPARSE.mul!(d_y, d_Aᴴ, d_x) + h_y = collect(d_y) + z = adjoint(A) * x + @test z ≈ h_y + end + end +end + end