Skip to content

Commit

Permalink
Support select operation in TessellateIPU (#28)
Browse files Browse the repository at this point in the history
Mapping `select` LAX operator to `popops::Select` vertex.
Note: still need to support broadcast in the future.
  • Loading branch information
balancap authored Sep 27, 2023
1 parent 5b0b37b commit 39d8d03
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 2 deletions.
4 changes: 2 additions & 2 deletions docs/operations.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
| `bitwise_and` | :white_check_mark: | :x: | |
| `bitwise_or` | :white_check_mark: | :x: | |
| `bitwise_xor` | :white_check_mark: | :x: | |
| `population_count` | :x: | :x: | |
| `population_count` | :white_check_mark: | :x: | |
| `broadcast` | :x: | :x: | |
| `broadcast_in_dim` | :x: | :x: | |
| `cbrt` | :white_check_mark: | :white_check_mark: | |
Expand Down Expand Up @@ -100,7 +100,7 @@
| `scatter_max` | :white_check_mark: | :x: | Limited set of configurations. See below. |
| `scatter_min` | :white_check_mark: | :x: | Limited set of configurations. See below. |
| `scatter_mul` | :white_check_mark: | :x: | Limited set of configurations. See below. |
| `select` | :x: | :x: | |
| `select` | :white_check_mark: | :x: | |
| `shift_left` | :white_check_mark: | :x: | |
| `shift_right_arithmetic`| :white_check_mark: | :x: | |
| `shift_right_logical` | :white_check_mark: | :x: | |
Expand Down
1 change: 1 addition & 0 deletions tessellate_ipu/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
population_count_p,
round_p,
rsqrt_p,
select_n_p,
sign_p,
sin_p,
sqrt_p,
Expand Down
49 changes: 49 additions & 0 deletions tessellate_ipu/lax/tile_lax_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
make_ipu_vertex_attributes,
make_ipu_vertex_in_info,
make_ipu_vertex_inout_info,
make_ipu_vertex_name_templated,
make_ipu_vertex_out_info,
primitive_clone,
register_ipu_tile_primitive,
Expand Down Expand Up @@ -278,3 +279,51 @@ def register_ipu_binary_inplace_tile_primitive(orig_prim):
pow_inplace_p = register_ipu_binary_inplace_tile_primitive(lax.pow_p)
rem_inplace_p = register_ipu_binary_inplace_tile_primitive(lax.rem_p)
sub_inplace_p = register_ipu_binary_inplace_tile_primitive(lax.sub_p)


def ipu_select_primitive_translation(
p: Primitive,
tiles: Tuple[int, ...],
inavals: List[ShapedArray],
attributes: Dict[str, Any] = None,
) -> IpuTileMapEquation:
"""IPU select_n LAX primitive translation rule to IPU vertex.
Args:
p: JAX primitive.
tiles: Collection of tiles.
inavals: Input shaped arrays.
attributes: (unused) attributes.
Returns:
IPU tile map primitive structure.
"""
assert len(inavals) == 3
cond, x, y = inavals
# A couple of initial checks!
assert cond.shape == x.shape
assert cond.shape == y.shape
assert cond.dtype == np.bool_
assert x.dtype == y.dtype

vname = make_ipu_vertex_name_templated("popops::Select", x.dtype)
# Note: using `vertex_dim2=1` as Select vertex expecting vector of vector.
inputs_info = [
make_ipu_vertex_in_info("in3", cond, vertex_dim2=1),
make_ipu_vertex_in_info("in1", x, vertex_dim2=1),
make_ipu_vertex_in_info("in2", y, vertex_dim2=1),
]
outputs_info = [make_ipu_vertex_out_info("out", x, vertex_dim2=1)]
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
tiles=tiles,
inputs_info=inputs_info,
outputs_info=outputs_info,
attributes_i32=[],
attributes_f32=[],
)
return ipu_prim_info


# Register JAX LAX select primitive.
register_ipu_tile_primitive(lax.select_n_p, ipu_select_primitive_translation)
22 changes: 22 additions & 0 deletions tests/lax/test_tile_lax_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,28 @@ def compute_fn(A, B, sB):
assert output.dtype == A.dtype
npt.assert_array_almost_equal(output.array, scale_op_p.impl(A, B, sB), decimal=2)

@parameterized.parameters([np.float32])
def test__tile_map__select__ipu_jitting__proper_result(self, dtype):
tiles = (3, 4, 5)
inshape = (len(tiles), 7, 9)
mask = np.random.rand(*inshape) >= 0.5
input0 = np.random.randn(*inshape).astype(dtype)
input1 = np.random.randn(*inshape).astype(dtype)

@partial(jax.jit, backend="ipu")
def compute_fn(mask, in0, in1):
mask = tile_put_sharded(mask, tiles)
input0 = tile_put_sharded(in0, tiles)
input1 = tile_put_sharded(in1, tiles)
output = tile_map(lax.select_n_p, mask, input0, input1)
return output

output = compute_fn(mask, input0, input1)
assert isinstance(output, TileShardedArray)
assert output.tiles == tiles
assert output.dtype == input0.dtype
npt.assert_array_almost_equal(output.array, np.where(mask, input0, input1))


class IpuTileShiftPrimitivesTests(chex.TestCase):
def setUp(self):
Expand Down

0 comments on commit 39d8d03

Please sign in to comment.