Skip to content

Commit

Permalink
feature: remove LinearOperator inheritance from TorchOperator
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Feb 25, 2023
1 parent 363873a commit 4428dda
Showing 1 changed file with 15 additions and 24 deletions.
39 changes: 15 additions & 24 deletions pylops/torchoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
"TorchOperator",
]

from typing import Optional, Callable
from typing import Optional

import numpy as np

from pylops import LinearOperator
from pylops.utils import deps, NDArray
from pylops.utils import deps

if deps.torch_enabled:
from pylops._torchoperator import _TorchOperator
Expand All @@ -20,7 +20,7 @@
from pylops.utils.typing import TensorTypeLike


class TorchOperator(LinearOperator):
class TorchOperator:
"""Wrap a PyLops operator into a Torch function.
This class can be used to wrap a pylops operator into a
Expand Down Expand Up @@ -63,33 +63,24 @@ def __init__(
raise NotImplementedError(torch_message)
self.device = device
self.devicetorch = devicetorch
super().__init__(
dtype=np.dtype(Op.dtype), dims=Op.dims, dimsd=Op.dims, name=Op.name
)
self.dtype = np.dtype(Op.dtype)
self.dims, self.dimsd = Op.dims, Op.dimsd
self.name = Op.name
# define transpose indices to bring batch to last dimension before applying
# pylops forward and adjoint (this will call matmat and rmatmat)
self.transpf = np.roll(np.arange(2 if flatten else len(self.dims) + 1), -1)
self.transpb = np.roll(np.arange(2 if flatten else len(self.dims) + 1), 1)
self.Op = Op
self._register_torchop(batch)
self.Top = _TorchOperator.apply

def _register_torchop(self, batch: bool):
# choose _matvec and _rmatvec
self.matvec: Callable
self.rmatvec: Callable
if not batch:
self.matvec = lambda x: self.Op @ x
self.rmatvec = lambda x: self.Op.H @ x
self.matvec = lambda x: Op @ x
self.rmatvec = lambda x: Op.H @ x
else:
self.matvec = lambda x: (self.Op @ x.transpose(self.transpf)).transpose(self.transpb)
self.rmatvec = lambda x: (self.Op.H @ x.transpose(self.transpf)).transpose(self.transpb)

def _matvec(self, x: NDArray) -> NDArray:
return self.matvec(x)

def _rmatvec(self, x: NDArray) -> NDArray:
return self.rmatvec(x)
self.matvec = lambda x: (Op @ x.transpose(self.transpf)).transpose(
self.transpb
)
self.rmatvec = lambda x: (Op.H @ x.transpose(self.transpf)).transpose(
self.transpb
)
self.Top = _TorchOperator.apply

def apply(self, x: TensorTypeLike) -> TensorTypeLike:
"""Apply forward pass to input vector
Expand Down

0 comments on commit 4428dda

Please sign in to comment.