Skip to content

Commit

Permalink
BUG: fix unit handling for np.take and unyt_array.take (#551)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyleaoman authored Dec 25, 2024
1 parent 41b4b36 commit d78655f
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
20 changes: 20 additions & 0 deletions unyt/_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,3 +1220,23 @@ def in1d(ar1, ar2, *args, **kwargs):
return np.isin._implementation(
np.asarray(ar1), np.asarray(ar2), *args, **kwargs
)


@implements(np.take)
def take(a, indices, axis=None, out=None, mode="raise"):
ret_units = getattr(a, "units", NULL_UNIT)

if out is not None:
out_view = np.asarray(out)
else:
out_view = None

res = np.take._implementation(
np.asarray(a), indices, axis=axis, out=out_view, mode=mode
)

if getattr(out, "units", None) is not None:
out.units = ret_units

ret_cls = unyt_quantity if res.ndim == 0 else unyt_array
return ret_cls(res, ret_units, bypass_validation=True)
15 changes: 15 additions & 0 deletions unyt/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2125,6 +2125,21 @@ def dot(self, b, out=None):
out.units = res_units
return ret

def take(self, indices, axis=None, out=None, mode="raise"):
"""method
Return an array formed from the elements of `a` at the given indices.
Refer to :func:`numpy.take` for full documentation.
See also
--------
numpy.take : equivalent function
"""
from ._array_functions import take

return take(self, indices, axis=axis, out=out, mode=mode)

def __reduce__(self):
"""Pickle reduction method
Expand Down
22 changes: 18 additions & 4 deletions unyt/tests/test_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@
np.nancumprod, # we get it for free with np.cumprod (tested)
np.bincount, # works out of the box (tested)
np.unique, # works out of the box (tested)
np.take, # works out of the box (tested)
np.min_scalar_type, # returns dtypes
np.extract, # works out of the box (tested)
np.setxor1d, # we get it for free with previously implemented functions (tested)
Expand Down Expand Up @@ -1848,10 +1847,25 @@ def test_unique_values():
assert_array_equal_units(res, values)


def test_take():
@pytest.mark.parametrize("indices", [[0, 1], 0])
def test_take(indices):
a = [1, 2, 3] * cm
res = np.take(a, [0, 1])
assert type(res) is unyt_array
res = np.take(a, indices)
if res.ndim == 0:
assert type(res) is unyt_quantity
else:
assert type(res) is unyt_array
assert res.units == cm


@pytest.mark.parametrize("indices", [[0, 1], 0])
def test_ndarray_take(indices):
a = [1, 2, 3] * cm
res = a.take(indices)
if res.ndim == 0:
assert type(res) is unyt_quantity
else:
assert type(res) is unyt_array
assert res.units == cm


Expand Down

0 comments on commit d78655f

Please sign in to comment.