Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BlockSparseArrays] Improve the design of block views #1481

Merged
merged 9 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.18"
version = "0.3.19"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@testset "dual axes" begin
r = gradedrange([U1(0) => 2, U1(1) => 2])
a = BlockSparseArray{elt}(dual(r), r)
a[Block(1, 1)] = randn(size(a[Block(1, 1)]))
a[Block(2, 2)] = randn(size(a[Block(2, 2)]))
a[Block(1, 1)] = randn(elt, size(a[Block(1, 1)]))
a[Block(2, 2)] = randn(elt, size(a[Block(2, 2)]))
a_dense = Array(a)
@test eachindex(a) == CartesianIndices(size(a))
for I in eachindex(a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using BlockArrays:
BlockRange,
BlockedUnitRange,
BlockVector,
BlockSlice,
block,
blockaxes,
blockedrange,
Expand All @@ -29,6 +30,36 @@ function sub_axis(a::AbstractUnitRange, indices::AbstractUnitRange)
return only(axes(blockedunitrange_getindices(a, indices)))
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices::BlockSlice{<:BlockRange{1}})
return sub_axis(a, indices.block)
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices::BlockSlice{<:Block{1}})
return sub_axis(a, Block(indices))
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices::BlockSlice{<:BlockIndexRange{1}})
return sub_axis(a, indices.block)
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices::Block)
return only(axes(blockedunitrange_getindices(a, indices)))
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices::BlockIndexRange)
return only(axes(blockedunitrange_getindices(a, indices)))
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices::AbstractVector{<:Block})
Expand Down Expand Up @@ -131,6 +162,14 @@ function blockrange(axis::AbstractUnitRange, r::BlockSlice)
return blockrange(axis, r.block)
end

function blockrange(axis::AbstractUnitRange, r::Block{1})
return r:r
end

function blockrange(axis::AbstractUnitRange, r::BlockIndexRange)
return Block(r):Block(r)
end

function blockrange(axis::AbstractUnitRange, r)
return error("Slicing not implemented for range of type `$(typeof(r))`.")
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,18 @@ end

# Materialize a SubArray view.
function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, axes)
# TODO: Make more generic for GPU.
a_dest = BlockSparseArray{eltype(a)}(axes)
a_dest .= a
return a_dest
end

# Materialize a SubArray view.
function ArrayLayouts.sub_materialize(
layout::BlockLayout{<:SparseLayout}, a, axes::Tuple{Vararg{Base.OneTo}}
)
# TODO: Make more generic for GPU.
a_dest = Array{eltype(a)}(undef, length.(axes))
a_dest .= a
return a_dest
end
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,14 @@ function SparseArrayInterface.sparse_map!(
for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...)
BI_dest = blockindexrange(a_dest, I)
BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs)
block_dest = @view a_dest[_block(BI_dest)]
block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs))
# TODO: Investigate why this doesn't work:
# block_dest = @view a_dest[_block(BI_dest)]
block_dest = blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...]
# TODO: Investigate why this doesn't work:
# block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs))
block_srcs = ntuple(length(a_srcs)) do i
return blocks(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
end
subblock_dest = @view block_dest[BI_dest.indices...]
subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs))
# TODO: Use `map!!` to handle immutable blocks.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
using BlockArrays: BlockIndexRange, BlockRange, BlockSlice, block

function blocksparse_view(a::AbstractArray, index::Block)
return blocks(a)[Int.(Tuple(index))...]
end

