diff --git a/pyproject.toml b/pyproject.toml
index 32ced4a..8a9fc71 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,7 +21,8 @@ dependencies = [
'torchvision',
'numpy',
'tqdm',
- 'cellpose'
+ 'cellpose',
+ 'ml-collections',
]
[project.optional-dependencies]
diff --git a/src/cellmap_models/__init__.py b/src/cellmap_models/__init__.py
index bbf2be9..fcbffb4 100644
--- a/src/cellmap_models/__init__.py
+++ b/src/cellmap_models/__init__.py
@@ -3,4 +3,4 @@
"""
from .utils import download_url_to_file
-from .pytorch import cosem, cellpose
+from .pytorch import cosem, cellpose, untrained_models
diff --git a/src/cellmap_models/pytorch/untrained_models/README.md b/src/cellmap_models/pytorch/untrained_models/README.md
new file mode 100644
index 0000000..9205fef
--- /dev/null
+++ b/src/cellmap_models/pytorch/untrained_models/README.md
@@ -0,0 +1,22 @@
+
+
+## This directory contains various untrained PyTorch model architectures.
+
+## Models
+
+***ResNet***: Parameterizable 2D and 3D ResNet models with a variable number of layers and channels. This model is based on the original ResNet architecture with the addition of a decoding path, which mirrors the encoder, after the bottleneck, to produce an image output.
+
+***UNet2D***: A simple 2D UNet model with a variable number of output channels.
+
+***UNet3D***: A simple 3D UNet model with a variable number of output channels.
+
+***ViTVNet***: A 3D VNet model with a Vision Transformer (ViT) encoder. This model is based on the original VNet architecture with the addition of a ViT encoder in place of the original convolutional encoder.
+
+## Usage
+
+To use these models, you can import them directly from the `cellmap_models.pytorch.untrained_models` module. For example, to import the ResNet model, you can use the following code:
+
+```python
+from cellmap_models.pytorch.untrained_models import ResNet
+model = ResNet(ndim=2, input_nc=1, output_nc=3, n_blocks=18)
+```
diff --git a/src/cellmap_models/pytorch/untrained_models/__init__.py b/src/cellmap_models/pytorch/untrained_models/__init__.py
new file mode 100644
index 0000000..5bedbad
--- /dev/null
+++ b/src/cellmap_models/pytorch/untrained_models/__init__.py
@@ -0,0 +1,4 @@
+from .resnet import ResNet
+from .vitnet import ViTVNet
+from .unet_2D import UNet2D
+from .unet_3D import UNet3D
diff --git a/src/cellmap_models/pytorch/untrained_models/resnet.py b/src/cellmap_models/pytorch/untrained_models/resnet.py
new file mode 100644
index 0000000..3b5a5b7
--- /dev/null
+++ b/src/cellmap_models/pytorch/untrained_models/resnet.py
@@ -0,0 +1,458 @@
+import functools
+import torch
+
+
+class Resnet2D(torch.nn.Module):
+ """Resnet that consists of Resnet blocks between a few downsampling/upsampling operations.
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
+ """
+
+ def __init__(
+ self,
+ input_nc=1,
+ output_nc=None,
+ ngf=64,
+ norm_layer=torch.nn.InstanceNorm2d,
+ use_dropout=False,
+ n_blocks=6,
+ padding_type="reflect",
+ activation=torch.nn.ReLU,
+ n_downsampling=2,
+ ):
+ """Construct a Resnet
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images (default is ngf)
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers
+ n_blocks (int) -- the number of ResNet blocks
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zeros | valid
+ activation -- non-linearity layer to apply (default is ReLU)
+ n_downsampling -- number of times to downsample data before ResBlocks
+ """
+ assert n_blocks >= 0
+ super().__init__()
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == torch.nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == torch.nn.InstanceNorm2d
+
+ if output_nc is None:
+ output_nc = ngf
+
+ p = 0
+ updown_p = 1
+ padder = []
+ if padding_type.lower() == "reflect" or padding_type.lower() == "same":
+ padder = [torch.nn.ReflectionPad2d(3)]
+ elif padding_type.lower() == "replicate":
+ padder = [torch.nn.ReplicationPad2d(3)]
+ elif padding_type.lower() == "zeros":
+ p = 3
+ elif padding_type.lower() == "valid":
+ p = "valid"
+ updown_p = 0
+
+ model = []
+ model += padder.copy()
+ model += [
+ torch.nn.Conv2d(input_nc, ngf, kernel_size=7, padding=p, bias=use_bias),
+ norm_layer(ngf),
+ activation(),
+ ]
+
+ for i in range(n_downsampling): # add downsampling layers
+ mult = 2**i
+ model += [
+ torch.nn.Conv2d(
+ ngf * mult,
+ ngf * mult * 2,
+ kernel_size=3,
+ stride=2,
+ padding=updown_p,
+ bias=use_bias,
+ ),
+ norm_layer(ngf * mult * 2),
+ activation(),
+ ]
+
+ mult = 2**n_downsampling
+ for i in range(n_blocks): # add ResNet blocks
+
+ model += [
+ ResnetBlock2D(
+ ngf * mult,
+ padding_type=padding_type.lower(),
+ norm_layer=norm_layer,
+ use_dropout=use_dropout,
+ use_bias=use_bias,
+ activation=activation,
+ )
+ ]
+
+ for i in range(n_downsampling): # add upsampling layers
+ mult = 2 ** (n_downsampling - i)
+ model += [
+ torch.nn.ConvTranspose2d(
+ ngf * mult,
+ int(ngf * mult / 2),
+ kernel_size=3,
+ stride=2,
+ padding=updown_p,
+ output_padding=updown_p,
+ bias=use_bias,
+ ),
+ norm_layer(int(ngf * mult / 2)),
+ activation(),
+ ]
+ model += padder.copy()
+ model += [torch.nn.Conv2d(ngf, output_nc, kernel_size=7, padding=p)]
+
+ self.model = torch.nn.Sequential(*model)
+
+ def forward(self, input):
+ """Standard forward"""
+ return self.model(input)
+
+
+class ResnetBlock2D(torch.nn.Module):
+ """Define a Resnet block"""
+
+ def __init__(
+ self,
+ dim,
+ padding_type,
+ norm_layer,
+ use_dropout,
+ use_bias,
+ activation=torch.nn.ReLU,
+ ):
+ """Initialize the Resnet block
+ A resnet block is a conv block with skip connections
+ We construct a conv block with build_conv_block function,
+ and implement skip connections in function.
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
+ """
+ super().__init__()
+ self.conv_block = self.build_conv_block(
+ dim, padding_type, norm_layer, use_dropout, use_bias, activation
+ )
+ self.padding_type = padding_type
+
+ def build_conv_block(
+ self,
+ dim,
+ padding_type,
+ norm_layer,
+ use_dropout,
+ use_bias,
+ activation=torch.nn.ReLU,
+ ):
+ """Construct a convolutional block.
+ Parameters:
+ dim (int) -- the number of channels in the conv layer.
+ padding_type (str) -- the name of padding layer: reflect | replicate | zeros | valid
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers.
+ use_bias (bool) -- if the conv layer uses bias or not
+ activation -- non-linearity layer to apply (default is ReLU)
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer)
+ """
+ p = 0
+ padder = []
+ if padding_type == "reflect" or padding_type.lower() == "same":
+ padder = [torch.nn.ReflectionPad2d(1)]
+ elif padding_type == "replicate":
+ padder = [torch.nn.ReplicationPad2d(1)]
+ elif padding_type == "zeros":
+ p = 1
+ elif padding_type == "valid":
+ p = "valid"
+ else:
+ raise NotImplementedError("padding [%s] is not implemented" % padding_type)
+
+ conv_block = []
+ conv_block += padder.copy()
+
+ conv_block += [
+ torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
+ norm_layer(dim),
+ activation(),
+ ]
+ if use_dropout:
+ conv_block += [torch.nn.Dropout(0.2)]
+
+ conv_block += padder.copy()
+ conv_block += [
+ torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
+ norm_layer(dim),
+ ]
+
+ return torch.nn.Sequential(*conv_block)
+
+ def crop(self, x, shape):
+ """Center-crop x to match spatial dimensions given by shape."""
+
+ x_target_size = x.size()[:-2] + shape
+
+ offset = tuple(
+ torch.div((a - b), 2, rounding_mode="trunc")
+ for a, b in zip(x.size(), x_target_size)
+ )
+
+ slices = tuple(slice(o, o + s) for o, s in zip(offset, x_target_size))
+
+ return x[slices]
+
+ def forward(self, x):
+ """Forward function (with skip connections)"""
+ if self.padding_type == "valid": # crop for valid networks
+ res = self.conv_block(x)
+ out = self.crop(x, res.size()[-2:]) + res
+ else:
+ out = x + self.conv_block(x) # add skip connections
+ return out
+
+
+class Resnet3D(torch.nn.Module):
+ """Resnet that consists of Resnet blocks between a few downsampling/upsampling operations.
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
+ """
+
+ def __init__(
+ self,
+ input_nc=1,
+ output_nc=None,
+ ngf=64,
+ norm_layer=torch.nn.InstanceNorm3d,
+ use_dropout=False,
+ n_blocks=6,
+ padding_type="reflect",
+ activation=torch.nn.ReLU,
+ n_downsampling=2,
+ ):
+ """Construct a Resnet
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers
+ n_blocks (int) -- the number of ResNet blocks
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zeros | valid
+ activation -- non-linearity layer to apply (default is ReLU)
+ n_downsampling -- number of times to downsample data before ResBlocks
+ """
+ assert n_blocks >= 0
+ super().__init__()
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == torch.nn.InstanceNorm3d
+ else:
+ use_bias = norm_layer == torch.nn.InstanceNorm3d
+
+ if output_nc is None:
+ output_nc = ngf
+
+ p = 0
+ updown_p = 1
+ padder = []
+ if padding_type.lower() == "reflect" or padding_type.lower() == "same":
+ padder = [torch.nn.ReflectionPad3d(3)]
+ elif padding_type.lower() == "replicate":
+ padder = [torch.nn.ReplicationPad3d(3)]
+ elif padding_type.lower() == "zeros":
+ p = 3
+ elif padding_type.lower() == "valid":
+ p = "valid"
+ updown_p = 0
+
+ model = []
+ model += padder.copy()
+ model += [
+ torch.nn.Conv3d(input_nc, ngf, kernel_size=7, padding=p, bias=use_bias),
+ norm_layer(ngf),
+ activation(),
+ ]
+
+ for i in range(n_downsampling): # add downsampling layers
+ mult = 2**i
+ model += [
+ torch.nn.Conv3d(
+ ngf * mult,
+ ngf * mult * 2,
+ kernel_size=3,
+ stride=2,
+ padding=updown_p,
+ bias=use_bias,
+ ), # TODO: Make actually use padding_type for every convolution (currently does zeros if not valid)
+ norm_layer(ngf * mult * 2),
+ activation(),
+ ]
+
+ mult = 2**n_downsampling
+ for i in range(n_blocks): # add ResNet blocks
+
+ model += [
+ ResnetBlock3D(
+ ngf * mult,
+ padding_type=padding_type.lower(),
+ norm_layer=norm_layer,
+ use_dropout=use_dropout,
+ use_bias=use_bias,
+ activation=activation,
+ )
+ ]
+
+ for i in range(n_downsampling): # add upsampling layers
+ mult = 2 ** (n_downsampling - i)
+ model += [
+ torch.nn.ConvTranspose3d(
+ ngf * mult,
+ int(ngf * mult / 2),
+ kernel_size=3,
+ stride=2,
+ padding=updown_p,
+ output_padding=updown_p,
+ bias=use_bias,
+ ),
+ norm_layer(int(ngf * mult / 2)),
+ activation(),
+ ]
+ model += padder.copy()
+ model += [torch.nn.Conv3d(ngf, output_nc, kernel_size=7, padding=p)]
+
+ self.model = torch.nn.Sequential(*model)
+
+ def forward(self, input):
+ """Standard forward"""
+ return self.model(input)
+
+
+class ResnetBlock3D(torch.nn.Module):
+ """Define a Resnet block"""
+
+ def __init__(
+ self,
+ dim,
+ padding_type,
+ norm_layer,
+ use_dropout,
+ use_bias,
+ activation=torch.nn.ReLU,
+ ):
+ """Initialize the Resnet block
+ A resnet block is a conv block with skip connections
+ We construct a conv block with build_conv_block function,
+ and implement skip connections in function.
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
+ """
+ super().__init__()
+ self.conv_block = self.build_conv_block(
+ dim, padding_type, norm_layer, use_dropout, use_bias, activation
+ )
+ self.padding_type = padding_type
+
+ def build_conv_block(
+ self,
+ dim,
+ padding_type,
+ norm_layer,
+ use_dropout,
+ use_bias,
+ activation=torch.nn.ReLU,
+ ):
+ """Construct a convolutional block.
+ Parameters:
+ dim (int) -- the number of channels in the conv layer.
+ padding_type (str) -- the name of padding layer: reflect | replicate | zeros | valid
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers.
+ use_bias (bool) -- if the conv layer uses bias or not
+ activation -- non-linearity layer to apply (default is ReLU)
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer)
+ """
+ p = 0
+ padder = []
+ if padding_type == "reflect" or padding_type.lower() == "same":
+ padder = [torch.nn.ReflectionPad3d(1)]
+ elif padding_type == "replicate":
+ padder = [torch.nn.ReplicationPad3d(1)]
+ elif padding_type == "zeros":
+ p = 1
+ elif padding_type == "valid":
+ p = "valid"
+ else:
+ raise NotImplementedError("padding [%s] is not implemented" % padding_type)
+
+ conv_block = []
+ conv_block += padder.copy()
+
+ conv_block += [
+ torch.nn.Conv3d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
+ norm_layer(dim),
+ activation(),
+ ]
+ if use_dropout:
+ conv_block += [torch.nn.Dropout(0.2)]
+
+ conv_block += padder.copy()
+ conv_block += [
+ torch.nn.Conv3d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
+ norm_layer(dim),
+ ]
+
+ return torch.nn.Sequential(*conv_block)
+
+ def crop(self, x, shape):
+ """Center-crop x to match spatial dimensions given by shape."""
+
+ x_target_size = x.size()[:-3] + shape
+
+ offset = tuple(
+ torch.div((a - b), 2, rounding_mode="trunc")
+ for a, b in zip(x.size(), x_target_size)
+ )
+
+ slices = tuple(slice(o, o + s) for o, s in zip(offset, x_target_size))
+
+ return x[slices]
+
+ def forward(self, x):
+ """Forward function (with skip connections)"""
+ if self.padding_type == "valid": # crop for valid networks
+ res = self.conv_block(x)
+ out = self.crop(x, res.size()[-3:]) + res
+ else:
+ out = x + self.conv_block(x) # add skip connections
+ return out
+
+
+class ResNet(Resnet2D, Resnet3D):
+ """Resnet that consists of Resnet blocks between a few downsampling/upsampling operations.
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
+ """
+
+ def __init__(self, ndims, **kwargs):
+ """Construct a Resnet
+ Parameters:
+ ndims (int) -- the number of dimensions of the input data
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images (default is ngf)
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers
+ n_blocks (int) -- the number of ResNet blocks
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zeros | valid
+ activation -- non-linearity layer to apply (default is ReLU)
+ n_downsampling -- number of times to downsample data before ResBlocks
+ """
+ if ndims == 2:
+ Resnet2D.__init__(self, **kwargs)
+ elif ndims == 3:
+ Resnet3D.__init__(self, **kwargs)
+ else:
+ raise ValueError(
+ ndims,
+ "Only 2D or 3D currently implemented. Feel free to contribute more!",
+ )
diff --git a/src/cellmap_models/pytorch/untrained_models/unet_2D.py b/src/cellmap_models/pytorch/untrained_models/unet_2D.py
new file mode 100644
index 0000000..34cfccb
--- /dev/null
+++ b/src/cellmap_models/pytorch/untrained_models/unet_2D.py
@@ -0,0 +1,114 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+""" Full assembly of the parts to form the complete network """
+# Original source code from:
+# https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py
+
+class UNet2D(nn.Module):
+ def __init__(self, n_channels, n_classes, trilinear=False):
+ super(UNet2D, self).__init__()
+ self.n_channels = n_channels
+ self.n_classes = n_classes
+ self.trilinear = trilinear
+
+ self.inc = DoubleConv(n_channels, 64)
+ self.down1 = Down(64, 128)
+ self.down2 = Down(128, 256)
+ self.down3 = Down(256, 512)
+ factor = 2 if trilinear else 1
+ self.down4 = Down(512, 1024 // factor)
+ self.up1 = Up(1024, 512 // factor, trilinear)
+ self.up2 = Up(512, 256 // factor, trilinear)
+ self.up3 = Up(256, 128 // factor, trilinear)
+ self.up4 = Up(128, 64, trilinear)
+ self.outc = OutConv(64, n_classes)
+
+ def forward(self, x):
+ x1 = self.inc(x)
+ x2 = self.down1(x1)
+ x3 = self.down2(x2)
+ x4 = self.down3(x3)
+ x5 = self.down4(x4)
+ x = self.up1(x5, x4)
+ x = self.up2(x, x3)
+ x = self.up3(x, x2)
+ x = self.up4(x, x1)
+ logits = self.outc(x)
+ return logits
+
+""" Parts of the 2D U-Net model """
+# Original source code from:
+# https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py
+
+class DoubleConv(nn.Module):
+ """(convolution => [BN] => ReLU) * 2"""
+
+ def __init__(self, in_channels, out_channels, mid_channels=None):
+ super().__init__()
+ if not mid_channels:
+ mid_channels = out_channels
+ self.double_conv = nn.Sequential(
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(mid_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True),
+ )
+
+ def forward(self, x):
+ return self.double_conv(x)
+
+
+class Down(nn.Module):
+ """Downscaling with maxpool then double conv"""
+
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.maxpool_conv = nn.Sequential(
+ nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)
+ )
+
+ def forward(self, x):
+ return self.maxpool_conv(x)
+
+
+class Up(nn.Module):
+ """Upscaling then double conv"""
+
+ def __init__(self, in_channels, out_channels, bilinear=True):
+ super().__init__()
+
+ # if bilinear, use the normal convolutions to reduce the number of channels
+ if bilinear:
+ self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
+ else:
+ self.up = nn.ConvTranspose2d(
+ in_channels, in_channels // 2, kernel_size=2, stride=2
+ )
+ self.conv = DoubleConv(in_channels, out_channels)
+
+ def forward(self, x1, x2):
+ x1 = self.up(x1)
+ # input is CHW
+ diffY = x2.size()[2] - x1.size()[2]
+ diffX = x2.size()[3] - x1.size()[3]
+
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
+ # if you have padding issues, see
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
+ x = torch.cat([x2, x1], dim=1)
+ return self.conv(x)
+
+
+class OutConv(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super(OutConv, self).__init__()
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
+
+ def forward(self, x):
+ return self.conv(x)
diff --git a/src/cellmap_models/pytorch/untrained_models/unet_3D.py b/src/cellmap_models/pytorch/untrained_models/unet_3D.py
new file mode 100644
index 0000000..cc648b5
--- /dev/null
+++ b/src/cellmap_models/pytorch/untrained_models/unet_3D.py
@@ -0,0 +1,128 @@
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+""" Full assembly of the parts to form the complete network """
+# Original source code from:
+# https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py
+
+class UNet3D(nn.Module):
+ def __init__(self, n_channels, n_classes, trilinear=False):
+ super(UNet3D, self).__init__()
+ self.n_channels = n_channels
+ self.n_classes = n_classes
+ self.trilinear = trilinear
+
+ self.inc = DoubleConv(n_channels, 64)
+ self.down1 = Down(64, 128)
+ self.down2 = Down(128, 256)
+ self.down3 = Down(256, 512)
+ factor = 2 if trilinear else 1
+ self.down4 = Down(512, 1024 // factor)
+ self.up1 = Up(1024, 512 // factor, trilinear)
+ self.up2 = Up(512, 256 // factor, trilinear)
+ self.up3 = Up(256, 128 // factor, trilinear)
+ self.up4 = Up(128, 64, trilinear)
+ self.outc = OutConv(64, n_classes)
+
+ def forward(self, x):
+ x1 = self.inc(x)
+ x2 = self.down1(x1)
+ x3 = self.down2(x2)
+ x4 = self.down3(x3)
+ x5 = self.down4(x4)
+ x = self.up1(x5, x4)
+ x = self.up2(x, x3)
+ x = self.up3(x, x2)
+ x = self.up4(x, x1)
+ logits = self.outc(x)
+ return logits
+
+""" Parts of the U-Net model """
+# Adapted from:
+# https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py
+# By Emma Avetissian, @aemmav
+
+class DoubleConv(nn.Module):
+ """(convolution => [BN] => ReLU) * 2"""
+
+ def __init__(self, in_channels, out_channels, mid_channels=None):
+ super().__init__()
+ if not mid_channels:
+ mid_channels = out_channels
+ self.double_conv = nn.Sequential(
+ nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm3d(mid_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm3d(out_channels),
+ nn.ReLU(inplace=True),
+ )
+
+ def forward(self, x):
+ return self.double_conv(x)
+
+
+class Down(nn.Module):
+ """Downscaling with maxpool then double conv"""
+
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.maxpool_conv = nn.Sequential(
+ nn.MaxPool3d(2), DoubleConv(in_channels, out_channels)
+ )
+
+ def forward(self, x):
+ return self.maxpool_conv(x)
+
+
+class Up(nn.Module):
+ """Upscaling then double conv"""
+
+ def __init__(self, in_channels, out_channels, trilinear=True):
+ super().__init__()
+
+ # if trilinear, use the normal convolutions to reduce the number of channels
+ if trilinear:
+ self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True)
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
+ else:
+ self.up = nn.ConvTranspose3d(
+ in_channels, in_channels // 2, kernel_size=2, stride=2
+ )
+ self.conv = DoubleConv(in_channels, out_channels)
+
+ def forward(self, x1, x2):
+ x1 = self.up(x1)
+ # input is CHW
+ diffY = x2.size()[2] - x1.size()[2] # height
+ diffX = x2.size()[3] - x1.size()[3] # width
+ diffZ = x2.size()[4] - x1.size()[4] # depth
+
+ x1 = F.pad(
+ x1,
+ [
+ diffX // 2,
+ diffX - diffX // 2,
+ diffY // 2,
+ diffY - diffY // 2,
+ diffZ // 2,
+ diffZ - diffZ // 2,
+ ],
+ )
+
+ # if you have padding issues, see
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
+ x = torch.cat([x2, x1], dim=1)
+ return self.conv(x)
+
+
+class OutConv(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super(OutConv, self).__init__()
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
+
+ def forward(self, x):
+ return self.conv(x)
diff --git a/src/cellmap_models/pytorch/untrained_models/vitnet.py b/src/cellmap_models/pytorch/untrained_models/vitnet.py
new file mode 100644
index 0000000..726eaa9
--- /dev/null
+++ b/src/cellmap_models/pytorch/untrained_models/vitnet.py
@@ -0,0 +1,448 @@
+# Adapted from:
+# https://github.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch
+# By Emma Avetissian
+
+# coding=utf-8
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import logging
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as nnf
+from torch.nn import Dropout, Softmax, Linear, Conv3d, LayerNorm
+from torch.nn.modules.utils import _pair, _triple
+from torch.distributions.normal import Normal
+import ml_collections
+
+
+def get_3DReg_config():
+ config = ml_collections.ConfigDict()
+ config.patches = ml_collections.ConfigDict({"size": (8, 8, 8)})
+ config.patches.grid = (8, 8, 8)
+ config.hidden_size = 252
+ config.transformer = ml_collections.ConfigDict()
+ config.transformer.mlp_dim = 3072
+ config.transformer.num_heads = 12
+ config.transformer.num_layers = 12
+ config.transformer.attention_dropout_rate = 0.0
+ config.transformer.dropout_rate = 0.1
+ config.patch_size = 8
+
+ config.conv_first_channel = 512
+ config.encoder_channels = (16, 32, 32)
+ config.down_factor = 2
+ config.down_num = 2
+ config.decoder_channels = (96, 48, 32, 32, 16)
+ config.skip_channels = (32, 32, 32, 32, 16)
+ config.n_skip = 5
+ config.input_channels = 1
+ return config
+
+
+logger = logging.getLogger(__name__)
+
+
+ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
+ATTENTION_K = "MultiHeadDotProductAttention_1/key"
+ATTENTION_V = "MultiHeadDotProductAttention_1/value"
+ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
+FC_0 = "MlpBlock_3/Dense_0"
+FC_1 = "MlpBlock_3/Dense_1"
+ATTENTION_NORM = "LayerNorm_0"
+MLP_NORM = "LayerNorm_2"
+
+
+def np2th(weights, conv=False):
+ """Possibly convert HWIO to OIHW."""
+ if conv:
+ weights = weights.transpose([3, 2, 0, 1])
+ return torch.from_numpy(weights)
+
+
+def swish(x):
+ return x * torch.sigmoid(x)
+
+
+ACT2FN = {
+ "gelu": torch.nn.functional.gelu,
+ "relu": torch.nn.functional.relu,
+ "swish": swish,
+}
+
+
+class Attention(nn.Module):
+ def __init__(self, config, vis):
+ super(Attention, self).__init__()
+ self.vis = vis
+ self.num_attention_heads = config.transformer["num_heads"]
+ self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = Linear(config.hidden_size, self.all_head_size)
+ self.key = Linear(config.hidden_size, self.all_head_size)
+ self.value = Linear(config.hidden_size, self.all_head_size)
+
+ self.out = Linear(config.hidden_size, config.hidden_size)
+ self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
+ self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
+
+ self.softmax = Softmax(dim=-1)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(self, hidden_states):
+ mixed_query_layer = self.query(hidden_states)
+ mixed_key_layer = self.key(hidden_states)
+ mixed_value_layer = self.value(hidden_states)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+ key_layer = self.transpose_for_scores(mixed_key_layer)
+ value_layer = self.transpose_for_scores(mixed_value_layer)
+
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ attention_probs = self.softmax(attention_scores)
+ weights = attention_probs if self.vis else None
+ attention_probs = self.attn_dropout(attention_probs)
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+ attention_output = self.out(context_layer)
+ attention_output = self.proj_dropout(attention_output)
+ return attention_output, weights
+
+
+class Mlp(nn.Module):
+ def __init__(self, config):
+ super(Mlp, self).__init__()
+ self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
+ self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
+ self.act_fn = ACT2FN["gelu"]
+ self.dropout = Dropout(config.transformer["dropout_rate"])
+
+ self._init_weights()
+
+ def _init_weights(self):
+ nn.init.xavier_uniform_(self.fc1.weight)
+ nn.init.xavier_uniform_(self.fc2.weight)
+ nn.init.normal_(self.fc1.bias, std=1e-6)
+ nn.init.normal_(self.fc2.bias, std=1e-6)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act_fn(x)
+ x = self.dropout(x)
+ x = self.fc2(x)
+ x = self.dropout(x)
+ return x
+
+
+class Embeddings(nn.Module):
+ """Construct the embeddings from patch, position embeddings."""
+
+ def __init__(self, config, img_size):
+ super(Embeddings, self).__init__()
+ self.config = config
+ down_factor = config.down_factor
+ patch_size = _triple(config.patches["size"])
+ n_patches = int(
+ (img_size[0] / 2**down_factor // patch_size[0])
+ * (img_size[1] / 2**down_factor // patch_size[1])
+ * (img_size[2] / 2**down_factor // patch_size[2])
+ )
+ self.hybrid_model = CNNEncoder(config, n_channels=config.input_channels)
+ in_channels = config["encoder_channels"][-1]
+ self.patch_embeddings = Conv3d(
+ in_channels=in_channels,
+ out_channels=config.hidden_size,
+ kernel_size=patch_size,
+ stride=patch_size,
+ )
+ self.position_embeddings = nn.Parameter(
+ torch.zeros(1, n_patches, config.hidden_size)
+ )
+
+ self.dropout = Dropout(config.transformer["dropout_rate"])
+
+ def forward(self, x):
+ x, features = self.hybrid_model(x)
+ x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
+ x = x.flatten(2)
+ x = x.transpose(-1, -2) # (B, n_patches, hidden)
+ embeddings = x + self.position_embeddings
+ embeddings = self.dropout(embeddings)
+ return embeddings, features
+
+
+class Block(nn.Module):
+ def __init__(self, config, vis):
+ super(Block, self).__init__()
+ self.hidden_size = config.hidden_size
+ self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
+ self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
+ self.ffn = Mlp(config)
+ self.attn = Attention(config, vis)
+
+ def forward(self, x):
+ h = x
+
+ x = self.attention_norm(x)
+ x, weights = self.attn(x)
+ x = x + h
+
+ h = x
+ x = self.ffn_norm(x)
+ x = self.ffn(x)
+ x = x + h
+ return x, weights
+
+
+class Encoder(nn.Module):
+ def __init__(self, config, vis):
+ super(Encoder, self).__init__()
+ self.vis = vis
+ self.layer = nn.ModuleList()
+ self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
+ for _ in range(config.transformer["num_layers"]):
+ layer = Block(config, vis)
+ self.layer.append(copy.deepcopy(layer))
+
+ def forward(self, hidden_states):
+ attn_weights = []
+ for layer_block in self.layer:
+ hidden_states, weights = layer_block(hidden_states)
+ if self.vis:
+ attn_weights.append(weights)
+ encoded = self.encoder_norm(hidden_states)
+ return encoded, attn_weights
+
+
+class Transformer(nn.Module):
+ def __init__(self, config, img_size, vis):
+ super(Transformer, self).__init__()
+ self.embeddings = Embeddings(config, img_size=img_size)
+ self.encoder = Encoder(config, vis)
+
+ def forward(self, input_ids):
+ embedding_output, features = self.embeddings(input_ids)
+ encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
+ return encoded, attn_weights, features
+
+
+class Conv3dReLU(nn.Sequential):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ padding=0,
+ stride=1,
+ use_batchnorm=True,
+ ):
+ conv = nn.Conv3d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ bias=not (use_batchnorm),
+ )
+ relu = nn.ReLU(inplace=True)
+
+ bn = nn.BatchNorm3d(out_channels)
+
+ super(Conv3dReLU, self).__init__(conv, bn, relu)
+
+
+class DecoderBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ skip_channels=0,
+ use_batchnorm=True,
+ ):
+ super().__init__()
+ self.conv1 = Conv3dReLU(
+ in_channels + skip_channels,
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ use_batchnorm=use_batchnorm,
+ )
+ self.conv2 = Conv3dReLU(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ use_batchnorm=use_batchnorm,
+ )
+ self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False)
+
+ def forward(self, x, skip=None):
+ x = self.up(x)
+ if skip is not None:
+ x = torch.cat([x, skip], dim=1)
+ x = self.conv1(x)
+ x = self.conv2(x)
+ return x
+
+
+class DecoderCup(nn.Module):
+ def __init__(self, config, img_size):
+ super().__init__()
+ self.config = config
+ self.down_factor = config.down_factor
+ head_channels = config.conv_first_channel
+ self.img_size = img_size
+ self.conv_more = Conv3dReLU(
+ config.hidden_size,
+ head_channels,
+ kernel_size=3,
+ padding=1,
+ use_batchnorm=True,
+ )
+ decoder_channels = config.decoder_channels
+ in_channels = [head_channels] + list(decoder_channels[:-1])
+ out_channels = decoder_channels
+ self.patch_size = _triple(config.patches["size"])
+ skip_channels = self.config.skip_channels
+ blocks = [
+ DecoderBlock(in_ch, out_ch, sk_ch)
+ for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
+ ]
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, hidden_states, features=None):
+ B, n_patch, hidden = (
+ hidden_states.size()
+ ) # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
+ l, h, w = (
+ (self.img_size[0] // 2**self.down_factor // self.patch_size[0]),
+ (self.img_size[1] // 2**self.down_factor // self.patch_size[1]),
+ (self.img_size[2] // 2**self.down_factor // self.patch_size[2]),
+ )
+ x = hidden_states.permute(0, 2, 1)
+ x = x.contiguous().view(B, hidden, l, h, w)
+ x = self.conv_more(x)
+ for i, decoder_block in enumerate(self.blocks):
+ if features is not None:
+ skip = features[i] if (i < self.config.n_skip) else None
+ # print(skip.shape)
+ else:
+ skip = None
+ x = decoder_block(x, skip=skip)
+ return x
+
+
+class DoubleConv(nn.Module):
+ """(convolution => [BN] => ReLU) * 2"""
+
+ def __init__(self, in_channels, out_channels, mid_channels=None):
+ super().__init__()
+ if not mid_channels:
+ mid_channels = out_channels
+ self.double_conv = nn.Sequential(
+ nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ )
+
+ def forward(self, x):
+ return self.double_conv(x)
+
+
+class Down(nn.Module):
+ """Downscaling with maxpool then double conv"""
+
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.maxpool_conv = nn.Sequential(
+ nn.MaxPool3d(2), DoubleConv(in_channels, out_channels)
+ )
+
+ def forward(self, x):
+ return self.maxpool_conv(x)
+
+
+class CNNEncoder(nn.Module):
+ def __init__(self, config, n_channels=2):
+ super(CNNEncoder, self).__init__()
+ self.n_channels = n_channels
+ decoder_channels = config.decoder_channels
+ encoder_channels = config.encoder_channels
+ self.down_num = config.down_num
+ self.inc = DoubleConv(n_channels, encoder_channels[0])
+ self.down1 = Down(encoder_channels[0], encoder_channels[1])
+ self.down2 = Down(encoder_channels[1], encoder_channels[2])
+ self.width = encoder_channels[-1]
+
+ def forward(self, x):
+ features = []
+ x1 = self.inc(x)
+ features.append(x1)
+ x2 = self.down1(x1)
+ features.append(x2)
+ feats = self.down2(x2)
+ features.append(feats)
+ feats_down = feats
+ for i in range(self.down_num):
+ feats_down = nn.MaxPool3d(2)(feats_down)
+ features.append(feats_down)
+ return feats, features[::-1]
+
+
+class RegistrationHead(nn.Sequential):
+ def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
+ conv3d = nn.Conv3d(
+ in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
+ )
+ conv3d.weight = nn.Parameter(Normal(0, 1e-5).sample(conv3d.weight.shape))
+ conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape))
+ super().__init__(conv3d)
+
+
+class ViTVNet(nn.Module):
+ def __init__(
+ self, out_channels, config="ViT-V-Net", img_size=(128, 128, 128), vis=False
+ ):
+ super(ViTVNet, self).__init__()
+ if isinstance(config, str):
+ config = CONFIGS[config]
+ else:
+ assert isinstance(
+ config, ml_collections.ConfigDict
+ ), "Is not a config object or the name of one"
+ self.transformer = Transformer(config, img_size, vis)
+ self.decoder = DecoderCup(config, img_size)
+ self.reg_head = RegistrationHead(
+ in_channels=config.decoder_channels[-1],
+ out_channels=out_channels,
+ kernel_size=3,
+ )
+ self.config = config
+
+ def forward(self, x):
+
+ x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
+ x = self.decoder(x, features)
+ out = self.reg_head(x)
+ return out
+
+
+CONFIGS = {
+ "ViT-V-Net": get_3DReg_config(),
+}