Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
fzimmermann89 committed Jan 3, 2025
2 parents 70f41a4 + caf7eb0 commit 0b52848
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 376 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ repos:
hooks:
- id: check-added-large-files
- id: check-merge-conflict
exclude_types: [rst]
- id: check-yaml
- id: check-toml
- id: check-json
Expand Down
288 changes: 281 additions & 7 deletions src/mrpro/data/KData.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
"""MR raw data / k-space data class."""

import copy
import dataclasses
import datetime
import warnings
from collections.abc import Callable, Sequence
from pathlib import Path
from types import EllipsisType
from typing import Literal, cast

import h5py
import ismrmrd
import numpy as np
import torch
from einops import rearrange
from typing_extensions import Self
from einops import rearrange, repeat
from typing_extensions import Self, TypeVar

from mrpro.data._kdata.KDataRearrangeMixin import KDataRearrangeMixin
from mrpro.data._kdata.KDataRemoveOsMixin import KDataRemoveOsMixin
from mrpro.data._kdata.KDataSelectMixin import KDataSelectMixin
from mrpro.data._kdata.KDataSplitMixin import KDataSplitMixin
from mrpro.data.acq_filters import has_n_coils, is_image_acquisition
from mrpro.data.AcqInfo import AcqInfo, convert_time_stamp_siemens, rearrange_acq_info_fields
from mrpro.data.EncodingLimits import EncodingLimits
Expand All @@ -30,6 +28,8 @@
from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator
from mrpro.data.traj_calculators.KTrajectoryIsmrmrd import KTrajectoryIsmrmrd

RotationOrTensor = TypeVar('RotationOrTensor', bound=torch.Tensor | Rotation)

KDIM_SORT_LABELS = (
'k1',
'k2',
Expand Down Expand Up @@ -64,7 +64,9 @@


@dataclasses.dataclass(slots=True, frozen=True)
class KData(KDataSplitMixin, KDataRearrangeMixin, KDataSelectMixin, KDataRemoveOsMixin, MoveDataMixin):
class KData(
MoveDataMixin,
):
"""MR raw data / k-space data class."""

header: KHeader
Expand Down Expand Up @@ -424,3 +426,275 @@ def compress_coils(
kdata_coil_compressed_flattened, [*kdata_permuted.shape[:-1], n_compressed_coils]
).permute(*np.argsort(permute_order))
return type(self)(self.header.clone(), kdata_coil_compressed, self.traj.clone())

def rearrange_k2_k1_into_k1(self: Self) -> Self:
"""Rearrange kdata from (... k2 k1 ...) to (... 1 (k2 k1) ...).
Parameters
----------
kdata
K-space data (other coils k2 k1 k0)
Returns
-------
K-space data (other coils 1 (k2 k1) k0)
"""
# Rearrange data
kdat = rearrange(self.data, '... coils k2 k1 k0->... coils 1 (k2 k1) k0')

# Rearrange trajectory
ktraj = rearrange(self.traj.as_tensor(), 'dim ... k2 k1 k0-> dim ... 1 (k2 k1) k0')

# Create new header with correct shape
kheader = copy.deepcopy(self.header)

# Update shape of acquisition info index
kheader.acq_info.apply_(
lambda field: rearrange_acq_info_fields(field, 'other k2 k1 ... -> other 1 (k2 k1) ...')
)

return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj))

def remove_readout_os(self: Self) -> Self:
"""Remove any oversampling along the readout (k0) direction [GAD]_.
Returns a copy of the data.
Parameters
----------
kdata
K-space data
Returns
-------
Copy of K-space data with oversampling removed.
Raises
------
ValueError
If the recon matrix along x is larger than the encoding matrix along x.
References
----------
.. [GAD] Gadgetron https://github.com/gadgetron/gadgetron-python
"""
from mrpro.operators.FastFourierOp import FastFourierOp

