Skip to content

Commit

Permalink
feat: add all direct from Awkward functions (#11)
Browse files Browse the repository at this point in the history
* feat: add all direct from Awkward functions

* empty_like

* full_like, ones_like, zeros_like

* Also check against original 'device'.

* broadcast_arrays

* concat

* where

* nonzero
  • Loading branch information
jpivarski authored Jan 10, 2024
1 parent 96bb313 commit 1890589
Show file tree
Hide file tree
Showing 11 changed files with 368 additions and 100 deletions.
59 changes: 56 additions & 3 deletions src/ragged/_spec_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,14 @@ def __init__(
elif isinstance(self._impl, np.ndarray) and device == "cuda":
cp = _import.cupy()
self._impl = cp.array(self._impl)
self._device = device
else:
if isinstance(self._impl, ak.Array):
self._device = ak.backend(self._impl)
elif isinstance(self._impl, np.ndarray):
self._device = "cpu"
else:
self._device = "cuda"

if copy is not None:
raise NotImplementedError("TODO 1") # noqa: EM101
Expand Down Expand Up @@ -1101,6 +1109,32 @@ def __irshift__(self, other: int | array, /) -> array:
__rrshift__ = __rshift__


def _is_shared(
x1: array | ak.Array | SupportsDLPack, x2: array | ak.Array | SupportsDLPack
) -> bool:
x1_buf = x1._impl if isinstance(x1, array) else x1 # pylint: disable=W0212
x2_buf = x2._impl if isinstance(x2, array) else x2 # pylint: disable=W0212

if isinstance(x1_buf, ak.Array):
x1_buf = x1_buf.layout
while not isinstance(x1_buf, NumpyArray):
x1_buf = x1_buf.content
x1_buf = x1_buf.data

if isinstance(x2_buf, ak.Array):
x2_buf = x2_buf.layout
while not isinstance(x2_buf, NumpyArray):
x2_buf = x2_buf.content
x2_buf = x2_buf.data

while x1_buf.base is not None: # type: ignore[union-attr]
x1_buf = x1_buf.base # type: ignore[union-attr]
while x2_buf.base is not None: # type: ignore[union-attr]
x2_buf = x2_buf.base # type: ignore[union-attr]

return x1_buf is x2_buf


def _unbox(*inputs: array) -> tuple[ak.Array | SupportsDLPack, ...]:
if len(inputs) > 1 and any(type(inputs[0]) is not type(x) for x in inputs):
types = "\n".join(f"{type(x).__module__}.{type(x).__name__}" for x in inputs)
Expand All @@ -1115,6 +1149,7 @@ def _box(
output: ak.Array | np.number | SupportsDLPack,
*,
dtype: None | Dtype = None,
device: None | Device = None,
) -> array:
if isinstance(output, ak.Array):
impl = output
Expand All @@ -1123,7 +1158,11 @@ def _box(
impl = ak.values_astype(impl, dtype)
else:
dtype = dtype_observed
device = ak.backend(output)
device_observed = ak.backend(output)
if device is None:
device = device_observed
elif device != device_observed:
output = ak.to_backend(output, device)

elif isinstance(output, np.number):
impl = np.array(output)
Expand All @@ -1133,7 +1172,12 @@ def _box(
impl = impl.astype(dtype)
else:
dtype = dtype_observed
device = "cpu"
device_observed = "cpu"
if device is None:
device = device_observed
elif device != device_observed:
cp = _import.cupy()
output = cp.array(output)

else:
impl = output
Expand All @@ -1143,7 +1187,16 @@ def _box(
impl = impl.astype(dtype)
else:
dtype = dtype_observed
device = "cpu" if isinstance(output, np.ndarray) else "cuda"
device_observed = "cpu" if isinstance(output, np.ndarray) else "cuda"
if device is None:
device = device_observed
elif device != device_observed:
if device == "cpu":
output = np.array(output)
else:
cp = _import.cupy()
output = cp.array(output)

if shape != ():
impl = ak.Array(impl)

Expand Down
43 changes: 25 additions & 18 deletions src/ragged/_spec_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from . import _import
from ._import import device_namespace
from ._spec_array_object import _box, array
from ._spec_array_object import _box, _unbox, array
from ._typing import (
Device,
Dtype,
Expand Down Expand Up @@ -161,10 +161,12 @@ def empty_like(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.empty_like.html
"""

x # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 37") # noqa: EM101
(impl,) = _unbox(x)
if isinstance(impl, ak.Array):
return _box(type(x), ak.zeros_like(impl), dtype=dtype, device=device)
else:
_, ns = device_namespace(x.device if device is None else device)
return _box(type(x), ns.empty_like(impl), dtype=dtype, device=device)


def eye(
Expand Down Expand Up @@ -292,11 +294,12 @@ def full_like(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.full_like.html
"""

x # noqa: B018, pylint: disable=W0104
fill_value # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 41") # noqa: EM101
(impl,) = _unbox(x)
if isinstance(impl, ak.Array):
return _box(type(x), ak.full_like(impl, fill_value), dtype=dtype, device=device)
else:
_, ns = device_namespace(x.device if device is None else device)
return _box(type(x), ns.full_like(impl, fill_value), dtype=dtype, device=device)


def linspace(
Expand Down Expand Up @@ -441,10 +444,12 @@ def ones_like(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.ones_like.html
"""

x # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 45") # noqa: EM101
(impl,) = _unbox(x)
if isinstance(impl, ak.Array):
return _box(type(x), ak.ones_like(impl), dtype=dtype, device=device)
else:
_, ns = device_namespace(x.device if device is None else device)
return _box(type(x), ns.ones_like(impl), dtype=dtype, device=device)


def tril(x: array, /, *, k: int = 0) -> array:
Expand Down Expand Up @@ -542,7 +547,9 @@ def zeros_like(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.zeros_like.html
"""

x # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 49") # noqa: EM101
(impl,) = _unbox(x)
if isinstance(impl, ak.Array):
return _box(type(x), ak.zeros_like(impl), dtype=dtype, device=device)
else:
_, ns = device_namespace(x.device if device is None else device)
return _box(type(x), ns.zeros_like(impl), dtype=dtype, device=device)
15 changes: 5 additions & 10 deletions src/ragged/_spec_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np

from ._spec_array_object import array
from ._spec_array_object import _box, _unbox, array
from ._typing import Dtype

_type = type
Expand All @@ -23,11 +23,7 @@ def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
Args:
x: Array to cast.
dtype: Desired data type.
copy: Specifies whether to copy an array when the specified `dtype`
matches the data type of the input array `x`. If `True`, a newly
allocated array is always returned. If `False` and the specified
`dtype` matches the data type of the input array, the input array
is returned; otherwise, a newly allocated array is returned.
copy: Ignored because `ragged.array` data buffers are immutable.
Returns:
An array having the specified data type. The returned array has the
Expand All @@ -36,10 +32,9 @@ def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html
"""

x # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
copy # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 50") # noqa: EM101
copy # noqa: B018, argument is ignored, pylint: disable=W0104

return _box(type(x), *_unbox(x), dtype=dtype)


def can_cast(from_: Dtype | array, to: Dtype, /) -> bool:
Expand Down
38 changes: 32 additions & 6 deletions src/ragged/_spec_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from __future__ import annotations

from ._spec_array_object import array
import awkward as ak

from ._spec_array_object import _box, _unbox, array


def broadcast_arrays(*arrays: array) -> list[array]:
Expand All @@ -23,8 +25,14 @@ def broadcast_arrays(*arrays: array) -> list[array]:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.broadcast_arrays.html
"""

arrays # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 114") # noqa: EM101
impls = _unbox(*arrays)
if all(not isinstance(x, ak.Array) for x in impls):
return [_box(type(arrays[i]), x) for i, x in enumerate(impls)]
else:
out = [x if isinstance(x, ak.Array) else x.reshape((1,)) for x in impls] # type: ignore[union-attr]
return [
_box(type(arrays[i]), x) for i, x in enumerate(ak.broadcast_arrays(*out))
]


def broadcast_to(x: array, /, shape: tuple[int, ...]) -> array:
Expand Down Expand Up @@ -71,9 +79,27 @@ def concat(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.concat.html
"""

arrays # noqa: B018, pylint: disable=W0104
axis # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 116") # noqa: EM101
if len(arrays) == 0:
msg = "need at least one array to concatenate"
raise ValueError(msg)

first = arrays[0]
if not all(first.ndim == x.ndim for x in arrays[1:]):
msg = "all the input arrays must have the same number of dimensions"
raise ValueError(msg)

if first.ndim == 0:
msg = "zero-dimensional arrays cannot be concatenated"
raise ValueError(msg)

impls = _unbox(*arrays)
assert all(isinstance(x, ak.Array) for x in impls)

if axis is None:
impls = [ak.ravel(x) for x in impls] # type: ignore[assignment]
axis = 0

return _box(type(first), ak.concatenate(impls, axis=axis))


def expand_dims(x: array, /, *, axis: int = 0) -> array:
Expand Down
29 changes: 23 additions & 6 deletions src/ragged/_spec_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import awkward as ak
import numpy as np

from ._import import device_namespace
from ._spec_array_object import _box, _unbox, array


Expand Down Expand Up @@ -123,8 +124,11 @@ def nonzero(x: array, /) -> tuple[array, ...]:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.nonzero.html
"""

x # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 126") # noqa: EM101
(impl,) = _unbox(x)
if not isinstance(impl, ak.Array):
impl = ak.Array(impl.reshape((1,))) # type: ignore[union-attr]

return tuple(_box(type(x), item) for item in ak.where(impl))


def where(condition: array, x1: array, x2: array, /) -> array:
Expand All @@ -146,7 +150,20 @@ def where(condition: array, x1: array, x2: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.where.html
"""

condition # noqa: B018, pylint: disable=W0104
x1 # noqa: B018, pylint: disable=W0104
x2 # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 127") # noqa: EM101
if condition.ndim == x1.ndim == x2.ndim == 0:
cond_impl, x1_impl, x2_impl = _unbox(condition, x1, x2)
_, ns = device_namespace(condition.device)
return _box(type(condition), ns.where(cond_impl, x1_impl, x2_impl))

else:
cond_impl, x1_impl, x2_impl = _unbox(condition, x1, x2)
if not isinstance(cond_impl, ak.Array):
cond_impl = ak.Array(cond_impl.reshape((1,))) # type: ignore[union-attr]
if not isinstance(x1_impl, ak.Array):
x1_impl = ak.Array(x1_impl.reshape((1,))) # type: ignore[union-attr]
if not isinstance(x2_impl, ak.Array):
x2_impl = ak.Array(x2_impl.reshape((1,))) # type: ignore[union-attr]

cond_impl, x1_impl, x2_impl = ak.broadcast_arrays(cond_impl, x1_impl, x2_impl)

return _box(type(condition), ak.where(cond_impl, x1_impl, x2_impl))
59 changes: 59 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import reprlib

import awkward as ak
import numpy as np
import pytest

import ragged
Expand All @@ -20,3 +22,60 @@ def repr1(self, x, level):

reprlib.Repr.repr1_original = reprlib.Repr.repr1 # type: ignore[attr-defined]
reprlib.Repr.repr1 = repr1 # type: ignore[method-assign]


@pytest.fixture(params=["regular", "irregular", "scalar"])
def x(request):
if request.param == "regular":
return ragged.array(np.array([1.0, 2.0, 3.0]))
elif request.param == "irregular":
return ragged.array(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]]))
else: # request.param == "scalar"
return ragged.array(np.array(10.0))


@pytest.fixture(params=["regular", "irregular", "scalar"])
def x_lt1(request):
if request.param == "regular":
return ragged.array(np.array([0.1, 0.2, 0.3]))
elif request.param == "irregular":
return ragged.array(ak.Array([[0.1, 0.2, 0.3], [], [0.4, 0.5]]))
else: # request.param == "scalar"
return ragged.array(np.array(0.5))


@pytest.fixture(params=["regular", "irregular", "scalar"])
def x_bool(request):
if request.param == "regular":
return ragged.array(np.array([False, True, False]))
elif request.param == "irregular":
return ragged.array(ak.Array([[True, True, False], [], [False, False]]))
else: # request.param == "scalar"
return ragged.array(np.array(True))


@pytest.fixture(params=["regular", "irregular", "scalar"])
def x_int(request):
if request.param == "regular":
return ragged.array(np.array([0, 1, 2], dtype=np.int64))
elif request.param == "irregular":
return ragged.array(ak.Array([[1, 2, 3], [], [4, 5]]))
else: # request.param == "scalar"
return ragged.array(np.array(10, dtype=np.int64))


@pytest.fixture(params=["regular", "irregular", "scalar"])
def x_complex(request):
if request.param == "regular":
return ragged.array(np.array([1 + 0.1j, 2 + 0.2j, 3 + 0.3j]))
elif request.param == "irregular":
return ragged.array(ak.Array([[1 + 0j, 2 + 0j, 3 + 0j], [], [4 + 0j, 5 + 0j]]))
else: # request.param == "scalar"
return ragged.array(np.array(10 + 1j))


y = x
y_lt1 = x_lt1
y_bool = x_bool
y_int = x_int
y_complex = x_complex
Loading

0 comments on commit 1890589

Please sign in to comment.