From 7924c976a7df0478f5b5c6c7b8be291d522bd718 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Mon, 14 Aug 2023 22:25:07 +0200 Subject: [PATCH] Partial 1.10 enablement (#330) Co-authored-by: Tim Besard --- .buildkite/pipeline.yml | 6 +++++- lib/mkl/linalg.jl | 36 ++++++++++++++++++++++++++++-------- lib/mkl/wrappers.jl | 2 +- test/onemkl.jl | 4 ++-- 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 666d4d25..ee19692b 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -31,12 +31,16 @@ steps: - "1.6" - "1.7" - "1.8" - - "1.9-nightly" + - "1.9" + - "1.10-nightly" - "nightly" adjustments: - with: julia: "nightly" soft_fail: true + - with: + julia: "1.10-nightly" + soft_fail: true # Special tests - group: ":eyes: Special" diff --git a/lib/mkl/linalg.jl b/lib/mkl/linalg.jl index f41e34b4..66db0123 100644 --- a/lib/mkl/linalg.jl +++ b/lib/mkl/linalg.jl @@ -137,6 +137,14 @@ if VERSION < v"1.10.0-DEV.1365" end # triangular +if isdefined(LinearAlgebra, :generic_trimatmul!) # VERSION >= v"1.10-DEVXYZ" +# multiplication +LinearAlgebra.generic_trimatmul!(c::oneStridedVector{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, b::AbstractVector{T}) where {T<:onemklFloat} = + trmv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, c === b ? c : copyto!(c, b)) +# division +LinearAlgebra.generic_trimatdiv!(C::oneStridedVector{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::AbstractVector{T}) where {T<:onemklFloat} = + trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B)) +else ## direct multiplication/division for (t, uploc, isunitc) in ((:LowerTriangular, 'L', 'N'), (:UnitLowerTriangular, 'L', 'U'), @@ -183,6 +191,7 @@ for (t, uploc, isunitc) in ((:LowerTriangular, 'U', 'N'), trsv!($uploc, 'C', $isunitc, parent(parent(A)), B) end end +end # VERSION # @@ -254,6 +263,16 @@ end end # VERSION # triangular +if isdefined(LinearAlgebra, :generic_trimatmul!) # VERSION >= v"1.10-DEVXYZ" +LinearAlgebra.generic_trimatmul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} = + trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B)) +LinearAlgebra.generic_mattrimul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} = + trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A)) +LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} = + trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B)) +LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} = + trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A)) +else ## direct multiplication/division for (t, uploc, isunitc) in ((:LowerTriangular, 'L', 'N'), (:UnitLowerTriangular, 'L', 'U'), @@ -261,16 +280,17 @@ for (t, uploc, isunitc) in ((:LowerTriangular, 'L', 'N'), (:UnitUpperTriangular, 'U', 'U')) @eval begin # Multiplication - LinearAlgebra.lmul!(A::$t{T,<:oneStridedVecOrMat}, - B::oneStridedVecOrMat{T}) where {T<:onemklFloat} = - trmm!('L', $uploc, 'N', $isunitc, one(T), parent(A), B, B) - LinearAlgebra.rmul!(A::oneStridedVecOrMat{T}, - B::$t{T,<:oneStridedVecOrMat}) where {T<:onemklFloat} = - trmm!('R', $uploc, 'N', $isunitc, one(T), parent(B), A, A) + LinearAlgebra.lmul!(A::$t{T,<:oneStridedMatrix}, + B::oneStridedMatrix{T}) where {T<:onemklFloat} = + trmm!('L', $uploc, 'N', $isunitc, one(T), parent(A), B) + LinearAlgebra.rmul!(A::oneStridedMatrix{T}, + B::$t{T,<:oneStridedMatrix}) where {T<:onemklFloat} = + trmm!('R', $uploc, 'N', $isunitc, one(T), parent(B), A) # Left division - LinearAlgebra.ldiv!(A::$t{T,<:oneStridedVecOrMat}, - B::oneStridedVecOrMat{T}) where {T<:onemklFloat} = + LinearAlgebra.ldiv!(A::$t{T,<:oneStridedMatrix}, + B::oneStridedMatrix{T}) where {T<:onemklFloat} = trsm!('L', $uploc, 'N', $isunitc, one(T), parent(A), B) end end +end # VERSION diff --git a/lib/mkl/wrappers.jl b/lib/mkl/wrappers.jl index 10cabc15..c5e5388e 100644 --- a/lib/mkl/wrappers.jl +++ b/lib/mkl/wrappers.jl @@ -1100,7 +1100,7 @@ function trmm(side::Char, alpha::Number, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where T - trmm!(side, uplo, transa, diag, alpha, A, B) + trmm!(side, uplo, transa, diag, alpha, A, copy(B)) end function trsm(side::Char, uplo::Char, diff --git a/test/onemkl.jl b/test/onemkl.jl index 0f757ae6..cfe46af9 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -655,9 +655,9 @@ end dA = oneArray(A) dB = oneArray(B) C = alpha*A*B - oneMKL.trmm('L','U','N','N',alpha,dA,dB) + dC = oneMKL.trmm('L','U','N','N',alpha,dA,dB) # move to host and compare - h_C = Array(dB) + h_C = Array(dC) @test C ≈ h_C end