Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/PyLops/pylops-mpi into 2ddistr
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Oct 27, 2024
2 parents 0e5464d + 57b793e commit 746428f
Show file tree
Hide file tree
Showing 2 changed files with 243 additions and 0 deletions.
9 changes: 9 additions & 0 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,12 @@ class DistributedArray:
Axis along which distribution occurs. Defaults to ``0``.
local_shapes : :obj:`list`, optional
List of tuples or integers representing local shapes at each rank.
<<<<<<< HEAD
mask : :obj:`list`, optional
Mask defining subsets of ranks to consider when performing 'global'
operations on the distributed array such as dot product or norm.
=======
>>>>>>> 57b793e8ce4c150d90866d1e41c0bd9e88cae985
engine : :obj:`str`, optional
Engine used to store array (``numpy`` or ``cupy``)
dtype : :obj:`str`, optional
Expand All @@ -93,7 +96,10 @@ def __init__(self, global_shape: Union[Tuple, Integral],
base_comm: Optional[MPI.Comm] = MPI.COMM_WORLD,
partition: Partition = Partition.SCATTER, axis: int = 0,
local_shapes: Optional[List[Union[Tuple, Integral]]] = None,
<<<<<<< HEAD
mask: Optional[List[Integral]] = None,
=======
>>>>>>> 57b793e8ce4c150d90866d1e41c0bd9e88cae985
engine: Optional[str] = "numpy",
dtype: Optional[DTypeLike] = np.float64):
if isinstance(global_shape, Integral):
Expand All @@ -109,8 +115,11 @@ def __init__(self, global_shape: Union[Tuple, Integral],
self._base_comm = base_comm
self._partition = partition
self._axis = axis
<<<<<<< HEAD
self._mask = mask
self._sub_comm = base_comm if mask is None else base_comm.Split(color=mask[base_comm.rank], key=base_comm.rank)
=======
>>>>>>> 57b793e8ce4c150d90866d1e41c0bd9e88cae985

local_shapes = local_shapes if local_shapes is None else [_value_or_sized_to_tuple(local_shape) for local_shape in local_shapes]
self._check_local_shapes(local_shapes)
Expand Down
234 changes: 234 additions & 0 deletions tutorials/mdd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
"""
Multi-Dimensional Deconvolution
===============================
This example shows how to set-up and run a Multi-Dimensional Deconvolution
problem in a distributed fashion, leveraging the :py:class:`pylops_mpi.waveeqprocessing.MDC`
class.
More precisely, compared to its counterpart in the PyLops documentation, this example distributes
the frequency slices of the kernel of the MDC operator across multiple processes. Whilst both the
entire model and data sit on all processes, within the MDC operator, and more precisely when the
:py:class:`pylops_mpi.signalprocessing.Fredholm1` is called, different groups of frequencies are
processed by the different ranks.
"""

import numpy as np
from scipy.signal import filtfilt
from matplotlib import pyplot as plt
from mpi4py import MPI

from pylops.utils.seismicevents import hyperbolic2d, makeaxis
from pylops.utils.tapers import taper3d
from pylops.utils.wavelets import ricker

import pylops_mpi
from pylops_mpi.DistributedArray import local_split, Partition

plt.close("all")
rank = MPI.COMM_WORLD.Get_rank()
size = MPI.COMM_WORLD.Get_size()
dtype = np.float32
cdtype = np.complex64

###############################################################################
# Let's start by creating a set of hyperbolic events to be used as
# our MDC kernel as well as the model

# Input parameters
par = {
"ox": -300,
"dx": 10,
"nx": 61,
"oy": -500,
"dy": 10,
"ny": 101,
"ot": 0,
"dt": 0.004,
"nt": 400,
"f0": 20,
"nfmax": 200,
}

t0_m = 0.2
vrms_m = 1100.0
amp_m = 1.0

t0_G = (0.2, 0.5, 0.7)
vrms_G = (1200.0, 1500.0, 2000.0)
amp_G = (1.0, 0.6, 0.5)

# Taper
tap = taper3d(par["nt"], (par["ny"], par["nx"]), (5, 5), tapertype="hanning")

# Create axis
t, t2, x, y = makeaxis(par)

# Create wavelet
wav = ricker(t[:41], f0=par["f0"])[0]

# Generate model
mrefl, mwav = hyperbolic2d(x, t, t0_m, vrms_m, amp_m, wav)

# Generate operator
G, Gwav = np.zeros((par["ny"], par["nx"], par["nt"])), np.zeros(
(par["ny"], par["nx"], par["nt"])
)
for iy, y0 in enumerate(y):
G[iy], Gwav[iy] = hyperbolic2d(x - y0, t, t0_G, vrms_G, amp_G, wav)
G, Gwav = G * tap, Gwav * tap

# Add negative part to data and model
mrefl = np.concatenate((np.zeros((par["nx"], par["nt"] - 1)), mrefl), axis=-1)
mwav = np.concatenate((np.zeros((par["nx"], par["nt"] - 1)), mwav), axis=-1)
Gwav2 = np.concatenate((np.zeros((par["ny"], par["nx"], par["nt"] - 1)), Gwav), axis=-1)

# Move to frequency
Gwav_fft = np.fft.rfft(Gwav2, 2 * par["nt"] - 1, axis=-1)
Gwav_fft = (Gwav_fft[..., : par["nfmax"]])

# Move frequency/time to first axis
mrefl, mwav = mrefl.T, mwav.T
Gwav_fft = Gwav_fft.transpose(2, 0, 1)

# Choose how to split frequencies to ranks
nf = par["nfmax"]
nf_rank = local_split((nf,), MPI.COMM_WORLD, Partition.SCATTER, 0)
nf_ranks = np.concatenate(MPI.COMM_WORLD.allgather(nf_rank))
ifin_rank = np.insert(np.cumsum(nf_ranks)[:-1], 0, 0)[rank]
ifend_rank = np.cumsum(nf_ranks)[rank]

# Extract batch of frequency slices (in practice, this will be directly read from input file)
G = Gwav_fft[ifin_rank:ifend_rank].astype(cdtype)

###############################################################################
# Let's now define the distributed operator and model as well as compute the
# data

# Define operator
MDCop = pylops_mpi.waveeqprocessing.MPIMDC((1.0 * par["dt"] * np.sqrt(par["nt"])) * G,
nt=2 * par["nt"] - 1, nv=1, nfreq=nf,
dt=par["dt"], dr=1.0, twosided=True,
fftengine="scipy", prescaled=True)

# Create model
m = pylops_mpi.DistributedArray(global_shape=(2 * par["nt"] - 1) * par["nx"] * 1,
partition=Partition.BROADCAST,
dtype=dtype)
m[:] = mrefl.astype(dtype).ravel()

# Create data
d = MDCop @ m
dloc = d.asarray().real.reshape(2 * par["nt"] - 1, par["ny"])

###############################################################################
# Let's display what we have so far: operator, input model, and data

if rank == 0:
fig, axs = plt.subplots(1, 2, figsize=(8, 6))
axs[0].imshow(
Gwav2[int(par["ny"] / 2)].T,
aspect="auto",
interpolation="nearest",
cmap="gray",
vmin=-np.abs(Gwav2.max()),
vmax=np.abs(Gwav2.max()),
extent=(x.min(), x.max(), t2.max(), t2.min()),
)
axs[0].set_title("G - inline view", fontsize=15)
axs[0].set_xlabel(r"$x_R$")
axs[1].set_ylabel(r"$t$")
axs[1].imshow(
Gwav2[:, int(par["nx"] / 2)].T,
aspect="auto",
interpolation="nearest",
cmap="gray",
vmin=-np.abs(Gwav2.max()),
vmax=np.abs(Gwav2.max()),
extent=(y.min(), y.max(), t2.max(), t2.min()),
)
axs[1].set_title("G - inline view", fontsize=15)
axs[1].set_xlabel(r"$x_S$")
axs[1].set_ylabel(r"$t$")
fig.tight_layout()

fig, axs = plt.subplots(1, 2, figsize=(8, 6))
axs[0].imshow(
mwav,
aspect="auto",
interpolation="nearest",
cmap="gray",
vmin=-np.abs(mwav.max()),
vmax=np.abs(mwav.max()),
extent=(x.min(), x.max(), t2.max(), t2.min()),
)
axs[0].set_title(r"$m$", fontsize=15)
axs[0].set_xlabel(r"$x_R$")
axs[0].set_ylabel(r"$t$")
axs[1].imshow(
dloc,
aspect="auto",
interpolation="nearest",
cmap="gray",
vmin=-np.abs(dloc.max()),
vmax=np.abs(dloc.max()),
extent=(x.min(), x.max(), t2.max(), t2.min()),
)
axs[1].set_title(r"$d$", fontsize=15)
axs[1].set_xlabel(r"$x_S$")
axs[1].set_ylabel(r"$t$")
fig.tight_layout()

###############################################################################
# We are now ready to compute the adjoint (i.e., cross-correlation) and invert
# back for our input model

# Adjoint
madj = MDCop.H @ d
madjloc = madj.asarray().real.reshape(2 * par["nt"] - 1, par["nx"])

# Inverse
m0 = pylops_mpi.DistributedArray(global_shape=(2 * par["nt"] - 1) * par["nx"] * 1,
partition=Partition.BROADCAST,
dtype=cdtype)
m0[:] = 0
minv = pylops_mpi.cgls(MDCop, d, x0=m0, niter=50, show=True if rank == 0 else False)[0]
minvloc = minv.asarray().real.reshape(2 * par["nt"] - 1, par["nx"])

if rank == 0:
fig = plt.figure(figsize=(8, 6))
ax1 = plt.subplot2grid((1, 5), (0, 0), colspan=2)
ax2 = plt.subplot2grid((1, 5), (0, 2), colspan=2)
ax3 = plt.subplot2grid((1, 5), (0, 4))
ax1.imshow(
madjloc,
aspect="auto",
interpolation="nearest",
cmap="gray",
vmin=-np.abs(madjloc.max()),
vmax=np.abs(madjloc.max()),
extent=(x.min(), x.max(), t2.max(), t2.min()),
)
ax1.set_title("Adjoint m", fontsize=15)
ax1.set_xlabel(r"$x_V$")
ax1.set_ylabel(r"$t$")
ax2.imshow(
minvloc,
aspect="auto",
interpolation="nearest",
cmap="gray",
vmin=-np.abs(minvloc.max()),
vmax=np.abs(minvloc.max()),
extent=(x.min(), x.max(), t2.max(), t2.min()),
)
ax2.set_title("Inverted m", fontsize=15)
ax2.set_xlabel(r"$x_V$")
ax2.set_ylabel(r"$t$")
ax3.plot(
madjloc[:, int(par["nx"] / 2)] / np.abs(madjloc[:, int(par["nx"] / 2)]).max(), t2, "r", lw=5
)
ax3.plot(
minvloc[:, int(par["nx"] / 2)] / np.abs(minvloc[:, int(par["nx"] / 2)]).max(), t2, "k", lw=3
)
ax3.set_ylim([t2[-1], t2[0]])
fig.tight_layout()

0 comments on commit 746428f

Please sign in to comment.