diff --git a/pyproject.toml b/pyproject.toml index 32ced4a..8a9fc71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,8 @@ dependencies = [ 'torchvision', 'numpy', 'tqdm', - 'cellpose' + 'cellpose', + 'ml-collections', ] [project.optional-dependencies] diff --git a/src/cellmap_models/__init__.py b/src/cellmap_models/__init__.py index bbf2be9..fcbffb4 100644 --- a/src/cellmap_models/__init__.py +++ b/src/cellmap_models/__init__.py @@ -3,4 +3,4 @@ """ from .utils import download_url_to_file -from .pytorch import cosem, cellpose +from .pytorch import cosem, cellpose, untrained_models diff --git a/src/cellmap_models/pytorch/untrained_models/README.md b/src/cellmap_models/pytorch/untrained_models/README.md new file mode 100644 index 0000000..9205fef --- /dev/null +++ b/src/cellmap_models/pytorch/untrained_models/README.md @@ -0,0 +1,22 @@ +CellMap logo + +## This directory contains various untrained PyTorch model architectures. + +## Models + +***ResNet***: Parameterizable 2D and 3D ResNet models with a variable number of layers and channels. This model is based on the original ResNet architecture with the addition of a decoding path, which mirrors the encoder, after the bottleneck, to produce an image output. + +***UNet2D***: A simple 2D UNet model with a variable number of output channels. + +***UNet3D***: A simple 3D UNet model with a variable number of output channels. + +***ViTVNet***: A 3D VNet model with a Vision Transformer (ViT) encoder. This model is based on the original VNet architecture with the addition of a ViT encoder in place of the original convolutional encoder. + +## Usage + +To use these models, you can import them directly from the `cellmap_models.pytorch.untrained_models` module. For example, to import the ResNet model, you can use the following code: + +```python +from cellmap_models.pytorch.untrained_models import ResNet +model = ResNet(ndim=2, input_nc=1, output_nc=3, n_blocks=18) +``` diff --git a/src/cellmap_models/pytorch/untrained_models/__init__.py b/src/cellmap_models/pytorch/untrained_models/__init__.py new file mode 100644 index 0000000..5bedbad --- /dev/null +++ b/src/cellmap_models/pytorch/untrained_models/__init__.py @@ -0,0 +1,4 @@ +from .resnet import ResNet +from .vitnet import ViTVNet +from .unet_2D import UNet2D +from .unet_3D import UNet3D diff --git a/src/cellmap_models/pytorch/untrained_models/resnet.py b/src/cellmap_models/pytorch/untrained_models/resnet.py new file mode 100644 index 0000000..3b5a5b7 --- /dev/null +++ b/src/cellmap_models/pytorch/untrained_models/resnet.py @@ -0,0 +1,458 @@ +import functools +import torch + + +class Resnet2D(torch.nn.Module): + """Resnet that consists of Resnet blocks between a few downsampling/upsampling operations. + We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + """ + + def __init__( + self, + input_nc=1, + output_nc=None, + ngf=64, + norm_layer=torch.nn.InstanceNorm2d, + use_dropout=False, + n_blocks=6, + padding_type="reflect", + activation=torch.nn.ReLU, + n_downsampling=2, + ): + """Construct a Resnet + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images (default is ngf) + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers + n_blocks (int) -- the number of ResNet blocks + padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zeros | valid + activation -- non-linearity layer to apply (default is ReLU) + n_downsampling -- number of times to downsample data before ResBlocks + """ + assert n_blocks >= 0 + super().__init__() + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == torch.nn.InstanceNorm2d + else: + use_bias = norm_layer == torch.nn.InstanceNorm2d + + if output_nc is None: + output_nc = ngf + + p = 0 + updown_p = 1 + padder = [] + if padding_type.lower() == "reflect" or padding_type.lower() == "same": + padder = [torch.nn.ReflectionPad2d(3)] + elif padding_type.lower() == "replicate": + padder = [torch.nn.ReplicationPad2d(3)] + elif padding_type.lower() == "zeros": + p = 3 + elif padding_type.lower() == "valid": + p = "valid" + updown_p = 0 + + model = [] + model += padder.copy() + model += [ + torch.nn.Conv2d(input_nc, ngf, kernel_size=7, padding=p, bias=use_bias), + norm_layer(ngf), + activation(), + ] + + for i in range(n_downsampling): # add downsampling layers + mult = 2**i + model += [ + torch.nn.Conv2d( + ngf * mult, + ngf * mult * 2, + kernel_size=3, + stride=2, + padding=updown_p, + bias=use_bias, + ), + norm_layer(ngf * mult * 2), + activation(), + ] + + mult = 2**n_downsampling + for i in range(n_blocks): # add ResNet blocks + + model += [ + ResnetBlock2D( + ngf * mult, + padding_type=padding_type.lower(), + norm_layer=norm_layer, + use_dropout=use_dropout, + use_bias=use_bias, + activation=activation, + ) + ] + + for i in range(n_downsampling): # add upsampling layers + mult = 2 ** (n_downsampling - i) + model += [ + torch.nn.ConvTranspose2d( + ngf * mult, + int(ngf * mult / 2), + kernel_size=3, + stride=2, + padding=updown_p, + output_padding=updown_p, + bias=use_bias, + ), + norm_layer(int(ngf * mult / 2)), + activation(), + ] + model += padder.copy() + model += [torch.nn.Conv2d(ngf, output_nc, kernel_size=7, padding=p)] + + self.model = torch.nn.Sequential(*model) + + def forward(self, input): + """Standard forward""" + return self.model(input) + + +class ResnetBlock2D(torch.nn.Module): + """Define a Resnet block""" + + def __init__( + self, + dim, + padding_type, + norm_layer, + use_dropout, + use_bias, + activation=torch.nn.ReLU, + ): + """Initialize the Resnet block + A resnet block is a conv block with skip connections + We construct a conv block with build_conv_block function, + and implement skip connections in function. + Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf + """ + super().__init__() + self.conv_block = self.build_conv_block( + dim, padding_type, norm_layer, use_dropout, use_bias, activation + ) + self.padding_type = padding_type + + def build_conv_block( + self, + dim, + padding_type, + norm_layer, + use_dropout, + use_bias, + activation=torch.nn.ReLU, + ): + """Construct a convolutional block. + Parameters: + dim (int) -- the number of channels in the conv layer. + padding_type (str) -- the name of padding layer: reflect | replicate | zeros | valid + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + use_bias (bool) -- if the conv layer uses bias or not + activation -- non-linearity layer to apply (default is ReLU) + Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer) + """ + p = 0 + padder = [] + if padding_type == "reflect" or padding_type.lower() == "same": + padder = [torch.nn.ReflectionPad2d(1)] + elif padding_type == "replicate": + padder = [torch.nn.ReplicationPad2d(1)] + elif padding_type == "zeros": + p = 1 + elif padding_type == "valid": + p = "valid" + else: + raise NotImplementedError("padding [%s] is not implemented" % padding_type) + + conv_block = [] + conv_block += padder.copy() + + conv_block += [ + torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(dim), + activation(), + ] + if use_dropout: + conv_block += [torch.nn.Dropout(0.2)] + + conv_block += padder.copy() + conv_block += [ + torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(dim), + ] + + return torch.nn.Sequential(*conv_block) + + def crop(self, x, shape): + """Center-crop x to match spatial dimensions given by shape.""" + + x_target_size = x.size()[:-2] + shape + + offset = tuple( + torch.div((a - b), 2, rounding_mode="trunc") + for a, b in zip(x.size(), x_target_size) + ) + + slices = tuple(slice(o, o + s) for o, s in zip(offset, x_target_size)) + + return x[slices] + + def forward(self, x): + """Forward function (with skip connections)""" + if self.padding_type == "valid": # crop for valid networks + res = self.conv_block(x) + out = self.crop(x, res.size()[-2:]) + res + else: + out = x + self.conv_block(x) # add skip connections + return out + + +class Resnet3D(torch.nn.Module): + """Resnet that consists of Resnet blocks between a few downsampling/upsampling operations. + We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + """ + + def __init__( + self, + input_nc=1, + output_nc=None, + ngf=64, + norm_layer=torch.nn.InstanceNorm3d, + use_dropout=False, + n_blocks=6, + padding_type="reflect", + activation=torch.nn.ReLU, + n_downsampling=2, + ): + """Construct a Resnet + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers + n_blocks (int) -- the number of ResNet blocks + padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zeros | valid + activation -- non-linearity layer to apply (default is ReLU) + n_downsampling -- number of times to downsample data before ResBlocks + """ + assert n_blocks >= 0 + super().__init__() + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == torch.nn.InstanceNorm3d + else: + use_bias = norm_layer == torch.nn.InstanceNorm3d + + if output_nc is None: + output_nc = ngf + + p = 0 + updown_p = 1 + padder = [] + if padding_type.lower() == "reflect" or padding_type.lower() == "same": + padder = [torch.nn.ReflectionPad3d(3)] + elif padding_type.lower() == "replicate": + padder = [torch.nn.ReplicationPad3d(3)] + elif padding_type.lower() == "zeros": + p = 3 + elif padding_type.lower() == "valid": + p = "valid" + updown_p = 0 + + model = [] + model += padder.copy() + model += [ + torch.nn.Conv3d(input_nc, ngf, kernel_size=7, padding=p, bias=use_bias), + norm_layer(ngf), + activation(), + ] + + for i in range(n_downsampling): # add downsampling layers + mult = 2**i + model += [ + torch.nn.Conv3d( + ngf * mult, + ngf * mult * 2, + kernel_size=3, + stride=2, + padding=updown_p, + bias=use_bias, + ), # TODO: Make actually use padding_type for every convolution (currently does zeros if not valid) + norm_layer(ngf * mult * 2), + activation(), + ] + + mult = 2**n_downsampling + for i in range(n_blocks): # add ResNet blocks + + model += [ + ResnetBlock3D( + ngf * mult, + padding_type=padding_type.lower(), + norm_layer=norm_layer, + use_dropout=use_dropout, + use_bias=use_bias, + activation=activation, + ) + ] + + for i in range(n_downsampling): # add upsampling layers + mult = 2 ** (n_downsampling - i) + model += [ + torch.nn.ConvTranspose3d( + ngf * mult, + int(ngf * mult / 2), + kernel_size=3, + stride=2, + padding=updown_p, + output_padding=updown_p, + bias=use_bias, + ), + norm_layer(int(ngf * mult / 2)), + activation(), + ] + model += padder.copy() + model += [torch.nn.Conv3d(ngf, output_nc, kernel_size=7, padding=p)] + + self.model = torch.nn.Sequential(*model) + + def forward(self, input): + """Standard forward""" + return self.model(input) + + +class ResnetBlock3D(torch.nn.Module): + """Define a Resnet block""" + + def __init__( + self, + dim, + padding_type, + norm_layer, + use_dropout, + use_bias, + activation=torch.nn.ReLU, + ): + """Initialize the Resnet block + A resnet block is a conv block with skip connections + We construct a conv block with build_conv_block function, + and implement skip connections in function. + Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf + """ + super().__init__() + self.conv_block = self.build_conv_block( + dim, padding_type, norm_layer, use_dropout, use_bias, activation + ) + self.padding_type = padding_type + + def build_conv_block( + self, + dim, + padding_type, + norm_layer, + use_dropout, + use_bias, + activation=torch.nn.ReLU, + ): + """Construct a convolutional block. + Parameters: + dim (int) -- the number of channels in the conv layer. + padding_type (str) -- the name of padding layer: reflect | replicate | zeros | valid + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + use_bias (bool) -- if the conv layer uses bias or not + activation -- non-linearity layer to apply (default is ReLU) + Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer) + """ + p = 0 + padder = [] + if padding_type == "reflect" or padding_type.lower() == "same": + padder = [torch.nn.ReflectionPad3d(1)] + elif padding_type == "replicate": + padder = [torch.nn.ReplicationPad3d(1)] + elif padding_type == "zeros": + p = 1 + elif padding_type == "valid": + p = "valid" + else: + raise NotImplementedError("padding [%s] is not implemented" % padding_type) + + conv_block = [] + conv_block += padder.copy() + + conv_block += [ + torch.nn.Conv3d(dim, dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(dim), + activation(), + ] + if use_dropout: + conv_block += [torch.nn.Dropout(0.2)] + + conv_block += padder.copy() + conv_block += [ + torch.nn.Conv3d(dim, dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(dim), + ] + + return torch.nn.Sequential(*conv_block) + + def crop(self, x, shape): + """Center-crop x to match spatial dimensions given by shape.""" + + x_target_size = x.size()[:-3] + shape + + offset = tuple( + torch.div((a - b), 2, rounding_mode="trunc") + for a, b in zip(x.size(), x_target_size) + ) + + slices = tuple(slice(o, o + s) for o, s in zip(offset, x_target_size)) + + return x[slices] + + def forward(self, x): + """Forward function (with skip connections)""" + if self.padding_type == "valid": # crop for valid networks + res = self.conv_block(x) + out = self.crop(x, res.size()[-3:]) + res + else: + out = x + self.conv_block(x) # add skip connections + return out + + +class ResNet(Resnet2D, Resnet3D): + """Resnet that consists of Resnet blocks between a few downsampling/upsampling operations. + We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + """ + + def __init__(self, ndims, **kwargs): + """Construct a Resnet + Parameters: + ndims (int) -- the number of dimensions of the input data + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images (default is ngf) + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers + n_blocks (int) -- the number of ResNet blocks + padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zeros | valid + activation -- non-linearity layer to apply (default is ReLU) + n_downsampling -- number of times to downsample data before ResBlocks + """ + if ndims == 2: + Resnet2D.__init__(self, **kwargs) + elif ndims == 3: + Resnet3D.__init__(self, **kwargs) + else: + raise ValueError( + ndims, + "Only 2D or 3D currently implemented. Feel free to contribute more!", + ) diff --git a/src/cellmap_models/pytorch/untrained_models/unet_2D.py b/src/cellmap_models/pytorch/untrained_models/unet_2D.py new file mode 100644 index 0000000..34cfccb --- /dev/null +++ b/src/cellmap_models/pytorch/untrained_models/unet_2D.py @@ -0,0 +1,114 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +""" Full assembly of the parts to form the complete network """ +# Original source code from: +# https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py + +class UNet2D(nn.Module): + def __init__(self, n_channels, n_classes, trilinear=False): + super(UNet2D, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.trilinear = trilinear + + self.inc = DoubleConv(n_channels, 64) + self.down1 = Down(64, 128) + self.down2 = Down(128, 256) + self.down3 = Down(256, 512) + factor = 2 if trilinear else 1 + self.down4 = Down(512, 1024 // factor) + self.up1 = Up(1024, 512 // factor, trilinear) + self.up2 = Up(512, 256 // factor, trilinear) + self.up3 = Up(256, 128 // factor, trilinear) + self.up4 = Up(128, 64, trilinear) + self.outc = OutConv(64, n_classes) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + logits = self.outc(x) + return logits + +""" Parts of the 2D U-Net model """ +# Original source code from: +# https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(mid_channels), + nn.ReLU(inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.maxpool_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv""" + + def __init__(self, in_channels, out_channels, bilinear=True): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + self.up = nn.ConvTranspose2d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + # input is CHW + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) + # if you have padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + return self.conv(x) diff --git a/src/cellmap_models/pytorch/untrained_models/unet_3D.py b/src/cellmap_models/pytorch/untrained_models/unet_3D.py new file mode 100644 index 0000000..cc648b5 --- /dev/null +++ b/src/cellmap_models/pytorch/untrained_models/unet_3D.py @@ -0,0 +1,128 @@ + +import torch +import torch.nn as nn +import torch.nn.functional as F + +""" Full assembly of the parts to form the complete network """ +# Original source code from: +# https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py + +class UNet3D(nn.Module): + def __init__(self, n_channels, n_classes, trilinear=False): + super(UNet3D, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.trilinear = trilinear + + self.inc = DoubleConv(n_channels, 64) + self.down1 = Down(64, 128) + self.down2 = Down(128, 256) + self.down3 = Down(256, 512) + factor = 2 if trilinear else 1 + self.down4 = Down(512, 1024 // factor) + self.up1 = Up(1024, 512 // factor, trilinear) + self.up2 = Up(512, 256 // factor, trilinear) + self.up3 = Up(256, 128 // factor, trilinear) + self.up4 = Up(128, 64, trilinear) + self.outc = OutConv(64, n_classes) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + logits = self.outc(x) + return logits + +""" Parts of the U-Net model """ +# Adapted from: +# https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py +# By Emma Avetissian, @aemmav + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm3d(mid_channels), + nn.ReLU(inplace=True), + nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm3d(out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool3d(2), DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.maxpool_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv""" + + def __init__(self, in_channels, out_channels, trilinear=True): + super().__init__() + + # if trilinear, use the normal convolutions to reduce the number of channels + if trilinear: + self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + self.up = nn.ConvTranspose3d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + # input is CHW + diffY = x2.size()[2] - x1.size()[2] # height + diffX = x2.size()[3] - x1.size()[3] # width + diffZ = x2.size()[4] - x1.size()[4] # depth + + x1 = F.pad( + x1, + [ + diffX // 2, + diffX - diffX // 2, + diffY // 2, + diffY - diffY // 2, + diffZ // 2, + diffZ - diffZ // 2, + ], + ) + + # if you have padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + return self.conv(x) diff --git a/src/cellmap_models/pytorch/untrained_models/vitnet.py b/src/cellmap_models/pytorch/untrained_models/vitnet.py new file mode 100644 index 0000000..726eaa9 --- /dev/null +++ b/src/cellmap_models/pytorch/untrained_models/vitnet.py @@ -0,0 +1,448 @@ +# Adapted from: +# https://github.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch +# By Emma Avetissian + +# coding=utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import logging +import math +import torch +import torch.nn as nn +import torch.nn.functional as nnf +from torch.nn import Dropout, Softmax, Linear, Conv3d, LayerNorm +from torch.nn.modules.utils import _pair, _triple +from torch.distributions.normal import Normal +import ml_collections + + +def get_3DReg_config(): + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({"size": (8, 8, 8)}) + config.patches.grid = (8, 8, 8) + config.hidden_size = 252 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 3072 + config.transformer.num_heads = 12 + config.transformer.num_layers = 12 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.patch_size = 8 + + config.conv_first_channel = 512 + config.encoder_channels = (16, 32, 32) + config.down_factor = 2 + config.down_num = 2 + config.decoder_channels = (96, 48, 32, 32, 16) + config.skip_channels = (32, 32, 32, 32, 16) + config.n_skip = 5 + config.input_channels = 1 + return config + + +logger = logging.getLogger(__name__) + + +ATTENTION_Q = "MultiHeadDotProductAttention_1/query" +ATTENTION_K = "MultiHeadDotProductAttention_1/key" +ATTENTION_V = "MultiHeadDotProductAttention_1/value" +ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" +FC_0 = "MlpBlock_3/Dense_0" +FC_1 = "MlpBlock_3/Dense_1" +ATTENTION_NORM = "LayerNorm_0" +MLP_NORM = "LayerNorm_2" + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = { + "gelu": torch.nn.functional.gelu, + "relu": torch.nn.functional.relu, + "swish": swish, +} + + +class Attention(nn.Module): + def __init__(self, config, vis): + super(Attention, self).__init__() + self.vis = vis + self.num_attention_heads = config.transformer["num_heads"] + self.attention_head_size = int(config.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(config.hidden_size, self.all_head_size) + self.key = Linear(config.hidden_size, self.all_head_size) + self.value = Linear(config.hidden_size, self.all_head_size) + + self.out = Linear(config.hidden_size, config.hidden_size) + self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) + self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + weights = attention_probs if self.vis else None + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output, weights + + +class Mlp(nn.Module): + def __init__(self, config): + super(Mlp, self).__init__() + self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) + self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) + self.act_fn = ACT2FN["gelu"] + self.dropout = Dropout(config.transformer["dropout_rate"]) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class Embeddings(nn.Module): + """Construct the embeddings from patch, position embeddings.""" + + def __init__(self, config, img_size): + super(Embeddings, self).__init__() + self.config = config + down_factor = config.down_factor + patch_size = _triple(config.patches["size"]) + n_patches = int( + (img_size[0] / 2**down_factor // patch_size[0]) + * (img_size[1] / 2**down_factor // patch_size[1]) + * (img_size[2] / 2**down_factor // patch_size[2]) + ) + self.hybrid_model = CNNEncoder(config, n_channels=config.input_channels) + in_channels = config["encoder_channels"][-1] + self.patch_embeddings = Conv3d( + in_channels=in_channels, + out_channels=config.hidden_size, + kernel_size=patch_size, + stride=patch_size, + ) + self.position_embeddings = nn.Parameter( + torch.zeros(1, n_patches, config.hidden_size) + ) + + self.dropout = Dropout(config.transformer["dropout_rate"]) + + def forward(self, x): + x, features = self.hybrid_model(x) + x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) + x = x.flatten(2) + x = x.transpose(-1, -2) # (B, n_patches, hidden) + embeddings = x + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings, features + + +class Block(nn.Module): + def __init__(self, config, vis): + super(Block, self).__init__() + self.hidden_size = config.hidden_size + self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn = Mlp(config) + self.attn = Attention(config, vis) + + def forward(self, x): + h = x + + x = self.attention_norm(x) + x, weights = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + return x, weights + + +class Encoder(nn.Module): + def __init__(self, config, vis): + super(Encoder, self).__init__() + self.vis = vis + self.layer = nn.ModuleList() + self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) + for _ in range(config.transformer["num_layers"]): + layer = Block(config, vis) + self.layer.append(copy.deepcopy(layer)) + + def forward(self, hidden_states): + attn_weights = [] + for layer_block in self.layer: + hidden_states, weights = layer_block(hidden_states) + if self.vis: + attn_weights.append(weights) + encoded = self.encoder_norm(hidden_states) + return encoded, attn_weights + + +class Transformer(nn.Module): + def __init__(self, config, img_size, vis): + super(Transformer, self).__init__() + self.embeddings = Embeddings(config, img_size=img_size) + self.encoder = Encoder(config, vis) + + def forward(self, input_ids): + embedding_output, features = self.embeddings(input_ids) + encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) + return encoded, attn_weights, features + + +class Conv3dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + bn = nn.BatchNorm3d(out_channels) + + super(Conv3dReLU, self).__init__(conv, bn, relu) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + skip_channels=0, + use_batchnorm=True, + ): + super().__init__() + self.conv1 = Conv3dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.conv2 = Conv3dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False) + + def forward(self, x, skip=None): + x = self.up(x) + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class DecoderCup(nn.Module): + def __init__(self, config, img_size): + super().__init__() + self.config = config + self.down_factor = config.down_factor + head_channels = config.conv_first_channel + self.img_size = img_size + self.conv_more = Conv3dReLU( + config.hidden_size, + head_channels, + kernel_size=3, + padding=1, + use_batchnorm=True, + ) + decoder_channels = config.decoder_channels + in_channels = [head_channels] + list(decoder_channels[:-1]) + out_channels = decoder_channels + self.patch_size = _triple(config.patches["size"]) + skip_channels = self.config.skip_channels + blocks = [ + DecoderBlock(in_ch, out_ch, sk_ch) + for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) + ] + self.blocks = nn.ModuleList(blocks) + + def forward(self, hidden_states, features=None): + B, n_patch, hidden = ( + hidden_states.size() + ) # reshape from (B, n_patch, hidden) to (B, h, w, hidden) + l, h, w = ( + (self.img_size[0] // 2**self.down_factor // self.patch_size[0]), + (self.img_size[1] // 2**self.down_factor // self.patch_size[1]), + (self.img_size[2] // 2**self.down_factor // self.patch_size[2]), + ) + x = hidden_states.permute(0, 2, 1) + x = x.contiguous().view(B, hidden, l, h, w) + x = self.conv_more(x) + for i, decoder_block in enumerate(self.blocks): + if features is not None: + skip = features[i] if (i < self.config.n_skip) else None + # print(skip.shape) + else: + skip = None + x = decoder_block(x, skip=skip) + return x + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool3d(2), DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.maxpool_conv(x) + + +class CNNEncoder(nn.Module): + def __init__(self, config, n_channels=2): + super(CNNEncoder, self).__init__() + self.n_channels = n_channels + decoder_channels = config.decoder_channels + encoder_channels = config.encoder_channels + self.down_num = config.down_num + self.inc = DoubleConv(n_channels, encoder_channels[0]) + self.down1 = Down(encoder_channels[0], encoder_channels[1]) + self.down2 = Down(encoder_channels[1], encoder_channels[2]) + self.width = encoder_channels[-1] + + def forward(self, x): + features = [] + x1 = self.inc(x) + features.append(x1) + x2 = self.down1(x1) + features.append(x2) + feats = self.down2(x2) + features.append(feats) + feats_down = feats + for i in range(self.down_num): + feats_down = nn.MaxPool3d(2)(feats_down) + features.append(feats_down) + return feats, features[::-1] + + +class RegistrationHead(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): + conv3d = nn.Conv3d( + in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 + ) + conv3d.weight = nn.Parameter(Normal(0, 1e-5).sample(conv3d.weight.shape)) + conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape)) + super().__init__(conv3d) + + +class ViTVNet(nn.Module): + def __init__( + self, out_channels, config="ViT-V-Net", img_size=(128, 128, 128), vis=False + ): + super(ViTVNet, self).__init__() + if isinstance(config, str): + config = CONFIGS[config] + else: + assert isinstance( + config, ml_collections.ConfigDict + ), "Is not a config object or the name of one" + self.transformer = Transformer(config, img_size, vis) + self.decoder = DecoderCup(config, img_size) + self.reg_head = RegistrationHead( + in_channels=config.decoder_channels[-1], + out_channels=out_channels, + kernel_size=3, + ) + self.config = config + + def forward(self, x): + + x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) + x = self.decoder(x, features) + out = self.reg_head(x) + return out + + +CONFIGS = { + "ViT-V-Net": get_3DReg_config(), +}