From 579a0d49b5a944ad9ca61fe62723aebcc6895724 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Tue, 26 Sep 2023 13:49:41 +0000 Subject: [PATCH] Improve TessellateIPU `gather` support. Now fully supporting `gather` on the first axis, for any input shape. Should allow to programmatically implement GPT-like on device embeddings using TessellateIPU. --- tessellate_ipu/lax/tile_lax_gather.py | 39 +++++++++++++++++++-------- tests/lax/test_tile_lax_gather.py | 24 ++++++++++------- 2 files changed, 42 insertions(+), 21 deletions(-) diff --git a/tessellate_ipu/lax/tile_lax_gather.py b/tessellate_ipu/lax/tile_lax_gather.py index 33c46e3..c9d9506 100644 --- a/tessellate_ipu/lax/tile_lax_gather.py +++ b/tessellate_ipu/lax/tile_lax_gather.py @@ -22,15 +22,25 @@ def make_gather_vertex_fullname(dtype: DTypeLike) -> str: return make_ipu_vertex_name_templated(basename, dtype) -def check_gather_dimension_numbers(dimension_numbers: GatherDimensionNumbers): +def check_gather_dimension_numbers(dimension_numbers: GatherDimensionNumbers, inshape: Tuple[int]): """Check `gather` dimension_numbers is supported on TessellateIPU. At the moment: basically only supporting a single configuration! We need to expand on this at some point! """ - dim_numbers_default = GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) - if dimension_numbers != dim_numbers_default: - raise NotImplementedError(f"TessellateIPU `gather` only support dimension numbers: {dim_numbers_default}.") + if dimension_numbers.start_index_map != (0,): + raise NotImplementedError( + f"TessellateIPU `gather` only supports `start_index_map` (0,), not {dimension_numbers}." + ) + if dimension_numbers.collapsed_slice_dims != (0,): + raise NotImplementedError( + f"TessellateIPU `gather` only supports `collapse_slice_dims` (0,), not {dimension_numbers}." + ) + expected_offset_dims = tuple(range(1, len(inshape))) + if dimension_numbers.offset_dims != expected_offset_dims: + raise NotImplementedError( + f"TessellateIPU only supports `gather` on the first axis. Expecting `offset_dims` {expected_offset_dims}, not {dimension_numbers}." + ) def ipu_gather_primitive_translation( @@ -52,7 +62,7 @@ def ipu_gather_primitive_translation( IPU tile map primitive structure. """ # TODO: query for JAX device. - num_context_workers = 6 + # num_context_workers = 6 assert len(inavals) == 2 assert attributes is not None @@ -67,14 +77,20 @@ def ipu_gather_primitive_translation( fill_value = attributes.get("fill_value", None) # Check gather attributes are supported by TessellateIPU. - assert operand.ndim == 1 - assert start_indices.ndim == 2 - assert slice_sizes == (1,) + assert start_indices.ndim == 2, "Only supporting gather indices of shape (N, 1)" assert ( mode == GatherScatterMode.PROMISE_IN_BOUNDS ), "Only `PROMISE_IN_BOUNDS` gather mode supported in TessellateIPU." assert start_indices.dtype == np.uint32, "TessellateIPU `gather` only supports `uint32` indices." - check_gather_dimension_numbers(dimension_numbers) + check_gather_dimension_numbers(dimension_numbers, operand.shape) + # Expected slice sizes if gather on first axis. + assert operand.ndim == len(slice_sizes) + expected_slice_sizes = (1, *operand.shape[1:]) + if slice_sizes != expected_slice_sizes: + raise NotImplementedError( + f"TessellateIPU only supports `gather` on the first axis, i.e. with slice sizes {expected_slice_sizes}, not {slice_sizes}." + ) + # Gather output aval. outaval = p.abstract_eval( *inavals, @@ -91,8 +107,9 @@ def ipu_gather_primitive_translation( attrs_i32, attrs_f32 = make_ipu_vertex_attributes( baseOffset=0, # unused? numBaseElements=operand.size, # Number of elements in input. - maxElementsPerWorker=int(np.ceil(start_indices.size / num_context_workers)), - regionSize=1, # TODO: understand? + # maxElementsPerWorker=int(np.ceil(operand.size / num_context_workers)), + maxElementsPerWorker=operand.size, # Need more understanding here? + regionSize=np.prod(slice_sizes), # Total slice size. splitSingleRegion=False, # Split regions between threads? TODO: understand! ) # TODO: should we use `split offsets` between threads? diff --git a/tests/lax/test_tile_lax_gather.py b/tests/lax/test_tile_lax_gather.py index 006abc1..52e7997 100644 --- a/tests/lax/test_tile_lax_gather.py +++ b/tests/lax/test_tile_lax_gather.py @@ -13,7 +13,7 @@ @pytest.mark.ipu_hardware -class IpuTilePrimitivesLaxGather(chex.TestCase, parameterized.TestCase): +class IpuTilePrimitivesLaxGatherHwTests(chex.TestCase, parameterized.TestCase): def setUp(self): super().setUp() self.device = jax.devices("ipu")[0] @@ -22,18 +22,22 @@ def setUp(self): np.random.seed(123) @parameterized.parameters( - {"num_elements": 8, "num_indices": 3}, - {"num_elements": 8, "num_indices": 12}, - {"num_elements": 256, "num_indices": 512}, + {"data_shape": (8,), "num_indices": 3}, + {"data_shape": (8,), "num_indices": 12}, + {"data_shape": (256,), "num_indices": 512}, + {"data_shape": (256, 17), "num_indices": 123}, + {"data_shape": (256, 5, 3), "num_indices": 373}, ) - def test__tile_map__gather__jitting__proper_result(self, num_elements, num_indices): + def test__tile_map__gather__first_axis_cases__jitting__proper_result(self, data_shape, num_indices): tiles = (0,) - data = np.random.randn(num_elements).astype(np.float32) - indices = np.random.randint(low=0, high=num_elements, size=num_indices) + data = np.random.randn(*data_shape).astype(np.float32) + indices = np.random.randint(low=0, high=data_shape[0], size=num_indices) indices = indices.reshape(-1, 1).astype(np.uint32) - # Only supported configuration! - dim_numbers = jax.lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) + # First axis gather only supported configuration! + dim_numbers = jax.lax.GatherDimensionNumbers( + offset_dims=tuple(range(1, len(data_shape))), collapsed_slice_dims=(0,), start_index_map=(0,) + ) def gather_fn(data, indices): data = tile_put_replicated(data, tiles) @@ -43,7 +47,7 @@ def gather_fn(data, indices): data, indices, dimension_numbers=dim_numbers, - slice_sizes=(1,), + slice_sizes=(1, *data_shape[1:]), mode=jax.lax.GatherScatterMode.PROMISE_IN_BOUNDS, unique_indices=False, indices_are_sorted=False,