From 4ac4224b1f9b3dd41065958c518bac8e001bf066 Mon Sep 17 00:00:00 2001 From: Kyle Oman Date: Fri, 20 Dec 2024 11:58:35 +0000 Subject: [PATCH 01/11] Implement np.take and np.ndarray.take functions. --- unyt/_array_functions.py | 22 ++++++++++++++++++++++ unyt/array.py | 5 +++++ unyt/tests/test_array_functions.py | 1 - 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index 03a4875d..7d7012b2 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -1220,3 +1220,25 @@ def in1d(ar1, ar2, *args, **kwargs): return np.isin._implementation( np.asarray(ar1), np.asarray(ar2), *args, **kwargs ) + + +@implements(np.take) +def take(a, *args, out=None, **kwargs): + retu = get_units((a, )) + if out is None: + return ( + np.take._implementation( + a, *args, **kwargs + ) + * retu + ) + + res = np.take._implementation( + np.asarray(a), + *args, + out=np.asarray(out), + **kwargs, + ) + if getattr(out, "units", None) is not None: + out.units = retu + return unyt_array(res, retu, bypass_validation=True) diff --git a/unyt/array.py b/unyt/array.py index b3982dbf..8c9f0b23 100644 --- a/unyt/array.py +++ b/unyt/array.py @@ -2125,6 +2125,11 @@ def dot(self, b, out=None): out.units = res_units return ret + def take(self, *args, **kwargs): + from ._array_functions import take as unyt_take + + return unyt_take(self, *args, **kwargs) + def __reduce__(self): """Pickle reduction method diff --git a/unyt/tests/test_array_functions.py b/unyt/tests/test_array_functions.py index a8b582c6..263983de 100644 --- a/unyt/tests/test_array_functions.py +++ b/unyt/tests/test_array_functions.py @@ -143,7 +143,6 @@ np.nancumprod, # we get it for free with np.cumprod (tested) np.bincount, # works out of the box (tested) np.unique, # works out of the box (tested) - np.take, # works out of the box (tested) np.min_scalar_type, # returns dtypes np.extract, # works out of the box (tested) np.setxor1d, # we get it for free with previously implemented functions (tested) From f75d25a9c00eb1e1949b236db6db93da3466b2b9 Mon Sep 17 00:00:00 2001 From: Kyle Oman Date: Fri, 20 Dec 2024 21:11:16 +0000 Subject: [PATCH 02/11] Amend take helper arguments. --- unyt/_array_functions.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index 7d7012b2..c334dd39 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -1223,20 +1223,22 @@ def in1d(ar1, ar2, *args, **kwargs): @implements(np.take) -def take(a, *args, out=None, **kwargs): +def take(a, indices, axis=None, out=None, *args, **kwargs): retu = get_units((a, )) if out is None: return ( np.take._implementation( - a, *args, **kwargs + a, indices, axis=axis, out=out, *args, **kwargs ) * retu ) res = np.take._implementation( np.asarray(a), - *args, + indices, + axis=axis, out=np.asarray(out), + *args, **kwargs, ) if getattr(out, "units", None) is not None: From 3b02fd13c78ce4c3894ed591425cbc0f732a3d76 Mon Sep 17 00:00:00 2001 From: Kyle Oman Date: Fri, 20 Dec 2024 21:15:13 +0000 Subject: [PATCH 03/11] Avoid infinite recursion. --- unyt/_array_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index c334dd39..05e9fd5c 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -1228,7 +1228,7 @@ def take(a, indices, axis=None, out=None, *args, **kwargs): if out is None: return ( np.take._implementation( - a, indices, axis=axis, out=out, *args, **kwargs + np.asarray(a), indices, axis=axis, out=out, *args, **kwargs ) * retu ) From 060a8eb0b5ac29c9608d1adb201bb44a842376d3 Mon Sep 17 00:00:00 2001 From: Kyle Oman Date: Sat, 21 Dec 2024 08:58:06 +0000 Subject: [PATCH 04/11] Basic working implementation for take. --- unyt/_array_functions.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index 05e9fd5c..91628cd1 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -1224,23 +1224,18 @@ def in1d(ar1, ar2, *args, **kwargs): @implements(np.take) def take(a, indices, axis=None, out=None, *args, **kwargs): - retu = get_units((a, )) - if out is None: - return ( - np.take._implementation( - np.asarray(a), indices, axis=axis, out=out, *args, **kwargs - ) - * retu - ) + ret_units = get_units((a, ))[0] + + if out is not None: + out_view = np.asarray(out) + else: + out_view = out res = np.take._implementation( - np.asarray(a), - indices, - axis=axis, - out=np.asarray(out), - *args, - **kwargs, + np.asarray(a), indices, axis=axis, out=out_view, *args, **kwargs ) + if getattr(out, "units", None) is not None: - out.units = retu - return unyt_array(res, retu, bypass_validation=True) + out.units = ret_units + + return unyt_array(res, ret_units, bypass_validation=True) From 62314711ca651bb96433f4b2aac9c9797e6d1334 Mon Sep 17 00:00:00 2001 From: Kyle Oman Date: Sat, 21 Dec 2024 20:40:26 +0000 Subject: [PATCH 05/11] Test for and handle a scalar return from np.take. --- unyt/_array_functions.py | 7 +++++-- unyt/tests/test_array_functions.py | 9 ++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index 91628cd1..0e7eb8dc 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -1229,7 +1229,7 @@ def take(a, indices, axis=None, out=None, *args, **kwargs): if out is not None: out_view = np.asarray(out) else: - out_view = out + out_view = None res = np.take._implementation( np.asarray(a), indices, axis=axis, out=out_view, *args, **kwargs @@ -1238,4 +1238,7 @@ def take(a, indices, axis=None, out=None, *args, **kwargs): if getattr(out, "units", None) is not None: out.units = ret_units - return unyt_array(res, ret_units, bypass_validation=True) + if res.ndim > 0: + return unyt_array(res, ret_units, bypass_validation=True) + else: + return unyt_quantity(res, ret_units, bypass_validation=True) diff --git a/unyt/tests/test_array_functions.py b/unyt/tests/test_array_functions.py index 263983de..7454c1a8 100644 --- a/unyt/tests/test_array_functions.py +++ b/unyt/tests/test_array_functions.py @@ -1847,10 +1847,13 @@ def test_unique_values(): assert_array_equal_units(res, values) -def test_take(): +@pytest.mark.parametrize("indices", [[0, 1], 0]) +def test_take(indices): a = [1, 2, 3] * cm - res = np.take(a, [0, 1]) - assert type(res) is unyt_array + res = np.take(a, indices) + assert isinstance(res, unyt_array) # can be subclass + if res.ndim == 0: + assert isinstance(res, unyt_quantity) assert res.units == cm From cfcde5392453516f3db4e37012cef98027132d18 Mon Sep 17 00:00:00 2001 From: Kyle Oman Date: Sat, 21 Dec 2024 20:42:56 +0000 Subject: [PATCH 06/11] Also test np.ndarray.take. --- unyt/tests/test_array_functions.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/unyt/tests/test_array_functions.py b/unyt/tests/test_array_functions.py index 7454c1a8..ffa75d08 100644 --- a/unyt/tests/test_array_functions.py +++ b/unyt/tests/test_array_functions.py @@ -1857,6 +1857,16 @@ def test_take(indices): assert res.units == cm +@pytest.mark.parametrize("indices", [[0, 1], 0]) +def test_ndarray_take(indices): + a = [1, 2, 3] * cm + res = a.take(indices) + assert isinstance(res, unyt_array) # can be subclass + if res.ndim == 0: + assert isinstance(res, unyt_quantity) + assert res.units == cm + + def test_pad(): a = [1, 2, 3] * cm res = np.pad(a, [0, 1]) From 6ca49cb35c799398b3246ab898bcbbd7ea332986 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 22 Dec 2024 09:45:58 +0000 Subject: [PATCH 07/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unyt/_array_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index 0e7eb8dc..e561e728 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -1224,7 +1224,7 @@ def in1d(ar1, ar2, *args, **kwargs): @implements(np.take) def take(a, indices, axis=None, out=None, *args, **kwargs): - ret_units = get_units((a, ))[0] + ret_units = get_units((a,))[0] if out is not None: out_view = np.asarray(out) From 3c16a8e0f0bc6ca744ced30164bb6a847aeb79ae Mon Sep 17 00:00:00 2001 From: Kyle Oman Date: Mon, 23 Dec 2024 17:25:36 +0000 Subject: [PATCH 08/11] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Clément Robert --- unyt/_array_functions.py | 8 +++----- unyt/array.py | 4 ++-- unyt/tests/test_array_functions.py | 10 ++++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index e561e728..88eac339 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -1224,7 +1224,7 @@ def in1d(ar1, ar2, *args, **kwargs): @implements(np.take) def take(a, indices, axis=None, out=None, *args, **kwargs): - ret_units = get_units((a,))[0] + ret_units = getattr(a, "units", NULL_UNIT) if out is not None: out_view = np.asarray(out) @@ -1238,7 +1238,5 @@ def take(a, indices, axis=None, out=None, *args, **kwargs): if getattr(out, "units", None) is not None: out.units = ret_units - if res.ndim > 0: - return unyt_array(res, ret_units, bypass_validation=True) - else: - return unyt_quantity(res, ret_units, bypass_validation=True) + ret_cls = unyt_quantity if res.ndim == 0 else unyt_array + return ret_cls(res, ret_units, bypass_validation=True) diff --git a/unyt/array.py b/unyt/array.py index 8c9f0b23..c3260030 100644 --- a/unyt/array.py +++ b/unyt/array.py @@ -2126,9 +2126,9 @@ def dot(self, b, out=None): return ret def take(self, *args, **kwargs): - from ._array_functions import take as unyt_take + from ._array_functions import take - return unyt_take(self, *args, **kwargs) + return take(self, *args, **kwargs) def __reduce__(self): """Pickle reduction method diff --git a/unyt/tests/test_array_functions.py b/unyt/tests/test_array_functions.py index ffa75d08..3da71efe 100644 --- a/unyt/tests/test_array_functions.py +++ b/unyt/tests/test_array_functions.py @@ -1851,9 +1851,10 @@ def test_unique_values(): def test_take(indices): a = [1, 2, 3] * cm res = np.take(a, indices) - assert isinstance(res, unyt_array) # can be subclass if res.ndim == 0: - assert isinstance(res, unyt_quantity) + assert type(res) is unyt_quantity + else: + assert type(res) is unyt_array assert res.units == cm @@ -1861,9 +1862,10 @@ def test_take(indices): def test_ndarray_take(indices): a = [1, 2, 3] * cm res = a.take(indices) - assert isinstance(res, unyt_array) # can be subclass if res.ndim == 0: - assert isinstance(res, unyt_quantity) + assert type(res) is unyt_quantity + else: + assert type(res) is unyt_array assert res.units == cm From e889370da4ae9f5d23ad258f9d6350fd620da256 Mon Sep 17 00:00:00 2001 From: Kyle Oman Date: Mon, 23 Dec 2024 17:31:35 +0000 Subject: [PATCH 09/11] Add docstring to array take, mimic numpy take signature. --- unyt/_array_functions.py | 4 ++-- unyt/array.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index 88eac339..b32d15fa 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -1223,7 +1223,7 @@ def in1d(ar1, ar2, *args, **kwargs): @implements(np.take) -def take(a, indices, axis=None, out=None, *args, **kwargs): +def take(a, indices, axis=None, out=None, mode="raise"): ret_units = getattr(a, "units", NULL_UNIT) if out is not None: @@ -1232,7 +1232,7 @@ def take(a, indices, axis=None, out=None, *args, **kwargs): out_view = None res = np.take._implementation( - np.asarray(a), indices, axis=axis, out=out_view, *args, **kwargs + np.asarray(a), indices, axis=axis, out=out_view, mode=mode ) if getattr(out, "units", None) is not None: diff --git a/unyt/array.py b/unyt/array.py index c3260030..f15d6952 100644 --- a/unyt/array.py +++ b/unyt/array.py @@ -2126,6 +2126,16 @@ def dot(self, b, out=None): return ret def take(self, *args, **kwargs): + """method + + Return an array formed from the elements of `a` at the given indices. + + Refer to :func:`numpy.take` for full documentation. + + See also + -------- + numpy.take : equivalent function + """ from ._array_functions import take return take(self, *args, **kwargs) From 2c850fbd5cad14b025c0e737c8d90b44376360d4 Mon Sep 17 00:00:00 2001 From: Kyle Oman Date: Tue, 24 Dec 2024 09:55:37 +0000 Subject: [PATCH 10/11] Explicit arguments for unyt_array.take --- unyt/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unyt/array.py b/unyt/array.py index f15d6952..a1ddeda9 100644 --- a/unyt/array.py +++ b/unyt/array.py @@ -2125,7 +2125,7 @@ def dot(self, b, out=None): out.units = res_units return ret - def take(self, *args, **kwargs): + def take(self, indices, axis=None, out=None, mode="raise"): """method Return an array formed from the elements of `a` at the given indices. From 8d6ba7121bc394c8e134be6705e05354726fcc5d Mon Sep 17 00:00:00 2001 From: Kyle Oman Date: Tue, 24 Dec 2024 10:17:29 +0000 Subject: [PATCH 11/11] ... and pass arguments through. --- unyt/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unyt/array.py b/unyt/array.py index a1ddeda9..c0182875 100644 --- a/unyt/array.py +++ b/unyt/array.py @@ -2138,7 +2138,7 @@ def take(self, indices, axis=None, out=None, mode="raise"): """ from ._array_functions import take - return take(self, *args, **kwargs) + return take(self, indices, axis=axis, out=out, mode=mode) def __reduce__(self): """Pickle reduction method