Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

__repr__ for operators #533

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/mrpro/operators/CartesianSamplingOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,18 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) ->
else:
self._inside_encoding_matrix_idx: torch.Tensor | None = None

self.register_buffer('_fft_idx', kidx)

# we can skip the indexing if the data is already sorted
self._needs_indexing = (
not torch.all(torch.diff(kidx) == 1)
or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx
or self._inside_encoding_matrix_idx is not None
)

if self._needs_indexing:
self.register_buffer('_fft_idx', kidx)
else:
self._fft_idx: torch.Tensor

self._trajectory_shape = traj.broadcasted_shape
self._sorted_grid_shape = sorted_grid_shape

Expand Down Expand Up @@ -223,6 +226,17 @@ def gram(self) -> 'CartesianSamplingGramOp':
"""
return CartesianSamplingGramOp(self)

def __repr__(self) -> str:
"""Representation method for CartesianSamplingOperator."""
device = self._fft_idx.device if self._fft_idx is not None else 'none'
out = (
f'{type(self).__name__} on device: {device}\n'
f'Needs indexing: {self._needs_indexing}\n'
f'Sorted grid shape: {self._sorted_grid_shape}\n'
f'Inside encoding matrix index: {self._inside_encoding_matrix_idx}'
)
return out


class CartesianSamplingGramOp(LinearOperator):
"""Gram operator for Cartesian Sampling Operator.
Expand Down
5 changes: 5 additions & 0 deletions src/mrpro/operators/FastFourierOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,8 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]:
dim=self._dim,
),
)

def __repr__(self):
"""Representation method for FastFourierOperator."""
out = f'Dimension along which FFT is applied: {list(self._dim)!s}'
return out
24 changes: 24 additions & 0 deletions src/mrpro/operators/FourierOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]):
self._fwd_nufft_op = None
self._adj_nufft_op = None
self._kshape = traj.broadcasted_shape
self._traj_out = repr(traj)

@classmethod
def from_kdata(cls, kdata: KData, recon_shape: SpatialDimension[int] | None = None) -> Self:
Expand Down Expand Up @@ -230,6 +231,29 @@ def gram(self) -> LinearOperator:
"""Return the gram operator."""
return FourierGramOp(self)

def __repr__(self) -> str:
"""Representation method for Fourier Operator."""
ignore = f'Dimension(s) which are ignored: {self._ignore_dims!s}\n'

string = ''
device_omega = None
device_cart = None

if self._nufft_dims:
string += f'Dimension(s) along which NUFFT is applied: {self._nufft_dims}\n{ignore}'
device_omega = self._omega.device if self._omega is not None else None
if self._fft_dims:
string += f'{self._fast_fourier_op}\n{ignore}\n{self._cart_sampling_op}\n'
device_cart = self._cart_sampling_op._fft_idx.device if self._cart_sampling_op is not None else None

if device_omega and device_cart:
device = device_omega if device_omega == device_cart else 'Different devices'
else:
device = device_omega or device_cart or 'None'

out = f'{type(self).__name__} on device: {device}\n' f'{string}\n' f'{self._traj_out}'
return out


def symmetrize(kernel: torch.Tensor, rank: int) -> torch.Tensor:
"""Enforce hermitian symmetry on the kernel. Returns only half of the kernel."""
Expand Down