diff --git a/cunumeric/_sphinxext/_cunumeric_directive.py b/cunumeric/_sphinxext/_cunumeric_directive.py index ef6402f6c..71c33aa3a 100644 --- a/cunumeric/_sphinxext/_cunumeric_directive.py +++ b/cunumeric/_sphinxext/_cunumeric_directive.py @@ -22,10 +22,10 @@ class CunumericDirective(SphinxDirective): def parse(self, rst_text: str, annotation: str) -> list[nodes.Node]: - result = ViewList() + result = ViewList() # type: ignore for line in rst_text.split("\n"): result.append(line, annotation) node = nodes.paragraph() node.document = self.state.document - nested_parse_with_titles(self.state, result, node) + nested_parse_with_titles(self.state, result, node) # type: ignore return node.children diff --git a/cunumeric/module.py b/cunumeric/module.py index 424f89df4..37fbd03d3 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -14,6 +14,7 @@ # from __future__ import annotations +import collections.abc import math import operator import re @@ -1748,7 +1749,7 @@ def check_list_depth(arr: Any, prefix: NdShape = (0,)) -> int: "List depths are mismatched. First element was at depth " f"{first_depth}, but there is an element at" f" depth {other_depth}, " - f"arrays{convert_to_array_form(prefix+(idx+1,))}" + f"arrays{convert_to_array_form(prefix + (idx + 1,))}" ) return depths[0] + 1 @@ -6736,6 +6737,255 @@ def convolve(a: ndarray, v: ndarray, mode: ConvolveMode = "full") -> ndarray: return out +@add_boilerplate("f") +def gradient( + f: ndarray, *varargs: Any, axis: Any = None, edge_order: int = 1 +) -> Any: + """ + + Return the gradient of an N-dimensional array. + The gradient is computed using second order accurate central differences + in the interior points and either first or second order accurate one-sides + (forward or backwards) differences at the boundaries. + The returned gradient hence has the same shape as the input array. + + Parameters + ---------- + f : array_like + An N-dimensional array containing samples of a scalar function. + varargs : list of scalar or array, optional + Spacing between f values. Default unitary spacing for all dimensions. + Spacing can be specified using: + 1. single scalar to specify a sample distance for all dimensions. + 2. N scalars to specify a constant sample distance for each dimension. + i.e. `dx`, `dy`, `dz`, ... + 3. N arrays to specify the coordinates of the values along each + dimension of F. The length of the array must match the size of + the corresponding dimension + 4. Any combination of N scalars/arrays with the meaning of 2. and 3. + If `axis` is given, the number of varargs must equal the number of + axes. Default: 1. + edge_order : {1, 2}, optional + Gradient is calculated using N-th order accurate differences + at the boundaries. Default: 1. + .. versionadded:: 1.9.1 + axis : None or int or tuple of ints, optional + Gradient is calculated only along the given axis or axes + The default (axis = None) is to calculate the gradient for all the axes + of the input array. axis may be negative, in which case it counts from + the last to the first axis. + .. versionadded:: 1.11.0 + + Returns + ------- + gradient : ndarray or list of ndarray + A list of ndarrays (or a single ndarray if there is only one dimension) + corresponding to the derivatives of f with respect to each dimension. + Each derivative has the same shape as f. + + See Also + -------- + numpy.gradient + + Availability + -------- + Multiple GPUs, Multiple CPUs + + """ + N = f.ndim # number of dimensions + + if axis is None: + axes = tuple(range(N)) + elif isinstance(axis, collections.abc.Sequence): + axes = tuple(normalize_axis_index(a, N) for a in axis) + else: + axis = normalize_axis_index(axis, N) + axes = (axis,) + + len_axes = len(axes) + if not varargs: + n = 0 + else: + n = len(varargs) + + if n == 0: + # no spacing argument - use 1 in all axes + dx = [asarray(1.0)] * len_axes + elif n == 1 and np.ndim(varargs[0]) == 0: + # single scalar for all axes + dx = list(asarray(varargs)) * len_axes + elif n == len_axes: + # scalar or 1d array for each axis + dx = list(asarray(v) for v in varargs) + for i, distances in enumerate(dx): + if distances.ndim == 0: + continue + elif distances.ndim != 1: + raise ValueError("distances must be either scalars or 1d") + if len(distances) != f.shape[axes[i]]: + raise ValueError( + "when 1d, distances must match " + "the length of the corresponding dimension" + ) + if np.issubdtype(distances.dtype, np.integer): + # Convert numpy integer types to float64 to avoid modular + # arithmetic in np.diff(distances). + distances = distances.astype(np.float64) + diffx = diff(distances) + # if distances are constant reduce to the scalar case + # since it brings a consistent speedup + if (diffx == diffx[0]).all(): + diffx = diffx[0] + dx[i] = diffx + else: + raise TypeError("invalid number of arguments") + + if edge_order > 2: + raise ValueError("'edge_order' greater than 2 not supported") + if edge_order < 0: + raise ValueError(" invalid 'edge_order'") + + # use central differences on interior and one-sided differences on the + # endpoints. This preserves second order-accuracy over the full domain. + + outvals = [] + + # create slice objects --- initially all are [:, :, ..., :] + slice1 = [slice(None)] * N + slice2 = [slice(None)] * N + slice3 = [slice(None)] * N + slice4 = [slice(None)] * N + + otype = f.dtype + if otype.type is np.datetime64: + raise TypeError("datetime64 is not supported by gradient in cuNumeric") + elif otype.type is np.timedelta64: + pass + elif np.issubdtype(otype, np.inexact): + pass + else: + # All other types convert to floating point. + # First check if f is a numpy integer type; if so, convert f to float64 + # to avoid modular arithmetic when computing the changes in f. + if np.issubdtype(otype, np.integer): + f = f.astype(np.float64) + otype = np.dtype(np.float64) + + for axis, ax_dx in zip(axes, dx): + if f.shape[axis] < edge_order + 1: + raise ValueError( + "Shape of array too small to calculate a numerical gradient, " + "at least (edge_order + 1) elements are required." + ) + # result allocation + out = empty_like(f, dtype=otype) + + # spacing for the current axis + uniform_spacing = np.ndim(ax_dx) == 0 + + # Numerical differentiation: 2nd order interior + slice1[axis] = slice(1, -1) + slice2[axis] = slice(None, -2) + slice3[axis] = slice(1, -1) + slice4[axis] = slice(2, None) + + if uniform_spacing: + out[tuple(slice1)] = (f[tuple(slice4)] - f[tuple(slice2)]) / ( + 2.0 * ax_dx + ) + else: + dx1 = ax_dx[0:-1] + dx2 = ax_dx[1:] + a = -(dx2) / (dx1 * (dx1 + dx2)) + b = (dx2 - dx1) / (dx1 * dx2) + c = dx1 / (dx2 * (dx1 + dx2)) + # fix the shape for broadcasting + shape = list(1 for i in range(N)) + shape[axis] = -1 + a = a.reshape(shape) + b = b.reshape(shape) + c = c.reshape(shape) + # 1D equivalent -- out[1:-1] = a * f[:-2] + b * f[1:-1] + c * f[2:] + out[tuple(slice1)] = ( + a * f[tuple(slice2)] + + b * f[tuple(slice3)] + + c * f[tuple(slice4)] + ) + + # Numerical differentiation: 1st order edges + if edge_order == 1: + slice1[axis] = 0 # type: ignore + slice2[axis] = 1 # type: ignore + slice3[axis] = 0 # type: ignore + dx_0 = ax_dx if uniform_spacing else ax_dx[0] + # 1D equivalent -- out[0] = (f[1] - f[0]) / (x[1] - x[0]) + out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_0 + + slice1[axis] = -1 # type: ignore + slice2[axis] = -1 # type: ignore + slice3[axis] = -2 # type: ignore + dx_n = ax_dx if uniform_spacing else ax_dx[-1] + # 1D equivalent -- out[-1] = (f[-1] - f[-2]) / (x[-1] - x[-2]) + out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_n + + # Numerical differentiation: 2nd order edges + else: + slice1[axis] = 0 # type: ignore + slice2[axis] = 0 # type: ignore + slice3[axis] = 1 # type: ignore + slice4[axis] = 2 # type: ignore + if uniform_spacing: + a = -1.5 / ax_dx + b = 2.0 / ax_dx + c = -0.5 / ax_dx + else: + dx1 = ax_dx[0] + dx2 = ax_dx[1] + a = -(2.0 * dx1 + dx2) / (dx1 * (dx1 + dx2)) + b = (dx1 + dx2) / (dx1 * dx2) + c = -dx1 / (dx2 * (dx1 + dx2)) + # 1D equivalent -- out[0] = a * f[0] + b * f[1] + c * f[2] + out[tuple(slice1)] = ( + a * f[tuple(slice2)] + + b * f[tuple(slice3)] + + c * f[tuple(slice4)] + ) + + slice1[axis] = -1 # type: ignore + slice2[axis] = -3 # type: ignore + slice3[axis] = -2 # type: ignore + slice4[axis] = -1 # type: ignore + if uniform_spacing: + a = 0.5 / ax_dx + b = -2.0 / ax_dx + c = 1.5 / ax_dx + else: + dx1 = ax_dx[-2] + dx2 = ax_dx[-1] + a = (dx2) / (dx1 * (dx1 + dx2)) + b = -(dx2 + dx1) / (dx1 * dx2) + c = (2.0 * dx2 + dx1) / (dx2 * (dx1 + dx2)) + # 1D equivalent -- out[-1] = a * f[-3] + b * f[-2] + c * f[-1] + out[tuple(slice1)] = ( + a * f[tuple(slice2)] + + b * f[tuple(slice3)] + + c * f[tuple(slice4)] + ) + + outvals.append(out) + + # reset the slice object in this dimension to ":" + slice1[axis] = slice(None) + slice2[axis] = slice(None) + slice3[axis] = slice(None) + slice4[axis] = slice(None) + + if len_axes == 1: + return outvals[0] + else: + return outvals + + @add_boilerplate("a") def clip( a: ndarray, diff --git a/docs/cunumeric/source/api/math.rst b/docs/cunumeric/source/api/math.rst index 4dd574dd7..9ca504e81 100644 --- a/docs/cunumeric/source/api/math.rst +++ b/docs/cunumeric/source/api/math.rst @@ -174,3 +174,4 @@ Miscellaneous inner outer vdot + gradient diff --git a/install.py b/install.py index 92d16ad55..816ed5430 100755 --- a/install.py +++ b/install.py @@ -332,14 +332,13 @@ def validate_path(path): # Also use preexisting CMAKE_ARGS from conda if set cmake_flags = cmd_env.get("CMAKE_ARGS", "").split(" ") - if debug or verbose: cmake_flags += ["--log-level=%s" % ("DEBUG" if debug else "VERBOSE")] - + build_type = ( + "Debug" if debug else "RelWithDebInfo" if debug_release else "Release" + ) # noqa cmake_flags += f"""\ --DCMAKE_BUILD_TYPE={( - "Debug" if debug else "RelWithDebInfo" if debug_release else "Release" -)} +-DCMAKE_BUILD_TYPE=build_type -DBUILD_SHARED_LIBS=ON -DCMAKE_CUDA_ARCHITECTURES={str(arch)} -DLegion_MAX_DIM={str(maxdim)} @@ -351,7 +350,7 @@ def validate_path(path): -DLegion_USE_LLVM={("ON" if llvm else "OFF")} -DLegion_NETWORKS={";".join(networks)} -DLegion_USE_HDF5={("ON" if hdf else "OFF")} -""".splitlines() +""".splitlines() # noqa if march: cmake_flags += [f"-DBUILD_MARCH={march}"] diff --git a/tests/integration/test_gradient.py b/tests/integration/test_gradient.py new file mode 100644 index 000000000..cbd8c5700 --- /dev/null +++ b/tests/integration/test_gradient.py @@ -0,0 +1,164 @@ +# Copyright 2024 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from math import prod + +import numpy as np +import pytest +from legate.core import LEGATE_MAX_DIM + +import cunumeric as cn + + +def test_gradient_with_scalar_dx(): + f_np = np.arange(1000, dtype=float) + f_cn = cn.array(f_np) + res_np = np.gradient(f_np) + res_cn = cn.gradient(f_cn) + assert np.allclose(res_np, res_cn) + + +def test_fradient_1d(): + a_np = np.array(np.random.random(size=1000), dtype=float) + f_np = np.sort(a_np) + f_cn = cn.array(f_np) + res_np = np.gradient(f_np) + res_cn = cn.gradient(f_cn) + assert np.allclose(res_np, res_cn) + + +@pytest.mark.parametrize("ndim", range(1, LEGATE_MAX_DIM + 1)) +@pytest.mark.parametrize("edge_order", [1, 2]) +def test_nd_arrays(ndim, edge_order): + shape = (5,) * ndim + size = prod(shape) + arr_np = np.random.randint(-100, 100, size, dtype=int) + in_np = np.sort(arr_np).reshape(shape).astype(float) + in_cn = cn.array(in_np) + + for a in range(0, ndim): + res_np = np.gradient(in_np, axis=a, edge_order=edge_order) + res_cn = np.gradient(in_cn, axis=a, edge_order=edge_order) + assert np.allclose(res_np, res_cn) + + +@pytest.mark.parametrize("ndim", range(1, LEGATE_MAX_DIM + 1)) +@pytest.mark.parametrize("varargs", [0.5, 1, 2, 0.3, 0]) +def test_scalar_varargs(ndim, varargs): + shape = (5,) * ndim + size = prod(shape) + arr_np = np.random.randint(-100, 100, size, dtype=int) + in_np = np.sort(arr_np).reshape(shape).astype(float) + in_cn = cn.array(in_np) + res_np = np.gradient(in_np, varargs) + res_cn = cn.gradient(in_cn, varargs) + assert np.allclose(res_np, res_cn, equal_nan=True) + + +@pytest.mark.parametrize("ndim", range(2, LEGATE_MAX_DIM + 1)) +def test_array_1d_varargs(ndim): + shape = (5,) * ndim + size = prod(shape) + arr_np = np.random.randint(-100, 100, size, dtype=int) + in_np = np.sort(arr_np).reshape(shape).astype(float) + in_cn = cn.array(in_np) + varargs = list(i * 0.5 for i in range(ndim)) + res_np = np.gradient(in_np, *varargs) + res_cn = cn.gradient(in_cn, *varargs) + assert np.allclose(res_np, res_cn) + + +@pytest.mark.parametrize("ndim", range(2, LEGATE_MAX_DIM + 1)) +def test_list_of_axes(ndim): + shape = (5,) * ndim + size = prod(shape) + arr_np = np.random.randint(-100, 100, size, dtype=int) + in_np = np.sort(arr_np).reshape(shape).astype(float) + in_cn = cn.array(in_np) + axes = tuple(i for i in range(ndim)) + res_np = np.gradient(in_np, axis=axes) + res_cn = cn.gradient(in_cn, axis=axes) + assert np.allclose(res_np, res_cn) + + +def test_varargs_coordinates(): + # Test gradient with varargs to specify coordinates + x_np = np.array([[1.0, 2.0, 4.0], [3.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) + x_cn = cn.array(x_np) + y_coordinates_np = np.array( + [0.0, 1.0, 2.0] + ) # Coordinates along the first dimension + y_coordinates_cn = cn.array(y_coordinates_np) + x_coordinates_np = np.array( + [10.0, 20.0, 30.0] + ) # Coordinates along the second dimension + x_coordinates_cn = cn.array(x_coordinates_np) + res_np = np.gradient(x_np, y_coordinates_np, x_coordinates_np) + res_cn = np.gradient(x_cn, y_coordinates_cn, x_coordinates_cn) + assert np.allclose(res_np, res_cn) + + +def test_mixed_varargs(): + # Test gradient with varargs to specify coordinates + x_np = np.array([[1.0, 2.0, 4.0], [3.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) + x_cn = cn.array(x_np) + y_coordinates_np = np.array( + [0.0, 1.0, 2.0] + ) # Coordinates along the first dimension + y_coordinates_cn = cn.array(y_coordinates_np) + res_np = np.gradient(x_np, y_coordinates_np, 0.5) + res_cn = np.gradient(x_cn, y_coordinates_cn, 0.5) + assert np.allclose(res_np, res_cn) + + +@pytest.mark.parametrize( + "in_arr", + [ + [], + [1], + ], +) +def test_corner_cases(in_arr): + in_cn = cn.array(in_arr) + with pytest.raises(ValueError): + cn.gradient(in_cn) # too small + + +@pytest.mark.parametrize("axis", [2, -4]) +def test_invalid_axis(axis): + # Test gradient with invalid axis + x = cn.array([[1.0, 2.0, 4.0], [3.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) + with pytest.raises(ValueError): + cn.gradient(x, axis=axis) # Invalid axis for 2D array + + +@pytest.mark.parametrize("edge_order", [3, -1]) +def test_invalid_edge_order(edge_order): + # Test gradient with invalid edge_order + x = cn.array([1.0, 2.0, 4.0, 7.0]) + with pytest.raises(ValueError): + cn.gradient(x, edge_order=edge_order) # Invalid edge_order + + +def test_invalid_varargs(): + # Test gradient with varargs + x = cn.array([[1.0, 2.0, 4.0], [3.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) + with pytest.raises(TypeError): + cn.gradient(x, 0.5, 1.0, 2.0) # Too many varargs provided + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(sys.argv)) diff --git a/tests/integration/utils/comparisons.py b/tests/integration/utils/comparisons.py index 65571b38c..9ab5247e9 100644 --- a/tests/integration/utils/comparisons.py +++ b/tests/integration/utils/comparisons.py @@ -50,7 +50,7 @@ def allclose( inds = islice(zip(*np.where(~close)), diff_limit) diffs = [f" index {i}: {a[i]} {b[i]}" for i in inds] N = len(diffs) - print(f"First {N} difference{'s' if N>1 else ''} for allclose:\n") + print(f"First {N} difference{'s' if N > 1 else ''} for allclose:\n") print("\n".join(diffs)) print(f"\nWith diff_limit={diff_limit}\n")