diff --git a/docs/operations.md b/docs/operations.md index 31441a3..dc00268 100644 --- a/docs/operations.md +++ b/docs/operations.md @@ -30,7 +30,7 @@ | `broadcast_in_dim` | :x: | :x: | | | `cbrt` | :white_check_mark: | :white_check_mark: | | | `ceil` | :white_check_mark: | :white_check_mark: | | -| `clamp` | :x: | :x: | | +| `clamp` | :white_check_mark: | :x: | | | `collapse` | :x: | :x: | | | `complex` | :x: | :x: | | | `concatenate` | :white_check_mark: | :x: | | diff --git a/tessellate_ipu/lax/tile_lax_binary.py b/tessellate_ipu/lax/tile_lax_binary.py index 391c113..6530c1f 100644 --- a/tessellate_ipu/lax/tile_lax_binary.py +++ b/tessellate_ipu/lax/tile_lax_binary.py @@ -327,3 +327,49 @@ def ipu_select_primitive_translation( # Register JAX LAX select primitive. register_ipu_tile_primitive(lax.select_n_p, ipu_select_primitive_translation) + + +def ipu_clamp_primitive_translation( + p: Primitive, + tiles: Tuple[int, ...], + inavals: List[ShapedArray], + attributes: Dict[str, Any] = None, +) -> IpuTileMapEquation: + """IPU `clamp` LAX primitive translation rule to IPU vertex. + + Args: + p: JAX primitive. + tiles: Collection of tiles. + inavals: Input shaped arrays. + attributes: (unused) attributes. + Returns: + IPU tile map primitive structure. + """ + assert len(inavals) == 3 + min, x, max = inavals + # A couple of initial checks! + assert max.shape == x.shape + assert min.shape == x.shape + + vname = make_ipu_vertex_name_templated("popops::Clamp", x.dtype) + # Note: using `vertex_dim2=1` as Select vertex expecting vector of vector. + inputs_info = [ + make_ipu_vertex_in_info("in2", min, vertex_dim2=1), + make_ipu_vertex_in_info("in1", x, vertex_dim2=1), + make_ipu_vertex_in_info("in3", max, vertex_dim2=1), + ] + outputs_info = [make_ipu_vertex_out_info("out", x, vertex_dim2=1)] + ipu_prim_info = IpuTileMapEquation( + vname=vname, + pname=p.name, + tiles=tiles, + inputs_info=inputs_info, + outputs_info=outputs_info, + attributes_i32=[], + attributes_f32=[], + ) + return ipu_prim_info + + +# Register JAX LAX clamp primitive. +register_ipu_tile_primitive(lax.clamp_p, ipu_clamp_primitive_translation) diff --git a/tests/lax/test_tile_lax_basic.py b/tests/lax/test_tile_lax_basic.py index e849d59..4679c36 100644 --- a/tests/lax/test_tile_lax_basic.py +++ b/tests/lax/test_tile_lax_basic.py @@ -394,7 +394,7 @@ def compute_fn(A, B, sB): assert output.dtype == A.dtype npt.assert_array_almost_equal(output.array, scale_op_p.impl(A, B, sB), decimal=2) - @parameterized.parameters([np.float32]) + @parameterized.parameters([np.float32, np.float16]) def test__tile_map__select__ipu_jitting__proper_result(self, dtype): tiles = (3, 4, 5) inshape = (len(tiles), 7, 9) @@ -416,6 +416,28 @@ def compute_fn(mask, in0, in1): assert output.dtype == input0.dtype npt.assert_array_almost_equal(output.array, np.where(mask, input0, input1)) + @parameterized.parameters([np.float32, np.float16]) + def test__tile_map__clamp__ipu_jitting__proper_result(self, dtype): + tiles = (3, 4, 5) + inshape = (len(tiles), 7, 9) + min = np.random.randn(*inshape).astype(dtype) + input = np.random.randn(*inshape).astype(dtype) + max = np.random.randn(*inshape).astype(dtype) + + @partial(jax.jit, backend="ipu") + def compute_fn(min, x, max): + min = tile_put_sharded(min, tiles) + x = tile_put_sharded(x, tiles) + max = tile_put_sharded(max, tiles) + output = tile_map(lax.clamp_p, min, x, max) + return output + + output = compute_fn(min, input, max) + assert isinstance(output, TileShardedArray) + assert output.tiles == tiles + assert output.dtype == input.dtype + npt.assert_array_almost_equal(output.array, np.clip(input, min, max)) + class IpuTileShiftPrimitivesTests(chex.TestCase): def setUp(self):