Skip to content

Commit

Permalink
Reduce allocations for multipling LazyTensor of sparse and dense (#80)
Browse files Browse the repository at this point in the history
Avoid reshape by letting sparse lazytensor gemm routines work on vectors.

Also check dimensions.

---------

Co-authored-by: Amit Rotem <[email protected]>
  • Loading branch information
AmitRotem and Amit Rotem authored Mar 27, 2023
1 parent 3ed0e84 commit 2506d0d
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 26 deletions.
8 changes: 6 additions & 2 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -392,5 +392,9 @@ multiplicable(a::AbstractOperator, b::Ket) = multiplicable(a.basis_r, b.basis)
multiplicable(a::Bra, b::AbstractOperator) = multiplicable(a.basis, b.basis_l)
multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l)

Base.size(op::AbstractOperator) = prod(length(op.basis_l),length(op.basis_r))
Base.size(op::AbstractOperator, i::Int) = (i==1 ? length(op.basis_l) : length(op.basis_r))
Base.size(op::AbstractOperator) = (length(op.basis_l),length(op.basis_r))
function Base.size(op::AbstractOperator, i::Int)
i < 1 && throw(ErrorException(lazy"dimension out of range, should be strictly positive, got $i"))
i > 2 && return 1
i==1 ? length(op.basis_l) : length(op.basis_r)
end
4 changes: 4 additions & 0 deletions src/operators_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ function _strides(shape)
return S
end

function _strides(shape::Ty)::Ty where Ty <: Tuple
accumulate(*, (1,Base.front(shape)...))
end

# Dense operator version
@generated function _ptrace(::Type{Val{RANK}}, a,
shape_l, shape_r,
Expand Down
74 changes: 50 additions & 24 deletions src/operators_lazytensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,11 @@ end
function _gemm_recursive_dense_lazy(i_k, N_k, K, J, val,
shape, strides_k, strides_j,
indices, h::LazyTensor,
op::Matrix, result::Matrix)
op::AbstractArray, result::AbstractArray)
if i_k > N_k
for I=1:size(op, 1)
if isa(op, AbstractVector)
result[K] += val*op[J]
else I=1:size(op, 1)
result[I, K] += val*op[I, J]
end
return nothing
Expand Down Expand Up @@ -609,7 +611,7 @@ end
function _gemm_recursive_lazy_dense(i_k, N_k, K, J, val,
shape, strides_k, strides_j,
indices, h::LazyTensor,
op::Matrix, result::Matrix)
op::AbstractArray, result::AbstractArray)
if i_k > N_k
for I=1:size(op, 2)
result[J, I] += val*op[K, I]
Expand Down Expand Up @@ -641,45 +643,69 @@ function _gemm_recursive_lazy_dense(i_k, N_k, K, J, val,
end
end

function _gemm_puresparse(alpha, op::Matrix, h::LazyTensor{B1,B2,F,I,T}, beta, result::Matrix) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
"""
check_mul!_compatibility(R, A, B)
Check that `R,A,B` are dimentially compatible for `R.=A*B`. And that `R` is not aliased with either `A` nor `B`.
"""
function check_mul!_compatibility(R::AbstractVecOrMat, A, B)
_check_mul!_aliasing_compatibility(R, A, B)
_check_mul!_dim_compatibility(size(R), size(A), size(B))
end
function _check_mul!_dim_compatibility(sizeR::Tuple, sizeA::Tuple, sizeB::Tuple)
# R .= A*B
if sizeA[2] != sizeB[1]
throw(DimensionMismatch(lazy"A has dimensions $sizeA but B has dimensions $sizeB. Can't do `A*B`"))
end
if sizeR != (sizeA[1], Base.tail(sizeB)...) # using tail to account for vectors
throw(DimensionMismatch(lazy"R has dimensions $sizeR but A*B has dimensions $((sizeA[1], Base.tail(sizeB)...)). Can't do `R.=A*B`"))
end
end
function _check_mul!_aliasing_compatibility(R, A, B)
if R===A || R===B
throw(ArgumentError(lazy"output matrix must not be aliased with input matrix"))
end
end


function _gemm_puresparse(alpha, op::AbstractArray, h::LazyTensor{B1,B2,F,I,T}, beta, result::AbstractArray) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
if op isa AbstractVector
# _gemm_recursive_dense_lazy will treat `op` as a `Bra`
_check_mul!_aliasing_compatibility(result, op, h)
_check_mul!_dim_compatibility(size(result), reverse(size(h)), size(op))
else
check_mul!_compatibility(result, op, h)
end
if iszero(beta)
fill!(result, beta)
elseif !isone(beta)
rmul!(result, beta)
end
N_k = length(h.basis_r.bases)
shape = [min(h.basis_l.shape[i], h.basis_r.shape[i]) for i=1:length(h.basis_l.shape)]
strides_j = _strides(h.basis_l.shape)
strides_k = _strides(h.basis_r.shape)
shape, strides_j, strides_k = _get_shape_and_strides(h)
_gemm_recursive_dense_lazy(1, N_k, 1, 1, alpha*h.factor, shape, strides_k, strides_j, h.indices, h, op, result)
end

function _gemm_puresparse(alpha, h::LazyTensor{B1,B2,F,I,T}, op::Matrix, beta, result::Matrix) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
function _gemm_puresparse(alpha, h::LazyTensor{B1,B2,F,I,T}, op::AbstractArray, beta, result::AbstractArray) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
check_mul!_compatibility(result, h, op)
if iszero(beta)
fill!(result, beta)
elseif !isone(beta)
rmul!(result, beta)
end
N_k = length(h.basis_l.bases)
shape = [min(h.basis_l.shape[i], h.basis_r.shape[i]) for i=1:length(h.basis_l.shape)]
strides_j = _strides(h.basis_l.shape)
strides_k = _strides(h.basis_r.shape)
shape, strides_j, strides_k = _get_shape_and_strides(h)
_gemm_recursive_lazy_dense(1, N_k, 1, 1, alpha*h.factor, shape, strides_k, strides_j, h.indices, h, op, result)
end

function _get_shape_and_strides(h)
shape_l, shape_r = _comp_size(h.basis_l), _comp_size(h.basis_r)
shape = min.(shape_l, shape_r)
strides_j, strides_k = _strides(shape_l), _strides(shape_r)
return shape, strides_j, strides_k
end

_mul_puresparse!(result::DenseOpType{B1,B3},h::LazyTensor{B1,B2,F,I,T},op::DenseOpType{B2,B3},alpha,beta) where {B1,B2,B3,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, h, op.data, beta, result.data); result)
_mul_puresparse!(result::DenseOpType{B1,B3},op::DenseOpType{B1,B2},h::LazyTensor{B2,B3,F,I,T},alpha,beta) where {B1,B2,B3,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, op.data, h, beta, result.data); result)
_mul_puresparse!(result::Ket{B1},a::LazyTensor{B1,B2,F,I,T},b::Ket{B2},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, a, b.data, beta, result.data); result)
_mul_puresparse!(result::Bra{B2},a::Bra{B1},b::LazyTensor{B1,B2,F,I,T},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, a.data, b, beta, result.data); result)

