diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 000000000..8bf716ed1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,97 @@ +name: Bug report +description: Submit a bug report +title: "[BUG] " +labels: TRIAGE +body: + - type: markdown + attributes: + value: "# Bug report" + - type: markdown + attributes: + value: Thank you for reporting a bug and helping us improve Cunumeric! + - type: markdown + attributes: + value: > + Please fill out all of the required information. + - type: markdown + attributes: + value: | + --- + ## Environment information + - type: textarea + id: legate_issue + attributes: + label: Software versions + description: >- + Run `legate-issue` and paste the output here. + placeholder: | + Python : 3.10.11 | packaged by conda-forge | (main, May 10 2023, 18:58:44) [GCC 11.3.0] + Platform : Linux-5.14.0-1042-oem-x86_64-with-glibc2.31 + Legion : v23.11.00.dev-16-g2499f878 + Legate : 23.11.00.dev+17.gb7b50313 + Cunumeric : (ImportError: cannot import name 'LogicalArray' from 'legate.core') + Numpy : 1.24.4 + Scipy : 1.10.1 + Numba : (not installed) + CTK package : cuda-version-11.8-h70ddcb2_2 (conda-forge) + GPU Driver : 515.65.01 + GPU Devices : + GPU 0: Quadro RTX 8000 + GPU 1: Quadro RTX 8000 + validations: + required: true + - type: input + id: jupyter + attributes: + label: Jupyter notebook / Jupyter Lab version + description: >- + Please supply if the issue you are reporting is related to Jupyter + notebook or Jupyter Lab. + validations: + required: false + - type: markdown + attributes: + value: | + ## Issue details + - type: textarea + id: expected-behavior + attributes: + label: Expected behavior + description: What did you expect to happen? + validations: + required: true + - type: textarea + id: observed-behavior + attributes: + label: Observed behavior + description: What did actually happen? + validations: + required: true + - type: markdown + attributes: + value: | + ## Directions to reproduce + - type: textarea + id: example + attributes: + label: Example code or instructions + description: > + Please provide detailed instructions to reproduce the issue. Ideally this includes a + [Complete, minimal, self-contained example code](https://stackoverflow.com/help/minimal-reproducible-example) + given here or as a link to code in another repository. + render: Python + validations: + required: true + - type: markdown + attributes: + value: | + ## Additional information + - type: textarea + id: traceback-console + attributes: + label: Stack traceback or browser console output + description: > + Add any error messages or logs that might be helpful in reproducing and + identifying the bug, for example a Python stack traceback. + validations: + required: false diff --git a/README.md b/README.md index 3ed163b57..7516331ff 100644 --- a/README.md +++ b/README.md @@ -40,15 +40,15 @@ If you have questions, please contact us at legate(at)nvidia.com. cuNumeric is available [on conda](https://anaconda.org/legate/cunumeric): ``` -conda install -c nvidia -c conda-forge -c legate cunumeric +mamba install -c nvidia -c conda-forge -c legate cunumeric ``` Only linux-64 packages are available at the moment. The default package contains GPU support, and is compatible with CUDA >= 11.8 (CUDA driver version >= r520), and Volta or later GPU architectures. There are -also CPU-only packages available, and will be automatically selected by `conda` -when installing on a machine without GPUs. +also CPU-only packages available, and will be automatically selected when +installing on a machine without GPUs. See the build instructions at https://nv-legate.github.io/cunumeric for details about building cuNumeric from source. @@ -119,7 +119,7 @@ with cuNumeric going forward: new features to cuNumeric. * We plan to add support for sharded file I/O for loading and storing large data sets that could never be loaded on a single node. - Initially this will begin with native support for [h5py](https://www.h5py.org/) + Initially this will begin with native support for hdf5 and zarr, but will grow to accommodate other formats needed by our lighthouse applications. * Strong scaling: while cuNumeric is currently implemented in a way that diff --git a/cmake/versions.json b/cmake/versions.json index b99da26cb..43d60fa5e 100644 --- a/cmake/versions.json +++ b/cmake/versions.json @@ -5,7 +5,7 @@ "git_url" : "https://github.com/nv-legate/legate.core.git", "git_shallow": false, "always_download": false, - "git_tag" : "8997f997be02936304b3ac23fe785f1de7a3424b" + "git_tag" : "6fa0acc9dcfa89be2702f1de6c045bc262f752b1" } } } diff --git a/continuous_integration/scripts/build-cunumeric-all b/continuous_integration/scripts/build-cunumeric-all index bcdbf62ec..66f5ccb6e 100755 --- a/continuous_integration/scripts/build-cunumeric-all +++ b/continuous_integration/scripts/build-cunumeric-all @@ -3,12 +3,6 @@ setup_env() { yaml_file=$(find ~/.artifacts -name "environment*.yaml" | head -n 1) - [ "${USE_CUDA:-}" = "ON" ] && - echo " - libcublas-dev" >> "${yaml_file}" && - echo " - libcufft-dev" >> "${yaml_file}" && - echo " - libcurand-dev" >> "${yaml_file}" && - echo " - libcusolver-dev" >> "${yaml_file}"; - echo "YAML file..." cat "${yaml_file}" diff --git a/cunumeric/__init__.py b/cunumeric/__init__.py index 13c8504b8..3ad86dd9f 100644 --- a/cunumeric/__init__.py +++ b/cunumeric/__init__.py @@ -28,7 +28,7 @@ import numpy as _np -from cunumeric import linalg, random, fft +from cunumeric import linalg, random, fft, ma from cunumeric.array import maybe_convert_to_np_ndarray, ndarray from cunumeric.bits import packbits, unpackbits from cunumeric.module import * diff --git a/cunumeric/array.py b/cunumeric/array.py index 0b7ce23c0..2767e1a0b 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -53,7 +53,13 @@ from .coverage import FALLBACK_WARNING, clone_class, is_implemented from .runtime import runtime from .types import NdShape -from .utils import deep_apply, dot_modes, to_core_dtype +from .utils import ( + calculate_volume, + deep_apply, + dot_modes, + to_core_dtype, + tuple_pop, +) if TYPE_CHECKING: from pathlib import Path @@ -159,7 +165,9 @@ def maybe_convert_to_np_ndarray(obj: Any) -> Any: """ Converts cuNumeric arrays into NumPy arrays, otherwise has no effect. """ - if isinstance(obj, ndarray): + from .ma import MaskedArray + + if isinstance(obj, (ndarray, MaskedArray)): return obj.__array__() return obj @@ -1664,8 +1672,6 @@ def __setitem__(self, key: Any, value: ndarray) -> None: """ check_writeable(self) - if key is None: - raise KeyError("invalid key passed to cunumeric.ndarray") if value.dtype != self.dtype: temp = ndarray(value.shape, dtype=self.dtype, inputs=(value,)) temp._thunk.convert(value._thunk) @@ -3086,12 +3092,54 @@ def _count_nonzero(self, axis: Any = None) -> Union[int, ndarray]: axis=axis, ) + def _summation_dtype( + self, dtype: Optional[np.dtype[Any]] + ) -> np.dtype[Any]: + # Pick our dtype if it wasn't picked yet + if dtype is None: + if self.dtype.kind != "f" and self.dtype.kind != "c": + return np.dtype(np.float64) + else: + return self.dtype + return dtype + + def _normalize_summation( + self, + sum_array: Any, + axis: Any, + dtype: np.dtype[Any], + ddof: int = 0, + keepdims: bool = False, + where: Union[ndarray, None] = None, + ) -> None: + if axis is None: + if where is not None: + divisor = where._count_nonzero() - ddof + else: + divisor = reduce(lambda x, y: x * y, self.shape, 1) - ddof + else: + if where is not None: + divisor = ( + where.sum(axis=axis, dtype=dtype, keepdims=keepdims) - ddof + ) + else: + divisor = self.shape[axis] - ddof + + # Divide by the number of things in the collapsed dimensions + # Pick the right kinds of division based on the dtype + if np.ndim(divisor) == 0: + divisor = np.array(divisor, dtype=sum_array.dtype) # type: ignore [assignment] # noqa + if dtype.kind == "f" or dtype.kind == "c": + sum_array.__itruediv__(divisor) + else: + sum_array.__ifloordiv__(divisor) + @add_boilerplate() def mean( self, axis: Any = None, - dtype: Union[np.dtype[Any], None] = None, - out: Union[ndarray, None] = None, + dtype: Optional[np.dtype[Any]] = None, + out: Optional[ndarray] = None, keepdims: bool = False, where: Union[ndarray, None] = None, ) -> ndarray: @@ -3113,16 +3161,12 @@ def mean( if axis is not None and not isinstance(axis, int): raise NotImplementedError( "cunumeric.mean only supports int types for " - "'axis' currently" + "`axis` currently" ) - # Pick our dtype if it wasn't picked yet - if dtype is None: - if self.dtype.kind != "f" and self.dtype.kind != "c": - dtype = np.dtype(np.float64) - else: - dtype = self.dtype + dtype = self._summation_dtype(dtype) where_array = broadcast_where(where, self.shape) + # Do the sum sum_array = ( self.sum( @@ -3138,28 +3182,10 @@ def mean( ) ) - if axis is None: - if where_array is not None: - divisor = where_array._count_nonzero() - else: - divisor = reduce(lambda x, y: x * y, self.shape, 1) - - else: - if where_array is not None: - divisor = where_array.sum( - axis=axis, dtype=dtype, keepdims=keepdims - ) - else: - divisor = self.shape[axis] + self._normalize_summation( + sum_array, axis, dtype, keepdims=keepdims, where=where_array + ) - # Divide by the number of things in the collapsed dimensions - # Pick the right kinds of division based on the dtype - if dtype.kind == "f" or dtype.kind == "c": - sum_array.__itruediv__( - divisor, - ) - else: - sum_array.__ifloordiv__(divisor) # Convert to the output we didn't already put it there if out is not None and sum_array is not out: assert out.dtype != sum_array.dtype @@ -3196,6 +3222,91 @@ def _nanmean( where=nan_mask, ) + @add_boilerplate() + def var( + self, + axis: Optional[Union[int, tuple[int, ...]]] = None, + dtype: Optional[np.dtype[Any]] = None, + out: Optional[ndarray] = None, + ddof: int = 0, + keepdims: bool = False, + *, + where: Union[ndarray, None] = None, + ) -> ndarray: + """a.var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False) + + Returns the variance of the array elements along given axis. + + Refer to :func:`cunumeric.var` for full documentation. + + See Also + -------- + cunumeric.var : equivalent function + + Availability + -------- + Multiple GPUs, Multiple CPUs + + """ + if axis is not None and not isinstance(axis, int): + raise NotImplementedError( + "cunumeric.var only supports int types for `axis` currently" + ) + + # this could be computed as a single pass through the array + # by computing both and and then computing - ^2. + # this would takee the difference of two large numbers and is unstable + # the mean needs to be computed first and the variance computed + # directly as <(x-mu)^2>, which then requires two passes through the + # data to first compute the mean and then compute the variance + # see https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + # TODO(https://github.com/nv-legate/cunumeric/issues/590) + + dtype = self._summation_dtype(dtype) + # calculate the mean, but keep the dimensions so that the + # mean can be broadcast against the original array + mu = self.mean(axis=axis, dtype=dtype, keepdims=True, where=where) + + # 1D arrays (or equivalent) should benefit from this unary reduction: + # + if axis is None or calculate_volume(tuple_pop(self.shape, axis)) == 1: + # this is a scalar reduction and we can optimize this as a single + # pass through a scalar reduction + result = self._perform_unary_reduction( + UnaryRedCode.VARIANCE, + self, + axis=axis, + dtype=dtype, + out=out, + keepdims=keepdims, + where=where, + args=(mu,), + ) + else: + # TODO(https://github.com/nv-legate/cunumeric/issues/591) + # there isn't really support for generic binary reductions + # right now all of the current binary reductions are boolean + # reductions like allclose. To implement this a single pass would + # require a variant of einsum/dot that produces + # (self-mu)*(self-mu) rather than self*mu. For now, we have to + # compute delta = self-mu in a first pass and then compute + # delta*delta in second pass + delta = self - mu + + result = self._perform_unary_reduction( + UnaryRedCode.SUM_SQUARES, + delta, + axis=axis, + dtype=dtype, + out=out, + keepdims=keepdims, + where=where, + ) + + self._normalize_summation(result, axis=axis, dtype=dtype, ddof=ddof) + + return result + @add_boilerplate() def min( self, diff --git a/cunumeric/config.py b/cunumeric/config.py index 6c5bbbb18..635544bd8 100644 --- a/cunumeric/config.py +++ b/cunumeric/config.py @@ -32,6 +32,7 @@ class _CunumericSharedLib: CUNUMERIC_ADVANCED_INDEXING: int CUNUMERIC_ARANGE: int CUNUMERIC_ARGWHERE: int + CUNUMERIC_BATCHED_CHOLESKY: int CUNUMERIC_BINARY_OP: int CUNUMERIC_BINARY_RED: int CUNUMERIC_BINCOUNT: int @@ -187,6 +188,8 @@ class _CunumericSharedLib: CUNUMERIC_RED_NANSUM: int CUNUMERIC_RED_PROD: int CUNUMERIC_RED_SUM: int + CUNUMERIC_RED_SUM_SQUARES: int + CUNUMERIC_RED_VARIANCE: int CUNUMERIC_REPEAT: int CUNUMERIC_SCALAR_UNARY_RED: int CUNUMERIC_SCAN_GLOBAL: int @@ -331,6 +334,7 @@ class CuNumericOpCode(IntEnum): ADVANCED_INDEXING = _cunumeric.CUNUMERIC_ADVANCED_INDEXING ARANGE = _cunumeric.CUNUMERIC_ARANGE ARGWHERE = _cunumeric.CUNUMERIC_ARGWHERE + BATCHED_CHOLESKY = _cunumeric.CUNUMERIC_BATCHED_CHOLESKY BINARY_OP = _cunumeric.CUNUMERIC_BINARY_OP BINARY_RED = _cunumeric.CUNUMERIC_BINARY_RED BINCOUNT = _cunumeric.CUNUMERIC_BINCOUNT @@ -452,6 +456,8 @@ class UnaryRedCode(IntEnum): NANSUM = _cunumeric.CUNUMERIC_RED_NANSUM PROD = _cunumeric.CUNUMERIC_RED_PROD SUM = _cunumeric.CUNUMERIC_RED_SUM + SUM_SQUARES = _cunumeric.CUNUMERIC_RED_SUM_SQUARES + VARIANCE = _cunumeric.CUNUMERIC_RED_VARIANCE # Match these to CuNumericBinaryOpCode in cunumeric_c.h diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index ae5c2ae57..fb288a205 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -159,6 +159,8 @@ def __init__( _UNARY_RED_TO_REDUCTION_OPS: Dict[int, int] = { UnaryRedCode.SUM: ReductionOp.ADD, + UnaryRedCode.SUM_SQUARES: ReductionOp.ADD, + UnaryRedCode.VARIANCE: ReductionOp.ADD, UnaryRedCode.PROD: ReductionOp.MUL, UnaryRedCode.MAX: ReductionOp.MAX, UnaryRedCode.MIN: ReductionOp.MIN, @@ -209,6 +211,8 @@ def min_identity( _UNARY_RED_IDENTITIES: Dict[UnaryRedCode, Callable[[Any], Any]] = { UnaryRedCode.SUM: lambda _: 0, + UnaryRedCode.SUM_SQUARES: lambda _: 0, + UnaryRedCode.VARIANCE: lambda _: 0, UnaryRedCode.PROD: lambda _: 1, UnaryRedCode.MIN: min_identity, UnaryRedCode.MAX: max_identity, diff --git a/cunumeric/eager.py b/cunumeric/eager.py index e9cce7db4..03929395d 100644 --- a/cunumeric/eager.py +++ b/cunumeric/eager.py @@ -1526,6 +1526,26 @@ def unary_reduction( else where.array, **kws, ) + elif op == UnaryRedCode.SUM_SQUARES: + squared = np.square(rhs.array) + np.sum( + squared, + out=self.array, + axis=orig_axis, + where=where, + keepdims=keepdims, + ) + elif op == UnaryRedCode.VARIANCE: + (mu,) = args + centered = np.subtract(rhs.array, mu) + squares = np.square(centered) + np.sum( + squares, + axis=orig_axis, + where=where, + keepdims=keepdims, + out=self.array, + ) elif op == UnaryRedCode.CONTAINS: self.array.fill(args[0] in rhs.array) elif op == UnaryRedCode.COUNT_NONZERO: @@ -1597,7 +1617,7 @@ def where(self, rhs1: Any, rhs2: Any, rhs3: Any) -> None: if self.deferred is not None: self.deferred.where(rhs1, rhs2, rhs3) else: - self.array[:] = np.where(rhs1.array, rhs2.array, rhs3.array) + self.array[...] = np.where(rhs1.array, rhs2.array, rhs3.array) def argwhere(self) -> NumPyThunk: if self.deferred is not None: diff --git a/cunumeric/linalg/cholesky.py b/cunumeric/linalg/cholesky.py index 9bba03361..eed4c3188 100644 --- a/cunumeric/linalg/cholesky.py +++ b/cunumeric/linalg/cholesky.py @@ -1,4 +1,4 @@ -# Copyright 2021-2022 NVIDIA Corporation +# Copyright 2023 NVIDIA Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -202,11 +202,47 @@ def tril(context: Context, p_output: StorePartition, n: int) -> None: task.execute() +def _batched_cholesky(output: DeferredArray, input: DeferredArray) -> None: + # the only feasible implementation for right now is that + # each cholesky submatrix fits on a single proc. We will have + # wildly varying memory available depending on the system. + # Just use a fixed cutoff to provide some sensible warning. + # TODO: find a better way to inform the user dims are too big + context: Context = output.context # type: ignore + task = context.create_auto_task(CuNumericOpCode.BATCHED_CHOLESKY) + task.add_input(input.base) + task.add_output(output.base) + ndim = input.base.ndim + task.add_broadcast(input.base, (ndim - 2, ndim - 1)) + task.add_broadcast(output.base, (ndim - 2, ndim - 1)) + task.add_alignment(input.base, output.base) + task.throws_exception(LinAlgError) + task.execute() + + def cholesky( output: DeferredArray, input: DeferredArray, no_tril: bool ) -> None: runtime = output.runtime - context = output.context + context: Context = output.context + if len(input.base.shape) > 2: + if no_tril: + raise NotImplementedError( + "batched cholesky expects to only " + "produce the lower triangular matrix" + ) + size = input.base.shape[-1] + # Choose 32768 as dimension cutoff for warning + # so that for float64 anything larger than + # 8 GiB produces a warning + if size > 32768: + runtime.warn( + "batched cholesky is only valid" + " when the square submatrices fit" + f" on a single proc, n > {size} may be too large", + category=UserWarning, + ) + return _batched_cholesky(output, input) if runtime.num_procs == 1: transpose_copy_single(context, input.base, output.base) diff --git a/cunumeric/linalg/linalg.py b/cunumeric/linalg/linalg.py index f3f7eb9fb..d1c0498b2 100644 --- a/cunumeric/linalg/linalg.py +++ b/cunumeric/linalg/linalg.py @@ -82,10 +82,6 @@ def cholesky(a: ndarray) -> ndarray: elif shape[-1] != shape[-2]: raise ValueError("Last 2 dimensions of the array must be square") - if len(shape) > 2: - raise NotImplementedError( - "cuNumeric needs to support stacked 2d arrays" - ) return _cholesky(a) diff --git a/cunumeric/ma/__init__.py b/cunumeric/ma/__init__.py new file mode 100644 index 000000000..14a9e0d46 --- /dev/null +++ b/cunumeric/ma/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +import numpy.ma as _ma + +from cunumeric.array import maybe_convert_to_np_ndarray +from cunumeric.coverage import clone_module +from cunumeric.ma._masked_array import MaskedArray + +masked_array = MaskedArray + +clone_module(_ma, globals(), maybe_convert_to_np_ndarray) + +del maybe_convert_to_np_ndarray +del clone_module +del _ma diff --git a/cunumeric/ma/_masked_array.py b/cunumeric/ma/_masked_array.py new file mode 100644 index 000000000..4420bdf1c --- /dev/null +++ b/cunumeric/ma/_masked_array.py @@ -0,0 +1,88 @@ +# Copyright 2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Type, Union + +if TYPE_CHECKING: + import numpy.typing as npt + from ..types import NdShape + + +import numpy as _np + +from ..array import maybe_convert_to_np_ndarray +from ..coverage import clone_class + +NDARRAY_INTERNAL = { + "__array_finalize__", + "__array_function__", + "__array_interface__", + "__array_prepare__", + "__array_priority__", + "__array_struct__", + "__array_ufunc__", + "__array_wrap__", +} + +MaskType = _np.bool_ +nomask = MaskType(0) + + +@clone_class(_np.ma.MaskedArray, NDARRAY_INTERNAL, maybe_convert_to_np_ndarray) +class MaskedArray: + _internal_ma: _np.ma.MaskedArray[Any, Any] + + def __new__(cls: Type[Any], *args: Any, **kw: Any) -> MaskedArray: + return super().__new__(cls) + + def __init__( + self, + data: Any = None, + mask: _np.bool_ = nomask, + dtype: Union[npt.DTypeLike, None] = None, + copy: bool = False, + subok: bool = True, + ndmin: int = 0, + fill_value: Any = None, + keep_mask: Any = True, + hard_mask: Any = None, + shrink: bool = True, + order: Union[str, None] = None, + ) -> None: + self._internal_ma = _np.ma.MaskedArray( # type: ignore + data=maybe_convert_to_np_ndarray(data), + mask=maybe_convert_to_np_ndarray(mask), + dtype=dtype, + copy=copy, + subok=subok, + ndmin=ndmin, + fill_value=fill_value, + keep_mask=keep_mask, + hard_mask=hard_mask, + shrink=shrink, + order=order, + ) + + def __array__(self, _dtype: Any = None) -> _np.ma.MaskedArray[Any, Any]: + return self._internal_ma + + @property + def size(self) -> int: + return self._internal_ma.size + + @property + def shape(self) -> NdShape: + return self._internal_ma.shape diff --git a/cunumeric/module.py b/cunumeric/module.py index 5d75476ea..d37ea4183 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -2763,8 +2763,8 @@ def flip(m: ndarray, axis: Optional[NdShapeLike] = None) -> ndarray: Returns ------- out : array_like - A view of `m` with the entries of axis reversed. Since a view is - returned, this operation is done in constant time. + A new array that is constructed from `m` with the entries of axis + reversed. See Also -------- @@ -2773,10 +2773,85 @@ def flip(m: ndarray, axis: Optional[NdShapeLike] = None) -> ndarray: Availability -------- Single GPU, Single CPU + + Notes + ----- + cuNumeric implementation doesn't return a view, it returns a new array """ return m.flip(axis=axis) +@add_boilerplate("m") +def flipud(m: ndarray) -> ndarray: + """ + Reverse the order of elements along axis 0 (up/down). + + For a 2-D array, this flips the entries in each column in the up/down + direction. Rows are preserved, but appear in a different order than before. + + Parameters + ---------- + m : array_like + Input array. + + Returns + ------- + out : array_like + A new array that is constructed from `m` with rows reversed. + + See Also + -------- + numpy.flipud + + Availability + -------- + Single GPU, Single CPU + + Notes + ----- + cuNumeric implementation doesn't return a view, it returns a new array + """ + if m.ndim < 1: + raise ValueError("Input must be >= 1-d.") + return flip(m, axis=0) + + +@add_boilerplate("m") +def fliplr(m: ndarray) -> ndarray: + """ + Reverse the order of elements along axis 1 (left/right). + + For a 2-D array, this flips the entries in each row in the left/right + direction. Columns are preserved, but appear in a different order than + before. + + Parameters + ---------- + m : array_like + Input array, must be at least 2-D. + + Returns + ------- + f : ndarray + A new array that is constructed from `m` with the columns reversed. + + See Also + -------- + numpy.fliplr + + Availability + -------- + Single GPU, Single CPU + + Notes + ----- + cuNumeric implementation doesn't return a view, it returns a new array + """ + if m.ndim < 2: + raise ValueError("Input must be >= 2-d.") + return flip(m, axis=1) + + ################### # Binary operations ################### @@ -4591,7 +4666,7 @@ def einsum( out: Optional[ndarray] = None, dtype: Optional[np.dtype[Any]] = None, casting: CastingKind = "safe", - optimize: Union[bool, str] = False, + optimize: Union[bool, Literal["greedy", "optimal"]] = True, ) -> ndarray: """ Evaluates the Einstein summation convention on the operands. @@ -4632,9 +4707,10 @@ def einsum( Default is 'safe'. optimize : ``{False, True, 'greedy', 'optimal'}``, optional - Controls if intermediate optimization should occur. No optimization - will occur if False. Uses opt_einsum to find an optimized contraction - plan if True. + Controls if intermediate optimization should occur. If False then + arrays will be contracted in input order, one at a time. True (the + default) will use the 'greedy' algorithm. See ``cunumeric.einsum_path`` + for more information on the available optimization algorithms. Returns ------- @@ -4658,7 +4734,9 @@ def einsum( if out is not None: out = convert_to_cunumeric_ndarray(out, share=True) - if not optimize: + if optimize is True: + optimize = "greedy" + elif optimize is False: optimize = NullOptimizer() # This call normalizes the expression (adds the output part if it's @@ -7121,6 +7199,79 @@ def nanmean( ) +@add_boilerplate("a") +def var( + a: ndarray, + axis: Optional[Union[int, tuple[int, ...]]] = None, + dtype: Optional[np.dtype[Any]] = None, + out: Optional[ndarray] = None, + ddof: int = 0, + keepdims: bool = False, + *, + where: Union[ndarray, None] = None, +) -> ndarray: + """ + Compute the variance along the specified axis. + + Returns the variance of the array elements, a measure of the spread of + a distribution. The variance is computed for the flattened array + by default, otherwise over the specified axis. + + Parameters + ---------- + a : array_like + Array containing numbers whose variance is desired. If `a` is not an + array, a conversion is attempted. + axis : None or int or tuple[int], optional + Axis or axes along which the variance is computed. The default is to + compute the variance of the flattened array. + + If this is a tuple of ints, a variance is performed over multiple axes, + instead of a single axis or all the axes as before. + dtype : data-type, optional + Type to use in computing the variance. For arrays of integer type + the default is float64; for arrays of float types + it is the same as the array type. + out : ndarray, optional + Alternate output array in which to place the result. It must have the + same shape as the expected output, but the type is cast if necessary. + ddof : int, optional + “Delta Degrees of Freedom”: the divisor used in the calculation is + N - ddof, where N represents the number of elements. By default + ddof is zero. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + where : array_like of bool, optional + A boolean array which is broadcasted to match the dimensions of array, + and selects elements to include in the reduction. + + Returns + ------- + m : ndarray, see dtype parameter above + If `out=None`, returns a new array of the same dtype as above + containing the variance values, otherwise a reference to the output + array is returned. + + See Also + -------- + numpy.var + + Availability + -------- + Multiple GPUs, Multiple CPUs + """ + return a.var( + axis=axis, + dtype=dtype, + out=out, + ddof=ddof, + keepdims=keepdims, + where=where, + ) + + # Histograms @@ -7182,7 +7333,7 @@ def bincount( raise ValueError("weights must be convertible to float64") # Make sure the weights are float64 weights = weights.astype(np.float64) - if x.dtype.kind != "i": + if not np.issubdtype(x.dtype, np.integer): raise TypeError("input array for bincount must be integer type") if minlength < 0: raise ValueError("'minlength' must not be negative") diff --git a/cunumeric/utils.py b/cunumeric/utils.py index 93e45fb74..8c2d70140 100644 --- a/cunumeric/utils.py +++ b/cunumeric/utils.py @@ -18,7 +18,7 @@ from functools import reduce from string import ascii_lowercase, ascii_uppercase from types import FrameType -from typing import Any, Callable, List, Sequence, Tuple, Union +from typing import Any, Callable, List, Sequence, Tuple, TypeVar, Union import legate.core.types as ty import numpy as np @@ -108,6 +108,13 @@ def calculate_volume(shape: NdShape) -> int: return reduce(lambda x, y: x * y, shape) +T = TypeVar("T") + + +def tuple_pop(tup: Tuple[T, ...], index: int) -> Tuple[T, ...]: + return tup[:index] + tup[index + 1 :] + + Modes = Tuple[List[str], List[str], List[str]] diff --git a/cunumeric_cpp.cmake b/cunumeric_cpp.cmake index 4270962ba..f7feee620 100644 --- a/cunumeric_cpp.cmake +++ b/cunumeric_cpp.cmake @@ -143,6 +143,7 @@ list(APPEND cunumeric_SOURCES src/cunumeric/index/putmask.cc src/cunumeric/item/read.cc src/cunumeric/item/write.cc + src/cunumeric/matrix/batched_cholesky.cc src/cunumeric/matrix/contract.cc src/cunumeric/matrix/diag.cc src/cunumeric/matrix/gemm.cc @@ -195,6 +196,7 @@ if(Legion_USE_OpenMP) src/cunumeric/index/repeat_omp.cc src/cunumeric/index/wrap_omp.cc src/cunumeric/index/zip_omp.cc + src/cunumeric/matrix/batched_cholesky_omp.cc src/cunumeric/matrix/contract_omp.cc src/cunumeric/matrix/diag_omp.cc src/cunumeric/matrix/gemm_omp.cc @@ -245,6 +247,7 @@ if(Legion_USE_CUDA) src/cunumeric/index/putmask.cu src/cunumeric/item/read.cu src/cunumeric/item/write.cu + src/cunumeric/matrix/batched_cholesky.cu src/cunumeric/matrix/contract.cu src/cunumeric/matrix/diag.cu src/cunumeric/matrix/gemm.cu diff --git a/docs/cunumeric/source/api/manipulation.rst b/docs/cunumeric/source/api/manipulation.rst index 18ffe7a16..86010e721 100644 --- a/docs/cunumeric/source/api/manipulation.rst +++ b/docs/cunumeric/source/api/manipulation.rst @@ -103,3 +103,5 @@ Rearranging elements :toctree: generated/ flip + fliplr + flipud diff --git a/docs/cunumeric/source/api/ndarray.rst b/docs/cunumeric/source/api/ndarray.rst index afdd1406f..aca3b9ce0 100644 --- a/docs/cunumeric/source/api/ndarray.rst +++ b/docs/cunumeric/source/api/ndarray.rst @@ -158,7 +158,7 @@ Calculation ndarray.sum ndarray.cumsum ndarray.mean - .. ndarray.var + ndarray.var .. ndarray.std ndarray.prod ndarray.cumprod diff --git a/docs/cunumeric/source/api/statistics.rst b/docs/cunumeric/source/api/statistics.rst index ef3056da8..48f10f19c 100644 --- a/docs/cunumeric/source/api/statistics.rst +++ b/docs/cunumeric/source/api/statistics.rst @@ -11,6 +11,7 @@ Averages and variances mean nanmean + var Histograms diff --git a/docs/cunumeric/source/versions.rst b/docs/cunumeric/source/versions.rst index 4a21cc9ef..1760786d8 100644 --- a/docs/cunumeric/source/versions.rst +++ b/docs/cunumeric/source/versions.rst @@ -11,3 +11,4 @@ Versions 23.03 23.07 23.09 + 23.11 diff --git a/scripts/util/build-caching.sh b/scripts/util/build-caching.sh index 70de985d3..9fb4c1b4a 100755 --- a/scripts/util/build-caching.sh +++ b/scripts/util/build-caching.sh @@ -7,9 +7,9 @@ if [[ -n "$(which sccache)" ]]; then CMAKE_CUDA_COMPILER_LAUNCHER="${CMAKE_CUDA_COMPILER_LAUNCHER:-$(which sccache)}"; elif [[ -n "$(which ccache)" ]]; then # Use ccache if installed - CMAKE_C_COMPILER_LAUNCHER="${CMAKE_C_COMPILER_LAUNCHER:-$(which cache)}"; - CMAKE_CXX_COMPILER_LAUNCHER="${CMAKE_CXX_COMPILER_LAUNCHER:-$(which cache)}"; - CMAKE_CUDA_COMPILER_LAUNCHER="${CMAKE_CUDA_COMPILER_LAUNCHER:-$(which cache)}"; + CMAKE_C_COMPILER_LAUNCHER="${CMAKE_C_COMPILER_LAUNCHER:-$(which ccache)}"; + CMAKE_CXX_COMPILER_LAUNCHER="${CMAKE_CXX_COMPILER_LAUNCHER:-$(which ccache)}"; + CMAKE_CUDA_COMPILER_LAUNCHER="${CMAKE_CUDA_COMPILER_LAUNCHER:-$(which ccache)}"; fi export CMAKE_C_COMPILER_LAUNCHER="$CMAKE_C_COMPILER_LAUNCHER" diff --git a/src/cunumeric/cunumeric_c.h b/src/cunumeric/cunumeric_c.h index 74c05fcd2..99d9bea19 100644 --- a/src/cunumeric/cunumeric_c.h +++ b/src/cunumeric/cunumeric_c.h @@ -29,6 +29,7 @@ enum CuNumericOpCode { CUNUMERIC_ADVANCED_INDEXING, CUNUMERIC_ARANGE, CUNUMERIC_ARGWHERE, + CUNUMERIC_BATCHED_CHOLESKY, CUNUMERIC_BINARY_OP, CUNUMERIC_BINARY_RED, CUNUMERIC_BINCOUNT, @@ -150,6 +151,8 @@ enum CuNumericUnaryRedCode { CUNUMERIC_RED_NANSUM, CUNUMERIC_RED_PROD, CUNUMERIC_RED_SUM, + CUNUMERIC_RED_SUM_SQUARES, + CUNUMERIC_RED_VARIANCE }; // Match these to BinaryOpCode in config.py diff --git a/src/cunumeric/mapper.cc b/src/cunumeric/mapper.cc index 247ded4fd..ba7114e45 100644 --- a/src/cunumeric/mapper.cc +++ b/src/cunumeric/mapper.cc @@ -145,6 +145,25 @@ std::vector CuNumericMapper::store_mappings( } return std::move(mappings); } + // CHANGE: If this code is changed, make sure all layouts are + // consistent with those assumed in batched_cholesky.cu, etc + case CUNUMERIC_BATCHED_CHOLESKY: { + std::vector mappings; + auto& inputs = task.inputs(); + auto& outputs = task.outputs(); + mappings.reserve(inputs.size() + outputs.size()); + for (auto& input : inputs) { + mappings.push_back(StoreMapping::default_mapping(input, options.front())); + mappings.back().policy.exact = true; + mappings.back().policy.ordering.set_c_order(); + } + for (auto& output : outputs) { + mappings.push_back(StoreMapping::default_mapping(output, options.front())); + mappings.back().policy.exact = true; + mappings.back().policy.ordering.set_c_order(); + } + return std::move(mappings); + } case CUNUMERIC_TRILU: { if (task.scalars().size() == 2) return {}; // If we're here, this task was the post-processing for Cholesky. diff --git a/src/cunumeric/matrix/batched_cholesky.cc b/src/cunumeric/matrix/batched_cholesky.cc new file mode 100644 index 000000000..30dbe3c53 --- /dev/null +++ b/src/cunumeric/matrix/batched_cholesky.cc @@ -0,0 +1,85 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/matrix/batched_cholesky.h" +#include "cunumeric/cunumeric.h" +#include "cunumeric/matrix/batched_cholesky_template.inl" + +#include +#include +#include + +namespace cunumeric { + +using namespace legate; + +template <> +void CopyBlockImpl::operator()(void* dst, const void* src, size_t size) +{ + ::memcpy(dst, src, size); +} + +template +struct BatchedTransposeImplBody { + using VAL = legate_type_of; + + static constexpr int tile_size = 64; + + void operator()(VAL* out, int n) const + { + VAL tile[tile_size][tile_size]; + int nblocks = (n + tile_size - 1) / tile_size; + + for (int rb = 0; rb < nblocks; ++rb) { + for (int cb = 0; cb < nblocks; ++cb) { + int r_start = rb * tile_size; + int r_stop = std::min(r_start + tile_size, n); + int c_start = cb * tile_size; + int c_stop = std::min(c_start + tile_size, n); + for (int r = r_start, tr = 0; r < r_stop; ++r, ++tr) { + for (int c = c_start, tc = 0; c < c_stop; ++c, ++tc) { + if (r <= c) { + tile[tr][tc] = out[r * n + c]; + } else { + tile[tr][tc] = 0; + } + } + } + for (int r = c_start, tr = 0; r < c_stop; ++r, ++tr) { + for (int c = r_start, tc = 0; c < r_stop; ++c, ++tc) { out[r * n + c] = tile[tc][tr]; } + } + } + } + } +}; + +/*static*/ void BatchedCholeskyTask::cpu_variant(TaskContext& context) +{ +#ifdef LEGATE_USE_OPENMP + openblas_set_num_threads(1); // make sure this isn't overzealous +#endif + batched_cholesky_task_context_dispatch(context); +} + +namespace // unnamed +{ +static void __attribute__((constructor)) register_tasks(void) +{ + BatchedCholeskyTask::register_variants(); +} +} // namespace + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/batched_cholesky.cu b/src/cunumeric/matrix/batched_cholesky.cu new file mode 100644 index 000000000..26fe3058f --- /dev/null +++ b/src/cunumeric/matrix/batched_cholesky.cu @@ -0,0 +1,111 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/matrix/batched_cholesky.h" +#include "cunumeric/matrix/potrf.h" +#include "cunumeric/matrix/batched_cholesky_template.inl" + +#include "cunumeric/cuda_help.h" + +namespace cunumeric { + +using namespace legate; + +#define TILE_DIM 32 +#define BLOCK_ROWS 8 + +template <> +void CopyBlockImpl::operator()(void* dst, const void* src, size_t size) +{ + cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, get_cached_stream()); +} + +template +__global__ static void __launch_bounds__((TILE_DIM * BLOCK_ROWS), MIN_CTAS_PER_SM) + transpose_2d_lower(VAL* out, int n) +{ + __shared__ VAL tile[TILE_DIM][TILE_DIM + 1 /*avoid bank conflicts*/]; + + // The y dim is fast-moving index for coalescing + auto r_block = blockIdx.x * TILE_DIM; + auto c_block = blockIdx.y * TILE_DIM; + auto r = blockIdx.x * TILE_DIM + threadIdx.x; + auto c = blockIdx.y * TILE_DIM + threadIdx.y; + auto stride = BLOCK_ROWS; + // The tile coordinates + auto tr = threadIdx.x; + auto tc = threadIdx.y; + auto offset = r * n + c; + + // only execute across the upper diagonal + // a single thread block will store the upper diagonal block into + // a temp shared memory then set the block to zeros + if (c_block >= r_block) { +#pragma unroll + for (int i = 0; i < TILE_DIM; i += BLOCK_ROWS, offset += stride) { + if (r < n && (c + i) < n) { + if (r <= (c + i)) { + tile[tr][tc + i] = out[offset]; + // clear the upper diagonal entry + out[offset] = 0; + } else { + tile[tr][tc + i] = 0; + } + } + } + + // Make sure all the data is in shared memory + __syncthreads(); + + // Transpose the global coordinates, keep y the fast-moving index + r = blockIdx.y * TILE_DIM + threadIdx.x; + c = blockIdx.x * TILE_DIM + threadIdx.y; + offset = r * n + c; + +#pragma unroll + for (int i = 0; i < TILE_DIM; i += BLOCK_ROWS, offset += stride) { + if (r < n && (c + i) < n) { + if (r >= (c + i)) { out[offset] = tile[tc + i][tr]; } + } + } + } +} + +template +struct BatchedTransposeImplBody { + using VAL = legate_type_of; + + void operator()(VAL* out, int n) const + { + const dim3 blocks((n + TILE_DIM - 1) / TILE_DIM, (n + TILE_DIM - 1) / TILE_DIM, 1); + const dim3 threads(TILE_DIM, BLOCK_ROWS, 1); + + auto stream = get_cached_stream(); + + // CUDA Potrf produces the full matrix, we only want + // the lower diagonal + transpose_2d_lower<<>>(out, n); + + CHECK_CUDA_STREAM(stream); + } +}; + +/*static*/ void BatchedCholeskyTask::gpu_variant(TaskContext& context) +{ + batched_cholesky_task_context_dispatch(context); +} + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/batched_cholesky.h b/src/cunumeric/matrix/batched_cholesky.h new file mode 100644 index 000000000..fceba2a9f --- /dev/null +++ b/src/cunumeric/matrix/batched_cholesky.h @@ -0,0 +1,38 @@ +/* Copyright 2021-2022 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +#include "cunumeric/cunumeric.h" +#include "cunumeric/cunumeric_c.h" + +namespace cunumeric { + +class BatchedCholeskyTask : public CuNumericTask { + public: + static const int TASK_ID = CUNUMERIC_BATCHED_CHOLESKY; + + public: + static void cpu_variant(legate::TaskContext& context); +#ifdef LEGATE_USE_OPENMP + static void omp_variant(legate::TaskContext& context); +#endif +#ifdef LEGATE_USE_CUDA + static void gpu_variant(legate::TaskContext& context); +#endif +}; + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/batched_cholesky_omp.cc b/src/cunumeric/matrix/batched_cholesky_omp.cc new file mode 100644 index 000000000..84b311ff2 --- /dev/null +++ b/src/cunumeric/matrix/batched_cholesky_omp.cc @@ -0,0 +1,83 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/cunumeric.h" +#include "cunumeric/matrix/batched_cholesky.h" +#include "cunumeric/matrix/batched_cholesky_template.inl" + +#include +#include +#include + +namespace cunumeric { + +using namespace legate; + +template <> +void CopyBlockImpl::operator()(void* dst, const void* src, size_t n) +{ + ::memcpy(dst, src, n); +} + +template +struct BatchedTransposeImplBody { + using VAL = legate_type_of; + + static constexpr int tile_size = 64; + + void operator()(VAL* out, int n) const + { + int nblocks = (n + tile_size - 1) / tile_size; + +#pragma omp parallel for + for (int rb = 0; rb < nblocks; ++rb) { + // only loop the upper diagonal + // transpose the elements that are there and + // zero out the elements after reading them + for (int cb = rb; cb < nblocks; ++cb) { + VAL tile[tile_size][tile_size]; + int r_start = rb * tile_size; + int r_stop = std::min(r_start + tile_size, n); + int c_start = cb * tile_size; + int c_stop = std::min(c_start + tile_size, n); + + for (int r = r_start, tr = 0; r < r_stop; ++r, ++tr) { + for (int c = c_start, tc = 0; c < c_stop; ++c, ++tc) { + if (r <= c) { + auto offset = r * n + c; + tile[tr][tc] = out[offset]; + out[offset] = 0; + } else { + tile[tr][tc] = 0; + } + } + } + + for (int r = c_start, tr = 0; r < c_stop; ++r, ++tr) { + for (int c = r_start, tc = 0; c < r_stop; ++c, ++tc) { out[r * n + c] = tile[tc][tr]; } + } + } + } + } +}; + +/*static*/ void BatchedCholeskyTask::omp_variant(TaskContext& context) +{ + openblas_set_num_threads(omp_get_max_threads()); + batched_cholesky_task_context_dispatch(context); +} + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/batched_cholesky_template.inl b/src/cunumeric/matrix/batched_cholesky_template.inl new file mode 100644 index 000000000..8d266e3f0 --- /dev/null +++ b/src/cunumeric/matrix/batched_cholesky_template.inl @@ -0,0 +1,147 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +// Useful for IDEs +#include +#include "cunumeric/cunumeric.h" +#include "cunumeric/matrix/batched_cholesky.h" +#include "cunumeric/matrix/potrf_template.inl" +#include "cunumeric/matrix/transpose_template.inl" +#include "cunumeric/pitches.h" + +namespace cunumeric { + +using namespace legate; + +template +struct BatchedCholeskyImplBody { + template + void operator()(T* array, int32_t m, int32_t n) + { + PotrfImplBody()(array, m, n); + } +}; + +template +struct CopyBlockImpl { + void operator()(void* dst, const void* src, size_t n); +}; + +template +struct BatchedTransposeImplBody { + using VAL = legate_type_of; + + void operator()(VAL* array, int32_t n); +}; + +template +struct _cholesky_supported { + static constexpr bool value = CODE == Type::Code::FLOAT64 || CODE == Type::Code::FLOAT32 || + CODE == Type::Code::COMPLEX64 || CODE == Type::Code::COMPLEX128; +}; + +template +struct BatchedCholeskyImpl { + template + void operator()(Array& input_array, Array& output_array) const + { + using VAL = legate_type_of; + + auto shape = input_array.shape(); + if (shape != output_array.shape()) { + throw legate::TaskException( + "Batched cholesky is not supported when input/output shapes differ"); + } + + Pitches pitches; + size_t volume = pitches.flatten(shape); + + if (volume == 0) return; + + auto ncols = shape.hi[DIM - 1] - shape.lo[DIM - 1] + 1; + + size_t in_strides[DIM]; + size_t out_strides[DIM]; + + auto input = input_array.read_accessor(shape).ptr(shape, in_strides); + if (in_strides[DIM - 2] != ncols || in_strides[DIM - 1] != 1) { + throw legate::TaskException( + "Bad input accessor in batched cholesky, last two dimensions must be non-transformed and " + "dense with stride == 1"); + } + + auto output = output_array.write_accessor(shape).ptr(shape, out_strides); + if (out_strides[DIM - 2] != ncols || out_strides[DIM - 1] != 1) { + throw legate::TaskException( + "Bad output accessor in batched cholesky, last two dimensions must be non-transformed and " + "dense with stride == 1"); + } + + if (shape.empty()) return; + + int num_blocks = 1; + for (int i = 0; i < (DIM - 2); ++i) { num_blocks *= (shape.hi[i] - shape.lo[i] + 1); } + + auto m = static_cast(shape.hi[DIM - 2] - shape.lo[DIM - 2] + 1); + auto n = static_cast(shape.hi[DIM - 1] - shape.lo[DIM - 1] + 1); + assert(m > 0 && n > 0); + + auto block_stride = m * n; + + for (int i = 0; i < num_blocks; ++i) { + if constexpr (_cholesky_supported::value) { + CopyBlockImpl()(output, input, sizeof(VAL) * block_stride); + PotrfImplBody()(output, m, n); + // Implicit assumption here about the cholesky code created. + // We assume the output has C layout, but each subblock + // will be generated in Fortran layout. Transpose the Fortran + // subblock into C layout. + // CHANGE: If this code is changed, please make sure all changes + // are consistent with those found in mapper.cc. + BatchedTransposeImplBody()(output, n); + input += block_stride; + output += block_stride; + } + } + } +}; + +template +static void batched_cholesky_task_context_dispatch(TaskContext& context) +{ + auto& batched_input = context.inputs()[0]; + auto& batched_output = context.outputs()[0]; + if (batched_input.code() != batched_output.code()) { + throw legate::TaskException( + "batched cholesky is not yet supported when input/output types differ"); + } + if (batched_input.dim() != batched_output.dim()) { + throw legate::TaskException("input/output have different dims in batched cholesky"); + } + if (batched_input.dim() <= 2) { + throw legate::TaskException( + "internal error: batched cholesky input does not have more than 2 dims"); + } + double_dispatch(batched_input.dim(), + batched_input.code(), + BatchedCholeskyImpl{}, + batched_input, + batched_output); +} + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/potrf.cc b/src/cunumeric/matrix/potrf.cc index 02ae06246..46ed58b6a 100644 --- a/src/cunumeric/matrix/potrf.cc +++ b/src/cunumeric/matrix/potrf.cc @@ -25,48 +25,48 @@ namespace cunumeric { using namespace legate; template <> -struct PotrfImplBody { - void operator()(float* array, int32_t m, int32_t n) - { - char uplo = 'L'; - int32_t info = 0; - LAPACK_spotrf(&uplo, &n, array, &m, &info); - if (info != 0) throw legate::TaskException("Matrix is not positive definite"); - } -}; +void PotrfImplBody::operator()(float* array, + int32_t m, + int32_t n) +{ + char uplo = 'L'; + int32_t info = 0; + LAPACK_spotrf(&uplo, &n, array, &m, &info); + if (info != 0) throw legate::TaskException("Matrix is not positive definite"); +} template <> -struct PotrfImplBody { - void operator()(double* array, int32_t m, int32_t n) - { - char uplo = 'L'; - int32_t info = 0; - LAPACK_dpotrf(&uplo, &n, array, &m, &info); - if (info != 0) throw legate::TaskException("Matrix is not positive definite"); - } -}; +void PotrfImplBody::operator()(double* array, + int32_t m, + int32_t n) +{ + char uplo = 'L'; + int32_t info = 0; + LAPACK_dpotrf(&uplo, &n, array, &m, &info); + if (info != 0) throw legate::TaskException("Matrix is not positive definite"); +} template <> -struct PotrfImplBody { - void operator()(complex* array, int32_t m, int32_t n) - { - char uplo = 'L'; - int32_t info = 0; - LAPACK_cpotrf(&uplo, &n, reinterpret_cast<__complex__ float*>(array), &m, &info); - if (info != 0) throw legate::TaskException("Matrix is not positive definite"); - } -}; +void PotrfImplBody::operator()(complex* array, + int32_t m, + int32_t n) +{ + char uplo = 'L'; + int32_t info = 0; + LAPACK_cpotrf(&uplo, &n, reinterpret_cast<__complex__ float*>(array), &m, &info); + if (info != 0) throw legate::TaskException("Matrix is not positive definite"); +} template <> -struct PotrfImplBody { - void operator()(complex* array, int32_t m, int32_t n) - { - char uplo = 'L'; - int32_t info = 0; - LAPACK_zpotrf(&uplo, &n, reinterpret_cast<__complex__ double*>(array), &m, &info); - if (info != 0) throw legate::TaskException("Matrix is not positive definite"); - } -}; +void PotrfImplBody::operator()(complex* array, + int32_t m, + int32_t n) +{ + char uplo = 'L'; + int32_t info = 0; + LAPACK_zpotrf(&uplo, &n, reinterpret_cast<__complex__ double*>(array), &m, &info); + if (info != 0) throw legate::TaskException("Matrix is not positive definite"); +} /*static*/ void PotrfTask::cpu_variant(TaskContext& context) { diff --git a/src/cunumeric/matrix/potrf.cu b/src/cunumeric/matrix/potrf.cu index 68616525f..8f13a5168 100644 --- a/src/cunumeric/matrix/potrf.cu +++ b/src/cunumeric/matrix/potrf.cu @@ -49,41 +49,38 @@ static inline void potrf_template( } template <> -struct PotrfImplBody { - void operator()(float* array, int32_t m, int32_t n) - { - potrf_template(cusolverDnSpotrf_bufferSize, cusolverDnSpotrf, array, m, n); - } -}; +void PotrfImplBody::operator()(float* array, + int32_t m, + int32_t n) +{ + potrf_template(cusolverDnSpotrf_bufferSize, cusolverDnSpotrf, array, m, n); +} template <> -struct PotrfImplBody { - void operator()(double* array, int32_t m, int32_t n) - { - potrf_template(cusolverDnDpotrf_bufferSize, cusolverDnDpotrf, array, m, n); - } -}; +void PotrfImplBody::operator()(double* array, + int32_t m, + int32_t n) +{ + potrf_template(cusolverDnDpotrf_bufferSize, cusolverDnDpotrf, array, m, n); +} template <> -struct PotrfImplBody { - void operator()(complex* array, int32_t m, int32_t n) - { - potrf_template( - cusolverDnCpotrf_bufferSize, cusolverDnCpotrf, reinterpret_cast(array), m, n); - } -}; +void PotrfImplBody::operator()(complex* array, + int32_t m, + int32_t n) +{ + potrf_template( + cusolverDnCpotrf_bufferSize, cusolverDnCpotrf, reinterpret_cast(array), m, n); +} template <> -struct PotrfImplBody { - void operator()(complex* array, int32_t m, int32_t n) - { - potrf_template(cusolverDnZpotrf_bufferSize, - cusolverDnZpotrf, - reinterpret_cast(array), - m, - n); - } -}; +void PotrfImplBody::operator()(complex* array, + int32_t m, + int32_t n) +{ + potrf_template( + cusolverDnZpotrf_bufferSize, cusolverDnZpotrf, reinterpret_cast(array), m, n); +} /*static*/ void PotrfTask::gpu_variant(TaskContext& context) { diff --git a/src/cunumeric/matrix/potrf_omp.cc b/src/cunumeric/matrix/potrf_omp.cc index d26143a6f..36b32968d 100644 --- a/src/cunumeric/matrix/potrf_omp.cc +++ b/src/cunumeric/matrix/potrf_omp.cc @@ -26,48 +26,48 @@ namespace cunumeric { using namespace legate; template <> -struct PotrfImplBody { - void operator()(float* array, int32_t m, int32_t n) - { - char uplo = 'L'; - int32_t info = 0; - LAPACK_spotrf(&uplo, &n, array, &m, &info); - if (info != 0) throw legate::TaskException("Matrix is not positive definite"); - } -}; +void PotrfImplBody::operator()(float* array, + int32_t m, + int32_t n) +{ + char uplo = 'L'; + int32_t info = 0; + LAPACK_spotrf(&uplo, &n, array, &m, &info); + if (info != 0) throw legate::TaskException("Matrix is not positive definite"); +} template <> -struct PotrfImplBody { - void operator()(double* array, int32_t m, int32_t n) - { - char uplo = 'L'; - int32_t info = 0; - LAPACK_dpotrf(&uplo, &n, array, &m, &info); - if (info != 0) throw legate::TaskException("Matrix is not positive definite"); - } -}; +void PotrfImplBody::operator()(double* array, + int32_t m, + int32_t n) +{ + char uplo = 'L'; + int32_t info = 0; + LAPACK_dpotrf(&uplo, &n, array, &m, &info); + if (info != 0) throw legate::TaskException("Matrix is not positive definite"); +} template <> -struct PotrfImplBody { - void operator()(complex* array, int32_t m, int32_t n) - { - char uplo = 'L'; - int32_t info = 0; - LAPACK_cpotrf(&uplo, &n, reinterpret_cast<__complex__ float*>(array), &m, &info); - if (info != 0) throw legate::TaskException("Matrix is not positive definite"); - } -}; +void PotrfImplBody::operator()(complex* array, + int32_t m, + int32_t n) +{ + char uplo = 'L'; + int32_t info = 0; + LAPACK_cpotrf(&uplo, &n, reinterpret_cast<__complex__ float*>(array), &m, &info); + if (info != 0) throw legate::TaskException("Matrix is not positive definite"); +} template <> -struct PotrfImplBody { - void operator()(complex* array, int32_t m, int32_t n) - { - char uplo = 'L'; - int32_t info = 0; - LAPACK_zpotrf(&uplo, &n, reinterpret_cast<__complex__ double*>(array), &m, &info); - if (info != 0) throw legate::TaskException("Matrix is not positive definite"); - } -}; +void PotrfImplBody::operator()(complex* array, + int32_t m, + int32_t n) +{ + char uplo = 'L'; + int32_t info = 0; + LAPACK_zpotrf(&uplo, &n, reinterpret_cast<__complex__ double*>(array), &m, &info); + if (info != 0) throw legate::TaskException("Matrix is not positive definite"); +} /*static*/ void PotrfTask::omp_variant(TaskContext& context) { diff --git a/src/cunumeric/matrix/potrf_template.inl b/src/cunumeric/matrix/potrf_template.inl index 55c782ad0..7e4252189 100644 --- a/src/cunumeric/matrix/potrf_template.inl +++ b/src/cunumeric/matrix/potrf_template.inl @@ -26,6 +26,26 @@ using namespace legate; template struct PotrfImplBody; +template +struct PotrfImplBody { + void operator()(float* array, int32_t m, int32_t n); +}; + +template +struct PotrfImplBody { + void operator()(double* array, int32_t m, int32_t n); +}; + +template +struct PotrfImplBody { + void operator()(complex* array, int32_t m, int32_t n); +}; + +template +struct PotrfImplBody { + void operator()(complex* array, int32_t m, int32_t n); +}; + template struct support_potrf : std::false_type {}; template <> diff --git a/src/cunumeric/unary/scalar_unary_red_template.inl b/src/cunumeric/unary/scalar_unary_red_template.inl index 110a73e7a..35173abeb 100644 --- a/src/cunumeric/unary/scalar_unary_red_template.inl +++ b/src/cunumeric/unary/scalar_unary_red_template.inl @@ -47,6 +47,7 @@ struct ScalarUnaryRed { Point origin; Point shape; RHS to_find; + RHS mu; bool dense; WHERE where; const bool* whereptr; @@ -64,6 +65,7 @@ struct ScalarUnaryRed { out = args.out.reduce_accessor(); if constexpr (OP_CODE == UnaryRedCode::CONTAINS) { to_find = args.args[0].scalar(); } + if constexpr (OP_CODE == UnaryRedCode::VARIANCE) { mu = args.args[0].scalar(); } if constexpr (HAS_WHERE) where = args.where.read_accessor(rect); #ifndef LEGATE_BOUNDS_CHECKS @@ -90,6 +92,8 @@ struct ScalarUnaryRed { OP_CODE == UnaryRedCode::NANARGMAX || OP_CODE == UnaryRedCode::NANARGMIN) { auto p = pitches.unflatten(idx, origin); if (mask) OP::template fold(lhs, OP::convert(p, shape, identity, inptr[idx])); + } else if constexpr (OP_CODE == UnaryRedCode::VARIANCE) { + if (mask) OP::template fold(lhs, OP::convert(inptr[idx] - mu, identity)); } else { if (mask) OP::template fold(lhs, OP::convert(inptr[idx], identity)); } @@ -106,6 +110,8 @@ struct ScalarUnaryRed { } else if constexpr (OP_CODE == UnaryRedCode::ARGMAX || OP_CODE == UnaryRedCode::ARGMIN || OP_CODE == UnaryRedCode::NANARGMAX || OP_CODE == UnaryRedCode::NANARGMIN) { if (mask) OP::template fold(lhs, OP::convert(p, shape, identity, in[p])); + } else if constexpr (OP_CODE == UnaryRedCode::VARIANCE) { + if (mask) OP::template fold(lhs, OP::convert(in[p] - mu, identity)); } else { if (mask) OP::template fold(lhs, OP::convert(in[p], identity)); } diff --git a/src/cunumeric/unary/unary_red_util.h b/src/cunumeric/unary/unary_red_util.h index 34d92710b..e822e40b4 100644 --- a/src/cunumeric/unary/unary_red_util.h +++ b/src/cunumeric/unary/unary_red_util.h @@ -40,6 +40,8 @@ enum class UnaryRedCode : int { NANSUM = CUNUMERIC_RED_NANSUM, PROD = CUNUMERIC_RED_PROD, SUM = CUNUMERIC_RED_SUM, + SUM_SQUARES = CUNUMERIC_RED_SUM_SQUARES, + VARIANCE = CUNUMERIC_RED_VARIANCE }; template @@ -89,6 +91,10 @@ constexpr decltype(auto) op_dispatch(UnaryRedCode op_code, Functor f, Fnargs&&.. return f.template operator()(std::forward(args)...); case UnaryRedCode::SUM: return f.template operator()(std::forward(args)...); + case UnaryRedCode::SUM_SQUARES: + return f.template operator()(std::forward(args)...); + case UnaryRedCode::VARIANCE: + return f.template operator()(std::forward(args)...); default: break; } assert(false); @@ -264,6 +270,52 @@ struct UnaryRedOp { __CUDA_HD__ static VAL convert(const RHS& rhs, const VAL) { return rhs; } }; +template +struct UnaryRedOp { + static constexpr bool valid = true; + + using RHS = legate::legate_type_of; + using VAL = RHS; + using OP = Legion::SumReduction; + + template + __CUDA_HD__ static void fold(VAL& a, VAL b) + { + OP::template fold(a, b); + } + + template + __CUDA_HD__ static VAL convert(const Legion::Point&, int32_t, const VAL, const RHS& rhs) + { + return rhs * rhs; + } + + __CUDA_HD__ static VAL convert(const RHS& rhs, const VAL) { return rhs * rhs; } +}; + +template +struct UnaryRedOp { + static constexpr bool valid = true; + + using RHS = legate::legate_type_of; + using VAL = RHS; + using OP = Legion::SumReduction; + + template + __CUDA_HD__ static void fold(VAL& a, VAL b) + { + OP::template fold(a, b); + } + + template + __CUDA_HD__ static VAL convert(const Legion::Point&, int32_t, const VAL, const RHS& rhs) + { + return rhs * rhs; + } + + __CUDA_HD__ static VAL convert(const RHS& rhs, const VAL) { return rhs * rhs; } +}; + template struct UnaryRedOp { static constexpr bool valid = !legate::is_complex::value; diff --git a/tests/integration/test_cholesky.py b/tests/integration/test_cholesky.py index 91edbaa7e..c4b52754b 100644 --- a/tests/integration/test_cholesky.py +++ b/tests/integration/test_cholesky.py @@ -1,4 +1,4 @@ -# Copyright 2021-2022 NVIDIA Corporation +# Copyright 2023 NVIDIA Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,12 +35,6 @@ def test_array_negative_1dim(): num.linalg.cholesky(arr) -def test_array_negative_3dim(): - arr = num.random.randint(0, 9, size=(3, 3, 3)) - with pytest.raises(NotImplementedError): - num.linalg.cholesky(arr) - - def test_array_negative(): arr = num.random.randint(0, 9, size=(3, 2, 3)) expected_exc = ValueError @@ -56,10 +50,14 @@ def test_diagonal(): assert allclose(b**2.0, a) +def _get_real_symm_posdef(n): + a = num.random.rand(n, n) + return a + a.T + num.eye(n) * n + + @pytest.mark.parametrize("n", SIZES) def test_real(n): - a = num.random.rand(n, n) - b = a + a.T + num.eye(n) * n + b = _get_real_symm_posdef(n) c = num.linalg.cholesky(b) c_np = np.linalg.cholesky(b.__array__()) assert allclose(c, c_np) @@ -80,6 +78,45 @@ def test_complex(n): assert allclose(c, c_np) +@pytest.mark.parametrize("n", SIZES) +def test_batched_3d(n): + batch = 4 + a = _get_real_symm_posdef(n) + np_a = a.__array__() + a_batched = num.einsum("i,jk->ijk", np.arange(batch) + 1, a) + test_c = num.linalg.cholesky(a_batched) + for i in range(batch): + correct = np.linalg.cholesky(np_a * (i + 1)) + test = test_c[i, :] + assert allclose(correct, test) + + +def test_batched_empty(): + batch = 4 + a = _get_real_symm_posdef(8) + a_batched = num.einsum("i,jk->ijk", np.arange(batch) + 1, a) + a_sliced = a_batched[0:0, :, :] + empty = num.linalg.cholesky(a_sliced) + assert empty.shape == a_sliced.shape + + +@pytest.mark.parametrize("n", SIZES) +def test_batched_4d(n): + batch = 2 + a = _get_real_symm_posdef(n) + np_a = a.__array__() + + outer = np.einsum("i,j->ij", np.arange(batch) + 1, np.arange(batch) + 1) + + a_batched = num.einsum("ij,kl->ijkl", outer, a) + test_c = num.linalg.cholesky(a_batched) + for i in range(batch): + for j in range(batch): + correct = np.linalg.cholesky(np_a * (i + 1) * (j + 1)) + test = test_c[i, j, :] + assert allclose(correct, test) + + if __name__ == "__main__": import sys diff --git a/tests/integration/test_einsum.py b/tests/integration/test_einsum.py index e482e8cf0..4fcdd2402 100644 --- a/tests/integration/test_einsum.py +++ b/tests/integration/test_einsum.py @@ -272,7 +272,7 @@ def test_cast(expr, dtype): False, "optimal", "greedy", - pytest.param(True, marks=pytest.mark.xfail), + True, ], ) def test_optimize(optimize): @@ -282,8 +282,6 @@ def test_optimize(optimize): np_res = np.einsum("ik,kj->ij", a, b, optimize=optimize) num_res = num.einsum("ik,kj->ij", a, b, optimize=optimize) assert allclose(np_res, num_res) - # when optimize=True, cunumeric raises - # TypeError: 'bool' object is not iterable def test_expr_opposite(): diff --git a/tests/integration/test_flip.py b/tests/integration/test_flip.py index 587b8c3b5..e4032174a 100644 --- a/tests/integration/test_flip.py +++ b/tests/integration/test_flip.py @@ -16,6 +16,7 @@ import numpy as np import pytest +from legate.core import LEGATE_MAX_DIM import cunumeric as num @@ -101,6 +102,66 @@ def test_axis_2d(self, axis): assert num.array_equal(b, bnp) +class TestFlipud: + def test_empty_array(self): + anp = [] + b = num.flipud(anp) + bnp = np.flipud(anp) + assert num.array_equal(b, bnp) + + def test_basic(self): + anp = a.__array__() + b = num.flipud(a) + bnp = np.flipud(anp) + assert num.array_equal(b, bnp) + + def test_wrong_dim(self): + anp = 4 + msg = r"Input must be >= 1-d" + with pytest.raises(ValueError, match=msg): + num.flipud(anp) + + +class TestFliplr: + def test_empty_array(self): + arr = num.random.random((1, 0, 1)) + anp = arr.__array__() + b = num.fliplr(anp) + bnp = np.fliplr(anp) + assert num.array_equal(b, bnp) + + def test_basic(self): + anp = a.__array__() + b = num.fliplr(a) + bnp = np.fliplr(anp) + assert num.array_equal(b, bnp) + + def test_wrong_dim(self): + anp = [] + msg = r"Input must be >= 2-d." + with pytest.raises(ValueError, match=msg): + num.fliplr(anp) + + +FLIP_FUNCS = ("flip", "fliplr", "flipud") + + +@pytest.mark.parametrize("func_name", FLIP_FUNCS) +@pytest.mark.parametrize("ndim", range(2, LEGATE_MAX_DIM + 1)) +def test_max_dims(func_name, ndim): + func_np = getattr(np, func_name) + func_num = getattr(num, func_name) + + shape = (5,) * ndim + a_np = np.random.random(shape) + a_num = num.array(a_np) + + out_np = func_np(a_np) + out_num = func_num(a_num) + + assert np.array_equal(out_num, out_np) + + if __name__ == "__main__": import sys diff --git a/tests/integration/test_set_item.py b/tests/integration/test_set_item.py index 2314916f5..8f9b4a1ac 100644 --- a/tests/integration/test_set_item.py +++ b/tests/integration/test_set_item.py @@ -29,6 +29,14 @@ def test_basic(): assert x[2] == 3 +def test_newaxis(): + arr = num.ones((4,)) + arr[None] = 1 + assert np.array_equal(arr, [1, 1, 1, 1]) + arr[None, :] = 2 + assert np.array_equal(arr, [2, 2, 2, 2]) + + ARRAYS_4_3_2_1_0 = [ 4 - num.arange(5), 4 - np.arange(5), diff --git a/tests/integration/test_stats.py b/tests/integration/test_stats.py new file mode 100644 index 000000000..dfa1b0fa3 --- /dev/null +++ b/tests/integration/test_stats.py @@ -0,0 +1,205 @@ +# Copyright 2022 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools + +import numpy as np +import pytest +from utils.comparisons import allclose + +import cunumeric as num + +np.random.seed(143) + + +def check_result(in_np, out_np, out_num, **isclose_kwargs): + if in_np.dtype == "e" or out_np.dtype == "e": + # The mantissa is only 10 bits, 2**-10 ~= 10^(-4) + # Gives 1e-3 as rtol to provide extra rounding error. + f16_rtol = 1e-2 + rtol = isclose_kwargs.setdefault("rtol", f16_rtol) + # make sure we aren't trying to fp16 compare with less precision + assert rtol >= f16_rtol + + if "negative_test" in isclose_kwargs: + is_negative_test = isclose_kwargs["negative_test"] + else: + is_negative_test = False + + result = ( + allclose(out_np, out_num, **isclose_kwargs) + and out_np.dtype == out_num.dtype + ) + if not result and not is_negative_test: + print("cunumeric failed the test") + print("Input:") + print(in_np) + print(f"dtype: {in_np.dtype}") + print("NumPy output:") + print(out_np) + print(f"dtype: {out_np.dtype}") + print("cuNumeric output:") + print(out_num) + print(f"dtype: {out_num.dtype}") + return result + + +def check_op(op_np, op_num, in_np, out_dtype, **check_kwargs): + in_num = num.array(in_np) + + out_np = op_np(in_np) + out_num = op_num(in_num) + + assert check_result(in_np, out_np, out_num, **check_kwargs) + + out_np = np.empty(out_np.shape, dtype=out_dtype) + out_num = num.empty(out_num.shape, dtype=out_dtype) + + op_np(in_np, out=out_np) + op_num(in_num, out=out_num) + + assert check_result(in_np, out_np, out_num, **check_kwargs) + + +def get_op_input( + shape=(4, 5), + a_min=None, + a_max=None, + randint=False, + offset=None, + astype=None, + out_dtype="d", + replace_zero=None, + **check_kwargs, +): + if randint: + assert a_min is not None + assert a_max is not None + in_np = np.random.randint(a_min, a_max, size=shape) + else: + in_np = np.random.randn(*shape) + if offset is not None: + in_np = in_np + offset + if a_min is not None: + in_np = np.maximum(a_min, in_np) + if a_max is not None: + in_np = np.minimum(a_max, in_np) + if astype is not None: + in_np = in_np.astype(astype) + + if replace_zero is not None: + in_np[in_np == 0] = replace_zero + + # converts to a scalar if shape is (1,) + if in_np.ndim == 1 and in_np.shape[0] == 1: + in_np = in_np[0] + + return in_np + + +dtypes = ( + "e", + "f", + "d", +) + + +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("ddof", [0, 1]) +@pytest.mark.parametrize("axis", [None, 0, 1]) +@pytest.mark.parametrize("keepdims", [False, True]) +def test_var_default_shape(dtype, ddof, axis, keepdims): + np_in = get_op_input(astype=dtype) + + op_np = functools.partial(np.var, ddof=ddof, axis=axis, keepdims=keepdims) + op_num = functools.partial( + num.var, ddof=ddof, axis=axis, keepdims=keepdims + ) + + check_op(op_np, op_num, np_in, dtype) + + +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("ddof", [0, 1]) +@pytest.mark.parametrize("axis", [None, 0, 1, 2]) +@pytest.mark.parametrize("shape", [(10,), (4, 5), (2, 3, 4)]) +def test_var_w_shape(dtype, ddof, axis, shape): + np_in = get_op_input(astype=dtype, shape=shape) + + if axis is not None and axis >= len(shape): + axis = None + + op_np = functools.partial(np.var, ddof=ddof, axis=axis) + op_num = functools.partial(num.var, ddof=ddof, axis=axis) + + check_op(op_np, op_num, np_in, dtype) + + +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("ddof", [0, 1]) +@pytest.mark.parametrize( + "axis", + [ + None, + ], +) +@pytest.mark.parametrize( + "shape", + [ + (10, 1), + ], +) +def test_var_corners(dtype, ddof, axis, shape): + np_in = get_op_input(astype=dtype, shape=shape) + + if axis is not None and axis >= len(shape): + axis = None + + op_np = functools.partial(np.var, ddof=ddof, axis=axis) + op_num = functools.partial(num.var, ddof=ddof, axis=axis) + + check_op(op_np, op_num, np_in, dtype) + + +@pytest.mark.xfail +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("ddof", [0, 1]) +@pytest.mark.parametrize( + "axis", + [ + None, + ], +) +@pytest.mark.parametrize( + "shape", + [ + (1,), + ], +) +def test_var_xfail(dtype, ddof, axis, shape): + np_in = get_op_input(astype=dtype, shape=shape) + + op_np = functools.partial(np.var, ddof=ddof, axis=axis) + op_num = functools.partial(num.var, ddof=ddof, axis=axis) + + check_op(op_np, op_num, np_in, dtype, negative_test=True) + + +if __name__ == "__main__": + import sys + + np.random.seed(12345) + + sys.exit(pytest.main(sys.argv)) diff --git a/tests/unit/cunumeric/test_config.py b/tests/unit/cunumeric/test_config.py index 5e85ccfde..6f8f43df5 100644 --- a/tests/unit/cunumeric/test_config.py +++ b/tests/unit/cunumeric/test_config.py @@ -117,6 +117,7 @@ def test_CuNumericOpCode() -> None: "ADVANCED_INDEXING", "ARANGE", "ARGWHERE", + "BATCHED_CHOLESKY", "BINARY_OP", "BINARY_RED", "BINCOUNT",