diff --git a/docs/operations.md b/docs/operations.md index 9647eef..b009d2d 100644 --- a/docs/operations.md +++ b/docs/operations.md @@ -70,7 +70,7 @@ | `imag` | :x: | :x: | | | `index_in_dim` | :x: | :x: | | | `index_take` | :x: | :x: | | -| `iota` | :x: | :x: | | +| `iota` | :white_check_mark: | n/a | | | `is_finite` | :white_check_mark: | :x: | | | `le` | :white_check_mark: | :x: | | | `lt` | :white_check_mark: | :x: | | diff --git a/tessellate_ipu/lax/tile_lax_unary.py b/tessellate_ipu/lax/tile_lax_unary.py index 3cae2e1..377a74c 100644 --- a/tessellate_ipu/lax/tile_lax_unary.py +++ b/tessellate_ipu/lax/tile_lax_unary.py @@ -2,6 +2,7 @@ import os from typing import Any, Dict, List, Tuple +import numpy as np from jax import lax from jax._src.lax.lax import copy_p from jax.core import Primitive, ShapedArray @@ -13,6 +14,7 @@ get_ipu_tile_primitive_translation, get_ipu_type_name, make_ipu_vertex_attributes, + make_ipu_vertex_constant_info, make_ipu_vertex_in_info, make_ipu_vertex_inout_info, make_ipu_vertex_name_templated, @@ -232,6 +234,51 @@ def ipu_integer_pow_translation( register_ipu_tile_primitive(lax.integer_pow_p, ipu_integer_pow_translation) +def ipu_iota_translation( + p: Primitive, + tiles: Tuple[int, ...], + inavals: List[ShapedArray], + attributes: Dict[str, Any] = None, +) -> IpuTileMapEquation: + """IPU `iota` primitive translation rule to IPU vertex. + + Args: + p: JAX primitive. + tiles: Collection of tiles. + inavals: Input shaped arrays. + attributes: (unused) attributes. + Returns: + IPU tile map primitive structure. + """ + assert len(inavals) == 0 + assert attributes is not None + print(attributes) + dtype = attributes["dtype"] + dimension = int(attributes["dimension"]) + shape = attributes["shape"] + + assert dimension == 0 + assert len(shape) == 1 + # Iota vertex in/outs + vname = make_ipu_vertex_name_templated("popops::Iota", dtype) + outaval = p.abstract_eval(dtype=dtype, dimension=dimension, shape=shape)[0] + inputs_info = [make_ipu_vertex_constant_info("offsets", np.array([0], dtype=dtype))] + outputs_info = [make_ipu_vertex_out_info("out", outaval, vertex_dim2=shape[0])] + ipu_prim_info = IpuTileMapEquation( + vname=vname, + pname=p.name, + tiles=tiles, + inputs_info=inputs_info, + outputs_info=outputs_info, + attributes_i32=[], + attributes_f32=[], + ) + return ipu_prim_info + + +register_ipu_tile_primitive(lax.iota_p, ipu_iota_translation) + + # On tile (mem)copy primitive. def ipu_tile_memcpy( p: Primitive, diff --git a/tests/lax/test_tile_lax_basic.py b/tests/lax/test_tile_lax_basic.py index 6a70d04..dc91aee 100644 --- a/tests/lax/test_tile_lax_basic.py +++ b/tests/lax/test_tile_lax_basic.py @@ -192,6 +192,29 @@ def compute_fn(input): assert output_ipu.shape == inshape npt.assert_array_almost_equal(output_ipu, output_cpu, decimal=2) + @parameterized.parameters( + [ + (np.int32,), + ] + ) + def test__tile_map__iota__ipu_jitting__proper_result(self, dtype): + tiles = (3, 4, 5) + N = 64 + + def compute_fn(): + return tile_map(lax.iota_p, dtype=dtype, dimension=0, shape=(N,), tiles=tiles) + + # compute_fn_cpu = partial(jax.jit, backend="cpu")(compute_fn) + compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn) + + output_ipu = compute_fn_ipu() + expected_output = np.stack([np.arange(0, N, dtype=dtype)] * len(tiles)) + assert isinstance(output_ipu, TileShardedArray) + assert output_ipu.tiles == tiles + assert output_ipu.dtype == dtype + assert output_ipu.shape == (len(tiles), N) + npt.assert_array_equal(output_ipu, expected_output) + @pytest.mark.ipu_hardware class IpuTileUnaryPrimitiveHwTests(chex.TestCase):