Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GradICON and ICON loss without using network wrapper #79

Open
iyerkrithika21 opened this issue Aug 22, 2024 · 2 comments
Open

GradICON and ICON loss without using network wrapper #79

iyerkrithika21 opened this issue Aug 22, 2024 · 2 comments
Assignees

Comments

@iyerkrithika21
Copy link

I am interested in calculating the GradICON and ICON loss for a network I am working on, and it does not follow the typical registration workflow with two input images and predicting the deformation fields, so I cannot use InverseConsistentNet or GradientICON or GradientICONSparse. Unfortunately, I cannot share my code.

Is it possible to provide a standalone loss function that calculates the losses, given two deformation field matrices (equivalent to your phi_AB.vectorfield and phi_BA.vectorfield variables?

Also, could you briefly explain the format of the GradICON output? I am confused about the use of phi_AB as a function vs. the actual deformation field in the image coordinates.
Thank you!

@HastingsGreer HastingsGreer self-assigned this Aug 29, 2024
@iyerkrithika21
Copy link
Author

iyerkrithika21 commented Sep 2, 2024

@HastingsGreer : Here is my attempt at it, you can see if the logic makes sense and matches the original implementation. I have used the same helper functions.

import torch.nn.functional as F
import torch
import glob
import os
import numpy as np

def create_identity_map(image_size, spacing=None):
    """
    Create an identity map for a given image size.
    
    :param image_size: Tuple specifying the size of the image (X, Y, Z) for 3D, (X, Y) for 2D, or (X,) for 1D.
    :param spacing: Tuple specifying the spacing between elements in each dimension. Defaults to (1, 1, 1).
    :return: Identity map as a numpy array.
    """
    # Set default spacing if not provided
    if spacing is None:
        spacing = tuple([1.0] * len(image_size))
    
    # Generate a grid of coordinates
    coordinates = np.mgrid[[slice(0, s) for s in image_size]]
    
    # Convert the coordinates to float and apply spacing
    identity_map = np.array(coordinates, dtype=np.float32)
    for i in range(len(image_size)):
        identity_map[i] *= spacing[i]
    
    return torch.from_numpy(identity_map).permute([1, 2, 0])


def inverse_consistency_loss(phi_AB_tensor, phi_BA_tensor, identity_map):
    """
    Calculate the inverse consistency loss between two deformation fields tensors
    by converting them to function form.
    
    Parameters:
    - phi_AB_tensor: Deformation field from space A to space B (tensor of shape [B, 2, H, W]).
    - phi_BA_tensor: Deformation field from space B to space A (tensor of shape [B, 2, H, W]).
    - identity_map: The identity map (grid) (tensor of shape [B, C, H, W]
    
    
    
    Returns:
    - loss: The inverse consistency loss and gradient inverse consistency loss

    NOTE: 
    vectorfields = deformation field + identity_map
    Deformation field: output of the U-Net in the case of gradicon model

    """
    
    # Repeat the identity map across the batch dimension
    batch_size = phi_AB_tensor.shape[0]
    device = phi_AB_tensor.device

    size = phi_AB_tensor[2:]
    phi_AB = tensor2function(phi_AB_tensor, size)
    phi_BA = tensor2function(phi_BA_tensor, size)
    
    voxel_spacing = (torch.max(identity_map)+1)/identity_map.shape[-1]
    Iepsilon = (identity_map + torch.randn(*identity_map.shape).to(identity_map.device)* voxel_spacing)
    
    approximate_Iepsilon1 = phi_AB(phi_BA(Iepsilon))

    approximate_Iepsilon2 = phi_BA(phi_AB(Iepsilon))


    inverse_consistency_loss = torch.mean((Iepsilon - approximate_Iepsilon1) ** 2) + torch.mean((Iepsilon - approximate_Iepsilon2) ** 2)
     

    direction_losses = []

    approximate_Iepsilon = phi_AB(phi_BA(Iepsilon))

    inverse_consistency_error = Iepsilon - approximate_Iepsilon

    delta = 1e-6


    if len(identity_map.shape) == 4:
        dx = torch.Tensor([[[[delta]], [[0.0]]]]).to(identity_map.device)
        dy = torch.Tensor([[[[0.0]], [[delta]]]]).to(identity_map.device)
        direction_vectors = (dx, dy)

    elif len(identity_map.shape) == 5:
        dx = torch.Tensor([[[[[delta]]], [[[0.0]]], [[[0.0]]]]]).to(
            identity_map.device
        )
        dy = torch.Tensor([[[[[0.0]]], [[[delta]]], [[[0.0]]]]]).to(
            identity_map.device
        )
        dz = torch.Tensor([[[[0.0]]], [[[0.0]]], [[[delta]]]]).to(
            identity_map.device
        )
        direction_vectors = (dx, dy, dz)
    elif len(identity_map.shape) == 3:
        dx = torch.Tensor([[[delta]]]).to(identity_map.device)
        direction_vectors = (dx,)

    for d in direction_vectors:
        
        approximate_Iepsilon_d = phi_AB(phi_BA(Iepsilon + d))
        inverse_consistency_error_d = Iepsilon + d - approximate_Iepsilon_d
        grad_d_icon_error = (
            inverse_consistency_error - inverse_consistency_error_d
        ) / delta
        direction_losses.append(torch.mean(grad_d_icon_error**2))

    grad_inverse_consistency_loss = sum(direction_losses)

    return inverse_consistency_loss, grad_inverse_consistency_loss
    








def scale_map(map, sz, spacing):
    """
    Scales the map to the [-1,1]^d format
    :param map: map in BxCxXxYxZ format
    :param sz: size of image being interpolated in XxYxZ format
    :param spacing: spacing of image in XxYxZ format
    :return: returns the scaled map
    """

    map_scaled = torch.zeros_like(map)
    ndim = len(spacing)

    # This is to compensate to get back to the [-1,1] mapping of the following form
    # id[d]*=2./(sz[d]-1)
    # id[d]-=1.

    for d in range(ndim):
        if sz[d + 2] > 1:
            map_scaled[:, d, ...] = (
                map[:, d, ...] * (2.0 / (sz[d + 2] - 1.0) / spacing[d])
                - 1.0
                # map[:, d, ...] * 2.0 - 1.0
            )
        else:
            map_scaled[:, d, ...] = map[:, d, ...]

    return map_scaled

class STNFunction_ND_BCXYZ:
    """
    Spatial transform function for 1D, 2D, and 3D. In BCXYZ format (this IS the format used in the current toolbox).
    """

    def __init__(
        self, spacing, zero_boundary=False, using_bilinear=True, using_01_input=True):
        """
        Constructor
        :param ndim: (int) spatial transformation of the transform
        """
        self.spacing = spacing
        self.ndim = len(spacing)
        # zero_boundary = False
        self.zero_boundary = "zeros" if zero_boundary else "border"
        self.mode = "bilinear" if using_bilinear else "nearest"
        self.using_01_input = using_01_input

    def forward_stn(self, input1, input2, ndim):
        if ndim == 1:
            # use 2D interpolation to mimick 1D interpolation
            # now test this for 1D
            phi_rs = input2.reshape(list(input2.size()) + [1])
            input1_rs = input1.reshape(list(input1.size()) + [1])

            phi_rs_size = list(phi_rs.size())
            phi_rs_size[1] = 2

            phi_rs_ordered = torch.zeros(
                phi_rs_size, dtype=phi_rs.dtype, device=phi_rs.device
            )
            # keep dimension 1 at zero
            phi_rs_ordered[:, 1, ...] = phi_rs[:, 0, ...]

            output_rs = torch.nn.functional.grid_sample(
                input1_rs,
                phi_rs_ordered.permute([0, 2, 3, 1]),
                mode=self.mode,
                padding_mode=self.zero_boundary,
                align_corners=True,
            )
            output = output_rs[:, :, :, 0]

        if ndim == 2:
            # todo double check, it seems no transpose is need for 2d, already in height width design
            # input2_ordered = torch.zeros_like(input2)
            # input2_ordered[:, 0, ...] = input2[:, 1, ...]
            # input2_ordered[:, 1, ...] = input2[:, 0, ...]
            input2_ordered = input2

            if input2_ordered.shape[0] == 1 and input1.shape[0] != 1:
                input2_ordered = input2_ordered.expand(input1.shape[0], -1, -1, -1)
            '''
            input = [N,C,H,W]
            grid = [N,H,W,2]
            output = [N,C,H,W]
            '''
            
            output = torch.nn.functional.grid_sample(
                input=input1,
                grid=input2_ordered.permute([0, 2, 3, 1]),
                mode=self.mode,
                padding_mode=self.zero_boundary,
                align_corners=True,
            )
            
        if ndim == 3:
            input2_ordered = torch.zeros_like(input2)
            input2_ordered[:, 0, ...] = input2[:, 2, ...]
            input2_ordered[:, 1, ...] = input2[:, 1, ...]
            input2_ordered[:, 2, ...] = input2[:, 0, ...]
            if input2_ordered.shape[0] == 1 and input1.shape[0] != 1:
                input2_ordered = input2_ordered.expand(input1.shape[0], -1, -1, -1, -1)
            output = torch.nn.functional.grid_sample(
                input1,
                input2_ordered.permute([0, 2, 3, 4, 1]),
                mode=self.mode,
                padding_mode=self.zero_boundary,
                align_corners=True,
            )

        return output

    def __call__(self, input1, input2):
        """
        Perform the actual spatial transform
        :param input1: image in BCXYZ format
        :param input2: spatial transform in BdimXYZ format
        :return: spatially transformed image in BCXYZ format
        """
        
        
        assert len(self.spacing) + 2 == len(input2.size())
        if self.using_01_input:
            output = self.forward_stn(
                input1, scale_map(input2, input1.shape, self.spacing), self.ndim
            )
        else:
            output = self.forward_stn(input1, input2, self.ndim)
        
        return output






def compute_warped_image_multiNC(I0, phi, zero_boundary=False):
    """Warps image.
    :param I0: image to warp, image size BxCxXxYxZ
    :param phi: map for the warping, size BxdimxXxYxZ
    :param spacing: image spacing [dx,dy,dz]
    :return: returns the warped image of size BxCxXxYxZ
    """
    spacing = I0.shape[2:]
    f = STNFunction_ND_BCXYZ(spacing, zero_boundary)
    """
    Simply returns the transformed input
    :param input1: image in BCXYZ format
    :param input2: map in BdimXYZ format
    :return: returns the transformed image
    """
    return f(I0, phi)

def as_function(image):
    """image is a tensor 
    Returns a python function that maps a tensor of coordinates [batch x N_dimensions x ...]
    into a tensor of intensities.
    """

    return lambda coordinates: compute_warped_image_multiNC(
        I0=image, phi=coordinates
    )


def tensor2function(tensor_of_displacements, spacing):
    
    

    displacement_field = as_function(tensor_of_displacements)

    def transform(coordinates,isIdentity=False):
        if isIdentity and coordinates.shape == tensor_of_displacements.shape:
            
            return coordinates + tensor_of_displacements
        return coordinates + displacement_field(coordinates)

    return transform

@BailiangJ
Copy link

Hi @iyerkrithika21 ,

I have created standalone ICON and GradICON losses in PR #80 .

I hope it will also help. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants