Skip to content

Commit

Permalink
Basic support of jax.lax.scatter ops in TessellateIPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Sep 22, 2023
1 parent bab7e7a commit a78b5f3
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions tessellate_ipu/lax/tile_lax_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Any, Dict, List, Tuple

import numpy as np
from jax.core import Primitive, ShapedArray
from jax.lax import (
GatherScatterMode,
ScatterDimensionNumbers,
scatter_add_p,
scatter_max_p,
scatter_min_p,
scatter_mul_p,
)

from tessellate_ipu.core import (
IpuTileMapEquation,
make_ipu_vertex_attributes,
make_ipu_vertex_in_info,
make_ipu_vertex_name_templated,
make_ipu_vertex_out_info,
register_ipu_tile_primitive,
)
from tessellate_ipu.utils import DTypeLike


def make_gather_vertex_fullname(dtype: DTypeLike) -> str:
"""Generate popops Gather/MultiSlice vertex name."""
basename = "popops::MultiSlice"
return make_ipu_vertex_name_templated(basename, dtype)

0 comments on commit a78b5f3

Please sign in to comment.