From 1eebe50e2879225d4a756288bae1a4eace1125f0 Mon Sep 17 00:00:00 2001 From: Stef-Martin Date: Thu, 14 Nov 2024 14:18:44 +0100 Subject: [PATCH 1/3] Repr FourierOp WIP --- src/mrpro/operators/FastFourierOp.py | 5 +++++ src/mrpro/operators/FourierOp.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/mrpro/operators/FastFourierOp.py b/src/mrpro/operators/FastFourierOp.py index 4ffe7e3a6..6163d0080 100644 --- a/src/mrpro/operators/FastFourierOp.py +++ b/src/mrpro/operators/FastFourierOp.py @@ -139,3 +139,8 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: dim=self._dim, ), ) + + def __repr__(self): + """Representation method for FastFourierOperator.""" + out = f'Dimension along which FFT/NUFFT is applied: {self._dim!s} \n' + return out diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index cacdda1dc..753b796c0 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -125,6 +125,7 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): self._fwd_nufft_op = None self._adj_nufft_op = None self._kshape = traj.broadcasted_shape + self._traj = traj @classmethod def from_kdata(cls, kdata: KData, recon_shape: SpatialDimension[int] | None = None) -> Self: @@ -223,3 +224,18 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: x = x.permute(*unpermute) return (x,) + + def __repr__(self): + """Representation method for Fourier Operator.""" + device = str(self._traj.device) + if self._nufft_dims: + dims = self._nufft_dims + else: + dims = self._fft_dims + out = ( + f'{type(self).__name__} on device: {device}\n' + f'{self._traj}\n' + f'Dimension along which FFT/NUFFT is applied: {dims}\n' + f'Dimension which is ignored in FFT/NUFFT: {self._ignore_dims}\n' + ) + return out From 4a69dd21c336b490e12d70cb3c6ed7e9b14ab587 Mon Sep 17 00:00:00 2001 From: Stef-Martin Date: Thu, 14 Nov 2024 16:20:40 +0100 Subject: [PATCH 2/3] Account for both FFT and NUFFT --- src/mrpro/operators/CartesianSamplingOp.py | 9 +++++++++ src/mrpro/operators/FastFourierOp.py | 2 +- src/mrpro/operators/FourierOp.py | 23 +++++++++++++--------- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/mrpro/operators/CartesianSamplingOp.py b/src/mrpro/operators/CartesianSamplingOp.py index 07f8aba65..bb280183e 100644 --- a/src/mrpro/operators/CartesianSamplingOp.py +++ b/src/mrpro/operators/CartesianSamplingOp.py @@ -212,3 +212,12 @@ def _broadcast_and_scatter_along_last_dim( ).scatter_(dim=-1, index=idx_expanded, src=data_to_scatter) return data_scattered + + def __repr__(self): + """Representation method for CartesianSamplingOperator.""" + out = ( + f'Needs indexing: {self._needs_indexing}\n' + f'Sorted grid shape: {self._sorted_grid_shape}\n' + f'Inside encoding matrix index: {self._inside_encoding_matrix_idx}' + ) + return out diff --git a/src/mrpro/operators/FastFourierOp.py b/src/mrpro/operators/FastFourierOp.py index 6163d0080..163cd32c1 100644 --- a/src/mrpro/operators/FastFourierOp.py +++ b/src/mrpro/operators/FastFourierOp.py @@ -142,5 +142,5 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: def __repr__(self): """Representation method for FastFourierOperator.""" - out = f'Dimension along which FFT/NUFFT is applied: {self._dim!s} \n' + out = f'Dimension along which FFT is applied: {list(self._dim)!s}' return out diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index 753b796c0..56158b06d 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -227,15 +227,20 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: def __repr__(self): """Representation method for Fourier Operator.""" - device = str(self._traj.device) if self._nufft_dims: - dims = self._nufft_dims + string = ( + f'Dimension along which NUFFT is applied: {self._nufft_dims}\n' + f'Dimension which is ignored in NUFFT: {self._ignore_dims!s}\n' + ) else: - dims = self._fft_dims - out = ( - f'{type(self).__name__} on device: {device}\n' - f'{self._traj}\n' - f'Dimension along which FFT/NUFFT is applied: {dims}\n' - f'Dimension which is ignored in FFT/NUFFT: {self._ignore_dims}\n' - ) + string = '' + if self._fft_dims: + string += ( + f'{self._fast_fourier_op}\n' + f'Dimension which is ignored in FFT: {self._ignore_dims!s}\n' + f'{self._cart_sampling_op}\n' + ) + device = str(self._traj.device) + traj = self._traj + out = f'{type(self).__name__} on device: {device}\n' f'{traj}\n' f'{string}' return out From 86efe0b4795917fae5bc3b344ef8543816676386 Mon Sep 17 00:00:00 2001 From: Stef-Martin Date: Wed, 20 Nov 2024 14:45:06 +0100 Subject: [PATCH 3/3] Solving merge issues, repr for adjoint and other ops still to do --- src/mrpro/operators/CartesianSamplingOp.py | 13 +++++--- src/mrpro/operators/FourierOp.py | 39 ++++++++++++---------- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/src/mrpro/operators/CartesianSamplingOp.py b/src/mrpro/operators/CartesianSamplingOp.py index f5476e690..bf63c4b50 100644 --- a/src/mrpro/operators/CartesianSamplingOp.py +++ b/src/mrpro/operators/CartesianSamplingOp.py @@ -89,8 +89,6 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) -> else: self._inside_encoding_matrix_idx: torch.Tensor | None = None - self.register_buffer('_fft_idx', kidx) - # we can skip the indexing if the data is already sorted self._needs_indexing = ( not torch.all(torch.diff(kidx) == 1) @@ -98,6 +96,11 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) -> or self._inside_encoding_matrix_idx is not None ) + if self._needs_indexing: + self.register_buffer('_fft_idx', kidx) + else: + self._fft_idx: torch.Tensor + self._trajectory_shape = traj.broadcasted_shape self._sorted_grid_shape = sorted_grid_shape @@ -222,10 +225,12 @@ def gram(self) -> 'CartesianSamplingGramOp': Gram operator for this Cartesian Sampling Operator """ return CartesianSamplingGramOp(self) - - def __repr__(self): + + def __repr__(self) -> str: """Representation method for CartesianSamplingOperator.""" + device = self._fft_idx.device if self._fft_idx is not None else 'none' out = ( + f'{type(self).__name__} on device: {device}\n' f'Needs indexing: {self._needs_indexing}\n' f'Sorted grid shape: {self._sorted_grid_shape}\n' f'Inside encoding matrix index: {self._inside_encoding_matrix_idx}' diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index 039a94b32..2a1dc6625 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -126,7 +126,7 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): self._fwd_nufft_op = None self._adj_nufft_op = None self._kshape = traj.broadcasted_shape - self._traj = traj + self._traj_out = repr(traj) @classmethod def from_kdata(cls, kdata: KData, recon_shape: SpatialDimension[int] | None = None) -> Self: @@ -225,30 +225,33 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: x = x.permute(*unpermute) return (x,) - + @property def gram(self) -> LinearOperator: """Return the gram operator.""" return FourierGramOp(self) - - def __repr__(self): + + def __repr__(self) -> str: """Representation method for Fourier Operator.""" + ignore = f'Dimension(s) which are ignored: {self._ignore_dims!s}\n' + + string = '' + device_omega = None + device_cart = None + if self._nufft_dims: - string = ( - f'Dimension along which NUFFT is applied: {self._nufft_dims}\n' - f'Dimension which is ignored in NUFFT: {self._ignore_dims!s}\n' - ) - else: - string = '' + string += f'Dimension(s) along which NUFFT is applied: {self._nufft_dims}\n{ignore}' + device_omega = self._omega.device if self._omega is not None else None if self._fft_dims: - string += ( - f'{self._fast_fourier_op}\n' - f'Dimension which is ignored in FFT: {self._ignore_dims!s}\n' - f'{self._cart_sampling_op}\n' - ) - device = str(self._traj.device) - traj = self._traj - out = f'{type(self).__name__} on device: {device}\n' f'{traj}\n' f'{string}' + string += f'{self._fast_fourier_op}\n{ignore}\n{self._cart_sampling_op}\n' + device_cart = self._cart_sampling_op._fft_idx.device if self._cart_sampling_op is not None else None + + if device_omega and device_cart: + device = device_omega if device_omega == device_cart else 'Different devices' + else: + device = device_omega or device_cart or 'None' + + out = f'{type(self).__name__} on device: {device}\n' f'{string}\n' f'{self._traj_out}' return out