Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fill TessellateIPU primitive mapping popops vertex. #38

Merged
merged 1 commit into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tessellate_ipu/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

from . import tile_lax_binary, tile_lax_dot, tile_lax_reduce, tile_lax_unary, tile_random
from .tile_lax_array import bitcast_convert_type_p, reshape_p
from .tile_lax_array import bitcast_convert_type_p, fill, fill_p, reshape_p, tile_fill, tile_sharded_identity
from .tile_lax_binary import (
add_inplace_p,
atan2_inplace_p,
Expand Down
146 changes: 143 additions & 3 deletions tessellate_ipu/lax/tile_lax_array.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Sequence, Tuple, Union

import jax.lax
import numpy as np
from jax.core import Primitive, ShapedArray
from jax.lax import bitcast_convert_type_p, reshape_p
from jax.interpreters import mlir
from jax.interpreters.mlir import LoweringRuleContext, ir
from jax.lax import bitcast_convert_type_p, reshape_p, scatter_p

from tessellate_ipu.core import IpuTileMapEquation, make_ipu_vertex_inout_info, register_ipu_tile_primitive
from tessellate_ipu.core import (
IpuTileMapEquation,
TileShardedArray,
make_ipu_vertex_attributes,
make_ipu_vertex_inout_info,
make_ipu_vertex_name_templated,
make_ipu_vertex_out_info,
register_ipu_tile_primitive,
tile_constant_replicated,
tile_constant_sharded,
tile_map,
)
from tessellate_ipu.utils import DTypeLike


def ipu_reshape_primitive_translation(
Expand Down Expand Up @@ -95,3 +111,127 @@ def ipu_bitcast_convert_type_primitive_translation(

# Register JAX LAX bitcast_convert_type_p primitive.
register_ipu_tile_primitive(bitcast_convert_type_p, ipu_bitcast_convert_type_primitive_translation)


fill_p = Primitive("fill")
"""Fill primitive: create an array, and fill it with a constant.
Note: compared to `jax.lax.full`, it guarantees allocation of the full array instead of broadcasting.
"""


def fill(shape: Tuple[int, ...], fill_value: Any, dtype: DTypeLike):
"""Fill a tensor with given shape and value."""
return fill_p.bind(shape=shape, fill_value=fill_value, dtype=dtype)


def fill_numpy_impl(shape: Tuple[int, ...], fill_value: Any, dtype: DTypeLike):
return np.full(shape, fill_value, dtype=dtype)


def fill_abstract_eval(shape: Tuple[int, ...], fill_value: Any, dtype: DTypeLike):
aval = jax.lax.full(shape, fill_value=fill_value, dtype=dtype)
return ShapedArray(aval.shape, dtype=aval.dtype)


def ipu_fill_primitive_translation_ipu(
p: Primitive,
tiles: Tuple[int, ...],
inavals: List[ShapedArray],
attributes: Dict[str, Any] = None,
) -> IpuTileMapEquation:
"""IPU tile translation for `fill`

Args:
p: JAX primitive.
tiles: Collection of tiles.
inavals: Input shaped arrays.
attributes: Op attributes.
Returns:
IPU tile map primitive structure.
"""
assert len(inavals) == 0
assert attributes is not None
shape = attributes["shape"]
fill_value = attributes["fill_value"]
dtype = attributes["dtype"]

outaval = fill_abstract_eval(shape, fill_value, dtype)
# Translation rule to IPU vertex
vname = make_ipu_vertex_name_templated("popops::Fill", outaval.dtype)
attrs_i32, attrs_f32 = make_ipu_vertex_attributes(**{"in": fill_value})
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
tiles=tiles,
inputs_info=[],
outputs_info=[make_ipu_vertex_out_info("out", outaval)],
attributes_i32=attrs_i32,
attributes_f32=attrs_f32,
)
return ipu_prim_info


def fill_mlir_translation_default(
ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **params
) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]:
"""`fill` default MLIR translation, for CPU/GPU/IPU/... backends."""
outaval = ctx.avals_out[0]
fill_value = params["fill_value"]

def fill_fn(*inputs):
return jax.lax.full(outaval.shape, fill_value, outaval.dtype)

# Lower to MLIR using JAX tooling. TODO: cache lowering?
fill_lower_fn = mlir.lower_fun(fill_fn, multiple_results=False)
return fill_lower_fn(ctx, *args)


fill_p.map_primitive = False
# Register the primal implementation with JAX.
fill_p.def_impl(fill_numpy_impl)
# Register the abstract evaluation with JAX.
fill_p.def_abstract_eval(fill_abstract_eval)
# Default MLIR translation for all backends.
mlir.register_lowering(fill_p, fill_mlir_translation_default)
# Register TessellateIPU translation.
register_ipu_tile_primitive(fill_p, ipu_fill_primitive_translation_ipu)


def tile_fill(shape: Tuple[int, ...], fill_value: Any, dtype: DTypeLike, tiles: Tuple[int, ...]) -> TileShardedArray:
"""Tile `fill` a tensor with given shape and value."""
return tile_map(fill_p, shape=shape, fill_value=fill_value, dtype=dtype, tiles=tiles) # type:ignore


def tile_sharded_identity(dtype: DTypeLike, tiles: Tuple[int, ...]) -> TileShardedArray:
"""Create a tile sharded identity matrix, i.e. sharded on tiles across the first axis.

Args:
dtype: Dtype of the identity matrix.
tiles: Sharding tiles.
Returns:
Sharded identity matrix (N, N), with N = len(tiles)
"""
with jax.named_scope("tile_sharded_identity"):
N = len(tiles)
# Build zero matrix + update diagonal entries.
arr = tile_fill((N,), 0, dtype=dtype, tiles=tiles)
# Requiring constants for indices + updates. Something more efficient?s
indices = tile_constant_sharded(np.arange(0, N, dtype=np.uint32).reshape(N, 1, 1), tiles=tiles)
updates = tile_constant_replicated(np.array([1], dtype=dtype), tiles=tiles)
# Not the simplest way ever of updating diagonal terms!
scatter_dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)
)
arr = tile_map(
scatter_p,
arr,
indices,
updates,
dimension_numbers=scatter_dnums,
indices_are_sorted=False,
unique_indices=False,
mode=jax.lax.GatherScatterMode.PROMISE_IN_BOUNDS,
update_jaxpr=None,
update_consts=None,
) # type:ignore
return arr
54 changes: 53 additions & 1 deletion tests/lax/test_tile_lax_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import jax
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized

