Skip to content

Commit

Permalink
Basic support of jax.lax.scatter operations in TessellateIPU. (#24)
Browse files Browse the repository at this point in the history
TessellateIPU `scatter` integration using popops `popops::MultiUpdateXXX<>` vertices.

At the moment, it only supports the most basic scatter configuration:
```python
ScatterDimensionNumbers(
    update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)
)
```
  • Loading branch information
balancap authored Sep 25, 2023
1 parent 4c00ca1 commit 49dea55
Show file tree
Hide file tree
Showing 12 changed files with 343 additions and 13 deletions.
2 changes: 2 additions & 0 deletions tessellate_ipu/core/tile_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def register_ipu_tile_primitive(primitive: Primitive, translation: IpuVertexTran
translation: IPU vertex translation rule.
"""
global _ipu_tile_primitive_registry
if primitive.name in _ipu_tile_primitive_registry:
raise KeyError(f"The primitive '{primitive.name}' is already registered in TessellateIPU.")
_ipu_tile_primitive_registry[primitive.name] = (primitive, translation)


Expand Down
1 change: 1 addition & 0 deletions tessellate_ipu/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from .tile_lax_dot import IpuConvVertexType
from .tile_lax_gather import gather_p
from .tile_lax_scatter import scatter_add_p, scatter_max_p, scatter_min_p, scatter_mul_p, scatter_p
from .tile_lax_unary import ( # tanh_inplace_p,
abs_inplace_p,
asin_inplace_p,
Expand Down
218 changes: 218 additions & 0 deletions tessellate_ipu/lax/tile_lax_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import logging
from typing import Any, Dict, List, Tuple

import numpy as np
from jax.core import Primitive, ShapedArray
from jax.lax import (
GatherScatterMode,
ScatterDimensionNumbers,
scatter_add_p,
scatter_max_p,
scatter_min_p,
scatter_mul_p,
scatter_p,
)

from tessellate_ipu.core import (
IpuTileMapEquation,
make_ipu_vertex_attributes,
make_ipu_vertex_constant_info,
make_ipu_vertex_in_info,
make_ipu_vertex_inout_info,
make_ipu_vertex_name_templated,
register_ipu_tile_primitive,
)
from tessellate_ipu.utils import DTypeLike

_scatter_primitive_to_properties: Dict[Primitive, Any] = {
scatter_add_p: (1, "ADD"),
scatter_min_p: (None, "MIN"),
scatter_max_p: (None, "MAX"),
scatter_mul_p: (None, "MUL"),
}
"""IPU translation properties for every JAX LAX scatter primitive.
"""


def make_scatter_vertex_fullname(dtype: DTypeLike, opname: str, scale: Any) -> str:
"""Generate popops Scatter/MultiUpdateOp vertex name."""
opname = f"popops::Operation::{opname}"
if scale is not None:
basename = "popops::ScaledMultiUpdateOp"
return make_ipu_vertex_name_templated(basename, dtype, dtype, False, opname)
else:
basename = "popops::MultiUpdateOp"
return make_ipu_vertex_name_templated(basename, dtype, False, opname)


def check_scatter_dimension_numbers(dimension_numbers: ScatterDimensionNumbers):
"""Check `scatter` dimension_numbers is supported on TessellateIPU.
At the moment: basically only supporting a single configuration!
We need to expand on this at some point!
"""
dim_numbers_default = ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)
)
if dimension_numbers != dim_numbers_default:
raise NotImplementedError(f"TessellateIPU `scatter` only support dimension numbers: {dim_numbers_default}.")


def ipu_scatter_op_primitive_translation(
p: Primitive,
tiles: Tuple[int, ...],
inavals: List[ShapedArray],
attributes: Dict[str, Any] = None,
) -> IpuTileMapEquation:
"""IPU `scatter_xx` primitive translation rule to IPU vertex.
See: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scatter.html
Args:
p: JAX primitive.
tiles: Collection of tiles.
inavals: Input data + start indices arrays.
attributes: Gather operator attributes
Returns:
IPU tile map primitive structure.
"""
# TODO: query for JAX device.
num_context_workers = 6

assert len(inavals) == 3
assert attributes is not None
operand, scatter_indices, updates = inavals
# Extract scatter attributes
dimension_numbers = attributes["dimension_numbers"]
# Default values from JAX LAX interface.
indices_are_sorted = attributes.get("indices_are_sorted", False)
unique_indices = attributes.get("unique_indices", False)
mode = attributes.get("mode", GatherScatterMode.PROMISE_IN_BOUNDS)

# Check scatter attributes are supported by TessellateIPU.
assert operand.ndim == 1
assert scatter_indices.ndim == 2
assert operand.dtype == updates.dtype
assert scatter_indices.dtype == np.uint32, "TessellateIPU `scatter` only supports `uint32` indices."
if indices_are_sorted:
logging.warning("TessellateIPU `scatter` operation does not make use of `indices_are_sorted` argument.")
if unique_indices:
logging.warning("TessellateIPU `scatter` operation does not make use of `unique_indices` argument.")
assert (
mode == GatherScatterMode.PROMISE_IN_BOUNDS
), "Only `PROMISE_IN_BOUNDS` scatter mode supported in TessellateIPU."
check_scatter_dimension_numbers(dimension_numbers)

# Primitive translation properties.
scale, opname = _scatter_primitive_to_properties[p]
vname = make_scatter_vertex_fullname(operand.dtype, opname, scale)
# Construct poplibs MultiSlice vertex attributes.
attrs_i32, attrs_f32 = make_ipu_vertex_attributes(
baseOffset=0, # unused?
numBaseElements=operand.size, # Number of elements in input.
maxElementsPerWorker=int(np.ceil(operand.size / num_context_workers)),
regionSize=1, # TODO: understand?
indicesAreSorted=False,
)

# Constant `scale` (if required by the vertex).
constants_info = []
if scale is not None:
constants_info = [make_ipu_vertex_constant_info("scale", np.array(scale, dtype=operand.dtype), vertex_dim2=-1)]
# For now: need to do it manually at the Python `tile_map` level.
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
tiles=tiles,
inputs_info=[
make_ipu_vertex_inout_info("baseT", operand),
make_ipu_vertex_in_info("offsets", scatter_indices),
make_ipu_vertex_in_info("subT", updates),
]
+ constants_info,
outputs_info=[make_ipu_vertex_inout_info("baseT", operand)],
attributes_i32=attrs_i32,
attributes_f32=attrs_f32,
)
return ipu_prim_info


def ipu_scatter_primitive_translation(
p: Primitive,
tiles: Tuple[int, ...],
inavals: List[ShapedArray],
attributes: Dict[str, Any] = None,
) -> IpuTileMapEquation:
"""IPU `scatter` primitive translation rule to IPU vertex.
Note: using a specific translation, as the poplibs vertex is different.
See: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scatter.html
Args:
p: JAX primitive.
tiles: Collection of tiles.
inavals: Input data + start indices arrays.
attributes: Gather operator attributes
Returns:
IPU tile map primitive structure.
"""
# TODO: query for JAX device.
num_context_workers = 6

assert len(inavals) == 3
assert attributes is not None
operand, scatter_indices, updates = inavals
# Extract scatter attributes
dimension_numbers = attributes["dimension_numbers"]
# Default values from JAX LAX interface.
indices_are_sorted = attributes.get("indices_are_sorted", False)
unique_indices = attributes.get("unique_indices", False)
mode = attributes.get("mode", GatherScatterMode.PROMISE_IN_BOUNDS)

# Check scatter attributes are supported by TessellateIPU.
assert operand.ndim == 1
assert scatter_indices.ndim == 2
assert operand.dtype == updates.dtype
assert scatter_indices.dtype == np.uint32, "TessellateIPU `scatter` only supports `uint32` indices."
if indices_are_sorted:
logging.warning("TessellateIPU `scatter` operation does not make use of `indices_are_sorted` argument.")
if unique_indices:
logging.warning("TessellateIPU `scatter` operation does not make use of `unique_indices` argument.")
assert (
mode == GatherScatterMode.PROMISE_IN_BOUNDS
), "Only `PROMISE_IN_BOUNDS` scatter mode supported in TessellateIPU."
check_scatter_dimension_numbers(dimension_numbers)

vname = make_ipu_vertex_name_templated("popops::MultiUpdate", operand.dtype)
# Construct poplibs MultiSlice vertex attributes.
attrs_i32, attrs_f32 = make_ipu_vertex_attributes(
baseOffset=0, # unused?
numBaseElements=operand.size, # Number of elements in input.
maxElementsPerWorker=int(np.ceil(operand.size / num_context_workers)),
regionSize=1, # TODO: understand?
indicesAreSorted=False,
splitSingleRegion=True,
)
# For now: need to do it manually at the Python `tile_map` level.
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
tiles=tiles,
inputs_info=[
make_ipu_vertex_inout_info("baseT", operand),
make_ipu_vertex_in_info("offsets", scatter_indices),
make_ipu_vertex_in_info("subT", updates),
],
outputs_info=[make_ipu_vertex_inout_info("baseT", operand)],
attributes_i32=attrs_i32,
attributes_f32=attrs_f32,
)
return ipu_prim_info


# Register JAX `scatter` primitives with update op.
for p in _scatter_primitive_to_properties.keys():
register_ipu_tile_primitive(p, ipu_scatter_op_primitive_translation)
# Specific translation for the simple `scatter` case
register_ipu_tile_primitive(scatter_p, ipu_scatter_primitive_translation)
3 changes: 2 additions & 1 deletion tessellate_ipu/lib/tessellate_ipu_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ NB_MODULE(pytessellate_ipu_core, m) {
nanobind::arg("shape"), nanobind::arg("dtype"),
nanobind::arg("constant_data"), nanobind::arg("slices2d"))
.def(nanobind::init<const std::string&, VertexIOType, const ShapeType&,
IpuType, std::size_t, const Base64Data&>(),
IpuType, int64_t, const Base64Data&>(),
nanobind::arg("name"), nanobind::arg("iotype"),
nanobind::arg("shape"), nanobind::arg("dtype"),
nanobind::arg("vertex_dim2") = 0,
Expand All @@ -118,6 +118,7 @@ NB_MODULE(pytessellate_ipu_core, m) {
.def_rw("aval", &VertexIOInfo::aval)
.def_rw("constant_data", &VertexIOInfo::constant_data)
.def_rw("slices2d", &VertexIOInfo::slices2d)
.def_rw("is_scalar", &VertexIOInfo::is_scalar)
.def_prop_ro("shape", [](const VertexIOInfo& v) { return v.aval.shape; })
.def_prop_ro("dtype", [](const VertexIOInfo& v) { return v.aval.dtype; })
.def_prop_ro("is_constant_input", &VertexIOInfo::isConstantInput);
Expand Down
4 changes: 3 additions & 1 deletion tessellate_ipu/lib/tile_map_ops.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) 2022 Graphcore Ltd. All rights reserved.
#include "tile_map_ops.hpp"

#include <iostream>
namespace ipu {

std::vector<poplar::Tensor> TileMapEquation::allocateInputTensors(
Expand Down Expand Up @@ -91,7 +92,8 @@ void TileMapEquation::add(poplar::Graph& graph, poplar::program::Sequence& prog,
// Map/connect vertex input tensors.
for (size_t k = 0; k < inputs.size(); ++k) {
const auto& info = inputs_info[k];
graph.connect(v[info.name], info.connectReshape(inputs[k][tidx]));
const auto tensor = info.connectReshape(inputs[k][tidx]);
graph.connect(v[info.name], tensor);
}
// Map/connect vertex output tensors.
for (size_t k = 0; k < outputs.size(); ++k) {
Expand Down
27 changes: 23 additions & 4 deletions tessellate_ipu/lib/tile_map_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ struct VertexIOInfo {
Base64Data constant_data = Base64Data();
/** Slices, in the case of 2d tensor input. */
std::vector<TensorSlice> slices2d;
/** Is the vertex IO tensor just a scalar? */
bool is_scalar = false;

/** Default constructors/assignment. */
VertexIOInfo() noexcept = default;
Expand Down Expand Up @@ -91,8 +93,8 @@ struct VertexIOInfo {
* @brief Build a vertex IO info (with vertex second dim info).
*/
VertexIOInfo(const std::string& _name, VertexIOType _iotype,
const ShapeType& _shape, IpuType _dtype,
std::size_t _vertex_dim2, const Base64Data& _constant_data)
const ShapeType& _shape, IpuType _dtype, int64_t _vertex_dim2,
const Base64Data& _constant_data)
: name{_name},
iotype{_iotype},
aval{_shape, _dtype},
Expand All @@ -102,6 +104,11 @@ struct VertexIOInfo {
slices2d = TensorSlice::makeTensor2dSlices(aval.size() / _vertex_dim2,
_vertex_dim2);
}
// Negative => code for scalar.
if (_vertex_dim2 < 0) {
is_scalar = true;
}
// Zero => normal flattened case.
}

/**
Expand Down Expand Up @@ -138,8 +145,20 @@ struct VertexIOInfo {

/**
* @brief Reshape a tensor to the proper rank for vertex connection.
*
* This bit of logic is necessary as Poplar vertices only support:
* rank 0: i.e. scalar entry;
* rank 1: flattened array;
* rank 2: collection of tensor slices;
*/
poplar::Tensor connectReshape(const poplar::Tensor& t) const {
if (is_scalar) {
if (t.numElements() != 1) {
throw std::logic_error(
"Expecting a single scalar element to connect to the vertex.");
}
return t.flatten()[0];
}
if (slices2d.empty()) {
// Rank 1 (no 2d slices): flatten the IO tensor.
return t.flatten();
Expand All @@ -159,12 +178,12 @@ struct VertexIOInfo {
}
};
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(VertexIOInfo, name, iotype, aval,
constant_data, slices2d)
constant_data, slices2d, is_scalar)

inline bool operator==(const VertexIOInfo& lhs, const VertexIOInfo& rhs) {
return lhs.name == rhs.name && lhs.iotype == rhs.iotype &&
lhs.aval.shape == rhs.aval.shape && lhs.aval.dtype == rhs.aval.dtype;
// TODO: compare 2d slices.
// TODO: compare 2d slices and is_scalar?
}

/**
Expand Down
4 changes: 2 additions & 2 deletions tests/core/custom_arange_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def custom_arange_tile_translation_ipu(
outaval = core.ShapedArray(outshape, outdtype)
gp_filename = custom_vertex_filename

global_scale_data = np.array([7], dtype=outdtype)
global_scale_data = np.array(7, dtype=outdtype)
ipu_dtype = from_numpy_dtype_to_ipu_type(outdtype)
vertex_name = f"CustomArangeVertex<{ipu_dtype.name.lower()}>"
# Translation rule to IPU vertex
Expand All @@ -69,7 +69,7 @@ def custom_arange_tile_translation_ipu(
# IO vertex infos.
inputs_info=[
make_ipu_vertex_in_info("scales", inavals[0], vertex_dim2=inavals[0].shape[1]),
make_ipu_vertex_constant_info("global_scale", global_scale_data),
make_ipu_vertex_constant_info("global_scale", global_scale_data, vertex_dim2=-1),
],
outputs_info=[make_ipu_vertex_out_info("out", outaval)],
# Additional attributes to pass to the vertex
Expand Down
5 changes: 3 additions & 2 deletions tests/core/custom_arange_vertex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ class CustomArangeVertex : public Vertex {
// Testing 2d tensor IO supported.
Vector<Input<Vector<T>>, poplar::VectorLayout::ONE_PTR> scales; // (2, size)
// Testing constant vertex tensor.
Input<Vector<T, poplar::VectorLayout::ONE_PTR>> global_scale; // (1,)
// Input<Vector<T, poplar::VectorLayout::ONE_PTR>> global_scale; // (1,)
Input<T> global_scale; // (,) scalar
Output<Vector<T, poplar::VectorLayout::SPAN>> out; // (size, )

bool compute() {
const auto outsize = out.size();
for (std::size_t idx = 0; idx < outsize; ++idx) {
out[idx] = T(idx) * scales[0][idx] * scales[1][idx] * global_scale[0];
out[idx] = T(idx) * scales[0][idx] * scales[1][idx] * (*global_scale);
}
return true;
}
Expand Down
Loading

0 comments on commit 49dea55

Please sign in to comment.