Skip to content

Commit

Permalink
Support iota LAX operation in TessellateIPU (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap authored Sep 27, 2023
1 parent bf3e8e1 commit 51e90a0
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/operations.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
| `imag` | :x: | :x: | |
| `index_in_dim` | :x: | :x: | |
| `index_take` | :x: | :x: | |
| `iota` | :x: | :x: | |
| `iota` | :white_check_mark: | n/a | |
| `is_finite` | :white_check_mark: | :x: | |
| `le` | :white_check_mark: | :x: | |
| `lt` | :white_check_mark: | :x: | |
Expand Down
47 changes: 47 additions & 0 deletions tessellate_ipu/lax/tile_lax_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from typing import Any, Dict, List, Tuple

import numpy as np
from jax import lax
from jax._src.lax.lax import copy_p
from jax.core import Primitive, ShapedArray
Expand All @@ -13,6 +14,7 @@
get_ipu_tile_primitive_translation,
get_ipu_type_name,
make_ipu_vertex_attributes,
make_ipu_vertex_constant_info,
make_ipu_vertex_in_info,
make_ipu_vertex_inout_info,
make_ipu_vertex_name_templated,
Expand Down Expand Up @@ -232,6 +234,51 @@ def ipu_integer_pow_translation(
register_ipu_tile_primitive(lax.integer_pow_p, ipu_integer_pow_translation)


def ipu_iota_translation(
p: Primitive,
tiles: Tuple[int, ...],
inavals: List[ShapedArray],
attributes: Dict[str, Any] = None,
) -> IpuTileMapEquation:
"""IPU `iota` primitive translation rule to IPU vertex.
Args:
p: JAX primitive.
tiles: Collection of tiles.
inavals: Input shaped arrays.
attributes: (unused) attributes.
Returns:
IPU tile map primitive structure.
"""
assert len(inavals) == 0
assert attributes is not None
print(attributes)
dtype = attributes["dtype"]
dimension = int(attributes["dimension"])
shape = attributes["shape"]

assert dimension == 0
assert len(shape) == 1
# Iota vertex in/outs
vname = make_ipu_vertex_name_templated("popops::Iota", dtype)
outaval = p.abstract_eval(dtype=dtype, dimension=dimension, shape=shape)[0]
inputs_info = [make_ipu_vertex_constant_info("offsets", np.array([0], dtype=dtype))]
outputs_info = [make_ipu_vertex_out_info("out", outaval, vertex_dim2=shape[0])]
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
tiles=tiles,
inputs_info=inputs_info,
outputs_info=outputs_info,
attributes_i32=[],
attributes_f32=[],
)
return ipu_prim_info


register_ipu_tile_primitive(lax.iota_p, ipu_iota_translation)


# On tile (mem)copy primitive.
def ipu_tile_memcpy(
p: Primitive,
Expand Down
23 changes: 23 additions & 0 deletions tests/lax/test_tile_lax_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,29 @@ def compute_fn(input):
assert output_ipu.shape == inshape
npt.assert_array_almost_equal(output_ipu, output_cpu, decimal=2)

@parameterized.parameters(
[
(np.int32,),
]
)
def test__tile_map__iota__ipu_jitting__proper_result(self, dtype):
tiles = (3, 4, 5)
N = 64

def compute_fn():
return tile_map(lax.iota_p, dtype=dtype, dimension=0, shape=(N,), tiles=tiles)

# compute_fn_cpu = partial(jax.jit, backend="cpu")(compute_fn)
compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn)

output_ipu = compute_fn_ipu()
expected_output = np.stack([np.arange(0, N, dtype=dtype)] * len(tiles))
assert isinstance(output_ipu, TileShardedArray)
assert output_ipu.tiles == tiles
assert output_ipu.dtype == dtype
assert output_ipu.shape == (len(tiles), N)
npt.assert_array_equal(output_ipu, expected_output)


@pytest.mark.ipu_hardware
class IpuTileUnaryPrimitiveHwTests(chex.TestCase):
Expand Down

0 comments on commit 51e90a0

Please sign in to comment.