Skip to content

Commit

Permalink
feat: enabled jax in waveeqprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Jul 3, 2024
1 parent 6067fa2 commit ca81494
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 18 deletions.
33 changes: 22 additions & 11 deletions pylops/waveeqprocessing/blending.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pylops import LinearOperator
from pylops.basicoperators import BlockDiag, HStack, Pad
from pylops.signalprocessing import Shift
from pylops.utils.backend import get_array_module
from pylops.utils.backend import get_array_module, inplace_add, inplace_set
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, NDArray

Expand Down Expand Up @@ -76,12 +76,7 @@ def __init__(
self.dt = dt
self.times = times
self.shiftall = shiftall
if np.max(self.times) // dt == np.max(self.times) / dt:
# do not add extra sample as no shift will be applied
self.nttot = int(np.max(self.times) / self.dt + self.nt)
else:
# add 1 extra sample at the end
self.nttot = int(np.max(self.times) / self.dt + self.nt + 1)
self.nttot = int(np.max(self.times) / self.dt + self.nt + 1)
if not self.shiftall:
# original implementation, where each source is shifted indipendently
self.PadOp = Pad((self.nr, self.nt), ((0, 0), (0, 1)), dtype=self.dtype)
Expand Down Expand Up @@ -143,15 +138,23 @@ def _matvec_smallrecs(self, x: NDArray) -> NDArray:
self.ns, self.nr, self.nt + 1
)
for i, shift_int in enumerate(self.shifts):
blended_data[:, shift_int : shift_int + self.nt + 1] += shifted_data[i]
blended_data = inplace_add(
shifted_data[i],
blended_data,
(slice(None, None), slice(shift_int, shift_int + self.nt + 1)),
)
return blended_data

@reshaped
def _rmatvec_smallrecs(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
shifted_data = ncp.zeros((self.ns, self.nr, self.nt + 1), dtype=self.dtype)
for i, shift_int in enumerate(self.shifts):
shifted_data[i, :, :] = x[:, shift_int : shift_int + self.nt + 1]
shifted_data = inplace_set(
x[:, shift_int : shift_int + self.nt + 1],
shifted_data,
(i, slice(None, None), slice(None, None)),
)
deblended_data = self.PadOp._rmatvec(
self.ShiftOp._rmatvec(shifted_data.ravel())
).reshape(self.dims)
Expand All @@ -170,7 +173,11 @@ def _matvec_largerecs(self, x: NDArray) -> NDArray:
.matvec(self.PadOp.matvec(x[i, :, :].ravel()))
.reshape(self.ShiftOps[i].dimsd)
)
blended_data[:, shift_int : shift_int + self.nt + 1] += shifted_data
blended_data = inplace_add(
shifted_data,
blended_data,
(slice(None, None), slice(shift_int, shift_int + self.nt + 1)),
)
return blended_data

@reshaped
Expand All @@ -186,7 +193,11 @@ def _rmatvec_largerecs(self, x: NDArray) -> NDArray:
x[:, shift_int : shift_int + self.nt + 1].ravel()
)
).reshape(self.PadOp.dims)
deblended_data[i, :, :] = shifted_data
deblended_data = inplace_set(
shifted_data,
deblended_data,
(i, slice(None, None), slice(None, None)),
)
return deblended_data

