Skip to content

Commit

Permalink
Refactor AcqInfo and Header information
Browse files Browse the repository at this point in the history
ghstack-source-id: 25b1329dab363823f7d4aa2322bc172a976f93b9
ghstack-comment-id: 2501006700
Pull Request resolved: #560
  • Loading branch information
fzimmermann89 committed Dec 28, 2024
1 parent 07811a9 commit 7b963aa
Show file tree
Hide file tree
Showing 19 changed files with 689 additions and 579 deletions.
256 changes: 166 additions & 90 deletions src/mrpro/data/AcqInfo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Acquisition information dataclass."""

from collections.abc import Sequence
from dataclasses import dataclass
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import Literal, TypeAlias, overload

import ismrmrd
import numpy as np
Expand All @@ -26,140 +27,199 @@ def rearrange_acq_info_fields(field: object, pattern: str, **axes_lengths: dict[
return field


_convert_time_stamp_type: TypeAlias = Callable[
[
torch.Tensor,
Literal[
'acquisition_time_stamp', 'physiology_time_stamp_1', 'physiology_time_stamp_2', 'physiology_time_stamp_3'
],
],
torch.Tensor,
]


def convert_time_stamp_siemens(
timestamp: torch.Tensor,
_: str,
) -> torch.Tensor:
"""Convert Siemens time stamp to seconds."""
return timestamp.double() * 2.5e-3


def _int_factory() -> torch.Tensor:
# TODO: check dtype
return torch.zeros(1, 1, 1, 1, dtype=torch.int64)


def _float_factory() -> torch.Tensor:
return torch.zeros(1, 1, 1, 1, dtype=torch.float)


def _position_factory() -> SpatialDimension[torch.Tensor]:
return SpatialDimension(
torch.zeros(1, 1, 1, 1, dtype=torch.float),
torch.zeros(1, 1, 1, 1, dtype=torch.float),
torch.zeros(1, 1, 1, 1, dtype=torch.float),
)


@dataclass(slots=True)
class AcqIdx(MoveDataMixin):
"""Acquisition index for each readout."""

k1: torch.Tensor
k1: torch.Tensor = field(default_factory=_int_factory)
"""First phase encoding."""

k2: torch.Tensor
k2: torch.Tensor = field(default_factory=_int_factory)
"""Second phase encoding."""

average: torch.Tensor
average: torch.Tensor = field(default_factory=_int_factory)
"""Signal average."""

slice: torch.Tensor
slice: torch.Tensor = field(default_factory=_int_factory)
"""Slice number (multi-slice 2D)."""

contrast: torch.Tensor
contrast: torch.Tensor = field(default_factory=_int_factory)
"""Echo number in multi-echo."""

phase: torch.Tensor
phase: torch.Tensor = field(default_factory=_int_factory)
"""Cardiac phase."""

repetition: torch.Tensor
repetition: torch.Tensor = field(default_factory=_int_factory)
"""Counter in repeated/dynamic acquisitions."""

set: torch.Tensor
set: torch.Tensor = field(default_factory=_int_factory)
"""Sets of different preparation, e.g. flow encoding, diffusion weighting."""

segment: torch.Tensor
segment: torch.Tensor = field(default_factory=_int_factory)
"""Counter for segmented acquisitions."""

user0: torch.Tensor
user0: torch.Tensor = field(default_factory=_int_factory)
"""User index 0."""

user1: torch.Tensor
user1: torch.Tensor = field(default_factory=_int_factory)
"""User index 1."""

user2: torch.Tensor
user2: torch.Tensor = field(default_factory=_int_factory)
"""User index 2."""

user3: torch.Tensor
user3: torch.Tensor = field(default_factory=_int_factory)
"""User index 3."""

user4: torch.Tensor
user4: torch.Tensor = field(default_factory=_int_factory)
"""User index 4."""

user5: torch.Tensor
user5: torch.Tensor = field(default_factory=_int_factory)
"""User index 5."""

user6: torch.Tensor
user6: torch.Tensor = field(default_factory=_int_factory)
"""User index 6."""

user7: torch.Tensor
user7: torch.Tensor = field(default_factory=_int_factory)
"""User index 7."""


@dataclass(slots=True)
class AcqInfo(MoveDataMixin):
"""Acquisition information for each readout."""

idx: AcqIdx
"""Indices describing acquisitions (i.e. readouts)."""

acquisition_time_stamp: torch.Tensor
"""Clock time stamp. Not in s but in vendor-specific time units (e.g. 2.5ms for Siemens)"""
class UserValues(MoveDataMixin):
"""User Values used in AcqInfo."""

float1: torch.Tensor = field(default_factory=_float_factory)
float2: torch.Tensor = field(default_factory=_float_factory)
float3: torch.Tensor = field(default_factory=_float_factory)
float4: torch.Tensor = field(default_factory=_float_factory)
float5: torch.Tensor = field(default_factory=_float_factory)
float6: torch.Tensor = field(default_factory=_float_factory)
float7: torch.Tensor = field(default_factory=_float_factory)
float8: torch.Tensor = field(default_factory=_float_factory)
int1: torch.Tensor = field(default_factory=_int_factory)
int2: torch.Tensor = field(default_factory=_int_factory)
int3: torch.Tensor = field(default_factory=_int_factory)
int4: torch.Tensor = field(default_factory=_int_factory)
int5: torch.Tensor = field(default_factory=_int_factory)
int6: torch.Tensor = field(default_factory=_int_factory)
int7: torch.Tensor = field(default_factory=_int_factory)
int8: torch.Tensor = field(default_factory=_int_factory)

active_channels: torch.Tensor
"""Number of active receiver coil elements."""

available_channels: torch.Tensor
"""Number of available receiver coil elements."""

center_sample: torch.Tensor
"""Index of the readout sample corresponding to k-space center (zero indexed)."""
@dataclass(slots=True)
class PhysiologyTimestamps:
"""Time stamps relative to physiological triggering, e.g. ECG. Not in s but in vendor-specific time units."""

channel_mask: torch.Tensor
"""Bit mask indicating active coils (64*16 = 1024 bits)."""
timestamp1: torch.Tensor = field(default_factory=_float_factory)
timestamp2: torch.Tensor = field(default_factory=_float_factory)
timestamp3: torch.Tensor = field(default_factory=_float_factory)

discard_post: torch.Tensor
"""Number of readout samples to be discarded at the end (e.g. if the ADC is active during gradient events)."""

discard_pre: torch.Tensor
"""Number of readout samples to be discarded at the beginning (e.g. if the ADC is active during gradient events)"""
@dataclass(slots=True)
class AcqInfo(MoveDataMixin):
"""Acquisition information for each readout."""

encoding_space_ref: torch.Tensor
"""Indexed reference to the encoding spaces enumerated in the MRD (xml) header."""
idx: AcqIdx = field(default_factory=AcqIdx)
"""Indices describing acquisitions (i.e. readouts)."""

flags: torch.Tensor
acquisition_time_stamp: torch.Tensor = field(default_factory=_float_factory)
"""Clock time stamp. Usually in seconds (Siemens: seconds since midnight)"""
# TODO: check dtype
flags: torch.Tensor = field(default_factory=_int_factory)
"""A bit mask of common attributes applicable to individual acquisition readouts."""

measurement_uid: torch.Tensor
"""Unique ID corresponding to the readout."""

number_of_samples: torch.Tensor
"""Number of sample points per readout (readouts may have different number of sample points)."""

orientation: Rotation
orientation: Rotation = field(default_factory=lambda: Rotation.identity((1, 1, 1, 1)))
"""Rotation describing the orientation of the readout, phase and slice encoding direction."""

patient_table_position: SpatialDimension[torch.Tensor]
patient_table_position: SpatialDimension[torch.Tensor] = field(default_factory=_position_factory)
"""Offset position of the patient table, in LPS coordinates [m]."""

physiology_time_stamp: torch.Tensor
physiology_time_stamps: PhysiologyTimestamps = field(default_factory=PhysiologyTimestamps)
"""Time stamps relative to physiological triggering, e.g. ECG. Not in s but in vendor-specific time units"""

position: SpatialDimension[torch.Tensor]
position: SpatialDimension[torch.Tensor] = field(default_factory=_position_factory)
"""Center of the excited volume, in LPS coordinates relative to isocenter [m]."""

sample_time_us: torch.Tensor
sample_time_us: torch.Tensor = field(default_factory=_float_factory)
"""Readout bandwidth, as time between samples [us]."""

scan_counter: torch.Tensor
"""Zero-indexed incrementing counter for readouts."""

trajectory_dimensions: torch.Tensor # =3. We only support 3D Trajectories: kz always exists.
"""Dimensionality of the k-space trajectory vector."""

user_float: torch.Tensor
"""User-defined float parameters."""
user: UserValues = field(default_factory=UserValues)
"""User defined float or int values"""

user_int: torch.Tensor
"""User-defined int parameters."""

version: torch.Tensor
"""Major version number."""
@overload
@classmethod
def from_ismrmrd_acquisitions(
cls,
acquisitions: Sequence[ismrmrd.acquisition.Acquisition],
*,
additional_fields: None,
convert_time_stamp: _convert_time_stamp_type = convert_time_stamp_siemens,
) -> Self: ...

@overload
@classmethod
def from_ismrmrd_acquisitions(
cls,
acquisitions: Sequence[ismrmrd.acquisition.Acquisition],
*,
additional_fields: Sequence[str],
convert_time_stamp: _convert_time_stamp_type = convert_time_stamp_siemens,
) -> tuple[Self, tuple[torch.Tensor, ...]]: ...

@classmethod
def from_ismrmrd_acquisitions(cls, acquisitions: Sequence[ismrmrd.Acquisition]) -> Self:
def from_ismrmrd_acquisitions(
cls,
acquisitions: Sequence[ismrmrd.acquisition.Acquisition],
*,
additional_fields: Sequence[str] | None = None,
convert_time_stamp: _convert_time_stamp_type = convert_time_stamp_siemens,
) -> Self | tuple[Self, tuple[torch.Tensor, ...]]:
"""Read the header of a list of acquisition and store information.
Parameters
----------
acquisitions:
acquisitions
list of ismrmrd acquisistions to read from. Needs at least one acquisition.
additional_fields
if supplied, additional information from fields with these names will be extracted from the
ismrmrd acquisitions and returned as tensors.
convert_time_stamp
function used to convert the raw time stamps to seconds.
"""
# Idea: create array of structs, then a struct of arrays,
# convert it into tensors to store in our dataclass.
Expand All @@ -169,9 +229,9 @@ def from_ismrmrd_acquisitions(cls, acquisitions: Sequence[ismrmrd.Acquisition])
raise ValueError('Acquisition list must not be empty.')

# Creating the dtype first and casting to bytes
# is a workaround for a bug in cpython > 3.12 causing a warning
# is np.array(AcquisitionHeader) is called directly.
# also, this needs to check the dtyoe only once.
# is a workaround for a bug in cpython causing a warning
# if np.array(AcquisitionHeader) is called directly.
# also, this needs to check the dtype only once.
acquisition_head_dtype = np.dtype(ismrmrd.AcquisitionHeader)
headers = np.frombuffer(
np.array([memoryview(a._head).cast('B') for a in acquisitions]),
Expand Down Expand Up @@ -228,33 +288,49 @@ def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]:
user6=tensor(idx['user'][:, 6]),
user7=tensor(idx['user'][:, 7]),
)

user = UserValues(
tensor_2d(headers['user_float'][:, 0]),
tensor_2d(headers['user_float'][:, 1]),
tensor_2d(headers['user_float'][:, 2]),
tensor_2d(headers['user_float'][:, 3]),
tensor_2d(headers['user_float'][:, 4]),
tensor_2d(headers['user_float'][:, 5]),
tensor_2d(headers['user_float'][:, 6]),
tensor_2d(headers['user_float'][:, 7]),
tensor_2d(headers['user_int'][:, 0]),
tensor_2d(headers['user_int'][:, 1]),
tensor_2d(headers['user_int'][:, 2]),
tensor_2d(headers['user_int'][:, 3]),
tensor_2d(headers['user_int'][:, 4]),
tensor_2d(headers['user_int'][:, 5]),
tensor_2d(headers['user_int'][:, 6]),
tensor_2d(headers['user_int'][:, 7]),
)
physiology_time_stamps = PhysiologyTimestamps(
convert_time_stamp(tensor_2d(headers['physiology_time_stamp'][:, 0]), 'physiology_time_stamp_1'),
convert_time_stamp(tensor_2d(headers['physiology_time_stamp'][:, 1]), 'physiology_time_stamp_2'),
convert_time_stamp(tensor_2d(headers['physiology_time_stamp'][:, 2]), 'physiology_time_stamp_3'),
)
acq_info = cls(
idx=acq_idx,
acquisition_time_stamp=tensor_2d(headers['acquisition_time_stamp']),
active_channels=tensor_2d(headers['active_channels']),
available_channels=tensor_2d(headers['available_channels']),
center_sample=tensor_2d(headers['center_sample']),
channel_mask=tensor_2d(headers['channel_mask']),
discard_post=tensor_2d(headers['discard_post']),
discard_pre=tensor_2d(headers['discard_pre']),
encoding_space_ref=tensor_2d(headers['encoding_space_ref']),
acquisition_time_stamp=convert_time_stamp(
tensor_2d(headers['acquisition_time_stamp']), 'acquisition_time_stamp'
),
flags=tensor_2d(headers['flags']),
measurement_uid=tensor_2d(headers['measurement_uid']),
number_of_samples=tensor_2d(headers['number_of_samples']),
orientation=Rotation.from_directions(
spatialdimension_2d(headers['slice_dir']),
spatialdimension_2d(headers['phase_dir']),
spatialdimension_2d(headers['read_dir']),
),
patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m),
physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']),
position=spatialdimension_2d(headers['position']).apply_(mm_to_m),
sample_time_us=tensor_2d(headers['sample_time_us']),
scan_counter=tensor_2d(headers['scan_counter']),
trajectory_dimensions=tensor_2d(headers['trajectory_dimensions']).fill_(3), # see above
user_float=tensor_2d(headers['user_float']),
user_int=tensor_2d(headers['user_int']),
version=tensor_2d(headers['version']),
user=user,
physiology_time_stamps=physiology_time_stamps,
)
return acq_info

if additional_fields is None:
return acq_info
else:
additional_values = tuple(tensor_2d(headers[field]) for field in additional_fields)
return acq_info, additional_values
Loading

0 comments on commit 7b963aa

Please sign in to comment.