function _mul_puresparse!(result::Ket{B1},a::LazyTensor{B1,B2,F,I,T},b::Ket{B2},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
b_data = reshape(b.data, length(b.data), 1)
result_data = reshape(result.data, length(result.data), 1)
_gemm_puresparse(alpha, a, b_data, beta, result_data)
result
end

function _mul_puresparse!(result::Bra{B2},a::Bra{B1},b::LazyTensor{B1,B2,F,I,T},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
a_data = reshape(a.data, 1, length(a.data))
result_data = reshape(result.data, 1, length(result.data))
_gemm_puresparse(alpha, a_data, b, beta, result_data)
result
end
10 changes: 10 additions & 0 deletions test/test_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,14 @@ op12 = destroy(bfock)⊗sigmap(bspin)
@test embed(b, [1,2], op12) == destroy(bfock)sigmap(bspin)one(bspin)
@test embed(b, [1,3], op12) == destroy(bfock)one(bspin)sigmap(bspin)

# size of AbstractOperator
b1, b2 = NLevelBasis.((2, 3))
Lop1 = LazyTensor(b1^2, b2^2, 2, sparse(randoperator(b1, b2)))
@test size(Lop1) == size(dense(Lop1)) == size(dense(Lop1).data)
@test all(size(Lop1, k) == size(dense(Lop1), k) for k=1:4)
@test_throws ErrorException size(Lop1, 0)
@test_throws ErrorException size(Lop1, -1)
@test_throws ErrorException size(dense(Lop1), 0) # check for consistency
@test_throws ErrorException size(dense(Lop1), -1)

end # testset
8 changes: 8 additions & 0 deletions test/test_operators_lazytensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -404,5 +404,13 @@ dop = randoperator(b3a⊗b3b, b2a⊗b2b)
@test dop*lop' Operator(dop.basis_l, lop.basis_l, dop.data*dense(lop).data')
@test lop*dop' Operator(lop.basis_l, dop.basis_l, dense(lop).data*dop.data')

# Dimension mismatches for LazyTensor with sparse
b1, b2 = NLevelBasis.((2, 3))
Lop1 = LazyTensor(b1^2, b2^2, 2, sparse(randoperator(b1, b2)))
@test_throws DimensionMismatch Lop1*Lop1
@test_throws DimensionMismatch dense(Lop1)*Lop1
@test_throws DimensionMismatch sparse(Lop1)*Lop1
@test_throws DimensionMismatch Lop1*dense(Lop1)
@test_throws DimensionMismatch Lop1*sparse(Lop1)

end # testset

0 comments on commit 2506d0d

Please sign in to comment.