Skip to content

Commit

Permalink
feat: adapted nonstatconvolve1d to jax
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Jun 23, 2024
1 parent a99ff5d commit 11c7e1d
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 18 deletions.
8 changes: 8 additions & 0 deletions docs/source/gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,14 @@ Signal processing:
- |:white_check_mark:|
- |:white_check_mark:|
- |:white_check_mark:|
* - :class:`pylops.basicoperators.NonStationaryConvolve1D`
- |:white_check_mark:|
- |:white_check_mark:|
- |:white_check_mark:|
* - :class:`pylops.basicoperators.NonStationaryFilters1D`
- |:white_check_mark:|
- |:white_check_mark:|
- |:white_check_mark:|


.. warning::
Expand Down
119 changes: 101 additions & 18 deletions pylops/signalprocessing/nonstatconvolve1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
from pylops.utils.backend import get_array_module
from pylops.utils.backend import get_array_module, inplace_add, inplace_set
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray

Expand Down Expand Up @@ -147,7 +147,8 @@ def _interpolate_h(hs, ix, oh, dh, nh):

@reshaped(swapaxis=True)
def _matvec(self, x: NDArray) -> NDArray:
y = np.zeros_like(x)
ncp = get_array_module(x)
y = ncp.zeros_like(x)
for ix in range(self.dims[self.axis]):
h = self._interpolate_h(self.hs, ix, self.oh, self.dh, self.nh)
xextremes = (
Expand All @@ -158,14 +159,25 @@ def _matvec(self, x: NDArray) -> NDArray:
max(0, -ix + self.hsize // 2),
min(self.hsize, self.hsize // 2 + (self.dims[self.axis] - ix)),
)
y[..., xextremes[0] : xextremes[1]] += (
x[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]]
# y[..., xextremes[0] : xextremes[1]] += (
# x[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]]
# )
sl = tuple(
[
slice(None, None),
]
* (len(self.dimsd) - 1)
+ [
slice(xextremes[0], xextremes[1]),
]
)
y = inplace_add(x[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]], y, sl)
return y

@reshaped(swapaxis=True)
def _rmatvec(self, x: NDArray) -> NDArray:
y = np.zeros_like(x)
ncp = get_array_module(x)
y = ncp.zeros_like(x)
for ix in range(self.dims[self.axis]):
h = self._interpolate_h(self.hs, ix, self.oh, self.dh, self.nh)
xextremes = (
Expand All @@ -176,17 +188,37 @@ def _rmatvec(self, x: NDArray) -> NDArray:
max(0, -ix + self.hsize // 2),
min(self.hsize, self.hsize // 2 + (self.dims[self.axis] - ix)),
)
y[..., ix] = np.sum(
h[hextremes[0] : hextremes[1]] * x[..., xextremes[0] : xextremes[1]],
axis=-1,
# y[..., ix] = ncp.sum(
# h[hextremes[0] : hextremes[1]] * x[..., xextremes[0] : xextremes[1]],
# axis=-1,
# )
sl = tuple(
[
slice(None, None),
]
* (len(self.dimsd) - 1)
+ [
ix,
]
)
y = inplace_set(
ncp.sum(
h[hextremes[0] : hextremes[1]]
* x[..., xextremes[0] : xextremes[1]],
axis=-1,
),
y,
sl,
)

return y

def todense(self):
ncp = get_array_module(self.hsinterp[0])
hs = self.hsinterp
H = np.array(
H = ncp.array(
[
np.roll(np.pad(h, (0, self.dims[self.axis])), ix)
ncp.roll(ncp.pad(h, (0, self.dims[self.axis])), ix)
for ix, h in enumerate(hs)
]
)
Expand Down Expand Up @@ -317,18 +349,55 @@ def _interpolate_hadj(htmp, hs, hextremes, ix, oh, dh, nh):
"""find closest filters and spread weighted psf"""
ih_closest = int(np.floor((ix - oh) / dh))
if ih_closest < 0:
hs[0, hextremes[0] : hextremes[1]] += htmp
# hs[0, hextremes[0] : hextremes[1]] += htmp
sl = tuple(
[
0,
]
+ [
slice(hextremes[0], hextremes[1]),
]
)
hs = inplace_add(htmp, hs, sl)
elif ih_closest >= nh - 1:
hs[nh - 1, hextremes[0] : hextremes[1]] += htmp
# hs[nh - 1, hextremes[0] : hextremes[1]] += htmp
sl = tuple(
[
nh - 1,
]
+ [
slice(hextremes[0], hextremes[1]),
]
)
hs = inplace_add(htmp, hs, sl)
else:
dh_closest = (ix - oh) / dh - ih_closest
hs[ih_closest, hextremes[0] : hextremes[1]] += (1 - dh_closest) * htmp
hs[ih_closest + 1, hextremes[0] : hextremes[1]] += dh_closest * htmp
# hs[ih_closest, hextremes[0] : hextremes[1]] += (1 - dh_closest) * htmp
sl = tuple(
[
ih_closest,
]
+ [
slice(hextremes[0], hextremes[1]),
]
)
hs = inplace_add((1 - dh_closest) * htmp, hs, sl)
# hs[ih_closest + 1, hextremes[0] : hextremes[1]] += dh_closest * htmp
sl = tuple(
[
ih_closest + 1,
]
+ [
slice(hextremes[0], hextremes[1]),
]
)
hs = inplace_add(dh_closest * htmp, hs, sl)
return hs

@reshaped
def _matvec(self, x: NDArray) -> NDArray:
y = np.zeros(self.dimsd, dtype=self.dtype)
ncp = get_array_module(x)
y = ncp.zeros(self.dimsd, dtype=self.dtype)
for ix in range(self.dimsd[0]):
h = self._interpolate_h(x, ix, self.oh, self.dh, self.nh)
xextremes = (
Expand All @@ -339,14 +408,28 @@ def _matvec(self, x: NDArray) -> NDArray:
max(0, -ix + self.hsize // 2),
min(self.hsize, self.hsize // 2 + (self.dimsd[0] - ix)),
)
y[..., xextremes[0] : xextremes[1]] += (
self.inp[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]]
# y[..., xextremes[0] : xextremes[1]] += (
# self.inp[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]]
# )
sl = tuple(
[
slice(None, None),
]
* (len(self.dimsd) - 1)
+ [
slice(xextremes[0], xextremes[1]),
]
)
y = inplace_add(
self.inp[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]], y, sl
)

return y

@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
hs = np.zeros(self.dims, dtype=self.dtype)
ncp = get_array_module(x)
hs = ncp.zeros(self.dims, dtype=self.dtype)
for ix in range(self.dimsd[0]):
xextremes = (
max(0, ix - self.hsize // 2),
Expand Down

0 comments on commit 11c7e1d

Please sign in to comment.