diff --git a/direct/nn/mri_models.py b/direct/nn/mri_models.py index f7b3a7f8..eb047eb6 100644 --- a/direct/nn/mri_models.py +++ b/direct/nn/mri_models.py @@ -136,6 +136,32 @@ def _do_iteration( regularizer_dict = { k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() } + + if self.ndim == 3 and "registration_model" in self.models: + # Perform registration and compute loss on registered image and displacement field + registered_image, displacement_field = self.do_registration(data, output_image) + + # If DL-based model calculate loss + if len(list(self.models["registration_model"].parameters())) > 0: + shape = data["reference_image"].shape + loss_dict = self.compute_loss_on_data( + loss_dict, + loss_fns, + data, + output_image=registered_image, + target_image=( + data["reference_image"] + if shape == registered_image.shape + else data["reference_image"].tile((1, registered_image.shape[1], *([1] * len(shape[1:])))) + ), + ) + loss_dict = self.compute_loss_on_data( + loss_dict, + loss_fns, + data, + output_displacement_field=displacement_field, + target_displacement_field=data["displacement_field"], + ) loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, output_image, output_kspace) regularizer_dict = self.compute_loss_on_data( regularizer_dict, regularizer_fns, data, output_image, output_kspace @@ -150,7 +176,11 @@ def _do_iteration( regularizer_dict = detach_dict(regularizer_dict) return DoIterationOutput( - output_image=output_image, + output_image=( + (output_image, registered_image) + if (self.ndim == 3 and "registration_model" in self.models) + else output_image + ), sensitivity_map=data["sensitivity_map"], sampling_mask=data["sampling_mask"], data_dict={**loss_dict, **regularizer_dict},