Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Sep 29, 2023
1 parent 2482e33 commit 0dfe0d5
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 15 deletions.
2 changes: 1 addition & 1 deletion tessellate_ipu/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

from . import tile_lax_binary, tile_lax_dot, tile_lax_reduce, tile_lax_unary, tile_random
from .tile_lax_array import bitcast_convert_type_p, fill_p, reshape_p
from .tile_lax_array import bitcast_convert_type_p, fill, fill_p, reshape_p
from .tile_lax_binary import (
add_inplace_p,
atan2_inplace_p,
Expand Down
27 changes: 23 additions & 4 deletions tessellate_ipu/lax/tile_lax_array.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Sequence, Tuple, Union

import jax.lax
import numpy as np
from jax.core import Primitive, ShapedArray
from jax.interpreters import mlir
from jax.interpreters.mlir import LoweringRuleContext, ir
from jax.lax import bitcast_convert_type_p, reshape_p

from tessellate_ipu.core import (
Expand Down Expand Up @@ -164,10 +166,27 @@ def ipu_fill_primitive_translation_ipu(
return ipu_prim_info


def fill_mlir_translation_default(
ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **params
) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]:
"""`fill` default MLIR translation, for CPU/GPU/IPU/... backends."""
outaval = ctx.avals_out[0]
fill_value = params["fill_value"]

def fill_fn(*inputs):
return jax.lax.full(outaval.shape, fill_value, outaval.dtype)

# Lower to MLIR using JAX tooling. TODO: cache lowering?
fill_lower_fn = mlir.lower_fun(fill_fn, multiple_results=False)
return fill_lower_fn(ctx, *args)


fill_p.map_primitive = False
# Register the primal implementation with JAX
# Register the primal implementation with JAX.
fill_p.def_impl(fill_numpy_impl)
# Register the abstract evaluation with JAX
# Register the abstract evaluation with JAX.
fill_p.def_abstract_eval(fill_abstract_eval)
# Register tile IPU translation.
# Default MLIR translation for all backends.
mlir.register_lowering(fill_p, fill_mlir_translation_default)
# Register TessellateIPU translation.
register_ipu_tile_primitive(fill_p, ipu_fill_primitive_translation_ipu)
37 changes: 27 additions & 10 deletions tests/lax/test_tile_lax_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import jax
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized

from tessellate_ipu import TileShardedArray, tile_map, tile_put_sharded
from tessellate_ipu.lax import bitcast_convert_type_p, fill_p, reshape_p
from tessellate_ipu.lax import bitcast_convert_type_p, fill, fill_p, reshape_p


class IpuTileArrayPrimitiveTests(chex.TestCase):
Expand Down Expand Up @@ -55,20 +56,36 @@ def compute_fn(input):
assert output_ipu.dtype == np.int32
npt.assert_array_equal(output_ipu, output_cpu)

def test__tile_map__fill__ipu_jitting__proper_result(self):
tiles = (3, 4, 5)
dtype = np.int32
# inshape = (len(tiles), 6, 4)
@parameterized.parameters(
[
(np.int32,),
(np.float16,),
(np.float32,),
]
)
def test__tile_map__fill__ipu_jitting__proper_result(self, dtype):
tiles = (3, 1, 5)
shape = (4, 5)
fill_value = 1

def compute_fn():
return tile_map(fill_p, shape=(4, 5), fill_value=1, dtype=dtype, tiles=tiles)
return tile_map(fill_p, shape=shape, fill_value=fill_value, dtype=dtype, tiles=tiles)

compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn)
output_ipu = compute_fn_ipu()
assert isinstance(output_ipu, TileShardedArray)
assert output_ipu.tiles == tiles
# assert output_ipu.dtype == np.int32
assert output_ipu.dtype == dtype
npt.assert_array_equal(output_ipu, np.full((len(tiles), *shape), fill_value, dtype=dtype))

def test__tile_map__fill__cpu_jitting__proper_result(self):
shape = (4, 5)
fill_value = 2

def compute_fn():
return fill(shape, fill_value, np.float32)

print(output_ipu)
assert False
# npt.assert_array_equal(output_ipu, output_cpu)
fn_cpu = partial(jax.jit, backend="cpu")(compute_fn)
output_cpu = fn_cpu()
assert output_cpu.dtype == np.float32
npt.assert_array_equal(output_cpu, np.full(shape, fill_value, dtype=np.float32))

0 comments on commit 0dfe0d5

Please sign in to comment.