diff --git a/src/ragged/_spec_array_object.py b/src/ragged/_spec_array_object.py index b300fc5bc1..ccf53b0966 100644 --- a/src/ragged/_spec_array_object.py +++ b/src/ragged/_spec_array_object.py @@ -216,6 +216,14 @@ def __init__( elif isinstance(self._impl, np.ndarray) and device == "cuda": cp = _import.cupy() self._impl = cp.array(self._impl) + self._device = device + else: + if isinstance(self._impl, ak.Array): + self._device = ak.backend(self._impl) + elif isinstance(self._impl, np.ndarray): + self._device = "cpu" + else: + self._device = "cuda" if copy is not None: raise NotImplementedError("TODO 1") # noqa: EM101 @@ -1101,6 +1109,32 @@ def __irshift__(self, other: int | array, /) -> array: __rrshift__ = __rshift__ +def _is_shared( + x1: array | ak.Array | SupportsDLPack, x2: array | ak.Array | SupportsDLPack +) -> bool: + x1_buf = x1._impl if isinstance(x1, array) else x1 # pylint: disable=W0212 + x2_buf = x2._impl if isinstance(x2, array) else x2 # pylint: disable=W0212 + + if isinstance(x1_buf, ak.Array): + x1_buf = x1_buf.layout + while not isinstance(x1_buf, NumpyArray): + x1_buf = x1_buf.content + x1_buf = x1_buf.data + + if isinstance(x2_buf, ak.Array): + x2_buf = x2_buf.layout + while not isinstance(x2_buf, NumpyArray): + x2_buf = x2_buf.content + x2_buf = x2_buf.data + + while x1_buf.base is not None: # type: ignore[union-attr] + x1_buf = x1_buf.base # type: ignore[union-attr] + while x2_buf.base is not None: # type: ignore[union-attr] + x2_buf = x2_buf.base # type: ignore[union-attr] + + return x1_buf is x2_buf + + def _unbox(*inputs: array) -> tuple[ak.Array | SupportsDLPack, ...]: if len(inputs) > 1 and any(type(inputs[0]) is not type(x) for x in inputs): types = "\n".join(f"{type(x).__module__}.{type(x).__name__}" for x in inputs) @@ -1115,6 +1149,7 @@ def _box( output: ak.Array | np.number | SupportsDLPack, *, dtype: None | Dtype = None, + device: None | Device = None, ) -> array: if isinstance(output, ak.Array): impl = output @@ -1123,7 +1158,11 @@ def _box( impl = ak.values_astype(impl, dtype) else: dtype = dtype_observed - device = ak.backend(output) + device_observed = ak.backend(output) + if device is None: + device = device_observed + elif device != device_observed: + output = ak.to_backend(output, device) elif isinstance(output, np.number): impl = np.array(output) @@ -1133,7 +1172,12 @@ def _box( impl = impl.astype(dtype) else: dtype = dtype_observed - device = "cpu" + device_observed = "cpu" + if device is None: + device = device_observed + elif device != device_observed: + cp = _import.cupy() + output = cp.array(output) else: impl = output @@ -1143,7 +1187,16 @@ def _box( impl = impl.astype(dtype) else: dtype = dtype_observed - device = "cpu" if isinstance(output, np.ndarray) else "cuda" + device_observed = "cpu" if isinstance(output, np.ndarray) else "cuda" + if device is None: + device = device_observed + elif device != device_observed: + if device == "cpu": + output = np.array(output) + else: + cp = _import.cupy() + output = cp.array(output) + if shape != (): impl = ak.Array(impl) diff --git a/src/ragged/_spec_creation_functions.py b/src/ragged/_spec_creation_functions.py index 189a787abf..844eceb6a1 100644 --- a/src/ragged/_spec_creation_functions.py +++ b/src/ragged/_spec_creation_functions.py @@ -13,7 +13,7 @@ from . import _import from ._import import device_namespace -from ._spec_array_object import _box, array +from ._spec_array_object import _box, _unbox, array from ._typing import ( Device, Dtype, @@ -161,10 +161,12 @@ def empty_like( https://data-apis.org/array-api/latest/API_specification/generated/array_api.empty_like.html """ - x # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 37") # noqa: EM101 + (impl,) = _unbox(x) + if isinstance(impl, ak.Array): + return _box(type(x), ak.zeros_like(impl), dtype=dtype, device=device) + else: + _, ns = device_namespace(x.device if device is None else device) + return _box(type(x), ns.empty_like(impl), dtype=dtype, device=device) def eye( @@ -292,11 +294,12 @@ def full_like( https://data-apis.org/array-api/latest/API_specification/generated/array_api.full_like.html """ - x # noqa: B018, pylint: disable=W0104 - fill_value # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 41") # noqa: EM101 + (impl,) = _unbox(x) + if isinstance(impl, ak.Array): + return _box(type(x), ak.full_like(impl, fill_value), dtype=dtype, device=device) + else: + _, ns = device_namespace(x.device if device is None else device) + return _box(type(x), ns.full_like(impl, fill_value), dtype=dtype, device=device) def linspace( @@ -441,10 +444,12 @@ def ones_like( https://data-apis.org/array-api/latest/API_specification/generated/array_api.ones_like.html """ - x # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 45") # noqa: EM101 + (impl,) = _unbox(x) + if isinstance(impl, ak.Array): + return _box(type(x), ak.ones_like(impl), dtype=dtype, device=device) + else: + _, ns = device_namespace(x.device if device is None else device) + return _box(type(x), ns.ones_like(impl), dtype=dtype, device=device) def tril(x: array, /, *, k: int = 0) -> array: @@ -542,7 +547,9 @@ def zeros_like( https://data-apis.org/array-api/latest/API_specification/generated/array_api.zeros_like.html """ - x # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 49") # noqa: EM101 + (impl,) = _unbox(x) + if isinstance(impl, ak.Array): + return _box(type(x), ak.zeros_like(impl), dtype=dtype, device=device) + else: + _, ns = device_namespace(x.device if device is None else device) + return _box(type(x), ns.zeros_like(impl), dtype=dtype, device=device) diff --git a/src/ragged/_spec_data_type_functions.py b/src/ragged/_spec_data_type_functions.py index 4106ef80dd..9c9858b5e0 100644 --- a/src/ragged/_spec_data_type_functions.py +++ b/src/ragged/_spec_data_type_functions.py @@ -10,7 +10,7 @@ import numpy as np -from ._spec_array_object import array +from ._spec_array_object import _box, _unbox, array from ._typing import Dtype _type = type @@ -23,11 +23,7 @@ def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array: Args: x: Array to cast. dtype: Desired data type. - copy: Specifies whether to copy an array when the specified `dtype` - matches the data type of the input array `x`. If `True`, a newly - allocated array is always returned. If `False` and the specified - `dtype` matches the data type of the input array, the input array - is returned; otherwise, a newly allocated array is returned. + copy: Ignored because `ragged.array` data buffers are immutable. Returns: An array having the specified data type. The returned array has the @@ -36,10 +32,9 @@ def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html """ - x # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - copy # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 50") # noqa: EM101 + copy # noqa: B018, argument is ignored, pylint: disable=W0104 + + return _box(type(x), *_unbox(x), dtype=dtype) def can_cast(from_: Dtype | array, to: Dtype, /) -> bool: diff --git a/src/ragged/_spec_manipulation_functions.py b/src/ragged/_spec_manipulation_functions.py index d53604c622..39657eb035 100644 --- a/src/ragged/_spec_manipulation_functions.py +++ b/src/ragged/_spec_manipulation_functions.py @@ -6,7 +6,9 @@ from __future__ import annotations -from ._spec_array_object import array +import awkward as ak + +from ._spec_array_object import _box, _unbox, array def broadcast_arrays(*arrays: array) -> list[array]: @@ -23,8 +25,14 @@ def broadcast_arrays(*arrays: array) -> list[array]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.broadcast_arrays.html """ - arrays # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 114") # noqa: EM101 + impls = _unbox(*arrays) + if all(not isinstance(x, ak.Array) for x in impls): + return [_box(type(arrays[i]), x) for i, x in enumerate(impls)] + else: + out = [x if isinstance(x, ak.Array) else x.reshape((1,)) for x in impls] # type: ignore[union-attr] + return [ + _box(type(arrays[i]), x) for i, x in enumerate(ak.broadcast_arrays(*out)) + ] def broadcast_to(x: array, /, shape: tuple[int, ...]) -> array: @@ -71,9 +79,27 @@ def concat( https://data-apis.org/array-api/latest/API_specification/generated/array_api.concat.html """ - arrays # noqa: B018, pylint: disable=W0104 - axis # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 116") # noqa: EM101 + if len(arrays) == 0: + msg = "need at least one array to concatenate" + raise ValueError(msg) + + first = arrays[0] + if not all(first.ndim == x.ndim for x in arrays[1:]): + msg = "all the input arrays must have the same number of dimensions" + raise ValueError(msg) + + if first.ndim == 0: + msg = "zero-dimensional arrays cannot be concatenated" + raise ValueError(msg) + + impls = _unbox(*arrays) + assert all(isinstance(x, ak.Array) for x in impls) + + if axis is None: + impls = [ak.ravel(x) for x in impls] # type: ignore[assignment] + axis = 0 + + return _box(type(first), ak.concatenate(impls, axis=axis)) def expand_dims(x: array, /, *, axis: int = 0) -> array: diff --git a/src/ragged/_spec_searching_functions.py b/src/ragged/_spec_searching_functions.py index 651850ceca..22aed41d3e 100644 --- a/src/ragged/_spec_searching_functions.py +++ b/src/ragged/_spec_searching_functions.py @@ -9,6 +9,7 @@ import awkward as ak import numpy as np +from ._import import device_namespace from ._spec_array_object import _box, _unbox, array @@ -123,8 +124,11 @@ def nonzero(x: array, /) -> tuple[array, ...]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.nonzero.html """ - x # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 126") # noqa: EM101 + (impl,) = _unbox(x) + if not isinstance(impl, ak.Array): + impl = ak.Array(impl.reshape((1,))) # type: ignore[union-attr] + + return tuple(_box(type(x), item) for item in ak.where(impl)) def where(condition: array, x1: array, x2: array, /) -> array: @@ -146,7 +150,20 @@ def where(condition: array, x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.where.html """ - condition # noqa: B018, pylint: disable=W0104 - x1 # noqa: B018, pylint: disable=W0104 - x2 # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 127") # noqa: EM101 + if condition.ndim == x1.ndim == x2.ndim == 0: + cond_impl, x1_impl, x2_impl = _unbox(condition, x1, x2) + _, ns = device_namespace(condition.device) + return _box(type(condition), ns.where(cond_impl, x1_impl, x2_impl)) + + else: + cond_impl, x1_impl, x2_impl = _unbox(condition, x1, x2) + if not isinstance(cond_impl, ak.Array): + cond_impl = ak.Array(cond_impl.reshape((1,))) # type: ignore[union-attr] + if not isinstance(x1_impl, ak.Array): + x1_impl = ak.Array(x1_impl.reshape((1,))) # type: ignore[union-attr] + if not isinstance(x2_impl, ak.Array): + x2_impl = ak.Array(x2_impl.reshape((1,))) # type: ignore[union-attr] + + cond_impl, x1_impl, x2_impl = ak.broadcast_arrays(cond_impl, x1_impl, x2_impl) + + return _box(type(condition), ak.where(cond_impl, x1_impl, x2_impl)) diff --git a/tests/conftest.py b/tests/conftest.py index 4dd6d6ea3d..0bebee7d2f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,8 @@ import reprlib +import awkward as ak +import numpy as np import pytest import ragged @@ -20,3 +22,60 @@ def repr1(self, x, level): reprlib.Repr.repr1_original = reprlib.Repr.repr1 # type: ignore[attr-defined] reprlib.Repr.repr1 = repr1 # type: ignore[method-assign] + + +@pytest.fixture(params=["regular", "irregular", "scalar"]) +def x(request): + if request.param == "regular": + return ragged.array(np.array([1.0, 2.0, 3.0])) + elif request.param == "irregular": + return ragged.array(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]])) + else: # request.param == "scalar" + return ragged.array(np.array(10.0)) + + +@pytest.fixture(params=["regular", "irregular", "scalar"]) +def x_lt1(request): + if request.param == "regular": + return ragged.array(np.array([0.1, 0.2, 0.3])) + elif request.param == "irregular": + return ragged.array(ak.Array([[0.1, 0.2, 0.3], [], [0.4, 0.5]])) + else: # request.param == "scalar" + return ragged.array(np.array(0.5)) + + +@pytest.fixture(params=["regular", "irregular", "scalar"]) +def x_bool(request): + if request.param == "regular": + return ragged.array(np.array([False, True, False])) + elif request.param == "irregular": + return ragged.array(ak.Array([[True, True, False], [], [False, False]])) + else: # request.param == "scalar" + return ragged.array(np.array(True)) + + +@pytest.fixture(params=["regular", "irregular", "scalar"]) +def x_int(request): + if request.param == "regular": + return ragged.array(np.array([0, 1, 2], dtype=np.int64)) + elif request.param == "irregular": + return ragged.array(ak.Array([[1, 2, 3], [], [4, 5]])) + else: # request.param == "scalar" + return ragged.array(np.array(10, dtype=np.int64)) + + +@pytest.fixture(params=["regular", "irregular", "scalar"]) +def x_complex(request): + if request.param == "regular": + return ragged.array(np.array([1 + 0.1j, 2 + 0.2j, 3 + 0.3j])) + elif request.param == "irregular": + return ragged.array(ak.Array([[1 + 0j, 2 + 0j, 3 + 0j], [], [4 + 0j, 5 + 0j]])) + else: # request.param == "scalar" + return ragged.array(np.array(10 + 1j)) + + +y = x +y_lt1 = x_lt1 +y_bool = x_bool +y_int = x_int +y_complex = x_complex diff --git a/tests/test_spec_creation_functions.py b/tests/test_spec_creation_functions.py index a1c08b67a4..02059baee6 100644 --- a/tests/test_spec_creation_functions.py +++ b/tests/test_spec_creation_functions.py @@ -63,6 +63,15 @@ def test_empty_ndim0(device): assert isinstance(a._impl, ns[device].ndarray) +@pytest.mark.parametrize("device", devices) +def test_empty_like(device): + a = ragged.array([[1, 2, 3], [], [4, 5]], device=device) + b = ragged.empty_like(a) + assert (b * 0).tolist() == [[0, 0, 0], [], [0, 0]] # type: ignore[comparison-overlap] + assert a.dtype == b.dtype + assert a.device == b.device == device + + @pytest.mark.parametrize("device", devices) def test_eye(device): a = ragged.eye(3, 5, k=1, device=device) @@ -97,6 +106,15 @@ def test_full_ndim0(device): assert isinstance(a._impl, ns[device].ndarray) +@pytest.mark.parametrize("device", devices) +def test_full_like(device): + a = ragged.array([[1, 2, 3], [], [4, 5]], device=device) + b = ragged.full_like(a, 5) + assert b.tolist() == [[5, 5, 5], [], [5, 5]] # type: ignore[comparison-overlap] + assert a.dtype == b.dtype + assert a.device == b.device == device + + @pytest.mark.parametrize("device", devices) def test_linspace(device): a = ragged.linspace(5, 8, 5, device=device) @@ -120,6 +138,15 @@ def test_ones_ndim0(device): assert isinstance(a._impl, ns[device].ndarray) +@pytest.mark.parametrize("device", devices) +def test_ones_like(device): + a = ragged.array([[1, 2, 3], [], [4, 5]], device=device) + b = ragged.ones_like(a) + assert b.tolist() == [[1, 1, 1], [], [1, 1]] # type: ignore[comparison-overlap] + assert a.dtype == b.dtype + assert a.device == b.device == device + + @pytest.mark.parametrize("device", devices) def test_zeros(device): a = ragged.zeros(5, device=device) @@ -134,3 +161,12 @@ def test_zeros_ndim0(device): assert a.shape == () assert a == 0 assert isinstance(a._impl, ns[device].ndarray) + + +@pytest.mark.parametrize("device", devices) +def test_zeros_like(device): + a = ragged.array([[1, 2, 3], [], [4, 5]], device=device) + b = ragged.zeros_like(a) + assert b.tolist() == [[0, 0, 0], [], [0, 0]] # type: ignore[comparison-overlap] + assert a.dtype == b.dtype + assert a.device == b.device == device diff --git a/tests/test_spec_data_type_functions.py b/tests/test_spec_data_type_functions.py index 526181f9d9..53ac14d90d 100644 --- a/tests/test_spec_data_type_functions.py +++ b/tests/test_spec_data_type_functions.py @@ -6,10 +6,27 @@ from __future__ import annotations +from typing import Any + +import awkward as ak import numpy as np +import pytest import ragged +devices = ["cpu"] +try: + import cupy as cp + + devices.append("cuda") +except ModuleNotFoundError: + cp = None + + +def first(x: ragged.array) -> Any: + out = ak.flatten(x._impl, axis=None)[0] if x.shape != () else x._impl + return np.asarray(out.item(), dtype=x.dtype) + def test_existence(): assert ragged.astype is not None @@ -20,6 +37,16 @@ def test_existence(): assert ragged.result_type is not None +@pytest.mark.parametrize("device", devices) +@pytest.mark.parametrize("dt", ["float64", np.float64, np.dtype(np.float64)]) +def test_astype(device, x_int, dt): + x = x_int.to_device(device) + y = ragged.astype(x, dt) + assert first(y) == first(x) + assert y.dtype == np.dtype(np.float64) + assert y.device == x.device + + def test_can_cast(): assert ragged.can_cast(np.float32, np.complex128) assert not ragged.can_cast(np.complex128, np.float32) diff --git a/tests/test_spec_elementwise_functions.py b/tests/test_spec_elementwise_functions.py index 383877866c..fdac4a92c2 100644 --- a/tests/test_spec_elementwise_functions.py +++ b/tests/test_spec_elementwise_functions.py @@ -29,63 +29,6 @@ cp = None -@pytest.fixture(params=["regular", "irregular", "scalar"]) -def x(request): - if request.param == "regular": - return ragged.array(np.array([1.0, 2.0, 3.0])) - elif request.param == "irregular": - return ragged.array(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]])) - else: # request.param == "scalar" - return ragged.array(np.array(10.0)) - - -@pytest.fixture(params=["regular", "irregular", "scalar"]) -def x_lt1(request): - if request.param == "regular": - return ragged.array(np.array([0.1, 0.2, 0.3])) - elif request.param == "irregular": - return ragged.array(ak.Array([[0.1, 0.2, 0.3], [], [0.4, 0.5]])) - else: # request.param == "scalar" - return ragged.array(np.array(0.5)) - - -@pytest.fixture(params=["regular", "irregular", "scalar"]) -def x_bool(request): - if request.param == "regular": - return ragged.array(np.array([False, True, False])) - elif request.param == "irregular": - return ragged.array(ak.Array([[True, True, False], [], [False, False]])) - else: # request.param == "scalar" - return ragged.array(np.array(True)) - - -@pytest.fixture(params=["regular", "irregular", "scalar"]) -def x_int(request): - if request.param == "regular": - return ragged.array(np.array([0, 1, 2], dtype=np.int64)) - elif request.param == "irregular": - return ragged.array(ak.Array([[1, 2, 3], [], [4, 5]])) - else: # request.param == "scalar" - return ragged.array(np.array(10, dtype=np.int64)) - - -@pytest.fixture(params=["regular", "irregular", "scalar"]) -def x_complex(request): - if request.param == "regular": - return ragged.array(np.array([1 + 0.1j, 2 + 0.2j, 3 + 0.3j])) - elif request.param == "irregular": - return ragged.array(ak.Array([[1 + 0j, 2 + 0j, 3 + 0j], [], [4 + 0j, 5 + 0j]])) - else: # request.param == "scalar" - return ragged.array(np.array(10 + 1j)) - - -y = x -y_lt1 = x_lt1 -y_bool = x_bool -y_int = x_int -y_complex = x_complex - - def first(x: ragged.array) -> Any: out = ak.flatten(x._impl, axis=None)[0] if x.shape != () else x._impl return xp.asarray(out.item(), dtype=x.dtype) diff --git a/tests/test_spec_manipulation_functions.py b/tests/test_spec_manipulation_functions.py index 20b5ff20c5..2487b40d94 100644 --- a/tests/test_spec_manipulation_functions.py +++ b/tests/test_spec_manipulation_functions.py @@ -6,8 +6,18 @@ from __future__ import annotations +import pytest + import ragged +devices = ["cpu"] +try: + import cupy as cp + + devices.append("cuda") +except ModuleNotFoundError: + cp = None + def test_existence(): assert ragged.broadcast_arrays is not None @@ -20,3 +30,58 @@ def test_existence(): assert ragged.roll is not None assert ragged.squeeze is not None assert ragged.stack is not None + + +@pytest.mark.parametrize("device", devices) +def test_broadcast_arrays(device, x, y): + x_bc, y_bc = ragged.broadcast_arrays(x.to_device(device), y.to_device(device)) + if x.shape == () and y.shape == (): + assert x_bc.shape == () + assert y_bc.shape == () + else: + assert x_bc.shape == y_bc.shape + if x_bc.shape == (3,): + assert (x_bc * 0).tolist() == (y_bc * 0).tolist() == [0, 0, 0] + if x_bc.shape == (3, None): + assert (x_bc * 0).tolist() == (y_bc * 0).tolist() == [[0, 0, 0], [], [0, 0]] # type: ignore[comparison-overlap] + + +def test_concat(x, y): + if x.ndim != y.ndim: + with pytest.raises(ValueError, match="same number of dimensions"): + ragged.concat([x, y]) + + elif x.ndim == 0: + with pytest.raises(ValueError, match="zero-dimensional"): + ragged.concat([x, y]) + + elif x.ndim == 1: + assert ragged.concat([x, y], axis=None).tolist() == x.tolist() + y.tolist() + assert ragged.concat([x, y], axis=0).tolist() == x.tolist() + y.tolist() + + else: + assert ragged.concat([x, y], axis=None).tolist() == [ + 1.1, + 2.2, + 3.3, + 4.4, + 5.5, + 1.1, + 2.2, + 3.3, + 4.4, + 5.5, + ] + assert ragged.concat([x, y], axis=0).tolist() == [ # type: ignore[comparison-overlap] + [1.1, 2.2, 3.3], + [], + [4.4, 5.5], + [1.1, 2.2, 3.3], + [], + [4.4, 5.5], + ] + assert ragged.concat([x, y], axis=1).tolist() == [ # type: ignore[comparison-overlap] + [1.1, 2.2, 3.3, 1.1, 2.2, 3.3], + [], + [4.4, 5.5, 4.4, 5.5], + ] diff --git a/tests/test_spec_searching_functions.py b/tests/test_spec_searching_functions.py index b9ecf0d761..9f6e1ae7c3 100644 --- a/tests/test_spec_searching_functions.py +++ b/tests/test_spec_searching_functions.py @@ -6,11 +6,20 @@ from __future__ import annotations +from typing import Any + +import awkward as ak +import numpy as np import pytest import ragged +def first(x: ragged.array) -> Any: + out = ak.flatten(x._impl, axis=None)[0] if x.shape != () else x._impl + return np.asarray(out.item(), dtype=x.dtype) + + def test_existence(): assert ragged.argmax is not None assert ragged.argmin is not None @@ -58,3 +67,34 @@ def test_argmin(): ragged.argmin(data, axis=2) with pytest.raises(ValueError, match=".*axis.*"): ragged.argmin(data, axis=-1) + + +def test_nonzero(): + (result,) = ragged.nonzero(ragged.array(0)) + assert result.tolist() == [] + + (result,) = ragged.nonzero(ragged.array(123)) + assert result.tolist() == [0] + + (result,) = ragged.nonzero(ragged.array([0])) + assert result.tolist() == [] + + (result,) = ragged.nonzero(ragged.array([123])) + assert result.tolist() == [0] + + (result,) = ragged.nonzero(ragged.array([111, 222, 0, 333])) + assert result.tolist() == [0, 1, 3] + + result1, result2 = ragged.nonzero(ragged.array([[111, 222, 0], [333, 0, 444]])) + assert result1.tolist() == [0, 0, 1, 1] + assert result2.tolist() == [0, 1, 0, 2] + + +def test_where(x_bool, x, y): + z = ragged.where(x_bool, x, y) + if x_bool.ndim == x.ndim == y.ndim == 0: + assert z.ndim == 0 + assert z == x if x_bool else y + else: + assert z.ndim == max(x_bool.ndim, x.ndim, y.ndim) + assert first(z) == first(x) if first(x_bool) else first(y)