Skip to content

Commit

Permalink
Support clamp LAX operation in TessellateIPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Sep 27, 2023
1 parent 0e27d03 commit 9ef8490
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/operations.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
| `broadcast_in_dim` | :x: | :x: | |
| `cbrt` | :white_check_mark: | :white_check_mark: | |
| `ceil` | :white_check_mark: | :white_check_mark: | |
| `clamp` | :x: | :x: | |
| `clamp` | :white_check_mark: | :x: | |
| `collapse` | :x: | :x: | |
| `complex` | :x: | :x: | |
| `concatenate` | :white_check_mark: | :x: | |
Expand Down
46 changes: 46 additions & 0 deletions tessellate_ipu/lax/tile_lax_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,49 @@ def ipu_select_primitive_translation(

# Register JAX LAX select primitive.
register_ipu_tile_primitive(lax.select_n_p, ipu_select_primitive_translation)


def ipu_clamp_primitive_translation(
p: Primitive,
tiles: Tuple[int, ...],
inavals: List[ShapedArray],
attributes: Dict[str, Any] = None,
) -> IpuTileMapEquation:
"""IPU `clamp` LAX primitive translation rule to IPU vertex.
Args:
p: JAX primitive.
tiles: Collection of tiles.
inavals: Input shaped arrays.
attributes: (unused) attributes.
Returns:
IPU tile map primitive structure.
"""
assert len(inavals) == 3
min, x, max = inavals
# A couple of initial checks!
assert max.shape == x.shape
assert min.shape == x.shape

vname = make_ipu_vertex_name_templated("popops::Clamp", x.dtype)
# Note: using `vertex_dim2=1` as Select vertex expecting vector of vector.
inputs_info = [
make_ipu_vertex_in_info("in2", min, vertex_dim2=1),
make_ipu_vertex_in_info("in1", x, vertex_dim2=1),
make_ipu_vertex_in_info("in3", max, vertex_dim2=1),
]
outputs_info = [make_ipu_vertex_out_info("out", x, vertex_dim2=1)]
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
tiles=tiles,
inputs_info=inputs_info,
outputs_info=outputs_info,
attributes_i32=[],
attributes_f32=[],
)
return ipu_prim_info


# Register JAX LAX clamp primitive.
register_ipu_tile_primitive(lax.clamp_p, ipu_clamp_primitive_translation)
24 changes: 23 additions & 1 deletion tests/lax/test_tile_lax_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def compute_fn(A, B, sB):
assert output.dtype == A.dtype
npt.assert_array_almost_equal(output.array, scale_op_p.impl(A, B, sB), decimal=2)

@parameterized.parameters([np.float32])
@parameterized.parameters([np.float32, np.float16])
def test__tile_map__select__ipu_jitting__proper_result(self, dtype):
tiles = (3, 4, 5)
inshape = (len(tiles), 7, 9)
Expand All @@ -416,6 +416,28 @@ def compute_fn(mask, in0, in1):
assert output.dtype == input0.dtype
npt.assert_array_almost_equal(output.array, np.where(mask, input0, input1))

@parameterized.parameters([np.float32, np.float16])
def test__tile_map__clamp__ipu_jitting__proper_result(self, dtype):
tiles = (3, 4, 5)
inshape = (len(tiles), 7, 9)
min = np.random.randn(*inshape).astype(dtype)
input = np.random.randn(*inshape).astype(dtype)
max = np.random.randn(*inshape).astype(dtype)

@partial(jax.jit, backend="ipu")
def compute_fn(min, x, max):
min = tile_put_sharded(min, tiles)
x = tile_put_sharded(x, tiles)
max = tile_put_sharded(max, tiles)
output = tile_map(lax.clamp_p, min, x, max)
return output

output = compute_fn(min, input, max)
assert isinstance(output, TileShardedArray)
assert output.tiles == tiles
assert output.dtype == input.dtype
npt.assert_array_almost_equal(output.array, np.clip(input, min, max))


class IpuTileShiftPrimitivesTests(chex.TestCase):
def setUp(self):
Expand Down

0 comments on commit 9ef8490

Please sign in to comment.