Skip to content

Commit

Permalink
feat: add all direct from NumPy/CuPy functions (#10)
Browse files Browse the repository at this point in the history
* feat: add all direct from NumPy/CuPy functions

* empty, eye, from_dlpack

* full, linspace, ones, zeros

* Also test cases that create scalars.

* can_cast, finfo, iinfo

* result_type

* unnecessary mypy ignore

* test the minimal NumPy version

* fix YAML

* fix YAML 2

* fix YAML names and PIP_ONLY_BINARY

* fix YAML 3

* fix YAML 4

* Old Ubuntu image should have old NumPy, I hope.

* Take a lower minimum Python (3.8) to get NumPy 1.18.0.

* Don't subscript np.dtype.

* Revert "Don't subscript np.dtype."

This reverts commit 1fad6020178fa2c1d33af8a32af4a2fe505ffa5a.

* Revert "Take a lower minimum Python (3.8) to get NumPy 1.18.0."

This reverts commit efe63ac3f1b49392b61f8b46dfd59e57df4b8cb3.

* Nope. Instead, increase the minimum NumPy to 1.19.3.

* Don't subscript np.dtype.

* pre-commit

* Create a fake numpy.array_api for tests of np 1.19.3.

* NumPy started experimental support for array_api with 1.22.0.

* Original NumPy Array API support did not include complex.
  • Loading branch information
jpivarski authored Jan 9, 2024
1 parent 19ddd26 commit 96bb313
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 50 deletions.
18 changes: 16 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,24 @@ jobs:
pipx run nox -s pylint
checks:
name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }}
name:
"py:${{ matrix.python-version }} np:${{ matrix.numpy-version }} os:${{
matrix.runs-on }}"
runs-on: ${{ matrix.runs-on }}
needs: [pre-commit]
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.12"]
numpy-version: ["latest"]
runs-on: [ubuntu-latest, macos-latest, windows-latest]

include:
- python-version: pypy-3.10
- python-version: "pypy-3.10"
numpy-version: "latest"
runs-on: ubuntu-latest
- python-version: "3.9"
numpy-version: "1.22.0"
runs-on: ubuntu-latest

steps:
Expand All @@ -57,9 +64,16 @@ jobs:
python-version: ${{ matrix.python-version }}
allow-prereleases: true

- name: Install old NumPy
if: matrix.numpy-version != 'latest'
run: python -m pip install numpy==${{ matrix.numpy-version }}

- name: Install package
run: python -m pip install .[test]

- name: Print NumPy version
run: python -c 'import numpy as np; print(np.__version__)'

- name: Test package
run: >-
python -m pytest -ra --cov --cov-report=xml --cov-report=term
Expand Down
14 changes: 14 additions & 0 deletions src/ragged/_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@

from typing import Any

import numpy as np

from ._typing import Device


def device_namespace(device: None | Device = None) -> tuple[Device, Any]:
if device is None or device == "cpu":
return "cpu", np
elif device == "cuda":
return "cuda", cupy()

msg = f"unrecognized device: {device!r}" # type: ignore[unreachable]
raise ValueError(msg)


