Skip to content

Commit

Permalink
[JAX] Make a one hot mode of take along axis.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681139055
  • Loading branch information
blakehechtman authored and Google-ML-Automation committed Oct 1, 2024
1 parent afed9f4 commit ce21a12
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 18 deletions.
3 changes: 3 additions & 0 deletions jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ class GatherScatterMode(enum.Enum):
CLIP = enum.auto()
FILL_OR_DROP = enum.auto()
PROMISE_IN_BOUNDS = enum.auto()
ONE_HOT = enum.auto()

@staticmethod
def from_any(s: str | GatherScatterMode | None):
Expand All @@ -278,6 +279,8 @@ def from_any(s: str | GatherScatterMode | None):
return GatherScatterMode.FILL_OR_DROP
if s == "promise_in_bounds":
return GatherScatterMode.PROMISE_IN_BOUNDS
if s == "one_hot":
return GatherScatterMode.ONE_HOT
else:
raise ValueError(f'Unknown gather mode "{s}"')

Expand Down
55 changes: 37 additions & 18 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,48 +32,47 @@
import importlib
import math
import operator
import string
import types
from typing import (overload, Any, Literal, NamedTuple,
Protocol, TypeVar, Union)
from typing import ( Any, Literal, NamedTuple,
Protocol, TypeVar, Union,overload)
import warnings

import numpy as np
import opt_einsum

import jax
from jax import jit
from jax import errors
from jax import jit
from jax import lax
from jax.sharding import Sharding, SingleDeviceSharding
from jax.tree_util import tree_leaves, tree_flatten, tree_map

from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src.custom_derivatives import custom_jvp
from jax._src import deprecations
from jax._src import dispatch
from jax._src import dtypes
from jax._src import xla_bridge
from jax._src.api_util import _ensure_index_tuple
from jax._src.array import ArrayImpl
from jax._src.core import ShapedArray, ConcreteArray
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
_sort_le_comparator, PrecisionLike)
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.custom_derivatives import custom_jvp
from jax._src.lax import lax as lax_internal
from jax._src.lax.lax import ( PrecisionLike,_array_copy,
_sort_le_comparator, _sort_lt_comparator)
from jax._src.lib import xla_client as xc
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax._src.numpy import util
from jax._src.numpy.vectorize import vectorize
from jax._src.typing import (
Array, ArrayLike, DeprecatedArg, DimSize, DuckTypedArray,
DType, DTypeLike, Shape, StaticScalar,
Array, ArrayLike,
DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape, StaticScalar,
)
from jax._src.util import (unzip2, subvals, safe_zip,
ceil_of_ratio, partition_list,
from jax._src.util import (
NumpyComplexWarning,
canonicalize_axis as _canonicalize_axis,
NumpyComplexWarning)
ceil_of_ratio, partition_list, safe_zip, subvals,unzip2)
from jax.sharding import Sharding, SingleDeviceSharding
from jax.tree_util import tree_flatten, tree_leaves, tree_map
import numpy as np
import opt_einsum

for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']:
try:
Expand Down Expand Up @@ -10333,6 +10332,26 @@ def replace(tup, val):
out_shape = lax.broadcast_shapes(idx_shape, arr_shape)
if axis_size == 0:
return zeros(out_shape, a.dtype)

if mode == "one_hot":
indices = _normalize_index(indices, axis_size)
hot = jax.nn.one_hot(indices, axis_size, dtype=bool_)
if a.ndim == 1:
return einsum("...b,b->...", hot, a, preferred_element_type=a.dtype)
if axis_int > len(string.ascii_letters) - 2:
raise ValueError(
"One Hot indexing is only supported for up to 50 leading dimensions."
)
labels = "".join([string.ascii_letters[i] for i in range(axis_int)])
eq = labels + "y...z," + labels + "z...->" + labels + "y..."
return einsum(
eq,
hot,
a,
precision=lax.Precision.HIGHEST,
preferred_element_type=a.dtype,
)

index_dims = [i for i, idx in enumerate(idx_shape) if i == axis_int or not core.definitely_equal(idx, 1)]

gather_index_shape = tuple(np.array(out_shape)[index_dims]) + (1,)
Expand Down
5 changes: 5 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4626,11 +4626,16 @@ def args_maker():
return x, i

jnp_op = lambda x, i: jnp.take_along_axis(x, i, axis=axis)
jnp_one_hot_op = lambda x, i: jnp.take_along_axis(
x, i, axis=axis, mode='one_hot'
)

if hasattr(np, "take_along_axis"):
np_op = lambda x, i: np.take_along_axis(x, i, axis=axis)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CheckAgainstNumpy(np_op, jnp_one_hot_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
self._CompileAndCheck(jnp_one_hot_op, args_maker)

def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self):
# https://github.com/jax-ml/jax/issues/5088
Expand Down

0 comments on commit ce21a12

Please sign in to comment.