diff --git a/docs/source/_static/cupy_diagram.png b/docs/source/_static/cupy_diagram.png new file mode 100755 index 00000000..7eafad98 Binary files /dev/null and b/docs/source/_static/cupy_diagram.png differ diff --git a/docs/source/_static/numpy_cupy_bd_diagram.png b/docs/source/_static/numpy_cupy_bd_diagram.png new file mode 100755 index 00000000..00ae4a9a Binary files /dev/null and b/docs/source/_static/numpy_cupy_bd_diagram.png differ diff --git a/docs/source/_static/numpy_cupy_vs_diagram.png b/docs/source/_static/numpy_cupy_vs_diagram.png new file mode 100755 index 00000000..96e3b470 Binary files /dev/null and b/docs/source/_static/numpy_cupy_vs_diagram.png differ diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 59e40e51..da9beb36 100755 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -62,6 +62,8 @@ Basic operators Real Imag Conj + ToCupy + Smoothing and derivatives ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/gpu.rst b/docs/source/gpu.rst index a37597e9..842770b2 100755 --- a/docs/source/gpu.rst +++ b/docs/source/gpu.rst @@ -29,6 +29,172 @@ provide data vectors to the solvers, e.g., when using For JAX, apart from following the same procedure described for CuPy, the PyLops operator must be also wrapped into a :class:`pylops.JaxOperator`. +See below for a comphrensive list of supported operators and additional functionalities for both the +``cupy`` and ``jax`` backends. + + +Examples +-------- + +Let's now briefly look at some use cases. + +End-to-end GPU powered inverse problems +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +First we consider the most common scenario when both the model and data +vectors fit onto the GPU memory. We can therefore simply replace all our +``numpy`` arrays with ``cupy`` arrays and solve the inverse problem of +interest end-to-end on the GPU. + +.. image:: _static/cupy_diagram.png + :width: 600 + :align: center + +Let's first write a code snippet using ``numpy`` arrays, which PyLops +will run on your CPU: + +.. code-block:: python + + ny, nx = 400, 400 + G = np.random.normal(0, 1, (ny, nx)).astype(np.float32) + x = np.ones(nx, dtype=np.float32) + + # Create operator + Gop = MatrixMult(G, dtype='float32') + + # Create data and invert + y = Gop @ x + xest = Gop / y + +Now we write a code snippet using ``cupy`` arrays, which PyLops will run on +your GPU: + +.. code-block:: python + + ny, nx = 400, 400 + G = cp.random.normal(0, 1, (ny, nx)).astype(np.float32) + x = cp.ones(nx, dtype=np.float32) + + # Create operator + Gop = MatrixMult(G, dtype='float32') + + # Create data and invert + y = Gop @ x + xest = Gop / y + +The code is almost unchanged apart from the fact that we now use ``cupy`` arrays, +PyLops will figure this out. + +Similarly, we write a code snippet using ``jax`` arrays which PyLops will run on +your GPU/TPU: + +.. code-block:: python + + ny, nx = 400, 400 + G = jnp.array(np.random.normal(0, 1, (ny, nx)).astype(np.float32)) + x = jnp.ones(nx, dtype=np.float32) + + # Create operator + Gop = JaxOperator(MatrixMult(G, dtype='float32')) + + # Create data and invert + y = Gop @ x + xest = Gop / y + + # Adjoint via AD + xadj = Gop.rmatvecad(x, y) + +Again, the code is almost unchanged apart from the fact that we now use ``jax`` arrays. + + +Mixed CPU-GPU powered inverse problems +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Let us now consider a more intricate scenario where we have access to +a GPU-powered operator, however the model and/or data vectors are too large +to fit onto the GPU memory (or VRAM). + +For the sake of clarity, we consider a problem where +the operator can be written as a :class:`pylops.BlockDiag` of +PyLops operators. Note how, by simply sandwiching any of the GPU-powered +operator within two :class:`pylops.ToCupy` operators, we are +able to tell PyLops to transfer to the GPU only the part of the model vector +required by a given operator and transfer back the output to the CPU before +forming the combine output vector (i.e., the output vector of the +:class:`pylops.BlockDiag`). + +.. image:: _static/numpy_cupy_bd_diagram.png + :width: 1000 + :align: center + +.. code-block:: python + + nops, n = 5, 4 + Ms = [np.diag((i + 1) * np.ones(n, dtype=dtype)) \ + for i in range(nops)] + Ms = [M.T @ M for M in Ms] + + # Create operator + Mops = [] + for iop in range(nops): + Mop = MatrixMult(cp.asarray(Ms[iop], dtype=dtype)) + Top = ToCupy(Mop.dims, dtype=dtype) + Top1 = ToCupy(Mop.dimsd, dtype=dtype) + Mop = Top1.H @ Mop @ Top + Mops.append(Mop) + Mops = BlockDiag(Mops, forceflat=True) + + # Create data and invert + x = np.ones(n * nops, dtype=dtype) + y = Mops @ x.ravel() + xest = Mops / y + + +Finally, let us consider a problem where +the operator can be written as a :class:`pylops.VStack` of +PyLops operators and the model vector can be fully transferred to the GPU. +We can use again the :class:`pylops.ToCupy` operator, however this +time we will only use it to move the output of each operator to the CPU. +Since we are now in a special scenario, where the input of the overall +operator sits on the GPU and the output on the +CPU, we need to inform the :class:`pylops.VStack` operator about this. +This can be easily done using the additional ``inoutengine`` parameter. Let's +see this with an example. + +.. image:: _static/numpy_cupy_vs_diagram.png + :width: 1000 + :align: center + +.. code-block:: python + + nops, n, m = 3, 4, 5 + Ms = [np.random.normal(0, 1, (n, m)) for _ in range(nops)] + + # Create operator + Mops = [] + for iop in range(nops): + Mop = MatrixMult(cp.asarray(Ms[iop]), dtype=dtype) + Top1 = ToCupy(Mop.dimsd, dtype=dtype) + Mop = Top1.H @ Mop + Mops.append(Mop) + Mops = VStack(Mops, inoutengine=("numpy", "cupy")) + + # Create data and invert + x = cp.ones(m, dtype=dtype) + y = Mops @ x.ravel() + xest = pylops_cgls(Mops, y, x0=cp.zeros_like(x))[0] + +These features are currently not available for ``jax`` arrays. + + +.. note:: + + More examples for the CuPy and JAX backends be found at `link1 `_ + and `link2 `_. + + +Supported Operators +------------------- In the following, we provide a list of methods in :class:`pylops.LinearOperator` with their current status (available on CPU, GPU with CuPy, and GPU with JAX): @@ -195,6 +361,7 @@ Smoothing and derivatives: - |:white_check_mark:| - |:white_check_mark:| + Signal processing: .. list-table:: @@ -322,6 +489,7 @@ Signal processing: - |:white_check_mark:| - |:white_check_mark:| + Wave-Equation processing .. list-table:: @@ -369,6 +537,7 @@ Wave-Equation processing - |:red_circle:| - |:red_circle:| + Geophysical subsurface characterization: .. list-table:: @@ -407,60 +576,3 @@ Geophysical subsurface characterization: operator currently works only with ``explicit=True`` due to the same issue as in point 1 for the :class:`pylops.signalprocessing.Convolve1D` operator employed when ``explicit=False``. - - -Example -------- - -Finally, let's briefly look at an example. First we write a code snippet using -``numpy`` arrays which PyLops will run on your CPU: - -.. code-block:: python - - ny, nx = 400, 400 - G = np.random.normal(0, 1, (ny, nx)).astype(np.float32) - x = np.ones(nx, dtype=np.float32) - - Gop = MatrixMult(G, dtype='float32') - y = Gop * x - xest = Gop / y - -Now we write a code snippet using ``cupy`` arrays which PyLops will run on -your GPU: - -.. code-block:: python - - ny, nx = 400, 400 - G = cp.random.normal(0, 1, (ny, nx)).astype(np.float32) - x = cp.ones(nx, dtype=np.float32) - - Gop = MatrixMult(G, dtype='float32') - y = Gop * x - xest = Gop / y - -The code is almost unchanged apart from the fact that we now use ``cupy`` arrays, -PyLops will figure this out. - -Similarly, we write a code snippet using ``jax`` arrays which PyLops will run on -your GPU/TPU: - -.. code-block:: python - - ny, nx = 400, 400 - G = jnp.array(np.random.normal(0, 1, (ny, nx)).astype(np.float32)) - x = jnp.ones(nx, dtype=np.float32) - - Gop = JaxOperator(MatrixMult(G, dtype='float32')) - y = Gop * x - xest = Gop / y - - # Adjoint via AD - xadj = Gop.rmatvecad(x, y) - - -Again, the code is almost unchanged apart from the fact that we now use ``jax`` arrays, - -.. note:: - - More examples for the CuPy and JAX backends be found `here `__ - and `here `__. \ No newline at end of file diff --git a/pylops/basicoperators/__init__.py b/pylops/basicoperators/__init__.py index d654e50d..5d3a3d3e 100755 --- a/pylops/basicoperators/__init__.py +++ b/pylops/basicoperators/__init__.py @@ -38,6 +38,7 @@ Gradient Gradient. FirstDirectionalDerivative First Directional derivative. SecondDirectionalDerivative Second Directional derivative. + ToCupy Convert to CuPy. """ from .functionoperator import * @@ -72,6 +73,8 @@ from .laplacian import * from .gradient import * from .directionalderivative import * +from .tocupy import * + __all__ = [ "FunctionOperator", @@ -107,4 +110,5 @@ "Gradient", "FirstDirectionalDerivative", "SecondDirectionalDerivative", + "ToCupy", ] diff --git a/pylops/basicoperators/blockdiag.py b/pylops/basicoperators/blockdiag.py index 166ae137..4d9fbc36 100644 --- a/pylops/basicoperators/blockdiag.py +++ b/pylops/basicoperators/blockdiag.py @@ -21,7 +21,7 @@ from pylops import LinearOperator from pylops.basicoperators import MatrixMult -from pylops.utils.backend import get_array_module, inplace_set +from pylops.utils.backend import get_array_module, get_module, inplace_set from pylops.utils.typing import DTypeLike, NDArray @@ -48,6 +48,12 @@ class BlockDiag(LinearOperator): .. versionadded:: 2.2.0 Force an array to be flattened after matvec and rmatvec. + inoutengine : :obj:`tuple`, optional + .. versionadded:: 2.4.0 + + Type of output vectors of `matvec` and `rmatvec. If ``None``, this is + inferred directly from the input vectors. Note that this is ignored + if ``nproc>1``. dtype : :obj:`str`, optional Type of elements in input array. @@ -113,6 +119,7 @@ def __init__( ops: Sequence[LinearOperator], nproc: int = 1, forceflat: bool = None, + inoutengine: Optional[tuple] = None, dtype: Optional[DTypeLike] = None, ) -> None: self.ops = ops @@ -149,6 +156,7 @@ def __init__( if self.nproc > 1: self.pool = mp.Pool(processes=nproc) + self.inoutengine = inoutengine dtype = _get_dtype(ops) if dtype is None else np.dtype(dtype) clinear = all([getattr(oper, "clinear", True) for oper in self.ops]) super().__init__( @@ -172,7 +180,11 @@ def nproc(self, nprocnew: int) -> None: self._nproc = nprocnew def _matvec_serial(self, x: NDArray) -> NDArray: - ncp = get_array_module(x) + ncp = ( + get_array_module(x) + if self.inoutengine is None + else get_module(self.inoutengine[0]) + ) y = ncp.zeros(self.nops, dtype=self.dtype) for iop, oper in enumerate(self.ops): y = inplace_set( @@ -183,7 +195,11 @@ def _matvec_serial(self, x: NDArray) -> NDArray: return y def _rmatvec_serial(self, x: NDArray) -> NDArray: - ncp = get_array_module(x) + ncp = ( + get_array_module(x) + if self.inoutengine is None + else get_module(self.inoutengine[1]) + ) y = ncp.zeros(self.mops, dtype=self.dtype) for iop, oper in enumerate(self.ops): y = inplace_set( diff --git a/pylops/basicoperators/hstack.py b/pylops/basicoperators/hstack.py index b71e8723..4cd450d5 100644 --- a/pylops/basicoperators/hstack.py +++ b/pylops/basicoperators/hstack.py @@ -21,7 +21,7 @@ from pylops import LinearOperator from pylops.basicoperators import MatrixMult -from pylops.utils.backend import get_array_module, inplace_add, inplace_set +from pylops.utils.backend import get_array_module, get_module, inplace_add, inplace_set from pylops.utils.typing import NDArray @@ -48,6 +48,12 @@ class HStack(LinearOperator): .. versionadded:: 2.2.0 Force an array to be flattened after matvec. + inoutengine : :obj:`tuple`, optional + .. versionadded:: 2.4.0 + + Type of output vectors of `matvec` and `rmatvec. If ``None``, this is + inferred directly from the input vectors. Note that this is ignored + if ``nproc>1``. dtype : :obj:`str`, optional Type of elements in input array. @@ -112,6 +118,7 @@ def __init__( ops: Sequence[LinearOperator], nproc: int = 1, forceflat: bool = None, + inoutengine: Optional[tuple] = None, dtype: Optional[str] = None, ) -> None: self.ops = ops @@ -139,6 +146,8 @@ def __init__( self.pool = None if self.nproc > 1: self.pool = mp.Pool(processes=nproc) + + self.inoutengine = inoutengine dtype = _get_dtype(self.ops) if dtype is None else np.dtype(dtype) clinear = all([getattr(oper, "clinear", True) for oper in self.ops]) super().__init__( @@ -162,7 +171,11 @@ def nproc(self, nprocnew: int): self._nproc = nprocnew def _matvec_serial(self, x: NDArray) -> NDArray: - ncp = get_array_module(x) + ncp = ( + get_array_module(x) + if self.inoutengine is None + else get_module(self.inoutengine[0]) + ) y = ncp.zeros(self.nops, dtype=self.dtype) for iop, oper in enumerate(self.ops): y = inplace_add( @@ -173,7 +186,11 @@ def _matvec_serial(self, x: NDArray) -> NDArray: return y def _rmatvec_serial(self, x: NDArray) -> NDArray: - ncp = get_array_module(x) + ncp = ( + get_array_module(x) + if self.inoutengine is None + else get_module(self.inoutengine[1]) + ) y = ncp.zeros(self.mops, dtype=self.dtype) for iop, oper in enumerate(self.ops): y = inplace_set( diff --git a/pylops/basicoperators/tocupy.py b/pylops/basicoperators/tocupy.py new file mode 100644 index 00000000..8a3a7c6b --- /dev/null +++ b/pylops/basicoperators/tocupy.py @@ -0,0 +1,60 @@ +__all__ = ["ToCupy"] + +from typing import Union + +import numpy as np + +from pylops import LinearOperator +from pylops.utils._internal import _value_or_sized_to_tuple +from pylops.utils.backend import to_cupy, to_numpy +from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray + + +class ToCupy(LinearOperator): + r"""Convert to CuPy. + + Convert an input array to CuPy in forward mode, + and convert back to NumPy in adjoint mode. + + Parameters + ---------- + dims : :obj:`list` or :obj:`int` + Number of samples for each dimension + dtype : :obj:`str`, optional + Type of elements in input array. + name : :obj:`str`, optional + Name of operator (to be used by :func:`pylops.utils.describe.describe`) + + Attributes + ---------- + shape : :obj:`tuple` + Operator shape + explicit : :obj:`bool` + Operator contains a matrix that can be solved explicitly + (``True``) or not (``False``) + + Notes + ----- + The ToCupy operator is a special operator that does not perform + any transformation on the input arrays other than converting + them from NumPy to CuPy. This operator can be used when one + is interested to create a chain of operators where only one + (or some of them) act on CuPy arrays, whilst other operate + on NumPy arrays. + + """ + + def __init__( + self, + dims: Union[int, InputDimsLike], + dtype: DTypeLike = "float64", + name: str = "C", + ) -> None: + dims = _value_or_sized_to_tuple(dims) + super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dims, name=name) + + def _matvec(self, x: NDArray) -> NDArray: + return to_cupy(x) + + def _rmatvec(self, x: NDArray) -> NDArray: + return to_numpy(x) diff --git a/pylops/basicoperators/vstack.py b/pylops/basicoperators/vstack.py index 0d66642e..55341411 100644 --- a/pylops/basicoperators/vstack.py +++ b/pylops/basicoperators/vstack.py @@ -21,7 +21,7 @@ from pylops import LinearOperator from pylops.basicoperators import MatrixMult -from pylops.utils.backend import get_array_module, inplace_add, inplace_set +from pylops.utils.backend import get_array_module, get_module, inplace_add, inplace_set from pylops.utils.typing import DTypeLike, NDArray @@ -48,6 +48,12 @@ class VStack(LinearOperator): .. versionadded:: 2.2.0 Force an array to be flattened after rmatvec. + inoutengine : :obj:`tuple`, optional + .. versionadded:: 2.4.0 + + Type of output vectors of `matvec` and `rmatvec. If ``None``, this is + inferred directly from the input vectors. Note that this is ignored + if ``nproc>1``. dtype : :obj:`str`, optional Type of elements in input array. @@ -112,6 +118,7 @@ def __init__( ops: Sequence[LinearOperator], nproc: int = 1, forceflat: bool = None, + inoutengine: Optional[tuple] = None, dtype: Optional[DTypeLike] = None, ) -> None: self.ops = ops @@ -139,6 +146,8 @@ def __init__( self.pool = None if self.nproc > 1: self.pool = mp.Pool(processes=nproc) + + self.inoutengine = inoutengine dtype = _get_dtype(self.ops) if dtype is None else np.dtype(dtype) clinear = all([getattr(oper, "clinear", True) for oper in self.ops]) super().__init__( @@ -162,7 +171,11 @@ def nproc(self, nprocnew: int): self._nproc = nprocnew def _matvec_serial(self, x: NDArray) -> NDArray: - ncp = get_array_module(x) + ncp = ( + get_array_module(x) + if self.inoutengine is None + else get_module(self.inoutengine[0]) + ) y = ncp.zeros(self.nops, dtype=self.dtype) for iop, oper in enumerate(self.ops): y = inplace_set( @@ -171,7 +184,11 @@ def _matvec_serial(self, x: NDArray) -> NDArray: return y def _rmatvec_serial(self, x: NDArray) -> NDArray: - ncp = get_array_module(x) + ncp = ( + get_array_module(x) + if self.inoutengine is None + else get_module(self.inoutengine[1]) + ) y = ncp.zeros(self.mops, dtype=self.dtype) for iop, oper in enumerate(self.ops): y = inplace_add( diff --git a/pylops/optimization/cls_basic.py b/pylops/optimization/cls_basic.py index 2e848fcf..b7c98e0a 100644 --- a/pylops/optimization/cls_basic.py +++ b/pylops/optimization/cls_basic.py @@ -10,7 +10,12 @@ import numpy as np from pylops.optimization.basesolver import Solver -from pylops.utils.backend import get_array_module, to_numpy +from pylops.utils.backend import ( + get_array_module, + to_cupy_conditional, + to_numpy, + to_numpy_conditional, +) from pylops.utils.typing import NDArray if TYPE_CHECKING: @@ -131,10 +136,10 @@ def step(self, x: NDArray, show: bool = False) -> NDArray: Updated model vector """ - Opc = self.Op.matvec(self.c) + Opc = self.Op.matvec(to_cupy_conditional(x, self.c)) cOpc = self.ncp.abs(self.c.dot(Opc.conj())) a = self.kold / cOpc - x += a * self.c + x += to_cupy_conditional(x, a) * to_cupy_conditional(x, self.c) self.r -= a * Opc k = self.ncp.abs(self.r.dot(self.r.conj())) b = k / self.kold @@ -386,7 +391,7 @@ def step(self, x: NDArray, show: bool = False) -> NDArray: self.q.dot(self.q.conj()) + self.damp * self.c.dot(self.c.conj()) ) x = x + a * self.c - self.s = self.s - a * self.q + self.s = self.s - to_numpy_conditional(self.q, a) * self.q r = self.Op.rmatvec(self.s) - self.damp * x k = self.ncp.abs(r.dot(r.conj())) b = k / self.kold @@ -773,7 +778,9 @@ def step(self, x: NDArray, show: bool = False) -> NDArray: # next beta, u, alfa, v. These satisfy the relations # beta*u = Op*v - alfa*u, # alfa*v = Op'*u - beta*v' - self.u = self.Op.matvec(self.v) - self.alfa * self.u + self.u = ( + self.Op.matvec(self.v) - to_numpy_conditional(self.u, self.alfa) * self.u + ) self.beta = self.ncp.linalg.norm(self.u) if self.beta > 0: self.u = self.u / self.beta @@ -812,7 +819,9 @@ def step(self, x: NDArray, show: bool = False) -> NDArray: self.w = self.v + self.t2 * self.w self.ddnorm = self.ddnorm + self.ncp.linalg.norm(self.dk) ** 2 if self.calc_var: - self.var = self.var + self.ncp.dot(self.dk, self.dk) + self.var = self.var + to_numpy_conditional( + self.var, self.ncp.dot(self.dk, self.dk) + ) # use a plane rotation on the right to eliminate the # super-diagonal element (theta) of the upper-bidiagonal matrix. diff --git a/pylops/utils/backend.py b/pylops/utils/backend.py index 9b7d3b2d..31836265 100644 --- a/pylops/utils/backend.py +++ b/pylops/utils/backend.py @@ -17,8 +17,10 @@ "get_sp_fft", "get_complex_dtype", "get_real_dtype", + "to_cupy", "to_numpy", "to_cupy_conditional", + "to_numpy_conditional", "inplace_set", "inplace_add", "inplace_multiply", @@ -484,6 +486,26 @@ def get_real_dtype(dtype: DTypeLike) -> DTypeLike: return np.real(np.ones(1, dtype)).dtype +def to_cupy(x: NDArray) -> NDArray: + """Convert x to cupy array + + Parameters + ---------- + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + Array to evaluate + + Returns + ------- + x : :obj:`numpy.ndarray` + Converted array + + """ + if deps.cupy_enabled: + if cp.get_array_module(x) == np: + x = cp.asarray(x) + return x + + def to_numpy(x: NDArray) -> NDArray: """Convert x to numpy array @@ -527,6 +549,28 @@ def to_cupy_conditional(x: npt.ArrayLike, y: npt.ArrayLike) -> NDArray: return y +def to_numpy_conditional(x: npt.ArrayLike, y: npt.ArrayLike) -> NDArray: + """Convert y to numpy array conditional to x being a numpy array + + Parameters + ---------- + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + Array to evaluate + y : :obj:`numpy.ndarray` + Array to convert + + Returns + ------- + y : :obj:`cupy.ndarray` + Converted array + + """ + if deps.cupy_enabled: + if cp.get_array_module(x) == np and cp.get_array_module(y) == cp: + y = cp.asnumpy(y) + return y + + def inplace_set(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray: """Perform inplace set based on input diff --git a/pytests/test_basicoperators.py b/pytests/test_basicoperators.py index 8f3f0528..f12c8c68 100755 --- a/pytests/test_basicoperators.py +++ b/pytests/test_basicoperators.py @@ -16,6 +16,7 @@ Roll, Sum, Symmetrize, + ToCupy, Zero, ) from pylops.utils import dottest @@ -601,3 +602,18 @@ def test_Conj(par): assert_array_equal(x, xadj) assert_array_equal(y, np.conj(x)) assert_array_equal(xadj, np.conj(y)) + + +@pytest.mark.parametrize("par", [(par1), (par2), (par1j), (par2j), (par3)]) +def test_ToCupy(par): + """Forward and adjoint for ToCupy operator (checking that it works also + when cupy is not available) + """ + Top = ToCupy(par["nx"], dtype=par["dtype"]) + + np.random.seed(10) + x = np.random.randn(par["nx"]) + par["imag"] * np.random.randn(par["nx"]) + y = Top * x + xadj = Top.H * y + assert_array_equal(x, xadj) + assert_array_equal(y, x)