Skip to content

Commit

Permalink
Improve TessellateIPU gather support.
Browse files Browse the repository at this point in the history
Now fully support gather on the first axis, for any input shape.
  • Loading branch information
balancap committed Sep 26, 2023
1 parent b7bb361 commit 512b592
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions tessellate_ipu/lax/tile_lax_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,22 @@ def make_gather_vertex_fullname(dtype: DTypeLike) -> str:
return make_ipu_vertex_name_templated(basename, dtype)


def check_gather_dimension_numbers(dimension_numbers: GatherDimensionNumbers):
def check_gather_dimension_numbers(dimension_numbers: GatherDimensionNumbers, inshape: Tuple[int]):
"""Check `gather` dimension_numbers is supported on TessellateIPU.
At the moment: basically only supporting a single configuration!
We need to expand on this at some point!
"""
dim_numbers_default = GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))
if dimension_numbers != dim_numbers_default:
raise NotImplementedError(f"TessellateIPU `gather` only support dimension numbers: {dim_numbers_default}.")
if dimension_numbers.start_index_map != (0,):
raise NotImplementedError(
f"TessellateIPU `gather` only supports start index map (0,), not {dimension_numbers}."
)
if dimension_numbers.collapsed_slice_dims != (0,):
raise NotImplementedError(
f"TessellateIPU `gather` only supports collapse slice dims (0,), not {dimension_numbers}."
)
if dimension_numbers.offset_dims != tuple(range(1, len(inshape))):
raise NotImplementedError(f"TessellateIPU only supports `gather` on the first axis. Not {dimension_numbers}.")


def ipu_gather_primitive_translation(
Expand Down Expand Up @@ -67,14 +74,20 @@ def ipu_gather_primitive_translation(
fill_value = attributes.get("fill_value", None)

# Check gather attributes are supported by TessellateIPU.
assert operand.ndim == 1
assert start_indices.ndim == 2
assert slice_sizes == (1,)
assert start_indices.ndim == 2, "Only supporting gather indices of shape (N, 1)"
assert (
mode == GatherScatterMode.PROMISE_IN_BOUNDS
), "Only `PROMISE_IN_BOUNDS` gather mode supported in TessellateIPU."
assert start_indices.dtype == np.uint32, "TessellateIPU `gather` only supports `uint32` indices."
check_gather_dimension_numbers(dimension_numbers)
check_gather_dimension_numbers(dimension_numbers, operand.shape)
# Expected slice sizes if gather on first axis.
assert operand.ndim == len(slice_sizes)
expected_slice_sizes = (1, *operand.shape[1:])
if slice_sizes != expected_slice_sizes:
raise NotImplementedError(
f"TessellateIPU only supports `gather` on the first axis, i.e. with slice sizes {expected_slice_sizes}, not {slice_sizes}."
)

# Gather output aval.
outaval = p.abstract_eval(
*inavals,
Expand All @@ -92,7 +105,7 @@ def ipu_gather_primitive_translation(
baseOffset=0, # unused?
numBaseElements=operand.size, # Number of elements in input.
maxElementsPerWorker=int(np.ceil(start_indices.size / num_context_workers)),
regionSize=1, # TODO: understand?
regionSize=np.prod(slice_sizes), # Total slice size.
splitSingleRegion=False, # Split regions between threads? TODO: understand!
)
# TODO: should we use `split offsets` between threads?
Expand Down

0 comments on commit 512b592

Please sign in to comment.