Skip to content

Commit

Permalink
add solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich committed Mar 22, 2023
1 parent 67dcf41 commit 2edbc71
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lib/mps/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@ include("decomposition.jl")
# matrix copy
include("copy.jl")

# solver
include("solve.jl")

end
105 changes: 105 additions & 0 deletions lib/mps/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,108 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T}

return B
end


function LinearAlgebra.ldiv!(A::LU{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
# TODO
end

function LinearAlgebra.ldiv!(A::UpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
M,N = size(B)
dev = current_device()
queue = global_queue(dev)
cmdbuf = MTLCommandBuffer(queue)
enqueue!(cmdbuf)

X = MtlMatrix{T}(undef, size(B))

mps_a = MPSMatrix(A)
mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector
mps_x = MPSMatrix(X)

solve_kernel = MPSMatrixSolveTriangular(dev, false, false, false, false, M, N, 1.0) # TODO: likely N, M is the correct order
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
commit!(cmdbuf)

return X
end

function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
M,N = size(B)
dev = current_device()
queue = global_queue(dev)
cmdbuf = MTLCommandBuffer(queue)
enqueue!(cmdbuf)

Bh = reshape(B, )
X = MtlMatrix{T}(undef, size(B))

mps_a = MPSMatrix(A)
mps_b = MPSMatrix(Bh) # TODO reshape to matrix if B is a vector
mps_x = MPSMatrix(X)

solve_kernel = MPSMatrixSolveTriangular(dev, false, false, false, true, M, N, 1.0)
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
commit!(cmdbuf)

return X
end

function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
M,N = size(B)
dev = current_device()
queue = global_queue(dev)
cmdbuf = MTLCommandBuffer(queue)
enqueue!(cmdbuf)

X = MtlMatrix{T}(undef, size(B))

mps_a = MPSMatrix(A)
mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector
mps_x = MPSMatrix(X)

solve_kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, M, N, 1.0)
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
commit!(cmdbuf)

return X
end

function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
M,N = size(B)
dev = current_device()
queue = global_queue(dev)
cmdbuf = MTLCommandBuffer(queue)
enqueue!(cmdbuf)

X = MtlMatrix{T}(undef, size(B))

mps_a = MPSMatrix(A)
mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector
mps_x = MPSMatrix(X)

solve_kernel = MPSMatrixSolveTriangular(dev, false, true, false, true, M, N, 1.0)
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
commit!(cmdbuf)

return X
end

# function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
# require_one_based_indexing(A, B)
# m, n = size(A)
# if m == n
# if istril(A)
# if istriu(A)
# return Diagonal(A) \ B
# else
# return LowerTriangular(A) \ B
# end
# end
# if istriu(A)
# return UpperTriangular(A) \ B
# end
# return lu(A) \ B
# end
# return qr(A, ColumnNorm()) \ B
# end
72 changes: 72 additions & 0 deletions lib/mps/solve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
export MPSMatrixSolveLU

@objcwrapper immutable=false MPSMatrixSolveLU <: MPSMatrixBinaryKernel

function MPSMatrixSolveLU(device, transpose, order, numberOfRightHandSides)
kernel = @objc [MPSMatrixSolveLU alloc]::id{MPSMatrixSolveLU}
obj = MPSMatrixSolveLU(kernel)
finalizer(release, obj)
@objc [obj::id{MPSMatrixSolveLU} initWithDevice:device::id{MTLDevice}
transpose:transpose::Bool
order:order::NSUInteger
numberOfRightHandSides:numberOfRightHandSides::NSUInteger]::id{MPSMatrixSolveLU}
return obj
end

function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveLU, sourceMatrix, rightHandSideMatrix, pivotIndices, solutionMatrix)
@objc [kernel::id{MPSMatrixSolveLU} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
sourceMatrix:sourceMatrix::id{MPSMatrix}
rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix}
pivotIndices:pivotIndices::id{MPSMatrix}
solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing
end


export MPSMatrixSolveTriangular

@objcwrapper immutable=false MPSMatrixSolveTriangular <: MPSMatrixBinaryKernel

function MPSMatrixSolveTriangular(device, right, upper, transpose, unit, order, numberOfRightHandSides, alpha)
kernel = @objc [MPSMatrixSolveTriangular alloc]::id{MPSMatrixSolveTriangular}
obj = MPSMatrixSolveTriangular(kernel)
finalizer(release, obj)
@objc [obj::id{MPSMatrixSolveTriangular} initWithDevice:device::id{MTLDevice}
right:right::Bool
upper:upper::Bool
transpose:transpose::Bool
unit:unit::Bool
order:order::NSUInteger
numberOfRightHandSides:numberOfRightHandSides::NSUInteger
alpha:alpha::Cdouble]::id{MPSMatrixSolveTriangular}
return obj
end

function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveTriangular, sourceMatrix, rightHandSideMatrix, solutionMatrix)
@objc [kernel::id{MPSMatrixSolveTriangular} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
sourceMatrix:sourceMatrix::id{MPSMatrix}
rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix}
solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing
end


export MPSMatrixSolveCholesky

@objcwrapper immutable=false MPSMatrixSolveCholesky <: MPSMatrixBinaryKernel

function MPSMatrixSolveCholesky(device, upper, order, numberOfRightHandSides)
kernel = @objc [MPSMatrixSolveCholesky alloc]::id{MPSMatrixSolveCholesky}
obj = MPSMatrixSolveCholesky(kernel)
finalizer(release, obj)
@objc [obj::id{MPSMatrixSolveCholesky} initWithDevice:device::id{MTLDevice}
upper:upper::Bool
order:order::NSUInteger
numberOfRightHandSides:numberOfRightHandSides::NSUInteger]::id{MPSMatrixSolveCholesky}
return obj
end

function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveCholesky, sourceMatrix, rightHandSideMatrix, solutionMatrix)
@objc [kernel::id{MPSMatrixSolveCholesky} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
sourceMatrix:sourceMatrix::id{MPSMatrix}
rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix}
solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing
end

0 comments on commit 2edbc71

Please sign in to comment.