Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
fzimmermann89 committed Dec 28, 2024
1 parent 1133d3a commit 594f139
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 15 deletions.
29 changes: 15 additions & 14 deletions src/mrpro/data/traj_calculators/KTrajectorySpiral2D.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
"""Spiral trajectory calculator."""

import torch

from mrpro.data import SpatialDimension
from mrpro.data.KTrajectory import KTrajectory
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator


class KTrajectorySpiral(KTrajectoryCalculator):
class KTrajectorySpiral2D(KTrajectoryCalculator):
"""A Spiral variable density trajectory.
Implements the spiral trajectory calculation as described in
Simple Analytic Variable Density Spiral Design by Kim et al., MRM 2003"""
Simple Analytic Variable Density Spiral Design by Kim et al., MRM 2003
"""

def __init__(
self,
max_gradient: float,
max_slewrate: float,
fov: SpatialDimension | float,
angle: float,
acceleration_per_interleave: float = 1.0,
fov: SpatialDimension | float = 0.5,
angle: float = 2.39996,
acceleration_per_interleave: float = 20.0,
density_factor: float = 1.0,
gamma: float = 42577478,
max_gradient: float = 0.1,
max_slewrate: float = 100,
):
"""Create a spiral trajectory calculator.
Expand Down Expand Up @@ -105,9 +108,7 @@ def __call__(
) # eq. 10
max_angle = 2 * torch.pi * n_turns
end_time_amplitude = (lam * max_angle) / (self.max_gradient_gamma * (self.density_factor + 1)) # eq. 5, Tes
end_time_slew = torch.sqrt(lam * max_angle**2 / (self.max_slewrate_gamma)) / (
self.density_factor / 2 + 1
) # eq. 8, Tea
end_time_slew = (lam / self.max_slewrate_gamma) ** 0.5 * max_angle / (self.density_factor / 2 + 1) # eq. 8, Tea

transition_time_slew_to_amplitude = (
end_time_slew ** ((self.density_factor + 1) / (self.density_factor / 2 + 1))
Expand All @@ -120,7 +121,7 @@ def __call__(
end_time = end_time_amplitude if has_amplitude_phase else end_time_slew

def tau(t: torch.Tensor) -> torch.Tensor:
"""Normalized time function."""
"""Convert to normalized time."""
# eq. 11
slew_phase = (t / end_time_slew) ** (1 / (self.density_factor / 2 + 1))
slew_phase = slew_phase * ((t >= 0) * (t <= transition_time_slew_to_amplitude))
Expand All @@ -133,7 +134,7 @@ def tau(t: torch.Tensor) -> torch.Tensor:
t = torch.linspace(0, end_time, n_k0)
tau_t = tau(t)
k = lam * tau_t**self.density_factor * torch.exp(1j * max_angle * tau_t) # eq. 2
phase_rotation = torch.exp(self.angle * k1_idx)
k = k[None, :] * phase_rotation[:, None]
phase_rotation = torch.exp(1j * self.angle * k1_idx)
k = k[None, :] * phase_rotation[..., None]
trajectory = KTrajectory(kx=k.real, ky=k.imag, kz=torch.zeros_like(k.real))
return trajectory
4 changes: 3 additions & 1 deletion src/mrpro/data/traj_calculators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from mrpro.data.traj_calculators.KTrajectoryIsmrmrd import KTrajectoryIsmrmrd
from mrpro.data.traj_calculators.KTrajectoryPulseq import KTrajectoryPulseq
from mrpro.data.traj_calculators.KTrajectoryCartesian import KTrajectoryCartesian
from mrpro.data.traj_calculators.KTrajectorySpiral2D import KTrajectorySpiral2D
__all__ = [
"KTrajectoryCalculator",
"KTrajectoryCartesian",
"KTrajectoryIsmrmrd",
"KTrajectoryPulseq",
"KTrajectoryRadial2D",
"KTrajectoryRpe",
"KTrajectorySpiral2D",
"KTrajectorySunflowerGoldenRpe"
]
]
12 changes: 12 additions & 0 deletions tests/data/test_traj_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
KTrajectoryPulseq,
KTrajectoryRadial2D,
KTrajectoryRpe,
KTrajectorySpiral2D,
KTrajectorySunflowerGoldenRpe,
)

from tests.data import IsmrmrdRawTestData, PulseqRadialTestSeq
from mrpro.data import SpatialDimension


def test_KTrajectoryRadial2D():
Expand Down Expand Up @@ -210,3 +212,13 @@ def test_KTrajectoryPulseq(pulseq_example_rad_seq):

torch.testing.assert_close(trajectory.kx.to(torch.float32), kx_test.to(torch.float32), atol=1e-2, rtol=1e-3)
torch.testing.assert_close(trajectory.ky.to(torch.float32), ky_test.to(torch.float32), atol=1e-2, rtol=1e-3)


def test_KTrajectorySpiral():
trajectory_calculator = KTrajectorySpiral2D()
trajectory = trajectory_calculator(
n_k0=1024, k1_idx=torch.arange(4)[None, None, :], encoding_matrix=SpatialDimension(1, 256, 256)
)
assert trajectory.kz.shape == (1, 1, 1, 1)
assert trajectory.ky.shape == (1, 1, 4, 1024)
assert trajectory.kx.shape == (1, 1, 4, 1024)

0 comments on commit 594f139

Please sign in to comment.