Skip to content

Commit

Permalink
[NDTensors] Fix contracting dense with diag on GPU (#1453)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT authored May 31, 2024
1 parent e6cdf37 commit 99baf1d
Show file tree
Hide file tree
Showing 12 changed files with 114 additions and 15 deletions.
10 changes: 6 additions & 4 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721"
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -35,17 +34,19 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
TBLIS = "48530278-0828-4a49-9772-0f3830dfa1e9"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[extensions]
NDTensorsAMDGPUExt = "AMDGPU"
NDTensorsCUDAExt = "CUDA"
NDTensorsAMDGPUExt = ["AMDGPU","GPUArraysCore"]
NDTensorsCUDAExt = ["CUDA","GPUArraysCore"]
NDTensorsGPUArraysCoreExt = "GPUArraysCore"
NDTensorsHDF5Ext = "HDF5"
NDTensorsMetalExt = "Metal"
NDTensorsMetalExt = ["GPUArraysCore", "Metal"]
NDTensorsOctavianExt = "Octavian"
NDTensorsTBLISExt = "TBLIS"
NDTensorscuTENSORExt = "cuTENSOR"
Expand Down Expand Up @@ -90,6 +91,7 @@ julia = "1.6"
[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/ext/NDTensorsAMDGPUExt/append.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using GPUArraysCore: @allowscalar
using AMDGPU: ROCArray
using GPUArraysCore: @allowscalar
using NDTensors.Expose: Exposed, unexpose

## Warning this append function uses scalar indexing and is therefore extremely slow
Expand Down
4 changes: 2 additions & 2 deletions NDTensors/ext/NDTensorsAMDGPUExt/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using NDTensors.Expose: Exposed, expose, parent, unexpose
using NDTensors.GPUArraysCoreExtensions: cpu
using AMDGPU: AMDGPU, ROCArray
using GPUArraysCore: @allowscalar
using NDTensors.Expose: Exposed, expose, parent, unexpose
using NDTensors.GPUArraysCoreExtensions: cpu

function Base.getindex(E::Exposed{<:ROCArray})
return @allowscalar unexpose(E)[]
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/ext/NDTensorsCUDAExt/append.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using GPUArraysCore: @allowscalar
using CUDA: CuArray
using GPUArraysCore: @allowscalar
using NDTensors.Expose: Exposed, unexpose

## Warning this append function uses scalar indexing and is therefore extremely slow
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module NDTensorsGPUArraysCoreExt
include("contract.jl")
end
48 changes: 48 additions & 0 deletions NDTensors/ext/NDTensorsGPUArraysCoreExt/contract.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using Adapt: adapt
using GPUArraysCore: AbstractGPUArray
using NDTensors: NDTensors, DenseTensor, DiagTensor, contract!, dense, inds, Tensor
using NDTensors.Expose: Exposed, expose, unexpose
using NDTensors.TypeParameterAccessors: parenttype, set_ndims

## In this function we convert the DiagTensor to a dense tensor and
## Feed it back into contract
function NDTensors.contract!(
output_tensor::Exposed{<:AbstractGPUArray,<:DenseTensor},
labelsoutput_tensor,
tensor1::Exposed{<:Number,<:DiagTensor},
labelstensor1,
tensor2::Exposed{<:AbstractGPUArray,<:DenseTensor},
labelstensor2,
α::Number=one(Bool),
β::Number=zero(Bool),
)
tensor1 = unexpose(tensor1)
## convert tensor1 to a dense
## TODO this allocates on CPU first then moves over to GPU which could be slow
tensor1 = adapt(set_ndims(parenttype(typeof(tensor2)), 1), dense(tensor1))
return contract!(
output_tensor,
labelsoutput_tensor,
expose(tensor1),
labelstensor1,
tensor2,
labelstensor2,
α,
β,
)
end

function NDTensors.contract!(
output_tensor::Exposed{<:AbstractGPUArray,<:DenseTensor},
labelsoutput_tensor,
tensor1::Exposed{<:AbstractGPUArray,<:DenseTensor},
labelstensor1,
tensor2::Exposed{<:Number,<:DiagTensor},
labelstensor2,
α::Number=one(Bool),
β::Number=zero(Bool),
)
return contract!(
output_tensor, labelsoutput_tensor, tensor2, labelstensor2, tensor1, labelstensor1, α, β
)
end
1 change: 1 addition & 0 deletions NDTensors/ext/NDTensorsMetalExt/permutedims.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Metal: MtlArray
using GPUArraysCore: @allowscalar
using NDTensors.Expose: Exposed, expose, unexpose
## Theres an issue in metal that `ReshapedArray' wrapped arrays cannot be permuted using
## permutedims (failing in that Metal uses scalar indexing)
Expand Down
4 changes: 3 additions & 1 deletion NDTensors/src/blocksparse/diagblocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,9 @@ function contract!(
# Overwrite the block of R
β = zero(ElR)
end
contract!(Rblock, labelsR, T1block, labelsT1, T2block, labelsT2, α, β)
contract!(
expose(Rblock), labelsR, expose(T1block), labelsT1, expose(T2block), labelsT2, α, β
)
end
return R
end
Expand Down
1 change: 0 additions & 1 deletion NDTensors/src/imports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ using Base.Threads
using Compat
using Dictionaries
using Folds
using GPUArraysCore
using InlineStrings
using Random
using LinearAlgebra
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using GPUArraysCore: AbstractGPUArray, @allowscalar
using NDTensors.Expose: Exposed, unexpose
using NDTensors.TypeParameterAccessors:
TypeParameterAccessors, type_parameter, set_type_parameter
Expand Down
14 changes: 11 additions & 3 deletions NDTensors/test/test_diag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
using NDTensors
using Test: @testset, @test, @test_throws
using GPUArraysCore: @allowscalar
using Adapt: adapt
include("NDTensorsTestUtils/NDTensorsTestUtils.jl")
using .NDTensorsTestUtils: devices_list, is_supported_eltype
using LinearAlgebra: dot

@testset "DiagTensor basic functionality" begin
@testset "test device: $dev" for dev in devices_list(copy(ARGS)),
Expand Down Expand Up @@ -56,13 +58,19 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
@test @allowscalar S1 S2
end
end
@testset "DiagTensor contractions" begin
t = tensor(Diag([1.0, 1.0, 1.0]), (3, 3))
A = randomTensor(Dense, (3, 3))
@testset "DiagTensor contractions" for dev in devices_list(copy(ARGS))
## TODO add more GPU tests
elt = (dev == NDTensors.mtl ? Float32 : Float64)
t = tensor(Diag(elt[1.0, 1.0, 1.0]), (3, 3))
A = randomTensor(Dense{elt}, (3, 3))

@test contract(t, (1, -2), t, (-2, 3)) == t
@test contract(A, (1, -2), t, (-2, 3)) == A
@test contract(A, (-2, 1), t, (-2, 3)) == transpose(A)

## Testing sparse contractions on GPU
t = tensor(Diag(one(elt)), (3, 3))
@test contract(t, (-1, -2), dev(A), (-1, -2))[] dot(t, A) rtol = sqrt(eps(elt))
end
nothing
end
39 changes: 38 additions & 1 deletion NDTensors/test/test_diagblocksparse.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
@eval module $(gensym())
using Dictionaries: Dictionary
using GPUArraysCore: @allowscalar
using NDTensors:
NDTensors,
Block,
BlockSparseTensor,
Diag,
DiagBlockSparse,
Tensor,
blockoffsets,
contract,
dense,
inds,
nzblocks
using Random: randn!
using Test: @test, @test_throws, @testset
using Test: @test, @test_broken, @test_throws, @testset
@testset "UniformDiagBlockSparseTensor basic functionality" begin
NeverAlias = NDTensors.NeverAlias
AllowAlias = NDTensors.AllowAlias
Expand Down Expand Up @@ -48,4 +52,37 @@ end
@test_throws ErrorException contract(a1, labels1, a2, labels2)
end
end

include("NDTensorsTestUtils/NDTensorsTestUtils.jl")
using .NDTensorsTestUtils: devices_list
@testset "DiagBlockSparse contract" for dev in devices_list(copy(ARGS))
elt = dev == NDTensors.mtl ? Float32 : Float64
A = dev(BlockSparseTensor{elt}([(1, 1), (2, 2)], [2, 2], [2, 2]))
randn!(A)
t = Tensor(DiagBlockSparse(one(elt), blockoffsets(A)), inds(A))
tdense = Tensor(Diag(one(elt)), inds(A))

a = dense(contract(A, (1, -2), t, (3, -2)))
b = contract(dense(A), (1, -2), tdense, (3, -2))
@test @allowscalar a b

a = dense(contract(A, (-2, 1), t, (-2, 3)))
b = contract(dense(A), (-2, 1), tdense, (-2, 3))
@test @allowscalar a b

a = contract(A, (-1, -2), t, (-1, -2))[]
b = contract(dense(A), (-1, -2), tdense, (-1, -2))[]
@test @allowscalar a b

## TODO fix these kinds of contractions
A = BlockSparseTensor{elt}([(1, 1), (2, 2)], [3, 2, 3], [2, 2])
randn!(A)
t = Tensor(DiagBlockSparse(one(elt), blockoffsets(A)), inds(A))
@test_broken dense(contract(A, (1, -2), (t), (3, -2)))
contract(dense(A), (1, -2), dense(t), (3, -2))
@test_broken dense(contract(A, (-2, 1), t, (-2, 3)))
contract(dense(A), (-2, 1), dense(t), (-2, 3))
@test_broken contract(dev(A), (-1, -2), dev(t), (-1, -2))[]
contract(dense(A), (-1, -2), dense(t), (-1, -2))[]
end
end

0 comments on commit 99baf1d

Please sign in to comment.