from tessellate_ipu import TileShardedArray, tile_map, tile_put_sharded
from tessellate_ipu.lax import bitcast_convert_type_p, reshape_p
from tessellate_ipu.lax import bitcast_convert_type_p, fill, reshape_p, tile_fill, tile_sharded_identity


class IpuTileArrayPrimitiveTests(chex.TestCase):
Expand Down Expand Up @@ -54,3 +55,54 @@ def compute_fn(input):
assert output_ipu.tiles == tiles
assert output_ipu.dtype == np.int32
npt.assert_array_equal(output_ipu, output_cpu)

@parameterized.parameters(
[
(np.int32,),
(np.float16,),
(np.float32,),
]
)
def test__tile_map__fill__ipu_jitting__proper_result(self, dtype):
tiles = (3, 1, 5)
shape = (4, 5)
fill_value = 1

def compute_fn():
return tile_fill(shape, fill_value, dtype=dtype, tiles=tiles)

compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn)
output_ipu = compute_fn_ipu()
assert isinstance(output_ipu, TileShardedArray)
assert output_ipu.tiles == tiles
assert output_ipu.dtype == dtype
npt.assert_array_equal(output_ipu, np.full((len(tiles), *shape), fill_value, dtype=dtype))

def test__tile_map__fill__cpu_jitting__proper_result(self):
shape = (4, 5)
fill_value = 2

def compute_fn():
return fill(shape, fill_value, np.float32)

fn_cpu = partial(jax.jit, backend="cpu")(compute_fn)
output_cpu = fn_cpu()
assert output_cpu.dtype == np.float32
npt.assert_array_equal(output_cpu, np.full(shape, fill_value, dtype=np.float32))

def test__tile_sharded_identity__ipu_jitting__proper_result(self):
dtype = np.float32
tiles = (1, 2, 5)
N = len(tiles)

def fn():
# Comparison point with the "obvious" way using JAX Numpy.
# return tile_put_sharded(jax.numpy.identity(N, dtype), tiles=tiles)
return tile_sharded_identity(dtype, tiles)

fn_ipu = partial(jax.jit, backend="ipu")(fn)
output_ipu = fn_ipu()
assert isinstance(output_ipu, TileShardedArray)
assert output_ipu.tiles == tiles
assert output_ipu.dtype == dtype
npt.assert_array_equal(output_ipu, np.identity(N, dtype=dtype))