From 512b5920d13dbc70121c52ba8c194bb7589b02fb Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Tue, 26 Sep 2023 13:49:41 +0000 Subject: [PATCH] Improve TessellateIPU `gather` support. Now fully support gather on the first axis, for any input shape. --- tessellate_ipu/lax/tile_lax_gather.py | 31 +++++++++++++++++++-------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/tessellate_ipu/lax/tile_lax_gather.py b/tessellate_ipu/lax/tile_lax_gather.py index 33c46e3..9081c4a 100644 --- a/tessellate_ipu/lax/tile_lax_gather.py +++ b/tessellate_ipu/lax/tile_lax_gather.py @@ -22,15 +22,22 @@ def make_gather_vertex_fullname(dtype: DTypeLike) -> str: return make_ipu_vertex_name_templated(basename, dtype) -def check_gather_dimension_numbers(dimension_numbers: GatherDimensionNumbers): +def check_gather_dimension_numbers(dimension_numbers: GatherDimensionNumbers, inshape: Tuple[int]): """Check `gather` dimension_numbers is supported on TessellateIPU. At the moment: basically only supporting a single configuration! We need to expand on this at some point! """ - dim_numbers_default = GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) - if dimension_numbers != dim_numbers_default: - raise NotImplementedError(f"TessellateIPU `gather` only support dimension numbers: {dim_numbers_default}.") + if dimension_numbers.start_index_map != (0,): + raise NotImplementedError( + f"TessellateIPU `gather` only supports start index map (0,), not {dimension_numbers}." + ) + if dimension_numbers.collapsed_slice_dims != (0,): + raise NotImplementedError( + f"TessellateIPU `gather` only supports collapse slice dims (0,), not {dimension_numbers}." + ) + if dimension_numbers.offset_dims != tuple(range(1, len(inshape))): + raise NotImplementedError(f"TessellateIPU only supports `gather` on the first axis. Not {dimension_numbers}.") def ipu_gather_primitive_translation( @@ -67,14 +74,20 @@ def ipu_gather_primitive_translation( fill_value = attributes.get("fill_value", None) # Check gather attributes are supported by TessellateIPU. - assert operand.ndim == 1 - assert start_indices.ndim == 2 - assert slice_sizes == (1,) + assert start_indices.ndim == 2, "Only supporting gather indices of shape (N, 1)" assert ( mode == GatherScatterMode.PROMISE_IN_BOUNDS ), "Only `PROMISE_IN_BOUNDS` gather mode supported in TessellateIPU." assert start_indices.dtype == np.uint32, "TessellateIPU `gather` only supports `uint32` indices." - check_gather_dimension_numbers(dimension_numbers) + check_gather_dimension_numbers(dimension_numbers, operand.shape) + # Expected slice sizes if gather on first axis. + assert operand.ndim == len(slice_sizes) + expected_slice_sizes = (1, *operand.shape[1:]) + if slice_sizes != expected_slice_sizes: + raise NotImplementedError( + f"TessellateIPU only supports `gather` on the first axis, i.e. with slice sizes {expected_slice_sizes}, not {slice_sizes}." + ) + # Gather output aval. outaval = p.abstract_eval( *inavals, @@ -92,7 +105,7 @@ def ipu_gather_primitive_translation( baseOffset=0, # unused? numBaseElements=operand.size, # Number of elements in input. maxElementsPerWorker=int(np.ceil(start_indices.size / num_context_workers)), - regionSize=1, # TODO: understand? + regionSize=np.prod(slice_sizes), # Total slice size. splitSingleRegion=False, # Split regions between threads? TODO: understand! ) # TODO: should we use `split offsets` between threads?