# Ratio of k0/x between encoded and recon space
x_ratio = self.header.recon_matrix.x / self.header.encoding_matrix.x
if x_ratio == 1:
# If the encoded and recon space is the same we don't have to do anything
return self
elif x_ratio > 1:
raise ValueError('Recon matrix along x should be equal or larger than encoding matrix along x.')

# Starting and end point of image after removing oversampling
start_cropped_readout = (self.header.encoding_matrix.x - self.header.recon_matrix.x) // 2
end_cropped_readout = start_cropped_readout + self.header.recon_matrix.x

def crop_readout(data_to_crop: torch.Tensor) -> torch.Tensor:
# returns a cropped copy
return data_to_crop[..., start_cropped_readout:end_cropped_readout].clone()

# Transform to image space along readout, crop to reconstruction matrix size and transform back
fourier_k0_op = FastFourierOp(dim=(-1,))
(cropped_data,) = fourier_k0_op(crop_readout(*fourier_k0_op.H(self.data)))

# Adapt trajectory
ks = [self.traj.kz, self.traj.ky, self.traj.kx]
# only cropped ks that are not broadcasted/singleton along k0
cropped_ks = [crop_readout(k) if k.shape[-1] > 1 else k.clone() for k in ks]
cropped_traj = KTrajectory(cropped_ks[0], cropped_ks[1], cropped_ks[2])

# Adapt header parameters
header = copy.deepcopy(self.header)
header.encoding_matrix.x = cropped_data.shape[-1]

return type(self)(header, cropped_data, cropped_traj)

def select_other_subset(
self: Self,
subset_idx: torch.Tensor,
subset_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'],
) -> Self:
"""Select a subset from the other dimension of KData.
Parameters
----------
kdata
K-space data (other coils k2 k1 k0)
subset_idx
Index which elements of the other subset to use, e.g. phase 0,1,2 and 5
subset_label
Name of the other label, e.g. phase
Returns
-------
K-space data (other_subset coils k2 k1 k0)
Raises
------
ValueError
If the subset indices are not available in the data
"""
# Make a copy such that the original kdata.header remains the same
kheader = copy.deepcopy(self.header)
ktraj = self.traj.as_tensor()

# Verify that the subset_idx is available
label_idx = getattr(kheader.acq_info.idx, subset_label)
if not all(el in torch.unique(label_idx) for el in subset_idx):
raise ValueError('Subset indices are outside of the available index range')

# Find subset index in acq_info index
other_idx = torch.cat([torch.where(idx == label_idx[:, 0, 0])[0] for idx in subset_idx], dim=0)

# Adapt header
kheader.acq_info.apply_(
lambda field: field[other_idx, ...] if isinstance(field, torch.Tensor | Rotation) else field
)

# Select data
kdat = self.data[other_idx, ...]

# Select ktraj
if ktraj.shape[1] > 1:
ktraj = ktraj[:, other_idx, ...]

return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj))

def _split_k2_or_k1_into_other(
self,
split_idx: torch.Tensor,
other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'],
split_dir: Literal['k2', 'k1'],
) -> Self:
"""Based on an index tensor, split the data in e.g. phases.
Parameters
----------
split_idx
2D index describing the k2 or k1 points in each block to be moved to the other dimension
(other_split, k1_per_split) or (other_split, k2_per_split)
other_label
Label of other dimension, e.g. repetition, phase
split_dir
Dimension to split, either 'k1' or 'k2'
Returns
-------
K-space data with new shape
((other other_split) coils k2 k1_per_split k0) or ((other other_split) coils k2_per_split k1 k0)
Raises
------
ValueError
Already existing "other_label" can only be of length 1
"""
# Number of other
n_other = split_idx.shape[0]

# Set-up splitting
if split_dir == 'k1':
# Split along k1 dimensions
def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor:
return dat_traj[:, :, :, split_idx, :]

def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor:
# cast due to https://github.com/python/mypy/issues/10817
return cast(RotationOrTensor, acq_info[:, :, split_idx, ...])

