Skip to content

Commit

Permalink
Update for registration
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Sep 21, 2024
1 parent 8a7d731 commit 6f52a4b
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion direct/nn/mri_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,32 @@ def _do_iteration(
regularizer_dict = {
k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys()
}

if self.ndim == 3 and "registration_model" in self.models:
# Perform registration and compute loss on registered image and displacement field
registered_image, displacement_field = self.do_registration(data, output_image)

# If DL-based model calculate loss
if len(list(self.models["registration_model"].parameters())) > 0:
shape = data["reference_image"].shape
loss_dict = self.compute_loss_on_data(
loss_dict,
loss_fns,
data,
output_image=registered_image,
target_image=(
data["reference_image"]
if shape == registered_image.shape
else data["reference_image"].tile((1, registered_image.shape[1], *([1] * len(shape[1:]))))
),
)
loss_dict = self.compute_loss_on_data(
loss_dict,
loss_fns,
data,
output_displacement_field=displacement_field,
target_displacement_field=data["displacement_field"],
)
loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, output_image, output_kspace)
regularizer_dict = self.compute_loss_on_data(
regularizer_dict, regularizer_fns, data, output_image, output_kspace
Expand All @@ -150,7 +176,11 @@ def _do_iteration(
regularizer_dict = detach_dict(regularizer_dict)

return DoIterationOutput(
output_image=output_image,
output_image=(
(output_image, registered_image)
if (self.ndim == 3 and "registration_model" in self.models)
else output_image
),
sensitivity_map=data["sensitivity_map"],
sampling_mask=data["sampling_mask"],
data_dict={**loss_dict, **regularizer_dict},
Expand Down

0 comments on commit 6f52a4b

Please sign in to comment.