# TODO: Define `AnyBlockSparseVector`.
function Base.view(a::BlockSparseArrayLike{<:Any,N}, index::Block{N}) where {N}
return blocksparse_view(a, index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,20 @@ function Base.getindex(a::BlockSparseArrayLike{<:Any,2}, I::Vararg{AbstractUnitR
return ArrayLayouts.layout_getindex(a, I...)
end

function Base.isassigned(a::BlockSparseArrayLike, index::Vararg{Block})
function Base.getindex(a::BlockSparseArrayLike{<:Any,N}, block::Block{N}) where {N}
return blocksparse_getindex(a, block)
end
function Base.getindex(
a::BlockSparseArrayLike{<:Any,N}, block::Vararg{Block{1},N}
) where {N}
return blocksparse_getindex(a, block...)
end

# TODO: Define `issasigned(a, ::Block{N})`.
function Base.isassigned(
a::BlockSparseArrayLike{<:Any,N}, index::Vararg{Block{1},N}
) where {N}
# TODO: Define `blocksparse_isassigned`.
return isassigned(blocks(a), Int.(index)...)
end

Expand All @@ -64,6 +77,12 @@ function Base.setindex!(a::BlockSparseArrayLike{<:Any,N}, value, I::BlockIndex{N
return a
end

function Base.setindex!(
a::BlockSparseArrayLike{<:Any,N}, value, I::Vararg{Block{1},N}
) where {N}
a[Block(Int.(I))] = value
return a
end
function Base.setindex!(a::BlockSparseArrayLike{<:Any,N}, value, I::Block{N}) where {N}
blocksparse_setindex!(a, value, I)
return a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where
return a[findblockindex.(axes(a), I)...]
end

function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
return blocksparse_getindex(a, Tuple(I)...)
end
function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Vararg{Block{1},N}) where {N}
# TODO: Avoid copy if the block isn't stored.
return copy(blocks(a)[Int.(I)...])
end

# TODO: Implement as `copy(@view a[I...])`, which is then implemented
# through `ArrayLayouts.sub_materialize`.
using ..SparseArrayInterface: set_getindex_zero_function
Expand Down Expand Up @@ -59,21 +67,41 @@ function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N
end

function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::BlockIndex{N}) where {N}
a_b = view(a, block(I))
i = Int.(Tuple(block(I)))
a_b = blocks(a)[i...]
a_b[I.α...] = value
# Set the block, required if it is structurally zero
a[block(I)] = a_b
# Set the block, required if it is structurally zero.
blocks(a)[i...] = a_b
return a
end

function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::Block{N}) where {N}
# TODO: Create a conversion function, say `CartesianIndex(Int.(Tuple(I)))`.
i = I.n
blocksparse_setindex!(a, value, Tuple(I)...)
return a
end
function blocksparse_setindex!(
a::AbstractArray{<:Any,N}, value, I::Vararg{Block{1},N}
) where {N}
i = Int.(I)
@boundscheck blockcheckbounds(a, i...)
# TODO: Use `blocksizes(a)[i...]` when we upgrade to
# BlockArrays.jl v1.
if size(value) ≠ size(view(a, I...))
return throw(
DimensionMismatch("Trying to set a block with an array of the wrong size.")
)
end
blocks(a)[i...] = value
return a
end

function blocksparse_view(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
return blocksparse_view(a, Tuple(I)...)
end
function blocksparse_view(a::AbstractArray{<:Any,N}, I::Vararg{Block{1},N}) where {N}
return SubArray(a, to_indices(a, I))
end

function blocksparse_viewblock(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
# TODO: Create a conversion function, say `CartesianIndex(Int.(Tuple(I)))`.
i = I.n
Expand Down
54 changes: 39 additions & 15 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using LinearAlgebra: mul!
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored, block_reshape
using NDTensors.SparseArrayInterface: nstored
using NDTensors.TensorAlgebra: contract
using Test: @test, @testset, @test_broken
using Test: @test, @test_broken, @test_throws, @testset
include("TestBlockSparseArraysUtils.jl")
@testset "BlockSparseArrays (eltype=$elt)" for elt in
(Float32, Float64, ComplexF32, ComplexF64)
Expand All @@ -20,6 +20,7 @@ include("TestBlockSparseArraysUtils.jl")
@test block_nstored(a) == 0
@test iszero(a)
@test all(I -> iszero(a[I]), eachindex(a))
@test_throws DimensionMismatch a[Block(1, 1)] = randn(elt, 2, 3)

a = BlockSparseArray{elt}([2, 3], [2, 3])
a[3, 3] = 33
Expand Down Expand Up @@ -225,36 +226,59 @@ include("TestBlockSparseArraysUtils.jl")
@test block_nstored(c) == 2
@test Array(c) == 2 * transpose(Array(a))

## Broken, need to fix.

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
@test_broken a[Block(1), Block(1):Block(2)]
b = a[Block(1), Block(1):Block(2)]
@test size(b) == (2, 7)
@test blocksize(b) == (1, 2)
@test b[Block(1, 1)] == a[Block(1, 1)]
@test b[Block(1, 2)] == a[Block(1, 2)]

# This is outputting only zero blocks.
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = a[Block(2):Block(2), Block(1):Block(2)]
@test_broken block_nstored(b) == 1
@test_broken b == Array(a)[3:5, 1:end]
b = copy(a)
x = randn(elt, size(@view(a[Block(2, 2)])))
b[Block(2), Block(2)] = x
@test b[Block(2, 2)] == x

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = copy(a)
x = randn(size(@view(a[Block(2, 2)])))
b[Block(2), Block(2)] = x
@test_broken b[Block(2, 2)] == x
b[Block(1, 1)] .= 1
# TODO: Use `blocksizes(b)[1, 1]` once we upgrade to
# BlockArrays.jl v1.
@test b[Block(1, 1)] == trues(size(@view(b[Block(1, 1)])))

# Doesnt' set the block
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
x = randn(elt, 1, 2)
@view(a[Block(2, 2)])[1:1, 1:2] = x
@test @view(a[Block(2, 2)])[1:1, 1:2] == x
@test a[Block(2, 2)][1:1, 1:2] == x

# TODO: This is broken, fix!
@test_broken a[3:3, 4:5] == x

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
x = randn(elt, 1, 2)
@views a[Block(2, 2)][1:1, 1:2] = x
@test @view(a[Block(2, 2)])[1:1, 1:2] == x
@test a[Block(2, 2)][1:1, 1:2] == x

# TODO: This is broken, fix!
@test_broken a[3:3, 4:5] == x

## Broken, need to fix.

# This is outputting only zero blocks.
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = copy(a)
b[Block(1, 1)] .= 1
@test_broken b[1, 1] == trues(size(@view(b[1, 1])))
b = a[Block(2):Block(2), Block(1):Block(2)]
@test_broken block_nstored(b) == 1
@test_broken b == Array(a)[3:5, 1:end]
end
@testset "LinearAlgebra" begin
a1 = BlockSparseArray{elt}([2, 3], [2, 3])
Expand Down
5 changes: 5 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ function blockedunitrange_getindices(
return mortar(map(index -> a[index], indices))
end

# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(a::BlockedUnitRange, indices::Block{1})
return a[indices]
end

# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(a::BlockedUnitRange, indices)
return error("Not implemented.")
Expand Down
11 changes: 11 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/unitrangedual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ function unitrangedual_getindices_blocks(a, indices)
return mortar([dual(b) for b in blocks(a_indices)])
end

# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(a::UnitRangeDual, indices::Block{1})
return a[indices]
end

function Base.getindex(a::UnitRangeDual, indices::Vector{<:Block{1}})
return unitrangedual_getindices_blocks(a, indices)
end
Expand All @@ -54,6 +59,12 @@ function BlockArrays.BlockSlice(b::Block, a::LabelledUnitRange)
return BlockSlice(b, unlabel(a))
end

using BlockArrays: BlockArrays, BlockSlice
using NDTensors.GradedAxes: UnitRangeDual, dual
function BlockArrays.BlockSlice(b::Block, r::UnitRangeDual)
return BlockSlice(b, dual(r))
end

using NDTensors.LabelledNumbers: LabelledNumbers, label
LabelledNumbers.label(a::UnitRangeDual) = dual(label(nondual(a)))

Expand Down
29 changes: 29 additions & 0 deletions NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,32 @@ Base.:-(x::LabelledInteger) = labelled_minus(x)
# TODO: This is only needed for older Julia versions, like Julia 1.6.
# Delete once we drop support for older Julia versions.
Base.hash(x::LabelledInteger, h::UInt64) = labelled_hash(x, h)

using Random: AbstractRNG, default_rng
default_eltype() = Float64
for f in [:rand, :randn]
@eval begin
function Base.$f(
rng::AbstractRNG,
elt::Type{<:Number},
dims::Tuple{LabelledInteger,Vararg{LabelledInteger}},
)
return a = $f(rng, elt, unlabel.(dims))
end
function Base.$f(
rng::AbstractRNG,
elt::Type{<:Number},
dim1::LabelledInteger,
dims::Vararg{LabelledInteger},
)
return $f(rng, elt, (dim1, dims...))
end
Base.$f(elt::Type{<:Number}, dims::Tuple{LabelledInteger,Vararg{LabelledInteger}}) =
$f(default_rng(), elt, dims)
Base.$f(elt::Type{<:Number}, dim1::LabelledInteger, dims::Vararg{LabelledInteger}) =
$f(elt, (dim1, dims...))
Base.$f(dims::Tuple{LabelledInteger,Vararg{LabelledInteger}}) =
$f(default_eltype(), dims)
Base.$f(dim1::LabelledInteger, dims::Vararg{LabelledInteger}) = $f((dim1, dims...))
end
end
Loading
Loading