# Rearrange other_split and k1 dimension
rearrange_pattern_data = 'other coils k2 other_split k1 k0->(other other_split) coils k2 k1 k0'
rearrange_pattern_traj = 'dim other k2 other_split k1 k0->dim (other other_split) k2 k1 k0'
rearrange_pattern_acq_info = 'other k2 other_split k1 ... -> (other other_split) k2 k1 ...'

elif split_dir == 'k2':
# Split along k2 dimensions
def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor:
return dat_traj[:, :, split_idx, :, :]

def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor:
return cast(RotationOrTensor, acq_info[:, split_idx, ...])

# Rearrange other_split and k1 dimension
rearrange_pattern_data = 'other coils other_split k2 k1 k0->(other other_split) coils k2 k1 k0'
rearrange_pattern_traj = 'dim other other_split k2 k1 k0->dim (other other_split) k2 k1 k0'
rearrange_pattern_acq_info = 'other other_split k2 k1 ... -> (other other_split) k2 k1 ...'

else:
raise ValueError('split_dir has to be "k1" or "k2"')

# Split data
kdat = rearrange(split_data_traj(self.data), rearrange_pattern_data)

# First we need to make sure the other dimension is the same as data then we can split the trajectory
ktraj = self.traj.as_tensor()
# Verify that other dimension of trajectory is 1 or matches data
if ktraj.shape[1] > 1 and ktraj.shape[1] != self.data.shape[0]:
raise ValueError(f'other dimension of trajectory has to be 1 or match data ({self.data.shape[0]})')
elif ktraj.shape[1] == 1 and self.data.shape[0] > 1:
ktraj = repeat(ktraj, 'dim other k2 k1 k0->dim (other_data other) k2 k1 k0', other_data=self.data.shape[0])
ktraj = rearrange(split_data_traj(ktraj), rearrange_pattern_traj)

# Create new header with correct shape
kheader = self.header.clone()

# Update shape of acquisition info index
kheader.acq_info.apply_(
lambda field: rearrange_acq_info_fields(split_acq_info(field), rearrange_pattern_acq_info)
if isinstance(field, Rotation | torch.Tensor)
else field
)

# acq_info for new other dimensions
acq_info_other_split = repeat(
torch.linspace(0, n_other - 1, n_other), 'other-> other k2 k1', k2=kdat.shape[-3], k1=kdat.shape[-2]
)
setattr(kheader.acq_info.idx, other_label, acq_info_other_split)

return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj))

def split_k1_into_other(
self: Self,
split_idx: torch.Tensor,
other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'],
) -> Self:
"""Based on an index tensor, split the data in e.g. phases.
Parameters
----------
kdata
K-space data (other coils k2 k1 k0)
split_idx
2D index describing the k1 points in each block to be moved to other dimension (other_split, k1_per_split)
other_label
Label of other dimension, e.g. repetition, phase
Returns
-------
K-space data with new shape ((other other_split) coils k2 k1_per_split k0)
"""
return self._split_k2_or_k1_into_other(split_idx, other_label, split_dir='k1')

def split_k2_into_other(
self: Self,
split_idx: torch.Tensor,
other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'],
) -> Self:
"""Based on an index tensor, split the data in e.g. phases.
Parameters
----------
kdata
K-space data (other coils k2 k1 k0)
split_idx
2D index describing the k2 points in each block to be moved to other dimension (other_split, k2_per_split)
other_label
Label of other dimension, e.g. repetition, phase
Returns
-------
K-space data with new shape ((other other_split) coils k2_per_split k1 k0)
"""
return self._split_k2_or_k1_into_other(split_idx, other_label, split_dir='k2')
41 changes: 0 additions & 41 deletions src/mrpro/data/_kdata/KDataProtocol.py

This file was deleted.

Loading

0 comments on commit 0b52848

Please sign in to comment.