def cupy() -> Any:
try:
Expand Down
2 changes: 2 additions & 0 deletions src/ragged/_spec_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,5 +1144,7 @@ def _box(
else:
dtype = dtype_observed
device = "cpu" if isinstance(output, np.ndarray) else "cuda"
if shape != ():
impl = ak.Array(impl)

return cls._new(impl, shape, dtype, device) # pylint: disable=W0212
76 changes: 37 additions & 39 deletions src/ragged/_spec_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@

from __future__ import annotations

import enum

import awkward as ak
import numpy as np

from ._spec_array_object import array
from . import _import
from ._import import device_namespace
from ._spec_array_object import _box, array
from ._typing import (
Device,
Dtype,
Expand Down Expand Up @@ -53,12 +58,8 @@ def arange(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.arange.html
"""

start # noqa: B018, pylint: disable=W0104
stop # noqa: B018, pylint: disable=W0104
step # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 35") # noqa: EM101
device, ns = device_namespace(device)
return _box(array, ns.arange(start, stop, step, dtype=dtype))


def asarray(
Expand Down Expand Up @@ -137,10 +138,8 @@ def empty(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.empty.html
"""

shape # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 36") # noqa: EM101
device, ns = device_namespace(device)
return _box(array, ns.empty(shape, dtype=dtype))


def empty_like(
Expand Down Expand Up @@ -197,12 +196,8 @@ def eye(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.eye.html
"""

n_rows # noqa: B018, pylint: disable=W0104
n_cols # noqa: B018, pylint: disable=W0104
k # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 38") # noqa: EM101
device, ns = device_namespace(device)
return _box(array, ns.eye(n_rows, n_cols, k, dtype=dtype))


def from_dlpack(x: object, /) -> array:
Expand All @@ -218,8 +213,21 @@ def from_dlpack(x: object, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.from_dlpack.html
"""

x # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 39") # noqa: EM101
device_type, _ = x.__dlpack_device__() # type: ignore[attr-defined]
if (
isinstance(device_type, enum.Enum) and device_type.value == 1
) or device_type == 1:
y = np.from_dlpack(x)
elif (
isinstance(device_type, enum.Enum) and device_type.value == 2
) or device_type == 2:
cp = _import.cupy()
y = cp.from_dlpack(x)
else:
msg = f"unsupported __dlpack_device__ type: {device_type}"
raise TypeError(msg)

return _box(array, y)


def full(
Expand Down Expand Up @@ -254,11 +262,8 @@ def full(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.full.html
"""

shape # noqa: B018, pylint: disable=W0104
fill_value # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 40") # noqa: EM101
device, ns = device_namespace(device)
return _box(array, ns.full(shape, fill_value, dtype=dtype))


def full_like(
Expand Down Expand Up @@ -344,13 +349,10 @@ def linspace(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.linspace.html
"""

start # noqa: B018, pylint: disable=W0104
stop # noqa: B018, pylint: disable=W0104
num # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
endpoint # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 42") # noqa: EM101
device, ns = device_namespace(device)
return _box(
array, ns.linspace(start, stop, num=num, endpoint=endpoint, dtype=dtype)
)


def meshgrid(*arrays: array, indexing: str = "xy") -> list[array]:
Expand Down Expand Up @@ -415,10 +417,8 @@ def ones(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.ones.html
"""

shape # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 44") # noqa: EM101
device, ns = device_namespace(device)
return _box(array, ns.ones(shape, dtype=dtype))


def ones_like(
Expand Down Expand Up @@ -518,10 +518,8 @@ def zeros(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.zeros.html
"""

shape # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 48") # noqa: EM101
device, ns = device_namespace(device)
return _box(array, ns.zeros(shape, dtype=dtype))


def zeros_like(
Expand Down
31 changes: 22 additions & 9 deletions src/ragged/_spec_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from ._spec_array_object import array
from ._typing import Dtype

_type = type


def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
"""
Expand Down Expand Up @@ -56,9 +58,7 @@ def can_cast(from_: Dtype | array, to: Dtype, /) -> bool:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.can_cast.html
"""

from_ # noqa: B018, pylint: disable=W0104
to # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 51") # noqa: EM101
return bool(np.can_cast(from_, to))


@dataclass
Expand Down Expand Up @@ -114,8 +114,16 @@ def finfo(type: Dtype | array, /) -> finfo_object: # pylint: disable=W0622
https://data-apis.org/array-api/latest/API_specification/generated/array_api.finfo.html
"""

type # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 52") # noqa: EM101
if not isinstance(type, np.dtype):
if not isinstance(type, _type) and hasattr(type, "dtype"):
out = np.finfo(type.dtype)
else:
out = np.finfo(np.dtype(type))
else:
out = np.finfo(type)
return finfo_object(
out.bits, out.eps, out.max, out.min, out.smallest_normal, out.dtype
)


@dataclass
Expand Down Expand Up @@ -155,8 +163,14 @@ def iinfo(type: Dtype | array, /) -> iinfo_object: # pylint: disable=W0622
https://data-apis.org/array-api/latest/API_specification/generated/array_api.iinfo.html
"""

type # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 53") # noqa: EM101
if not isinstance(type, np.dtype):
if not isinstance(type, _type) and hasattr(type, "dtype"):
out = np.iinfo(type.dtype)
else:
out = np.iinfo(np.dtype(type))
else:
out = np.iinfo(type)
return iinfo_object(out.bits, out.max, out.min, out.dtype)


def isdtype(dtype: Dtype, kind: Dtype | str | tuple[Dtype | str, ...]) -> bool:
Expand Down Expand Up @@ -218,5 +232,4 @@ def result_type(*arrays_and_dtypes: array | Dtype) -> Dtype:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.result_type.html
"""

arrays_and_dtypes # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 55") # noqa: EM101
return np.result_type(*arrays_and_dtypes)
108 changes: 108 additions & 0 deletions tests/test_spec_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,21 @@

from __future__ import annotations

import numpy as np
import pytest

import ragged

devices = ["cpu"]
ns = {"cpu": np}
try:
import cupy as cp

devices.append("cuda")
ns["cuda"] = cp
except ModuleNotFoundError:
cp = None


def test_existence():
assert ragged.arange is not None
Expand All @@ -26,3 +39,98 @@ def test_existence():
assert ragged.triu is not None
assert ragged.zeros is not None
assert ragged.zeros_like is not None


@pytest.mark.parametrize("device", devices)
def test_arange(device):
a = ragged.arange(5, 10, 2, device=device)
assert a.tolist() == [5, 7, 9]
assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr]


@pytest.mark.parametrize("device", devices)
def test_empty(device):
a = ragged.empty((2, 3, 5), device=device)
assert a.shape == (2, 3, 5)
assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr]


@pytest.mark.parametrize("device", devices)
def test_empty_ndim0(device):
a = ragged.empty((), device=device)
assert a.ndim == 0
assert a.shape == ()
assert isinstance(a._impl, ns[device].ndarray)


@pytest.mark.parametrize("device", devices)
def test_eye(device):
a = ragged.eye(3, 5, k=1, device=device)
assert a.tolist() == [[0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0]]
assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr]


@pytest.mark.skipif(
not hasattr(np, "from_dlpack"), reason=f"np.from_dlpack not in {np.__version__}"
)
@pytest.mark.parametrize("device", devices)
def test_from_dlpack(device):
a = ns[device].array([1, 2, 3, 4, 5])
b = ragged.from_dlpack(a)
assert b.tolist() == [1, 2, 3, 4, 5]
assert isinstance(b._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr]


@pytest.mark.parametrize("device", devices)
def test_full(device):
a = ragged.full(5, 3, device=device)
assert a.tolist() == [3, 3, 3, 3, 3]
assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr]


@pytest.mark.parametrize("device", devices)
def test_full_ndim0(device):
a = ragged.full((), 3, device=device)
assert a.ndim == 0
assert a.shape == ()
assert a == 3
assert isinstance(a._impl, ns[device].ndarray)


@pytest.mark.parametrize("device", devices)
def test_linspace(device):
a = ragged.linspace(5, 8, 5, device=device)
assert a.tolist() == [5, 5.75, 6.5, 7.25, 8]
assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr]


@pytest.mark.parametrize("device", devices)
def test_ones(device):
a = ragged.ones(5, device=device)
assert a.tolist() == [1, 1, 1, 1, 1]
assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr]


@pytest.mark.parametrize("device", devices)
def test_ones_ndim0(device):
a = ragged.ones((), device=device)
assert a.ndim == 0
assert a.shape == ()
assert a == 1
assert isinstance(a._impl, ns[device].ndarray)


@pytest.mark.parametrize("device", devices)
def test_zeros(device):
a = ragged.zeros(5, device=device)
assert a.tolist() == [0, 0, 0, 0, 0]
assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr]


@pytest.mark.parametrize("device", devices)
def test_zeros_ndim0(device):
a = ragged.zeros((), device=device)
assert a.ndim == 0
assert a.shape == ()
assert a == 0
assert isinstance(a._impl, ns[device].ndarray)
Loading

0 comments on commit 96bb313

Please sign in to comment.