diff --git a/pylops/utils/backend.py b/pylops/utils/backend.py index c08b9aec..77467615 100644 --- a/pylops/utils/backend.py +++ b/pylops/utils/backend.py @@ -117,7 +117,7 @@ def get_array_module(x: npt.ArrayLike) -> ModuleType: Parameters ---------- - x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array` Array Returns @@ -143,7 +143,7 @@ def get_convolve(x: npt.ArrayLike) -> Callable: Parameters ---------- - x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array` Array Returns @@ -168,7 +168,7 @@ def get_fftconvolve(x: npt.ArrayLike) -> Callable: Parameters ---------- - x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array` Array Returns @@ -193,7 +193,7 @@ def get_oaconvolve(x: npt.ArrayLike) -> Callable: Parameters ---------- - x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array` Array Returns @@ -222,7 +222,7 @@ def get_correlate(x: npt.ArrayLike) -> Callable: Parameters ---------- - x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array` Array Returns @@ -247,7 +247,7 @@ def get_add_at(x: npt.ArrayLike) -> Callable: Parameters ---------- - x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array` Array Returns @@ -270,7 +270,7 @@ def get_sliding_window_view(x: npt.ArrayLike) -> Callable: Parameters ---------- - x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array` Array Returns @@ -293,7 +293,7 @@ def get_block_diag(x: npt.ArrayLike) -> Callable: Parameters ---------- - x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array` Array Returns @@ -506,21 +506,21 @@ def to_cupy_conditional(x: npt.ArrayLike, y: npt.ArrayLike) -> NDArray: return y -def inplace_set(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> Callable: +def inplace_set(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray: """Perform inplace set based on input Parameters ---------- - x : :obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`jax.Array` Array to sum - y : :obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray` + y : :obj:`numpy.ndarray` or :obj:`jax.Array` Output array idx : :obj:`list` Indices to sum at Returns ------- - y : :obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray` + y : :obj:`numpy.ndarray` or :obj:`jax.Array` Output array """ @@ -532,21 +532,21 @@ def inplace_set(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> Callable: return y -def inplace_add(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> Callable: +def inplace_add(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray: """Perform inplace add based on input Parameters ---------- - x : :obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`jax.Array` Array to sum - y : :obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray` + y : :obj:`numpy.ndarray` or :obj:`jax.Array` Output array idx : :obj:`list` Indices to sum at Returns ------- - y : :obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray` + y : :obj:`numpy.ndarray` or :obj:`jax.Array` Output array """ @@ -558,21 +558,21 @@ def inplace_add(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> Callable: return y -def inplace_multiply(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> Callable: +def inplace_multiply(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray: """Perform inplace multiplication based on input Parameters ---------- - x : :obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`jax.Array` Array to sum - y : :obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray` + y : :obj:`numpy.ndarray` or :obj:`jax.Array` Output array idx : :obj:`list` Indices to multiply at Returns ------- - y : :obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray` + y : :obj:`numpy.ndarray` or :obj:`jax.Array` Output array """ @@ -584,21 +584,21 @@ def inplace_multiply(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> Callable: return y -def inplace_divide(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> Callable: +def inplace_divide(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray: """Perform inplace division based on input Parameters ---------- - x : :obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`jax.Array` Array to sum - y : :obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray` + y : :obj:`numpy.ndarray` or :obj:`jax.Array` Output array idx : :obj:`list` Indices to divide at Returns ------- - y : :obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray` + y : :obj:`numpy.ndarray` or :obj:`jax.Array` Output array """