diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1fc948f77f..2636bcdd57 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,17 +34,24 @@ jobs: pipx run nox -s pylint checks: - name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }} + name: + "py:${{ matrix.python-version }} np:${{ matrix.numpy-version }} os:${{ + matrix.runs-on }}" runs-on: ${{ matrix.runs-on }} needs: [pre-commit] strategy: fail-fast: false matrix: python-version: ["3.9", "3.12"] + numpy-version: ["latest"] runs-on: [ubuntu-latest, macos-latest, windows-latest] include: - - python-version: pypy-3.10 + - python-version: "pypy-3.10" + numpy-version: "latest" + runs-on: ubuntu-latest + - python-version: "3.9" + numpy-version: "1.22.0" runs-on: ubuntu-latest steps: @@ -57,9 +64,16 @@ jobs: python-version: ${{ matrix.python-version }} allow-prereleases: true + - name: Install old NumPy + if: matrix.numpy-version != 'latest' + run: python -m pip install numpy==${{ matrix.numpy-version }} + - name: Install package run: python -m pip install .[test] + - name: Print NumPy version + run: python -c 'import numpy as np; print(np.__version__)' + - name: Test package run: >- python -m pytest -ra --cov --cov-report=xml --cov-report=term diff --git a/src/ragged/_import.py b/src/ragged/_import.py index 05c73ef3c4..9c568278f6 100644 --- a/src/ragged/_import.py +++ b/src/ragged/_import.py @@ -4,6 +4,20 @@ from typing import Any +import numpy as np + +from ._typing import Device + + +def device_namespace(device: None | Device = None) -> tuple[Device, Any]: + if device is None or device == "cpu": + return "cpu", np + elif device == "cuda": + return "cuda", cupy() + + msg = f"unrecognized device: {device!r}" # type: ignore[unreachable] + raise ValueError(msg) + def cupy() -> Any: try: diff --git a/src/ragged/_spec_array_object.py b/src/ragged/_spec_array_object.py index fd1966077a..b300fc5bc1 100644 --- a/src/ragged/_spec_array_object.py +++ b/src/ragged/_spec_array_object.py @@ -1144,5 +1144,7 @@ def _box( else: dtype = dtype_observed device = "cpu" if isinstance(output, np.ndarray) else "cuda" + if shape != (): + impl = ak.Array(impl) return cls._new(impl, shape, dtype, device) # pylint: disable=W0212 diff --git a/src/ragged/_spec_creation_functions.py b/src/ragged/_spec_creation_functions.py index c6895c01d6..189a787abf 100644 --- a/src/ragged/_spec_creation_functions.py +++ b/src/ragged/_spec_creation_functions.py @@ -6,9 +6,14 @@ from __future__ import annotations +import enum + import awkward as ak +import numpy as np -from ._spec_array_object import array +from . import _import +from ._import import device_namespace +from ._spec_array_object import _box, array from ._typing import ( Device, Dtype, @@ -53,12 +58,8 @@ def arange( https://data-apis.org/array-api/latest/API_specification/generated/array_api.arange.html """ - start # noqa: B018, pylint: disable=W0104 - stop # noqa: B018, pylint: disable=W0104 - step # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 35") # noqa: EM101 + device, ns = device_namespace(device) + return _box(array, ns.arange(start, stop, step, dtype=dtype)) def asarray( @@ -137,10 +138,8 @@ def empty( https://data-apis.org/array-api/latest/API_specification/generated/array_api.empty.html """ - shape # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 36") # noqa: EM101 + device, ns = device_namespace(device) + return _box(array, ns.empty(shape, dtype=dtype)) def empty_like( @@ -197,12 +196,8 @@ def eye( https://data-apis.org/array-api/latest/API_specification/generated/array_api.eye.html """ - n_rows # noqa: B018, pylint: disable=W0104 - n_cols # noqa: B018, pylint: disable=W0104 - k # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 38") # noqa: EM101 + device, ns = device_namespace(device) + return _box(array, ns.eye(n_rows, n_cols, k, dtype=dtype)) def from_dlpack(x: object, /) -> array: @@ -218,8 +213,21 @@ def from_dlpack(x: object, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.from_dlpack.html """ - x # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 39") # noqa: EM101 + device_type, _ = x.__dlpack_device__() # type: ignore[attr-defined] + if ( + isinstance(device_type, enum.Enum) and device_type.value == 1 + ) or device_type == 1: + y = np.from_dlpack(x) + elif ( + isinstance(device_type, enum.Enum) and device_type.value == 2 + ) or device_type == 2: + cp = _import.cupy() + y = cp.from_dlpack(x) + else: + msg = f"unsupported __dlpack_device__ type: {device_type}" + raise TypeError(msg) + + return _box(array, y) def full( @@ -254,11 +262,8 @@ def full( https://data-apis.org/array-api/latest/API_specification/generated/array_api.full.html """ - shape # noqa: B018, pylint: disable=W0104 - fill_value # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 40") # noqa: EM101 + device, ns = device_namespace(device) + return _box(array, ns.full(shape, fill_value, dtype=dtype)) def full_like( @@ -344,13 +349,10 @@ def linspace( https://data-apis.org/array-api/latest/API_specification/generated/array_api.linspace.html """ - start # noqa: B018, pylint: disable=W0104 - stop # noqa: B018, pylint: disable=W0104 - num # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - endpoint # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 42") # noqa: EM101 + device, ns = device_namespace(device) + return _box( + array, ns.linspace(start, stop, num=num, endpoint=endpoint, dtype=dtype) + ) def meshgrid(*arrays: array, indexing: str = "xy") -> list[array]: @@ -415,10 +417,8 @@ def ones( https://data-apis.org/array-api/latest/API_specification/generated/array_api.ones.html """ - shape # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 44") # noqa: EM101 + device, ns = device_namespace(device) + return _box(array, ns.ones(shape, dtype=dtype)) def ones_like( @@ -518,10 +518,8 @@ def zeros( https://data-apis.org/array-api/latest/API_specification/generated/array_api.zeros.html """ - shape # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 48") # noqa: EM101 + device, ns = device_namespace(device) + return _box(array, ns.zeros(shape, dtype=dtype)) def zeros_like( diff --git a/src/ragged/_spec_data_type_functions.py b/src/ragged/_spec_data_type_functions.py index fb61fdfc27..4106ef80dd 100644 --- a/src/ragged/_spec_data_type_functions.py +++ b/src/ragged/_spec_data_type_functions.py @@ -13,6 +13,8 @@ from ._spec_array_object import array from ._typing import Dtype +_type = type + def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array: """ @@ -56,9 +58,7 @@ def can_cast(from_: Dtype | array, to: Dtype, /) -> bool: https://data-apis.org/array-api/latest/API_specification/generated/array_api.can_cast.html """ - from_ # noqa: B018, pylint: disable=W0104 - to # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 51") # noqa: EM101 + return bool(np.can_cast(from_, to)) @dataclass @@ -114,8 +114,16 @@ def finfo(type: Dtype | array, /) -> finfo_object: # pylint: disable=W0622 https://data-apis.org/array-api/latest/API_specification/generated/array_api.finfo.html """ - type # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 52") # noqa: EM101 + if not isinstance(type, np.dtype): + if not isinstance(type, _type) and hasattr(type, "dtype"): + out = np.finfo(type.dtype) + else: + out = np.finfo(np.dtype(type)) + else: + out = np.finfo(type) + return finfo_object( + out.bits, out.eps, out.max, out.min, out.smallest_normal, out.dtype + ) @dataclass @@ -155,8 +163,14 @@ def iinfo(type: Dtype | array, /) -> iinfo_object: # pylint: disable=W0622 https://data-apis.org/array-api/latest/API_specification/generated/array_api.iinfo.html """ - type # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 53") # noqa: EM101 + if not isinstance(type, np.dtype): + if not isinstance(type, _type) and hasattr(type, "dtype"): + out = np.iinfo(type.dtype) + else: + out = np.iinfo(np.dtype(type)) + else: + out = np.iinfo(type) + return iinfo_object(out.bits, out.max, out.min, out.dtype) def isdtype(dtype: Dtype, kind: Dtype | str | tuple[Dtype | str, ...]) -> bool: @@ -218,5 +232,4 @@ def result_type(*arrays_and_dtypes: array | Dtype) -> Dtype: https://data-apis.org/array-api/latest/API_specification/generated/array_api.result_type.html """ - arrays_and_dtypes # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 55") # noqa: EM101 + return np.result_type(*arrays_and_dtypes) diff --git a/tests/test_spec_creation_functions.py b/tests/test_spec_creation_functions.py index 45c1b57834..a1c08b67a4 100644 --- a/tests/test_spec_creation_functions.py +++ b/tests/test_spec_creation_functions.py @@ -6,8 +6,21 @@ from __future__ import annotations +import numpy as np +import pytest + import ragged +devices = ["cpu"] +ns = {"cpu": np} +try: + import cupy as cp + + devices.append("cuda") + ns["cuda"] = cp +except ModuleNotFoundError: + cp = None + def test_existence(): assert ragged.arange is not None @@ -26,3 +39,98 @@ def test_existence(): assert ragged.triu is not None assert ragged.zeros is not None assert ragged.zeros_like is not None + + +@pytest.mark.parametrize("device", devices) +def test_arange(device): + a = ragged.arange(5, 10, 2, device=device) + assert a.tolist() == [5, 7, 9] + assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] + + +@pytest.mark.parametrize("device", devices) +def test_empty(device): + a = ragged.empty((2, 3, 5), device=device) + assert a.shape == (2, 3, 5) + assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] + + +@pytest.mark.parametrize("device", devices) +def test_empty_ndim0(device): + a = ragged.empty((), device=device) + assert a.ndim == 0 + assert a.shape == () + assert isinstance(a._impl, ns[device].ndarray) + + +@pytest.mark.parametrize("device", devices) +def test_eye(device): + a = ragged.eye(3, 5, k=1, device=device) + assert a.tolist() == [[0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0]] + assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] + + +@pytest.mark.skipif( + not hasattr(np, "from_dlpack"), reason=f"np.from_dlpack not in {np.__version__}" +) +@pytest.mark.parametrize("device", devices) +def test_from_dlpack(device): + a = ns[device].array([1, 2, 3, 4, 5]) + b = ragged.from_dlpack(a) + assert b.tolist() == [1, 2, 3, 4, 5] + assert isinstance(b._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] + + +@pytest.mark.parametrize("device", devices) +def test_full(device): + a = ragged.full(5, 3, device=device) + assert a.tolist() == [3, 3, 3, 3, 3] + assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] + + +@pytest.mark.parametrize("device", devices) +def test_full_ndim0(device): + a = ragged.full((), 3, device=device) + assert a.ndim == 0 + assert a.shape == () + assert a == 3 + assert isinstance(a._impl, ns[device].ndarray) + + +@pytest.mark.parametrize("device", devices) +def test_linspace(device): + a = ragged.linspace(5, 8, 5, device=device) + assert a.tolist() == [5, 5.75, 6.5, 7.25, 8] + assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] + + +@pytest.mark.parametrize("device", devices) +def test_ones(device): + a = ragged.ones(5, device=device) + assert a.tolist() == [1, 1, 1, 1, 1] + assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] + + +@pytest.mark.parametrize("device", devices) +def test_ones_ndim0(device): + a = ragged.ones((), device=device) + assert a.ndim == 0 + assert a.shape == () + assert a == 1 + assert isinstance(a._impl, ns[device].ndarray) + + +@pytest.mark.parametrize("device", devices) +def test_zeros(device): + a = ragged.zeros(5, device=device) + assert a.tolist() == [0, 0, 0, 0, 0] + assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] + + +@pytest.mark.parametrize("device", devices) +def test_zeros_ndim0(device): + a = ragged.zeros((), device=device) + assert a.ndim == 0 + assert a.shape == () + assert a == 0 + assert isinstance(a._impl, ns[device].ndarray) diff --git a/tests/test_spec_data_type_functions.py b/tests/test_spec_data_type_functions.py index b31bc52b48..526181f9d9 100644 --- a/tests/test_spec_data_type_functions.py +++ b/tests/test_spec_data_type_functions.py @@ -6,6 +6,8 @@ from __future__ import annotations +import numpy as np + import ragged @@ -16,3 +18,55 @@ def test_existence(): assert ragged.iinfo is not None assert ragged.isdtype is not None assert ragged.result_type is not None + + +def test_can_cast(): + assert ragged.can_cast(np.float32, np.complex128) + assert not ragged.can_cast(np.complex128, np.float32) + + +def test_finfo(): + f = ragged.finfo(np.float64) + assert f.bits == 64 + assert f.eps == 2.220446049250313e-16 + assert f.max == 1.7976931348623157e308 + assert f.min == -1.7976931348623157e308 + assert f.smallest_normal == 2.2250738585072014e-308 + assert f.dtype == np.dtype(np.float64) + + +def test_finfo_array(): + f = ragged.finfo(np.array([1.1, 2.2, 3.3])) + assert f.bits == 64 + assert f.dtype == np.dtype(np.float64) + + +def test_finfo_array2(): + f = ragged.finfo(ragged.array([1.1, 2.2, 3.3])) + assert f.bits == 64 + assert f.dtype == np.dtype(np.float64) + + +def test_iinfo(): + f = ragged.iinfo(np.int16) + assert f.bits == 16 + assert f.max == 32767 + assert f.min == -32768 + assert f.dtype == np.dtype(np.int16) + + +def test_iinfo_array(): + f = ragged.iinfo(np.array([1, 2, 3], np.int16)) + assert f.bits == 16 + assert f.dtype == np.dtype(np.int16) + + +def test_iinfo_array2(): + f = ragged.iinfo(ragged.array([1, 2, 3], np.int16)) + assert f.bits == 16 + assert f.dtype == np.dtype(np.int16) + + +def test_result_type(): + dt = ragged.result_type(ragged.array([1, 2, 3]), ragged.array([1.1, 2.2, 3.3])) + assert dt == np.dtype(np.float64) diff --git a/tests/test_spec_elementwise_functions.py b/tests/test_spec_elementwise_functions.py index 089443e4f5..383877866c 100644 --- a/tests/test_spec_elementwise_functions.py +++ b/tests/test_spec_elementwise_functions.py @@ -440,6 +440,10 @@ def test_ceil_int(device, x_int): assert xp.ceil(first(x_int)).dtype == result.dtype +@pytest.mark.skipif( + np.dtype("complex128") not in xp._dtypes._all_dtypes, + reason=f"complex not allowed in np.array_api version {np.__version__}", +) @pytest.mark.parametrize("device", devices) def test_conj(device, x_complex): result = ragged.conj(x_complex.to_device(device)) @@ -623,6 +627,10 @@ def test_greater_equal_method(device, x, y): assert xp.greater_equal(first(x), first(y)).dtype == result.dtype +@pytest.mark.skipif( + np.dtype("complex128") not in xp._dtypes._all_dtypes, + reason=f"complex not allowed in np.array_api version {np.__version__}", +) @pytest.mark.parametrize("device", devices) def test_imag(device, x_complex): result = ragged.imag(x_complex.to_device(device)) @@ -886,6 +894,10 @@ def test_pow_inplace_method(device, x, y): assert x.dtype == z.dtype +@pytest.mark.skipif( + np.dtype("complex128") not in xp._dtypes._all_dtypes, + reason=f"complex not allowed in np.array_api version {np.__version__}", +) @pytest.mark.parametrize("device", devices) def test_real(device, x_complex): result = ragged.real(x_complex.to_device(device)) @@ -932,6 +944,10 @@ def test_round(device, x): assert xp.round(first(x)).dtype == result.dtype +@pytest.mark.skipif( + np.dtype("complex128") not in xp._dtypes._all_dtypes, + reason=f"complex not allowed in np.array_api version {np.__version__}", +) @pytest.mark.parametrize("device", devices) def test_round_complex(device, x_complex): result = ragged.round(x_complex.to_device(device))