def _register_multiplications(self) -> None:
Expand Down
39 changes: 33 additions & 6 deletions pylops/waveeqprocessing/kirchhoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,11 @@ def __init__(
)
self.rix = np.tile((recs[0] - x[0]) // dx, (ns, 1)).astype(int).ravel()
elif self.ndims == 3:
# TODO: 3D normalized distances
raise NotImplementedError("dynamic=True currently not available in 3D")
# TODO: compute 3D indices for aperture filter
# currently no aperture filter in 3D... just make indices 0
# so check if always passed
self.six = np.zeros(nr * ns)
self.rix = np.zeros(nr * ns)

# compute traveltime and distances
self.travsrcrec = True # use separate tables for src and rec traveltimes
Expand Down Expand Up @@ -362,8 +365,26 @@ def __init__(
trav_recs_grad[0], trav_recs_grad[1]
).reshape(np.prod(dims), nr)
else:
# TODO: 3D
raise NotImplementedError("dynamic=True currently not available in 3D")
trav_srcs_grad = np.concatenate(
[trav_srcs_grad[i][np.newaxis] for i in range(3)]
)
trav_recs_grad = np.concatenate(
[trav_recs_grad[i][np.newaxis] for i in range(3)]
)
self.angle_srcs = (
np.sign(trav_srcs_grad[1])
* np.arccos(
trav_srcs_grad[-1]
/ np.sqrt(np.sum(trav_srcs_grad**2, axis=0))
)
).reshape(np.prod(dims), ns)
self.angle_recs = (
np.sign(trav_srcs_grad[1])
* np.arccos(
trav_recs_grad[-1]
/ np.sqrt(np.sum(trav_recs_grad**2, axis=0))
)
).reshape(np.prod(dims), nr)

# pre-compute traveltime indices if total traveltime is used
if not self.travsrcrec:
Expand All @@ -386,6 +407,12 @@ def __init__(

# define aperture
# if aperture=None, we want to ensure the check is always matched (no aperture limits...)
# if aperture!=None in 3d, force to None as aperture checks are not yet implemented
if aperture is not None and self.ndims == 3:
aperture = None
warnings.warn(
"Aperture is forced to None as currently not implemented in 3D"
)
if aperture is not None:
warnings.warn(
"Aperture is currently defined as ratio of offset over depth, "
Expand Down Expand Up @@ -608,10 +635,10 @@ def _traveltime_table(

# compute traveltime gradients at image points
trav_srcs_grad = np.gradient(
trav_srcs.reshape(*dims, ns), axis=np.arange(ndims)
trav_srcs.reshape(*dims, ns), *dsamp, axis=np.arange(ndims)
)
trav_recs_grad = np.gradient(
trav_recs.reshape(*dims, nr), axis=np.arange(ndims)
trav_recs.reshape(*dims, nr), *dsamp, axis=np.arange(ndims)
)

return (
Expand Down
17 changes: 16 additions & 1 deletion pylops/waveeqprocessing/wavedecomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def _obliquity3D(
critical: float = 100.0,
ntaper: int = 10,
composition: bool = True,
fftengine: str = "scipy",
backend: str = "numpy",
dtype: DTypeLike = "complex128",
) -> Tuple[LinearOperator, LinearOperator]:
Expand Down Expand Up @@ -187,6 +188,9 @@ def _obliquity3D(
composition : :obj:`bool`, optional
Create obliquity factor for composition (``True``) or
decomposition (``False``)
fftengine : :obj:`str`, optional
Engine used for fft computation (``numpy`` or ``scipy``). Choose
``numpy`` when working with cupy and jax arrays.
backend : :obj:`str`, optional
Backend used for creation of obliquity factor operator
(``numpy`` or ``cupy``)
Expand All @@ -203,7 +207,11 @@ def _obliquity3D(
"""
# create Fourier operator
FFTop = FFTND(
dims=[nr[0], nr[1], nt], nffts=nffts, sampling=[dr[0], dr[1], dt], dtype=dtype
dims=[nr[0], nr[1], nt],
nffts=nffts,
sampling=[dr[0], dr[1], dt],
engine=fftengine,
dtype=dtype,
)

# create obliquity operator
Expand Down Expand Up @@ -547,6 +555,7 @@ def UpDownComposition3D(
critical: float = 100.0,
ntaper: int = 10,
scaling: float = 1.0,
fftengine: str = "scipy",
backend: str = "numpy",
dtype: DTypeLike = "complex128",
name: str = "U",
Expand Down Expand Up @@ -588,6 +597,11 @@ def UpDownComposition3D(
angle
scaling : :obj:`float`, optional
Scaling to apply to the operator (see Notes for more details)
fftengine : :obj:`str`, optional
.. versionadded:: 2.3.0
Engine used for fft computation (``numpy`` or ``scipy``). Choose
``numpy`` when working with cupy and jax arrays.
backend : :obj:`str`, optional
Backend used for creation of obliquity factor operator
(``numpy`` or ``cupy``)
Expand Down Expand Up @@ -638,6 +652,7 @@ def UpDownComposition3D(
critical=critical,
ntaper=ntaper,
composition=True,
fftengine=fftengine,
backend=backend,
dtype=dtype,
)
Expand Down

0 comments on commit ca81494

Please sign in to comment.