Skip to content

Commit

Permalink
feat: enable jax in FirstDerivative and SecondDerivative
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Jun 23, 2024
1 parent 61b5807 commit cfb09e3
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 202 deletions.
26 changes: 23 additions & 3 deletions docs/source/gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ GPU with CuPy, and GPU with JAX.
* - Operator/method
- CPU
- GPU with CuPy
- GPU with JAX
- GPU/TPU with JAX
* - :meth:`pylops.LinearOperator.eigs`
- |:white_check_mark:|
- |:red_circle:|
Expand All @@ -62,7 +62,7 @@ Basic operators:
* - Operator/method
- CPU
- GPU with CuPy
- GPU with JAX
- GPU/TPU with JAX
* - :meth:`pylops.basicoperators.MatrixMult`
- |:white_check_mark:|
- |:white_check_mark:|
Expand Down Expand Up @@ -150,11 +150,31 @@ Smoothing and derivatives:
* - Operator/method
- CPU
- GPU with CuPy
- GPU with JAX
- GPU/TPU with JAX
* - :meth:`pylops.basicoperators.FirstDerivative`
- |:white_check_mark:|
- |:white_check_mark:|
- |:white_check_mark:|
* - :meth:`pylops.basicoperators.SecondDerivative`
- |:white_check_mark:|
- |:white_check_mark:|
- |:white_check_mark:|
* - :meth:`pylops.basicoperators.Laplacian`
- |:white_check_mark:|
- |:white_check_mark:|
- |:white_check_mark:|
* - :meth:`pylops.basicoperators.Gradient`
- |:white_check_mark:|
- |:white_check_mark:|
- |:white_check_mark:|
* - :meth:`pylops.basicoperators.FirstDirectionalDerivative`
- |:white_check_mark:|
- |:white_check_mark:|
- |:white_check_mark:|
* - :meth:`pylops.basicoperators.SecondDirectionalDerivative`
- |:white_check_mark:|
- |:white_check_mark:|
- |:white_check_mark:|


Example
Expand Down
239 changes: 61 additions & 178 deletions pylops/basicoperators/firstderivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,152 +100,33 @@ def __init__(
self.kind = kind
self.edge = edge
self.order = order
self.slice_1 = tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
slice(None, -1),
]
)
self.slice1 = tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
slice(1, None),
]
)
self.slice1_1 = tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
slice(1, -1),
]
)
self.slice1_3 = tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
slice(1, -3),
]
)
self.slice_2 = tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
slice(None, -2),
]
)
self.slice2 = tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
slice(2, None),
]
)
self.slice2_2 = tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
slice(2, -2),
]
)
self.slice3_1 = tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
slice(3, -1),
]
)
self.slice4 = tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
slice(4, None),
]
)
self.slice_4 = tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
slice(None, -4),
]
)

self.sample0 = tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
slice(0, 1),
]
)
self.sample1 = tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
slice(1, 2),
]
)
self.sample2 = tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
slice(2, 3),
]
)
self.sample3 = tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
slice(3, 4),
]
)
self.sample_2 = tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
slice(-2, -1),
]
)

self.sample_1 = tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
slice(-1, None),
]
)
self.slice = {
i: {
j: tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
slice(i, j),
]
)
for j in (None, -1, -2, -3, -4)
}
for i in (None, 1, 2, 3, 4)
}
self.sample = {
i: tuple(
[
slice(None, None),
]
* (len(dims) - 1)
+ [
i,
]
)
for i in range(-3, 4)
}
self._register_multiplications(self.kind, self.order)

