Skip to content

Commit

Permalink
Merge pull request #489 from rohanbabbar04/abc
Browse files Browse the repository at this point in the history
Abstraction to `_matvec` and `_rmatvec`
  • Loading branch information
mrava87 authored Feb 27, 2023
2 parents 5af878d + fee227f commit d5c6252
Show file tree
Hide file tree
Showing 14 changed files with 77 additions and 54 deletions.
4 changes: 2 additions & 2 deletions examples/plot_sliding.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
# This is because we have not inverted our operator but simply applied
# the adjoint to estimate the representation of the input data in the Radon
# domain. We can do better if we use the inverse instead.
radoninv = pylops.LinearOperator(Slid, explicit=False).div(data.ravel(), niter=10)
radoninv = Slid.div(data.ravel(), niter=10)
reconstructed_datainv = Slid * radoninv.ravel()

radoninv = radoninv.reshape(dims)
Expand Down Expand Up @@ -288,7 +288,7 @@

reconstructed_data = Slid * radon

radoninv = pylops.LinearOperator(Slid, explicit=False).div(data.ravel(), niter=10)
radoninv = Slid.div(data.ravel(), niter=10)
radoninv = radoninv.reshape(Slid.dims)
reconstructed_datainv = Slid * radoninv

Expand Down
4 changes: 2 additions & 2 deletions pylops/basicoperators/directionalderivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def FirstDirectionalDerivative(
else:
Dop = Diagonal(v.ravel(), dtype=dtype)
Sop = Sum(dims=[len(dims)] + list(dims), axis=0, dtype=dtype)
ddop = LinearOperator(Sop * Dop * Gop)
ddop = Sop * Dop * Gop
ddop.dims = ddop.dimsd = dims
ddop.sampling = sampling
return ddop
Expand Down Expand Up @@ -136,7 +136,7 @@ def SecondDirectionalDerivative(
in the literature.
"""
Dop = FirstDirectionalDerivative(dims, v, sampling=sampling, edge=edge, dtype=dtype)
ddop = LinearOperator(-Dop.H * Dop)
ddop = -Dop.H * Dop
ddop.dims = ddop.dimsd = dims
ddop.sampling = sampling
return ddop
26 changes: 16 additions & 10 deletions pylops/basicoperators/firstderivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,28 +108,34 @@ def _register_multiplications(
order: int,
) -> None:
# choose _matvec and _rmatvec kind
self._matvec: Callable
self._rmatvec: Callable
self._hmatvec: Callable
self._hrmatvec: Callable
if kind == "forward":
self._matvec = self._matvec_forward
self._rmatvec = self._rmatvec_forward
self._hmatvec = self._matvec_forward
self._hrmatvec = self._rmatvec_forward
elif kind == "centered":
if order == 3:
self._matvec = self._matvec_centered3
self._rmatvec = self._rmatvec_centered3
self._hmatvec = self._matvec_centered3
self._hrmatvec = self._rmatvec_centered3
elif order == 5:
self._matvec = self._matvec_centered5
self._rmatvec = self._rmatvec_centered5
self._hmatvec = self._matvec_centered5
self._hrmatvec = self._rmatvec_centered5
else:
raise NotImplementedError("'order' must be '3, or '5'")
elif kind == "backward":
self._matvec = self._matvec_backward
self._rmatvec = self._rmatvec_backward
self._hmatvec = self._matvec_backward
self._hrmatvec = self._rmatvec_backward
else:
raise NotImplementedError(
"'kind' must be 'forward', 'centered', or 'backward'"
)

def _matvec(self, x: NDArray) -> NDArray:
return self._hmatvec(x)

def _rmatvec(self, x: NDArray) -> NDArray:
return self._hrmatvec(x)

@reshaped(swapaxis=True)
def _matvec_forward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
Expand Down
22 changes: 14 additions & 8 deletions pylops/basicoperators/secondderivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,28 @@ def _register_multiplications(
kind: str,
) -> None:
# choose _matvec and _rmatvec kind
self._matvec: Callable
self._rmatvec: Callable
self._hmatvec: Callable
self._hrmatvec: Callable
if kind == "forward":
self._matvec = self._matvec_forward
self._rmatvec = self._rmatvec_forward
self._hmatvec = self._matvec_forward
self._hrmatvec = self._rmatvec_forward
elif kind == "centered":
self._matvec = self._matvec_centered
self._rmatvec = self._rmatvec_centered
self._hmatvec = self._matvec_centered
self._hrmatvec = self._rmatvec_centered
elif kind == "backward":
self._matvec = self._matvec_backward
self._rmatvec = self._rmatvec_backward
self._hmatvec = self._matvec_backward
self._hrmatvec = self._rmatvec_backward
else:
raise NotImplementedError(
"'kind' must be 'forward', 'centered' or 'backward'"
)

def _matvec(self, x: NDArray) -> NDArray:
return self._hmatvec(x)

def _rmatvec(self, x: NDArray) -> NDArray:
return self._hrmatvec(x)

@reshaped(swapaxis=True)
def _matvec_forward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
Expand Down
22 changes: 18 additions & 4 deletions pylops/linearoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
"aslinearoperator",
]


import logging
from abc import ABC, abstractmethod

import numpy as np
import scipy as sp
Expand Down Expand Up @@ -40,7 +40,21 @@
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)


class LinearOperator:
class _LinearOperator(ABC):
"""Meta-class for Linear operator"""

@abstractmethod
def _matvec(self, x: NDArray) -> NDArray:
"""Matrix-vector multiplication handler."""
pass

@abstractmethod
def _rmatvec(self, x: NDArray) -> NDArray:
"""Matrix-vector adjoint multiplication handler."""
pass


class LinearOperator(_LinearOperator):
"""Common interface for performing matrix-vector products.
This class acts as an abstract interface between matrix-like
Expand Down Expand Up @@ -567,14 +581,14 @@ def dot(self, x: NDArray) -> NDArray:
# cast x to pylops linear operator if not already (this is done
# to allow mixing pylops and scipy operators)
Opx = aslinearoperator(x)
Op = LinearOperator(Op=_ProductLinearOperator(self, Opx))
Op = _ProductLinearOperator(self, Opx)
self._copy_attributes(Op, exclude=["dims", "explicit", "name"])
Op.clinear = Op.clinear and Opx.clinear
Op.explicit = False
Op.dims = Opx.dims
return Op
elif np.isscalar(x):
Op = LinearOperator(Op=_ScaledLinearOperator(self, x))
Op = _ScaledLinearOperator(self, x)
self._copy_attributes(
Op,
exclude=["explicit", "name"],
Expand Down
5 changes: 1 addition & 4 deletions pylops/optimization/cls_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,6 @@ def setup(
Display setup log
"""
self.Op = LinearOperator(self.Op)
self.y = y
self.niter_outer = niter_outer
self.niter_inner = niter_inner
Expand Down Expand Up @@ -1240,10 +1239,8 @@ def setup(
if alpha is not None:
self.alpha = alpha
elif not hasattr(self, "alpha"):
if not isinstance(self.Op, LinearOperator):
self.Op = LinearOperator(self.Op, explicit=False)
# compute largest eigenvalues of Op^H * Op
Op1 = LinearOperator(self.Op.H * self.Op, explicit=False)
Op1 = self.Op.H * self.Op
if get_module_name(self.ncp) == "numpy":
maxeig: float = np.abs(
Op1.eigs(
Expand Down
4 changes: 2 additions & 2 deletions pylops/signalprocessing/patch2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np

from pylops import LinearOperator, aslinearoperator
from pylops import LinearOperator
from pylops.basicoperators import BlockDiag, Diagonal, HStack, Restriction
from pylops.signalprocessing.sliding2d import _slidingsteps
from pylops.utils.tapers import taper2d
Expand Down Expand Up @@ -264,7 +264,7 @@ def Patch2D(
for win_in, win_end in zip(dwin0_ins, dwin0_ends)
]
)
Pop = aslinearoperator(combining0 * combining1 * OOp)
Pop = LinearOperator(combining0 * combining1 * OOp)
Pop.dims, Pop.dimsd = (
nwins0,
nwins1,
Expand Down
5 changes: 2 additions & 3 deletions pylops/signalprocessing/patch3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np

from pylops import LinearOperator, aslinearoperator
from pylops import LinearOperator
from pylops.basicoperators import BlockDiag, Diagonal, HStack, Restriction
from pylops.signalprocessing.sliding2d import _slidingsteps
from pylops.utils.tapers import tapernd
Expand Down Expand Up @@ -443,7 +443,7 @@ def Patch3D(
]
)

Pop = aslinearoperator(combining0 * combining1 * combining2 * OOp)
Pop = LinearOperator(combining0 * combining1 * combining2 * OOp)
Pop.dims, Pop.dimsd = (
nwins0,
nwins1,
Expand All @@ -452,6 +452,5 @@ def Patch3D(
int(dims[1] // nwins1),
int(dims[2] // nwins2),
), dimsd

Pop.name = name
return Pop
4 changes: 2 additions & 2 deletions pylops/signalprocessing/sliding1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
from typing import Tuple, Union

from pylops import LinearOperator, aslinearoperator
from pylops import LinearOperator
from pylops.basicoperators import BlockDiag, Diagonal, HStack, Restriction
from pylops.signalprocessing.sliding2d import _slidingsteps
from pylops.utils._internal import _value_or_sized_to_tuple
Expand Down Expand Up @@ -180,7 +180,7 @@ def Sliding1D(
for win_in, win_end in zip(dwin_ins, dwin_ends)
]
)
Sop = aslinearoperator(combining * OOp)
Sop = LinearOperator(combining * OOp)
Sop.dims, Sop.dimsd = (nwins, int(dim[0] // nwins)), dimd
Sop.name = name
return Sop
4 changes: 2 additions & 2 deletions pylops/signalprocessing/sliding2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np

from pylops import LinearOperator, aslinearoperator
from pylops import LinearOperator
from pylops.basicoperators import BlockDiag, Diagonal, HStack, Restriction
from pylops.utils.tapers import taper2d
from pylops.utils.typing import InputDimsLike, NDArray
Expand Down Expand Up @@ -214,7 +214,7 @@ def Sliding2D(
for win_in, win_end in zip(dwin_ins, dwin_ends)
]
)
Sop = aslinearoperator(combining * OOp)
Sop = LinearOperator(combining * OOp)
Sop.dims, Sop.dimsd = (nwins, int(dims[0] // nwins), dims[1]), dimsd
Sop.name = name
return Sop
6 changes: 3 additions & 3 deletions pylops/signalprocessing/sliding3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
from typing import Tuple

from pylops import LinearOperator, aslinearoperator
from pylops import LinearOperator
from pylops.basicoperators import BlockDiag, Diagonal, HStack, Restriction
from pylops.signalprocessing.sliding2d import _slidingsteps
from pylops.utils.tapers import taper3d
Expand Down Expand Up @@ -99,7 +99,7 @@ def Sliding3D(
tapertype: str = "hanning",
nproc: int = 1,
name: str = "P",
) -> None:
) -> LinearOperator:
"""3D Sliding transform operator.w
Apply a transform operator ``Op`` repeatedly to patches of the model
Expand Down Expand Up @@ -215,7 +215,7 @@ def Sliding3D(
for win_in, win_end in zip(dwin0_ins, dwin0_ends)
]
)
Sop = aslinearoperator(combining0 * combining1 * OOp)
Sop = LinearOperator(combining0 * combining1 * OOp)
Sop.dims, Sop.dimsd = (
nwins0,
nwins1,
Expand Down
11 changes: 7 additions & 4 deletions pylops/torchoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pylops.utils.typing import TensorTypeLike


class TorchOperator(LinearOperator):
class TorchOperator:
"""Wrap a PyLops operator into a Torch function.
This class can be used to wrap a pylops operator into a
Expand Down Expand Up @@ -63,9 +63,9 @@ def __init__(
raise NotImplementedError(torch_message)
self.device = device
self.devicetorch = devicetorch
super().__init__(
dtype=np.dtype(Op.dtype), dims=Op.dims, dimsd=Op.dims, name=Op.name
)
self.dtype = np.dtype(Op.dtype)
self.dims, self.dimsd = Op.dims, Op.dimsd
self.name = Op.name
# define transpose indices to bring batch to last dimension before applying
# pylops forward and adjoint (this will call matmat and rmatmat)
self.transpf = np.roll(np.arange(2 if flatten else len(self.dims) + 1), -1)
Expand All @@ -82,6 +82,9 @@ def __init__(
)
self.Top = _TorchOperator.apply

def __call__(self, x):
return self.apply(x)

def apply(self, x: TensorTypeLike) -> TensorTypeLike:
"""Apply forward pass to input vector
Expand Down
7 changes: 3 additions & 4 deletions pytests/test_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pytest
from numpy.testing import assert_array_almost_equal

from pylops import LinearOperator
from pylops.basicoperators import MatrixMult
from pylops.signalprocessing import Patch2D, Patch3D
from pylops.signalprocessing.patch2d import patch2d_design
Expand Down Expand Up @@ -111,7 +110,7 @@ def test_Patch2D(par):
x = np.ones((par["ny"] * nwins[0], par["nt"] * nwins[1]))
y = Pop * x.ravel()

xinv = LinearOperator(Pop) / y
xinv = Pop / y
assert_array_almost_equal(x.ravel(), xinv)


Expand Down Expand Up @@ -145,7 +144,7 @@ def test_Patch2D_scalings(par):
x = np.ones((par["ny"] * nwins[0], par["nt"] * nwins[1]))
y = Pop * x.ravel()

xinv = LinearOperator(Pop) / y
xinv = Pop / y
assert_array_almost_equal(x.ravel(), xinv)


Expand Down Expand Up @@ -189,5 +188,5 @@ def test_Patch3D(par):
x = np.ones((par["ny"] * nwins[0], par["nx"] * nwins[1], par["nt"] * nwins[2]))
y = Pop * x.ravel()

xinv = LinearOperator(Pop) / y
xinv = Pop / y
assert_array_almost_equal(x.ravel(), xinv)
7 changes: 3 additions & 4 deletions pytests/test_sliding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pytest
from numpy.testing import assert_array_almost_equal

from pylops import LinearOperator
from pylops.basicoperators import MatrixMult
from pylops.signalprocessing import Sliding1D, Sliding2D, Sliding3D
from pylops.signalprocessing.sliding1d import sliding1d_design
Expand Down Expand Up @@ -89,7 +88,7 @@ def test_Sliding1D(par):
x = np.ones(par["ny"] * nwins)
y = Slid * x.ravel()

xinv = LinearOperator(Slid) / y
xinv = Slid / y
assert_array_almost_equal(x.ravel(), xinv)


Expand All @@ -113,7 +112,7 @@ def test_Sliding2D(par):
x = np.ones((par["ny"] * nwins, par["nt"]))
y = Slid * x.ravel()

xinv = LinearOperator(Slid) / y
xinv = Slid / y
assert_array_almost_equal(x.ravel(), xinv)


Expand Down Expand Up @@ -150,5 +149,5 @@ def test_Sliding3D(par):
x = np.ones((par["ny"] * par["nx"] * nwins[0] * nwins[1], par["nt"]))
y = Slid * x.ravel()

xinv = LinearOperator(Slid) / y
xinv = Slid / y
assert_array_almost_equal(x.ravel(), xinv)

0 comments on commit d5c6252

Please sign in to comment.