Skip to content

Commit

Permalink
fix other solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich committed Mar 31, 2023
1 parent ef68e4d commit 80af119
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 61 deletions.
159 changes: 98 additions & 61 deletions lib/mps/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,110 +192,147 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T}

commit!(cmdbuf)

wait_completed(cmdbuf)

return B
end


function LinearAlgebra.:(\)(A::LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
C = deepcopy(B)
LinearAlgebra.ldiv!(A, C)
return C
end


function LinearAlgebra.ldiv!(A::LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
orig = size(B)
M,N = size(B)[1], ndims(B) > 1 ? size(B)[2] : 1
M,N = size(B,1), size(B,2)
dev = current_device()
queue = global_queue(dev)

B = reshape(B, (N,M))
At = similar(A.factors)
Bt = similar(B, (N,M))
P = reshape((A.ipiv .- UInt32(1)), (1,M))
X = similar(B)
X = similar(B, (N,M))

mps_a = MPSMatrix(A.factors)
mps_b = MPSMatrix(B)
transpose!(At, A.factors)
transpose!(Bt, B)

mps_a = MPSMatrix(At)
mps_b = MPSMatrix(Bt)
mps_p = MPSMatrix(P)
mps_x = MPSMatrix(X)

MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveLU(dev, true, M, N)
kernel = MPSMatrixSolveLU(dev, false, M, N)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x)
end

B .= X
B = reshape(B, orig)
transpose!(B, X)
return B
end

function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
M,N = size(B)

function LinearAlgebra.ldiv!(A::UpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
M,N = size(B,1), size(B,2)
dev = current_device()
queue = global_queue(dev)
cmdbuf = MTLCommandBuffer(queue)
enqueue!(cmdbuf)

Bh = reshape(B, )
X = MtlMatrix{T}(undef, size(B))
Ad = MtlMatrix(A')
Br = similar(B, (M,M))
X = similar(Br)

mps_a = MPSMatrix(A)
mps_b = MPSMatrix(Bh) # TODO reshape to matrix if B is a vector
transpose!(Br, B)

mps_a = MPSMatrix(Ad)
mps_b = MPSMatrix(Br)
mps_x = MPSMatrix(X)

solve_kernel = MPSMatrixSolveTriangular(dev, false, false, false, true, M, N, 1.0)
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
commit!(cmdbuf)
buf = MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, N, M, 1.0)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
end

return X
wait_completed(buf)

copy!(B, X)
return B
end

function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
M,N = size(B)

function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
M,N = size(B,1), size(B,2)
dev = current_device()
queue = global_queue(dev)
cmdbuf = MTLCommandBuffer(queue)
enqueue!(cmdbuf)

X = MtlMatrix{T}(undef, size(B))
Ad = MtlMatrix(A)
Br = reshape(B, (M,N))
X = similar(Br)

mps_a = MPSMatrix(A)
mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector
mps_a = MPSMatrix(Ad)
mps_b = MPSMatrix(Br)
mps_x = MPSMatrix(X)

solve_kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, M, N, 1.0)
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
commit!(cmdbuf)

buf = MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, true, false, false, true, M, N, 1.0)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
end

return X
wait_completed(buf)

copy!(Br, X)
return B
end

function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
M,N = size(B)

function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
M,N = size(B,1), size(B,2)
dev = current_device()
queue = global_queue(dev)
cmdbuf = MTLCommandBuffer(queue)
enqueue!(cmdbuf)

X = MtlMatrix{T}(undef, size(B))
Ad = MtlMatrix(A)
Br = reshape(B, (M,N))
X = similar(Br)

mps_a = MPSMatrix(A)
mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector
mps_a = MPSMatrix(Ad)
mps_b = MPSMatrix(Br)
mps_x = MPSMatrix(X)

solve_kernel = MPSMatrixSolveTriangular(dev, false, true, false, true, M, N, 1.0)
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
commit!(cmdbuf)

buf = MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, true, true, false, false, M, N, 1.0)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
end

wait_completed(buf)

return X
copy!(Br, X)
return B
end

# function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
# require_one_based_indexing(A, B)
# m, n = size(A)
# if m == n
# if istril(A)
# if istriu(A)
# return Diagonal(A) \ B
# else
# return LowerTriangular(A) \ B
# end
# end
# if istriu(A)
# return UpperTriangular(A) \ B
# end
# return lu(A) \ B
# end
# return qr(A, ColumnNorm()) \ B
# end

function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
M,N = size(B,1), size(B,2)
dev = current_device()
queue = global_queue(dev)

Ad = MtlMatrix(A)
Br = reshape(B, (M,N))
X = similar(Br)

mps_a = MPSMatrix(Ad)
mps_b = MPSMatrix(Br)
mps_x = MPSMatrix(X)


buf = MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, true, true, false, true, M, N, 1.0)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
end

wait_completed(buf)

copy!(Br, X)
return B
end
35 changes: 35 additions & 0 deletions test/mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,39 @@ end
@test_throws SingularException lu(A)
end

@testset "solves" begin
b = MtlVector(rand(Float32, 1024))
B = MtlMatrix(rand(Float32, 1024, 1024))

A = MtlMatrix(rand(Float32, 1024, 512))
x = lu(A) \ b
@test A * x b
X = lu(A) \ B
@test A * X B

A = UpperTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
x = A \ b
@test A * x b
X = A \ B
@test A * X B

A = UnitUpperTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
x = A \ b
@test A * x b
X = A \ B
@test A * X B

A = LowerTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
x = A \ b
@test A * x b
X = A \ B
@test A * X B

A = UnitLowerTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
x = A \ b
@test A * x b
X = A \ B
@test A * X B
end

end

0 comments on commit 80af119

Please sign in to comment.