From cfb09e3f0a05ba277d7b9e44c8e539039326e1ad Mon Sep 17 00:00:00 2001 From: mrava87 Date: Sun, 23 Jun 2024 20:17:35 +0300 Subject: [PATCH] feat: enable jax in FirstDerivative and SecondDerivative --- docs/source/gpu.rst | 26 ++- pylops/basicoperators/firstderivative.py | 239 ++++++---------------- pylops/basicoperators/secondderivative.py | 97 +++++++-- 3 files changed, 160 insertions(+), 202 deletions(-) diff --git a/docs/source/gpu.rst b/docs/source/gpu.rst index 44cc7b4e..f891a8b6 100755 --- a/docs/source/gpu.rst +++ b/docs/source/gpu.rst @@ -39,7 +39,7 @@ GPU with CuPy, and GPU with JAX. * - Operator/method - CPU - GPU with CuPy - - GPU with JAX + - GPU/TPU with JAX * - :meth:`pylops.LinearOperator.eigs` - |:white_check_mark:| - |:red_circle:| @@ -62,7 +62,7 @@ Basic operators: * - Operator/method - CPU - GPU with CuPy - - GPU with JAX + - GPU/TPU with JAX * - :meth:`pylops.basicoperators.MatrixMult` - |:white_check_mark:| - |:white_check_mark:| @@ -150,11 +150,31 @@ Smoothing and derivatives: * - Operator/method - CPU - GPU with CuPy - - GPU with JAX + - GPU/TPU with JAX * - :meth:`pylops.basicoperators.FirstDerivative` - |:white_check_mark:| - |:white_check_mark:| - |:white_check_mark:| + * - :meth:`pylops.basicoperators.SecondDerivative` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :meth:`pylops.basicoperators.Laplacian` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :meth:`pylops.basicoperators.Gradient` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :meth:`pylops.basicoperators.FirstDirectionalDerivative` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :meth:`pylops.basicoperators.SecondDirectionalDerivative` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| Example diff --git a/pylops/basicoperators/firstderivative.py b/pylops/basicoperators/firstderivative.py index 6ec821f3..41084867 100644 --- a/pylops/basicoperators/firstderivative.py +++ b/pylops/basicoperators/firstderivative.py @@ -100,152 +100,33 @@ def __init__( self.kind = kind self.edge = edge self.order = order - self.slice_1 = tuple( - [ - slice(None, None), - ] - * (len(dims) - 1) - + [ - slice(None, -1), - ] - ) - self.slice1 = tuple( - [ - slice(None, None), - ] - * (len(dims) - 1) - + [ - slice(1, None), - ] - ) - self.slice1_1 = tuple( - [ - slice(None, None), - ] - * (len(dims) - 1) - + [ - slice(1, -1), - ] - ) - self.slice1_3 = tuple( - [ - slice(None, None), - ] - * (len(dims) - 1) - + [ - slice(1, -3), - ] - ) - self.slice_2 = tuple( - [ - slice(None, None), - ] - * (len(dims) - 1) - + [ - slice(None, -2), - ] - ) - self.slice2 = tuple( - [ - slice(None, None), - ] - * (len(dims) - 1) - + [ - slice(2, None), - ] - ) - self.slice2_2 = tuple( - [ - slice(None, None), - ] - * (len(dims) - 1) - + [ - slice(2, -2), - ] - ) - self.slice3_1 = tuple( - [ - slice(None, None), - ] - * (len(dims) - 1) - + [ - slice(3, -1), - ] - ) - self.slice4 = tuple( - [ - slice(None, None), - ] - * (len(dims) - 1) - + [ - slice(4, None), - ] - ) - self.slice_4 = tuple( - [ - slice(None, None), - ] - * (len(dims) - 1) - + [ - slice(None, -4), - ] - ) - - self.sample0 = tuple( - [ - slice(None, None), - ] - * (len(dims) - 1) - + [ - slice(0, 1), - ] - ) - self.sample1 = tuple( - [ - slice(None, None), - ] - * (len(dims) - 1) - + [ - slice(1, 2), - ] - ) - self.sample2 = tuple( - [ - slice(None, None), - ] - * (len(dims) - 1) - + [ - slice(2, 3), - ] - ) - self.sample3 = tuple( - [ - slice(None, None), - ] - * (len(dims) - 1) - + [ - slice(3, 4), - ] - ) - self.sample_2 = tuple( - [ - slice(None, None), - ] - * (len(dims) - 1) - + [ - slice(-2, -1), - ] - ) - - self.sample_1 = tuple( - [ - slice(None, None), - ] - * (len(dims) - 1) - + [ - slice(-1, None), - ] - ) + self.slice = { + i: { + j: tuple( + [ + slice(None, None), + ] + * (len(dims) - 1) + + [ + slice(i, j), + ] + ) + for j in (None, -1, -2, -3, -4) + } + for i in (None, 1, 2, 3, 4) + } + self.sample = { + i: tuple( + [ + slice(None, None), + ] + * (len(dims) - 1) + + [ + i, + ] + ) + for i in range(-3, 4) + } self._register_multiplications(self.kind, self.order) def _register_multiplications( @@ -287,7 +168,9 @@ def _matvec_forward(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) # y[..., :-1] = (x[..., 1:] - x[..., :-1]) / self.sampling - y = inplace_set((x[..., 1:] - x[..., :-1]) / self.sampling, y, self.slice_1) + y = inplace_set( + (x[..., 1:] - x[..., :-1]) / self.sampling, y, self.slice[None][-1] + ) return y @reshaped(swapaxis=True) @@ -295,9 +178,9 @@ def _rmatvec_forward(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) # y[..., :-1] -= x[..., :-1] + y = inplace_add(-x[..., :-1], y, self.slice[None][-1]) # y[..., 1:] += x[..., :-1] - y = inplace_add(-x[..., :-1], y, self.slice_1) - y = inplace_add(x[..., :-1], y, self.slice1) + y = inplace_add(x[..., :-1], y, self.slice[1][None]) y /= self.sampling return y @@ -306,13 +189,12 @@ def _matvec_centered3(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) # y[..., 1:-1] = 0.5 * (x[..., 2:] - x[..., :-2]) - y = inplace_set(0.5 * (x[..., 2:] - x[..., :-2]), y, self.slice1_1) - + y = inplace_set(0.5 * (x[..., 2:] - x[..., :-2]), y, self.slice[1][-1]) if self.edge: # y[..., 0] = x[..., 1] - x[..., 0] - y = inplace_set(x[..., 1] - x[..., 0], y, self.sample0) + y = inplace_set(x[..., 1] - x[..., 0], y, self.sample[0]) # y[..., -1] = x[..., -1] - x[..., -2] - y = inplace_set(x[..., -1] - x[..., -2], y, self.sample_1) + y = inplace_set(x[..., -1] - x[..., -2], y, self.sample[-1]) y /= self.sampling return y @@ -321,18 +203,18 @@ def _rmatvec_centered3(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) # y[..., :-2] -= 0.5 * x[..., 1:-1] - y = inplace_add(-0.5 * x[..., 1:-1], y, self.slice_2) + y = inplace_add(-0.5 * x[..., 1:-1], y, self.slice[None][-2]) # y[..., 2:] += 0.5 * x[..., 1:-1] - y = inplace_add(0.5 * x[..., 1:-1], y, self.slice2) + y = inplace_add(0.5 * x[..., 1:-1], y, self.slice[2][None]) if self.edge: # y[..., 0] -= x[..., 0] - y = inplace_add(-x[..., 0], y, self.sample0) + y = inplace_add(-x[..., 0], y, self.sample[0]) # y[..., 1] += x[..., 0] - y = inplace_add(x[..., 0], y, self.sample1) + y = inplace_add(x[..., 0], y, self.sample[1]) # y[..., -2] -= x[..., -1] - y = inplace_add(-x[..., -1], y, self.sample_2) + y = inplace_add(-x[..., -1], y, self.sample[-2]) # y[..., -1] += x[..., -1] - y = inplace_add(x[..., -1], y, self.sample_1) + y = inplace_add(x[..., -1], y, self.sample[-1]) y /= self.sampling return y @@ -354,18 +236,17 @@ def _matvec_centered5(self, x: NDArray) -> NDArray: - x[..., 4:] / 12.0 ), y, - self.slice2_2, + self.slice[2][-2], ) - if self.edge: # y[..., 0] = x[..., 1] - x[..., 0] - y = inplace_set(x[..., 1] - x[..., 0], y, self.sample0) + y = inplace_set(x[..., 1] - x[..., 0], y, self.sample[0]) # y[..., 1] = 0.5 * (x[..., 2] - x[..., 0]) - y = inplace_set(0.5 * (x[..., 2] - x[..., 0]), y, self.sample1) + y = inplace_set(0.5 * (x[..., 2] - x[..., 0]), y, self.sample[1]) # y[..., -2] = 0.5 * (x[..., -1] - x[..., -3]) - y = inplace_set(0.5 * (x[..., -1] - x[..., -3]), y, self.sample_2) + y = inplace_set(0.5 * (x[..., -1] - x[..., -3]), y, self.sample[-2]) # y[..., -1] = x[..., -1] - x[..., -2] - y = inplace_set(x[..., -1] - x[..., -2], y, self.sample_1) + y = inplace_set(x[..., -1] - x[..., -2], y, self.sample[-1]) y /= self.sampling return y @@ -374,26 +255,26 @@ def _rmatvec_centered5(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) # y[..., :-4] += x[..., 2:-2] / 12.0 - y = inplace_add(x[..., 2:-2] / 12.0, y, self.slice_4) + y = inplace_add(x[..., 2:-2] / 12.0, y, self.slice[None][-4]) # y[..., 1:-3] -= 2.0 * x[..., 2:-2] / 3.0 - y = inplace_add(-2.0 * x[..., 2:-2] / 3.0, y, self.slice1_3) + y = inplace_add(-2.0 * x[..., 2:-2] / 3.0, y, self.slice[1][-3]) # y[..., 3:-1] += 2.0 * x[..., 2:-2] / 3.0 - y = inplace_add(2.0 * x[..., 2:-2] / 3.0, y, self.slice3_1) + y = inplace_add(2.0 * x[..., 2:-2] / 3.0, y, self.slice[3][-1]) # y[..., 4:] -= x[..., 2:-2] / 12.0 - y = inplace_add(-x[..., 2:-2] / 12.0, y, self.slice4) + y = inplace_add(-x[..., 2:-2] / 12.0, y, self.slice[4][None]) if self.edge: # y[..., 0] -= x[..., 0] + 0.5 * x[..., 1] - y = inplace_add(-x[..., 0] + 0.5 * x[..., 1], y, self.sample0) + y = inplace_add(-(x[..., 0] + 0.5 * x[..., 1]), y, self.sample[0]) # y[..., 1] += x[..., 0] - y = inplace_add(x[..., 0], y, self.sample1) + y = inplace_add(x[..., 0], y, self.sample[1]) # y[..., 2] += 0.5 * x[..., 1] - y = inplace_add(0.5 * x[..., 1], y, self.sample2) + y = inplace_add(0.5 * x[..., 1], y, self.sample[2]) # y[..., -3] -= 0.5 * x[..., -2] - y = inplace_add(-0.5 * x[..., -2], y, self.sample_3) + y = inplace_add(-0.5 * x[..., -2], y, self.sample[-3]) # y[..., -2] -= x[..., -1] - y = inplace_add(-x[..., -1], y, self.sample_2) + y = inplace_add(-x[..., -1], y, self.sample[-2]) # y[..., -1] += 0.5 * x[..., -2] + x[..., -1] - y = inplace_add(0.5 * x[..., -2] + x[..., -1], y, self.sample_1) + y = inplace_add(0.5 * x[..., -2] + x[..., -1], y, self.sample[-1]) y /= self.sampling return y @@ -402,7 +283,9 @@ def _matvec_backward(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) # y[..., 1:] = (x[..., 1:] - x[..., :-1]) / self.sampling - y = inplace_set((x[..., 1:] - x[..., :-1]) / self.sampling, y, self.slice1) + y = inplace_set( + (x[..., 1:] - x[..., :-1]) / self.sampling, y, self.slice[1][None] + ) return y @reshaped(swapaxis=True) @@ -410,8 +293,8 @@ def _rmatvec_backward(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) # y[..., :-1] -= x[..., 1:] + y = inplace_add(-x[..., 1:], y, self.slice[None][-1]) # y[..., 1:] += x[..., 1:] - y = inplace_add(-x[..., 1:], y, self.slice_1) - y = inplace_add(x[..., 1:], y, self.slice1) + y = inplace_add(x[..., 1:], y, self.slice[1][None]) y /= self.sampling return y diff --git a/pylops/basicoperators/secondderivative.py b/pylops/basicoperators/secondderivative.py index 744d067a..b541c676 100644 --- a/pylops/basicoperators/secondderivative.py +++ b/pylops/basicoperators/secondderivative.py @@ -7,7 +7,7 @@ from pylops import LinearOperator from pylops.utils._internal import _value_or_sized_to_tuple -from pylops.utils.backend import get_array_module +from pylops.utils.backend import get_array_module, inplace_add, inplace_set from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray @@ -90,6 +90,33 @@ def __init__( self.sampling = sampling self.kind = kind self.edge = edge + self.slice = { + i: { + j: tuple( + [ + slice(None, None), + ] + * (len(dims) - 1) + + [ + slice(i, j), + ] + ) + for j in (None, -1, -2, -3, -4) + } + for i in (None, 1, 2, 3, 4) + } + self.sample = { + i: tuple( + [ + slice(None, None), + ] + * (len(dims) - 1) + + [ + i, + ] + ) + for i in range(-3, 4) + } self._register_multiplications(self.kind) def _register_multiplications( @@ -123,7 +150,10 @@ def _rmatvec(self, x: NDArray) -> NDArray: def _matvec_forward(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., :-2] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2] + # y[..., :-2] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2] + y = inplace_set( + x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2], y, self.slice[None][-2] + ) y /= self.sampling**2 return y @@ -131,9 +161,12 @@ def _matvec_forward(self, x: NDArray) -> NDArray: def _rmatvec_forward(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., :-2] += x[..., :-2] - y[..., 1:-1] -= 2 * x[..., :-2] - y[..., 2:] += x[..., :-2] + # y[..., :-2] += x[..., :-2] + y = inplace_add(x[..., :-2], y, self.slice[None][-2]) + # y[..., 1:-1] -= 2 * x[..., :-2] + y = inplace_add(-2 * x[..., :-2], y, self.slice[1][-1]) + # y[..., 2:] += x[..., :-2] + y = inplace_add(x[..., :-2], y, self.slice[2][None]) y /= self.sampling**2 return y @@ -141,10 +174,17 @@ def _rmatvec_forward(self, x: NDArray) -> NDArray: def _matvec_centered(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., 1:-1] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2] + # y[..., 1:-1] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2] + y = inplace_set( + x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2], y, self.slice[1][-1] + ) if self.edge: - y[..., 0] = x[..., 0] - 2 * x[..., 1] + x[..., 2] - y[..., -1] = x[..., -3] - 2 * x[..., -2] + x[..., -1] + # y[..., 0] = x[..., 0] - 2 * x[..., 1] + x[..., 2] + y = inplace_set(x[..., 0] - 2 * x[..., 1] + x[..., 2], y, self.sample[0]) + # y[..., -1] = x[..., -3] - 2 * x[..., -2] + x[..., -1] + y = inplace_set( + x[..., -3] - 2 * x[..., -2] + x[..., -1], y, self.sample[-1] + ) y /= self.sampling**2 return y @@ -152,16 +192,25 @@ def _matvec_centered(self, x: NDArray) -> NDArray: def _rmatvec_centered(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., :-2] += x[..., 1:-1] - y[..., 1:-1] -= 2 * x[..., 1:-1] - y[..., 2:] += x[..., 1:-1] + # y[..., :-2] += x[..., 1:-1] + y = inplace_add(x[..., 1:-1], y, self.slice[None][-2]) + # y[..., 1:-1] -= 2 * x[..., 1:-1] + y = inplace_add(-2 * x[..., 1:-1], y, self.slice[1][-1]) + # y[..., 2:] += x[..., 1:-1] + y = inplace_add(x[..., 1:-1], y, self.slice[2][None]) if self.edge: - y[..., 0] += x[..., 0] - y[..., 1] -= 2 * x[..., 0] - y[..., 2] += x[..., 0] - y[..., -3] += x[..., -1] - y[..., -2] -= 2 * x[..., -1] - y[..., -1] += x[..., -1] + # y[..., 0] += x[..., 0] + y = inplace_add(x[..., 0], y, self.sample[0]) + # y[..., 1] -= 2 * x[..., 0] + y = inplace_add(-2 * x[..., 0], y, self.sample[1]) + # y[..., 2] += x[..., 0] + y = inplace_add(x[..., 0], y, self.sample[2]) + # y[..., -3] += x[..., -1] + y = inplace_add(x[..., -1], y, self.sample[-3]) + # y[..., -2] -= 2 * x[..., -1] + y = inplace_add(-2 * x[..., -1], y, self.sample[-2]) + # y[..., -1] += x[..., -1] + y = inplace_add(x[..., -1], y, self.sample[-1]) y /= self.sampling**2 return y @@ -169,7 +218,10 @@ def _rmatvec_centered(self, x: NDArray) -> NDArray: def _matvec_backward(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., 2:] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2] + # y[..., 2:] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2] + y = inplace_set( + x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2], y, self.slice[2][None] + ) y /= self.sampling**2 return y @@ -177,8 +229,11 @@ def _matvec_backward(self, x: NDArray) -> NDArray: def _rmatvec_backward(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., :-2] += x[..., 2:] - y[..., 1:-1] -= 2 * x[..., 2:] - y[..., 2:] += x[..., 2:] + # y[..., :-2] += x[..., 2:] + y = inplace_add(x[..., 2:], y, self.slice[None][-2]) + # y[..., 1:-1] -= 2 * x[..., 2:] + y = inplace_add(-2 * x[..., 2:], y, self.slice[1][-1]) + # y[..., 2:] += x[..., 2:] + y = inplace_add(x[..., 2:], y, self.slice[2][None]) y /= self.sampling**2 return y