def _register_multiplications(
Expand Down Expand Up @@ -287,17 +168,19 @@ def _matvec_forward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
# y[..., :-1] = (x[..., 1:] - x[..., :-1]) / self.sampling
y = inplace_set((x[..., 1:] - x[..., :-1]) / self.sampling, y, self.slice_1)
y = inplace_set(
(x[..., 1:] - x[..., :-1]) / self.sampling, y, self.slice[None][-1]
)
return y

@reshaped(swapaxis=True)
def _rmatvec_forward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
# y[..., :-1] -= x[..., :-1]
y = inplace_add(-x[..., :-1], y, self.slice[None][-1])
# y[..., 1:] += x[..., :-1]
y = inplace_add(-x[..., :-1], y, self.slice_1)
y = inplace_add(x[..., :-1], y, self.slice1)
y = inplace_add(x[..., :-1], y, self.slice[1][None])
y /= self.sampling
return y

Expand All @@ -306,13 +189,12 @@ def _matvec_centered3(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
# y[..., 1:-1] = 0.5 * (x[..., 2:] - x[..., :-2])
y = inplace_set(0.5 * (x[..., 2:] - x[..., :-2]), y, self.slice1_1)

y = inplace_set(0.5 * (x[..., 2:] - x[..., :-2]), y, self.slice[1][-1])
if self.edge:
# y[..., 0] = x[..., 1] - x[..., 0]
y = inplace_set(x[..., 1] - x[..., 0], y, self.sample0)
y = inplace_set(x[..., 1] - x[..., 0], y, self.sample[0])
# y[..., -1] = x[..., -1] - x[..., -2]
y = inplace_set(x[..., -1] - x[..., -2], y, self.sample_1)
y = inplace_set(x[..., -1] - x[..., -2], y, self.sample[-1])
y /= self.sampling
return y

Expand All @@ -321,18 +203,18 @@ def _rmatvec_centered3(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
# y[..., :-2] -= 0.5 * x[..., 1:-1]
y = inplace_add(-0.5 * x[..., 1:-1], y, self.slice_2)
y = inplace_add(-0.5 * x[..., 1:-1], y, self.slice[None][-2])
# y[..., 2:] += 0.5 * x[..., 1:-1]
y = inplace_add(0.5 * x[..., 1:-1], y, self.slice2)
y = inplace_add(0.5 * x[..., 1:-1], y, self.slice[2][None])
if self.edge:
# y[..., 0] -= x[..., 0]
y = inplace_add(-x[..., 0], y, self.sample0)
y = inplace_add(-x[..., 0], y, self.sample[0])
# y[..., 1] += x[..., 0]
y = inplace_add(x[..., 0], y, self.sample1)
y = inplace_add(x[..., 0], y, self.sample[1])
# y[..., -2] -= x[..., -1]
y = inplace_add(-x[..., -1], y, self.sample_2)
y = inplace_add(-x[..., -1], y, self.sample[-2])
# y[..., -1] += x[..., -1]
y = inplace_add(x[..., -1], y, self.sample_1)
y = inplace_add(x[..., -1], y, self.sample[-1])
y /= self.sampling
return y

Expand All @@ -354,18 +236,17 @@ def _matvec_centered5(self, x: NDArray) -> NDArray:
- x[..., 4:] / 12.0
),
y,
self.slice2_2,
self.slice[2][-2],
)

if self.edge:
# y[..., 0] = x[..., 1] - x[..., 0]
y = inplace_set(x[..., 1] - x[..., 0], y, self.sample0)
y = inplace_set(x[..., 1] - x[..., 0], y, self.sample[0])
# y[..., 1] = 0.5 * (x[..., 2] - x[..., 0])
y = inplace_set(0.5 * (x[..., 2] - x[..., 0]), y, self.sample1)
y = inplace_set(0.5 * (x[..., 2] - x[..., 0]), y, self.sample[1])
# y[..., -2] = 0.5 * (x[..., -1] - x[..., -3])
y = inplace_set(0.5 * (x[..., -1] - x[..., -3]), y, self.sample_2)
y = inplace_set(0.5 * (x[..., -1] - x[..., -3]), y, self.sample[-2])
# y[..., -1] = x[..., -1] - x[..., -2]
y = inplace_set(x[..., -1] - x[..., -2], y, self.sample_1)
y = inplace_set(x[..., -1] - x[..., -2], y, self.sample[-1])
y /= self.sampling
return y

Expand All @@ -374,26 +255,26 @@ def _rmatvec_centered5(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
# y[..., :-4] += x[..., 2:-2] / 12.0
y = inplace_add(x[..., 2:-2] / 12.0, y, self.slice_4)
y = inplace_add(x[..., 2:-2] / 12.0, y, self.slice[None][-4])
# y[..., 1:-3] -= 2.0 * x[..., 2:-2] / 3.0
y = inplace_add(-2.0 * x[..., 2:-2] / 3.0, y, self.slice1_3)
y = inplace_add(-2.0 * x[..., 2:-2] / 3.0, y, self.slice[1][-3])
# y[..., 3:-1] += 2.0 * x[..., 2:-2] / 3.0
y = inplace_add(2.0 * x[..., 2:-2] / 3.0, y, self.slice3_1)
y = inplace_add(2.0 * x[..., 2:-2] / 3.0, y, self.slice[3][-1])
# y[..., 4:] -= x[..., 2:-2] / 12.0
y = inplace_add(-x[..., 2:-2] / 12.0, y, self.slice4)
y = inplace_add(-x[..., 2:-2] / 12.0, y, self.slice[4][None])
if self.edge:
# y[..., 0] -= x[..., 0] + 0.5 * x[..., 1]
y = inplace_add(-x[..., 0] + 0.5 * x[..., 1], y, self.sample0)
y = inplace_add(-(x[..., 0] + 0.5 * x[..., 1]), y, self.sample[0])
# y[..., 1] += x[..., 0]
y = inplace_add(x[..., 0], y, self.sample1)
y = inplace_add(x[..., 0], y, self.sample[1])
# y[..., 2] += 0.5 * x[..., 1]
y = inplace_add(0.5 * x[..., 1], y, self.sample2)
y = inplace_add(0.5 * x[..., 1], y, self.sample[2])
# y[..., -3] -= 0.5 * x[..., -2]
y = inplace_add(-0.5 * x[..., -2], y, self.sample_3)
y = inplace_add(-0.5 * x[..., -2], y, self.sample[-3])
# y[..., -2] -= x[..., -1]
y = inplace_add(-x[..., -1], y, self.sample_2)
y = inplace_add(-x[..., -1], y, self.sample[-2])
# y[..., -1] += 0.5 * x[..., -2] + x[..., -1]
y = inplace_add(0.5 * x[..., -2] + x[..., -1], y, self.sample_1)
y = inplace_add(0.5 * x[..., -2] + x[..., -1], y, self.sample[-1])
y /= self.sampling
return y

Expand All @@ -402,16 +283,18 @@ def _matvec_backward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
# y[..., 1:] = (x[..., 1:] - x[..., :-1]) / self.sampling
y = inplace_set((x[..., 1:] - x[..., :-1]) / self.sampling, y, self.slice1)
y = inplace_set(
(x[..., 1:] - x[..., :-1]) / self.sampling, y, self.slice[1][None]
)
return y

@reshaped(swapaxis=True)
def _rmatvec_backward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
# y[..., :-1] -= x[..., 1:]
y = inplace_add(-x[..., 1:], y, self.slice[None][-1])
# y[..., 1:] += x[..., 1:]
y = inplace_add(-x[..., 1:], y, self.slice_1)
y = inplace_add(x[..., 1:], y, self.slice1)
y = inplace_add(x[..., 1:], y, self.slice[1][None])
y /= self.sampling
return y
Loading

0 comments on commit cfb09e3

Please sign in to comment.