From a78b5f3961cb94526941f79b4de4e507a6bf8ce2 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 22 Sep 2023 16:03:51 +0000 Subject: [PATCH] Basic support of jax.lax.scatter ops in TessellateIPU. --- tessellate_ipu/lax/tile_lax_scatter.py | 29 ++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 tessellate_ipu/lax/tile_lax_scatter.py diff --git a/tessellate_ipu/lax/tile_lax_scatter.py b/tessellate_ipu/lax/tile_lax_scatter.py new file mode 100644 index 0000000..6ecb165 --- /dev/null +++ b/tessellate_ipu/lax/tile_lax_scatter.py @@ -0,0 +1,29 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from typing import Any, Dict, List, Tuple + +import numpy as np +from jax.core import Primitive, ShapedArray +from jax.lax import ( + GatherScatterMode, + ScatterDimensionNumbers, + scatter_add_p, + scatter_max_p, + scatter_min_p, + scatter_mul_p, +) + +from tessellate_ipu.core import ( + IpuTileMapEquation, + make_ipu_vertex_attributes, + make_ipu_vertex_in_info, + make_ipu_vertex_name_templated, + make_ipu_vertex_out_info, + register_ipu_tile_primitive, +) +from tessellate_ipu.utils import DTypeLike + + +def make_gather_vertex_fullname(dtype: DTypeLike) -> str: + """Generate popops Gather/MultiSlice vertex name.""" + basename = "popops::MultiSlice" + return make_ipu_vertex_name_templated(basename, dtype)