diff --git a/docs/source/highlights.md b/docs/source/highlights.md index f7895bf65d..113fc1a1c5 100644 --- a/docs/source/highlights.md +++ b/docs/source/highlights.md @@ -72,7 +72,16 @@ torch.backends.cudnn.benchmark = False There are domain-specific loss functions in the medical research area which are different from the generic computer vision ones. As an important module of MONAI, these loss functions are implemented in PyTorch, such as Dice loss and generalized Dice loss. ## Network architectures -Some deep neural network architectures have shown to be particularly effective for medical imaging analysis tasks. MONAI implements reference networks with the aims of both flexibility and code readability. +Some deep neural network architectures have shown to be particularly effective for medical imaging analysis tasks. MONAI implements reference networks with the aims of both flexibility and code readability. +In order to leverage the common network layers and blocks, MONAI provides several predefined layers and blocks which are compatible with 1D, 2D and 3D networks. Users can easily integrate the layer factories in their own networks. +For example: +```py +# add Convolution layer to the network which is compatible with different spatial dimensions. +dimension = 3 +name = Conv.CONVTRANS +conv_type = Conv[name, dimension] +add_module('conv1', conv_type(in_channels, out_channels, kernel_size=1, bias=False)) +``` ## Evaluation To run model inferences and evaluate the model quality, MONAI provides reference implementations for the relevant widely-used approaches. Currently, several popular evaluation metrics and inference patterns are included: diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 5d456c1528..8b05382ba9 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -26,26 +26,14 @@ Blocks Layers ------ -`get_conv_type` -~~~~~~~~~~~~~~~ -.. automethod:: monai.networks.layers.factories.get_conv_type - -`get_dropout_type` -~~~~~~~~~~~~~~~~~~ -.. automethod:: monai.networks.layers.factories.get_dropout_type - -`get_normalize_type` -~~~~~~~~~~~~~~~~~~~~ -.. automethod:: monai.networks.layers.factories.get_normalize_type - -`get_maxpooling_type` -~~~~~~~~~~~~~~~~~~~~~ -.. automethod:: monai.networks.layers.factories.get_maxpooling_type - -`get_avgpooling_type` -~~~~~~~~~~~~~~~~~~~~~ -.. automethod:: monai.networks.layers.factories.get_avgpooling_type - +`Factories` +~~~~~~~~~~~ +.. automodule:: monai.networks.layers.factories +.. currentmodule:: monai.networks.layers.factories + +`LayerFactory` +############## +.. autoclass:: LayerFactory .. automodule:: monai.networks.layers.simplelayers .. currentmodule:: monai.networks.layers.simplelayers diff --git a/monai/networks/blocks/convolutions.py b/monai/networks/blocks/convolutions.py index b8b2814c67..3c618636d0 100644 --- a/monai/networks/blocks/convolutions.py +++ b/monai/networks/blocks/convolutions.py @@ -12,14 +12,17 @@ import numpy as np import torch.nn as nn -from monai.networks.layers.factories import get_conv_type, get_dropout_type, get_normalize_type +from monai.networks.layers.factories import Dropout, Norm, Act, Conv, split_args from monai.networks.layers.convutils import same_padding class Convolution(nn.Sequential): + """ + Constructs a convolution with optional dropout, normalization, and activation layers. + """ - def __init__(self, dimensions, in_channels, out_channels, strides=1, kernel_size=3, instance_norm=True, dropout=0, - dilation=1, bias=True, conv_only=False, is_transposed=False): + def __init__(self, dimensions, in_channels, out_channels, strides=1, kernel_size=3, act=Act.PRELU, + norm=Norm.INSTANCE, dropout=None, dilation=1, bias=True, conv_only=False, is_transposed=False): super().__init__() self.dimensions = dimensions self.in_channels = in_channels @@ -27,9 +30,25 @@ def __init__(self, dimensions, in_channels, out_channels, strides=1, kernel_size self.is_transposed = is_transposed padding = same_padding(kernel_size, dilation) - normalize_type = get_normalize_type(dimensions, instance_norm) - conv_type = get_conv_type(dimensions, is_transposed) - drop_type = get_dropout_type(dimensions) + conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, dimensions] + + # define the normalisation type and the arguments to the constructor + norm_name, norm_args = split_args(norm) + norm_type = Norm[norm_name, dimensions] + + # define the activation type and the arguments to the constructor + act_name, act_args = split_args(act) + act_type = Act[act_name] + + if dropout: + # if dropout was specified simply as a p value, use default name and make a keyword map with the value + if isinstance(dropout, (int, float)): + drop_name = Dropout.DROPOUT + drop_args = {"p": dropout} + else: + drop_name, drop_args = split_args(dropout) + + drop_type = Dropout[drop_name, dimensions] if is_transposed: conv = conv_type(in_channels, out_channels, kernel_size, strides, padding, strides - 1, 1, bias, dilation) @@ -39,17 +58,16 @@ def __init__(self, dimensions, in_channels, out_channels, strides=1, kernel_size self.add_module("conv", conv) if not conv_only: - self.add_module("norm", normalize_type(out_channels)) - if dropout > 0: # omitting Dropout2d appears faster than relying on it short-circuiting when dropout==0 - self.add_module("dropout", drop_type(dropout)) + self.add_module("norm", norm_type(out_channels, **norm_args)) + if dropout: + self.add_module("dropout", drop_type(**drop_args)) - self.add_module("prelu", nn.modules.PReLU()) + self.add_module("act", act_type(**act_args)) class ResidualUnit(nn.Module): - - def __init__(self, dimensions, in_channels, out_channels, strides=1, kernel_size=3, subunits=2, instance_norm=True, - dropout=0, dilation=1, bias=True, last_conv_only=False): + def __init__(self, dimensions, in_channels, out_channels, strides=1, kernel_size=3, subunits=2, + act=Act.PRELU, norm=Norm.INSTANCE, dropout=None, dilation=1, bias=True, last_conv_only=False): super().__init__() self.dimensions = dimensions self.in_channels = in_channels @@ -64,10 +82,13 @@ def __init__(self, dimensions, in_channels, out_channels, strides=1, kernel_size for su in range(subunits): conv_only = last_conv_only and su == (subunits - 1) - unit = Convolution(dimensions, schannels, out_channels, sstrides, kernel_size, instance_norm, dropout, - dilation, bias, conv_only) + unit = Convolution(dimensions, schannels, out_channels, sstrides, + kernel_size, act, norm, dropout, dilation, bias, conv_only) + self.conv.add_module("unit%i" % su, unit) - schannels = out_channels # after first loop set channels and strides to what they should be for subsequent units + + # after first loop set channels and strides to what they should be for subsequent units + schannels = out_channels sstrides = 1 # apply convolution to input to change number of output channels and size to match that coming from self.conv @@ -79,7 +100,7 @@ def __init__(self, dimensions, in_channels, out_channels, strides=1, kernel_size rkernel_size = 1 rpadding = 0 - conv_type = get_conv_type(dimensions, False) + conv_type = Conv[Conv.CONV, dimensions] self.residual = conv_type(in_channels, out_channels, rkernel_size, strides, rpadding, bias=bias) def forward(self, x): diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 139de92655..05213ab68b 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -10,46 +10,217 @@ # limitations under the License. """ -handles spatial 1D, 2D, 3D network components with a factory pattern. +Defines factories for creating layers in generic, extensible, and dimensionally independent ways. A separate factory +object is created for each type of layer, and factory functions keyed to names are added to these objects. Whenever +a layer is requested the factory name and any necessary arguments are passed to the factory object. The return value +is typically a type but can be any callable producing a layer object. + +The factory objects contain functions keyed to names converted to upper case, these names can be referred to as members +of the factory so that they can function as constant identifiers. eg. instance normalisation is named `Norm.INSTANCE`. + +For example, to get a transpose convolution layer the name is needed and then a dimension argument is provided which is +passed to the factory function: + +.. code-block:: python + + dimension = 3 + name = Conv.CONVTRANS + conv = Conv[name, dimension] + +This allows the `dimension` value to be set in the constructor, for example so that the dimensionality of a network is +parameterizable. Not all factories require arguments after the name, the caller must be aware which are required. + +Defining new factories involves creating the object then associating it with factory functions: + +.. code-block:: python + + fact = LayerFactory() + + @fact.factory_function('test') + def make_something(x, y): + # do something with x and y to choose which layer type to return + return SomeLayerType + ... + + # request object from factory TEST with 1 and 2 as values for x and y + layer = fact[fact.TEST, 1, 2] + +Typically the caller of a factory would know what arguments to pass (ie. the dimensionality of the requested type) but +can be parameterized with the factory name and the arguments to pass to the created type at instantiation time: + +.. code-block:: python + + def use_factory(fact_args): + fact_name, type_args = split_args + layer_type = fact[fact_name, 1, 2] + return layer_type(**type_args) + ... + + kw_args = {'arg0':0, 'arg1':True} + layer = use_factory( (fact.TEST, kwargs) ) """ -from torch import nn as nn +from typing import Callable + +import torch.nn as nn + + +class LayerFactory: + """ + Factory object for creating layers, this uses given factory functions to actually produce the types or constructing + callables. These functions are referred to by name and can be added at any time. + """ + + def __init__(self): + self.factories = {} + + @property + def names(self): + """ + Produces all factory names. + """ + + return tuple(self.factories) + + def add_factory_callable(self, name, func): + """ + Add the factory function to this object under the given name. + """ + + self.factories[name.upper()] = func + + def factory_function(self, name): + """ + Decorator for adding a factory function with the given name. + """ + + def _add(func): + self.add_factory_callable(name, func) + return func + + return _add + + def get_constructor(self, factory_name, *args): + """ + Get the constructor for the given factory name and arguments. + """ + + if not isinstance(factory_name, str): + raise ValueError("Factories must be selected by name") + + fact = self.factories[factory_name.upper()] + return fact(*args) + + def __getitem__(self, args): + """ + Get the given name or name/arguments pair. If `args` is a callable it is assumed to be the constructor + itself and is returned, otherwise it should be the factory name or a pair containing the name and arguments. + """ + + # `args[0]` is actually a type or constructor + if callable(args): + return args + # `args` is a factory name or a name with arguments + if isinstance(args, str): + name_obj, args = args, () + else: + name_obj, *args = args -def get_conv_type(dim, is_transpose): - if is_transpose: - types = [nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d] + return self.get_constructor(name_obj, *args) + + def __getattr__(self, key): + """ + If `key` is a factory name, return it, otherwise behave as inherited. This allows referring to factory names + as if they were constants, eg. `Fact.FOO` for a factory Fact with factory function foo. + """ + + if key in self.factories: + return key + + return super().__getattr__(key) + + +def split_args(args): + """ + Split arguments in a way to be suitable for using with the factory types. If `args` is a name it's interpreted + """ + + if isinstance(args, str): + return args, {} else: - types = [nn.Conv1d, nn.Conv2d, nn.Conv3d] + name_obj, args = args - return types[dim - 1] + if not isinstance(name_obj, (str, Callable)) or not isinstance(args, dict): + msg = "Layer specifiers must be single strings or pairs of the form (name/object-types, argument dict)" + raise ValueError(msg) + + return name_obj, args -def get_dropout_type(dim): +# Define factories for these layer types + +Dropout = LayerFactory() +Norm = LayerFactory() +Act = LayerFactory() +Conv = LayerFactory() +Pool = LayerFactory() + + +@Dropout.factory_function("dropout") +def dropout_factory(dim): types = [nn.Dropout, nn.Dropout2d, nn.Dropout3d] return types[dim - 1] -def get_normalize_type(dim, is_instance): - if is_instance: - types = [nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d] - else: - types = [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d] +@Norm.factory_function("instance") +def instance_factory(dim): + types = [nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d] + return types[dim - 1] + +@Norm.factory_function("batch") +def batch_factory(dim): + types = [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d] return types[dim - 1] -def get_maxpooling_type(dim, is_adaptive): - if is_adaptive: - types = [nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d] - else: - types = [nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d] +Act.add_factory_callable("relu", lambda: nn.modules.ReLU) +Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU) +Act.add_factory_callable("prelu", lambda: nn.modules.PReLU) + + +@Conv.factory_function("conv") +def conv_factory(dim): + types = [nn.Conv1d, nn.Conv2d, nn.Conv3d] return types[dim - 1] -def get_avgpooling_type(dim, is_adaptive): - if is_adaptive: - types = [nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d] - else: - types = [nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d] +@Conv.factory_function("convtrans") +def convtrans_factory(dim): + types = [nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d] + return types[dim - 1] + + +@Pool.factory_function("max") +def maxpooling_factory(dim): + types = [nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d] + return types[dim - 1] + + +@Pool.factory_function("adaptivemax") +def adaptive_maxpooling_factory(dim): + types = [nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d] + return types[dim - 1] + + +@Pool.factory_function("avg") +def avgpooling_factory(dim): + types = [nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d] + return types[dim - 1] + + +@Pool.factory_function("adaptiveavg") +def adaptive_avgpooling_factory(dim): + types = [nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d] return types[dim - 1] diff --git a/monai/networks/nets/densenet3d.py b/monai/networks/nets/densenet3d.py index 78fab167c4..cbc90d209b 100644 --- a/monai/networks/nets/densenet3d.py +++ b/monai/networks/nets/densenet3d.py @@ -14,8 +14,7 @@ import torch import torch.nn as nn -from monai.networks.layers.factories import (get_avgpooling_type, get_conv_type, get_dropout_type, get_maxpooling_type, - get_normalize_type) +from monai.networks.layers.factories import Conv, Dropout, Pool, Norm def densenet121(**kwargs): @@ -44,17 +43,20 @@ def __init__(self, spatial_dims, in_channels, growth_rate, bn_size, dropout_prob super(_DenseLayer, self).__init__() out_channels = bn_size * growth_rate - conv_type = get_conv_type(spatial_dims, is_transpose=False) - self.add_module('norm1', get_normalize_type(spatial_dims, is_instance=False)(in_channels)) + conv_type = Conv[Conv.CONV, spatial_dims] + norm_type = Norm[Norm.BATCH, spatial_dims] + dropout_type = Dropout[Dropout.DROPOUT, spatial_dims] + + self.add_module('norm1', norm_type(in_channels)) self.add_module('relu1', nn.ReLU(inplace=True)) self.add_module('conv1', conv_type(in_channels, out_channels, kernel_size=1, bias=False)) - self.add_module('norm2', get_normalize_type(spatial_dims, is_instance=False)(out_channels)) + self.add_module('norm2', norm_type(out_channels)) self.add_module('relu2', nn.ReLU(inplace=True)) self.add_module('conv2', conv_type(out_channels, growth_rate, kernel_size=3, padding=1, bias=False)) if dropout_prob > 0: - self.add_module('dropout', get_dropout_type(spatial_dims)(dropout_prob)) + self.add_module('dropout', dropout_type(dropout_prob)) def forward(self, x): new_features = super(_DenseLayer, self).forward(x) @@ -75,12 +77,15 @@ class _Transition(nn.Sequential): def __init__(self, spatial_dims, in_channels, out_channels): super(_Transition, self).__init__() - conv_type = get_conv_type(spatial_dims, is_transpose=False) - self.add_module('norm', get_normalize_type(spatial_dims, is_instance=False)(in_channels)) + conv_type = Conv[Conv.CONV, spatial_dims] + norm_type = Norm[Norm.BATCH, spatial_dims] + pool_type = Pool[Pool.AVG, spatial_dims] + + self.add_module('norm', norm_type(in_channels)) self.add_module('relu', nn.ReLU(inplace=True)) self.add_module('conv', conv_type(in_channels, out_channels, kernel_size=1, bias=False)) - self.add_module('pool', get_avgpooling_type(spatial_dims, is_adaptive=False)(kernel_size=2, stride=2)) + self.add_module('pool', pool_type(kernel_size=2, stride=2)) class DenseNet(nn.Module): @@ -112,15 +117,18 @@ def __init__(self, dropout_prob=0): super(DenseNet, self).__init__() - conv_type = get_conv_type(spatial_dims, is_transpose=False) - norm_type = get_normalize_type(spatial_dims, is_instance=False) + + conv_type = Conv[Conv.CONV, spatial_dims] + norm_type = Norm[Norm.BATCH, spatial_dims] + pool_type = Pool[Pool.MAX, spatial_dims] + avg_pool_type = Pool[Pool.ADAPTIVEAVG, spatial_dims] self.features = nn.Sequential( OrderedDict([ ('conv0', conv_type(in_channels, init_features, kernel_size=7, stride=2, padding=3, bias=False)), ('norm0', norm_type(init_features)), ('relu0', nn.ReLU(inplace=True)), - ('pool0', get_maxpooling_type(spatial_dims, is_adaptive=False)(kernel_size=3, stride=2, padding=1)), + ('pool0', pool_type(kernel_size=3, stride=2, padding=1)), ])) in_channels = init_features @@ -145,7 +153,7 @@ def __init__(self, self.class_layers = nn.Sequential( OrderedDict([ ('relu', nn.ReLU(inplace=True)), - ('norm', get_avgpooling_type(spatial_dims, is_adaptive=True)(1)), + ('norm', avg_pool_type(1)), ('flatten', nn.Flatten(1)), ('class', nn.Linear(in_channels, out_channels)), ])) diff --git a/monai/networks/nets/highresnet.py b/monai/networks/nets/highresnet.py index 4486f5b585..711ab815af 100644 --- a/monai/networks/nets/highresnet.py +++ b/monai/networks/nets/highresnet.py @@ -13,11 +13,11 @@ import torch.nn.functional as F from monai.networks.layers.convutils import same_padding -from monai.networks.layers.factories import (get_conv_type, get_dropout_type, get_normalize_type) +from monai.networks.layers.factories import Conv, Dropout, Norm SUPPORTED_NORM = { - 'batch': lambda spatial_dims: get_normalize_type(spatial_dims, is_instance=False), - 'instance': lambda spatial_dims: get_normalize_type(spatial_dims, is_instance=True), + 'batch': lambda spatial_dims: Norm[Norm.BATCH, spatial_dims], + 'instance': lambda spatial_dims: Norm[Norm.INSTANCE, spatial_dims], } SUPPORTED_ACTI = {'relu': nn.ReLU, 'prelu': nn.PReLU, 'relu6': nn.ReLU6} DEFAULT_LAYER_PARAMS_3D = ( @@ -48,7 +48,7 @@ def __init__(self, layers = nn.ModuleList() - conv_type = get_conv_type(spatial_dims, is_transpose=False) + conv_type = Conv[Conv.CONV, spatial_dims] padding_size = same_padding(kernel_size) conv = conv_type(in_channels, out_channels, kernel_size, padding=padding_size) layers.append(conv) @@ -58,7 +58,7 @@ def __init__(self, if acti_type is not None: layers.append(SUPPORTED_ACTI[acti_type](inplace=True)) if dropout_prob is not None: - dropout_type = get_dropout_type(spatial_dims) + dropout_type = Dropout[Dropout.DROPOUT, spatial_dims] layers.append(dropout_type(p=dropout_prob)) self.layers = nn.Sequential(*layers) @@ -84,7 +84,7 @@ def __init__(self, with either zero padding ('pad') or a trainable conv with kernel size 1 ('project'). """ super(HighResBlock, self).__init__() - conv_type = get_conv_type(spatial_dims, is_transpose=False) + conv_type = Conv[Conv.CONV, spatial_dims] self.project, self.pad = None, None if in_channels != out_channels: diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index ad9b3ddbf4..6789f958f5 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -12,6 +12,7 @@ import torch.nn as nn from monai.networks.blocks.convolutions import Convolution, ResidualUnit +from monai.networks.layers.factories import Norm, Act from monai.networks.layers.simplelayers import SkipConnection from monai.utils import export from monai.utils.aliases import alias @@ -22,7 +23,7 @@ class UNet(nn.Module): def __init__(self, dimensions, in_channels, out_channels, channels, strides, kernel_size=3, up_kernel_size=3, - num_res_units=0, instance_norm=True, dropout=0): + num_res_units=0, act=Act.PRELU, norm=Norm.INSTANCE, dropout=0): super().__init__() assert len(channels) == (len(strides) + 1) self.dimensions = dimensions @@ -33,7 +34,8 @@ def __init__(self, dimensions, in_channels, out_channels, channels, strides, ker self.kernel_size = kernel_size self.up_kernel_size = up_kernel_size self.num_res_units = num_res_units - self.instance_norm = instance_norm + self.act = act + self.norm = norm self.dropout = dropout def _create_block(inc, outc, channels, strides, is_top): @@ -62,35 +64,21 @@ def _create_block(inc, outc, channels, strides, is_top): def _get_down_layer(self, in_channels, out_channels, strides, is_top): if self.num_res_units > 0: return ResidualUnit(self.dimensions, in_channels, out_channels, strides, self.kernel_size, self.num_res_units, - self.instance_norm, self.dropout) + self.act, self.norm, self.dropout) else: - return Convolution(self.dimensions, in_channels, out_channels, strides, self.kernel_size, self.instance_norm, + return Convolution(self.dimensions, in_channels, out_channels, strides, self.kernel_size, self.act, self.norm, self.dropout) def _get_bottom_layer(self, in_channels, out_channels): return self._get_down_layer(in_channels, out_channels, 1, False) def _get_up_layer(self, in_channels, out_channels, strides, is_top): - conv = Convolution(self.dimensions, - in_channels, - out_channels, - strides, - self.up_kernel_size, - self.instance_norm, - self.dropout, - conv_only=is_top and self.num_res_units == 0, - is_transposed=True) + conv = Convolution(self.dimensions, in_channels, out_channels, strides, self.up_kernel_size, self.act, self.norm, + self.dropout, conv_only=is_top and self.num_res_units == 0, is_transposed=True) if self.num_res_units > 0: - ru = ResidualUnit(self.dimensions, - out_channels, - out_channels, - 1, - self.kernel_size, - 1, - self.instance_norm, - self.dropout, - last_conv_only=is_top) + ru = ResidualUnit(self.dimensions, out_channels, out_channels, 1, self.kernel_size, 1, self.act, self.norm, + self.dropout, last_conv_only=is_top) return nn.Sequential(conv, ru) else: return conv diff --git a/tests/test_unet.py b/tests/test_unet.py index 5b8e85f915..be661a4dbb 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -14,8 +14,23 @@ import torch from parameterized import parameterized +from monai.networks.layers.factories import Norm, Act from monai.networks.nets.unet import UNet + +TEST_CASE_0 = [ # single channel 2D, batch 16, no residual + { + 'dimensions': 2, + 'in_channels': 1, + 'out_channels': 3, + 'channels': (16, 32, 64), + 'strides': (2, 2), + 'num_res_units': 0, + }, + torch.randn(16, 1, 32, 32), + (16, 3, 32, 32), +] + TEST_CASE_1 = [ # single channel 2D, batch 16 { 'dimensions': 2, @@ -55,10 +70,54 @@ (16, 3, 32, 64, 48), ] +TEST_CASE_4 = [ # 4-channel 3D, batch 16, batch normalisation + { + 'dimensions': 3, + 'in_channels': 4, + 'out_channels': 3, + 'channels': (16, 32, 64), + 'strides': (2, 2), + 'num_res_units': 1, + 'norm': Norm.BATCH, + }, + torch.randn(16, 4, 32, 64, 48), + (16, 3, 32, 64, 48), +] + +TEST_CASE_5 = [ # 4-channel 3D, batch 16, LeakyReLU activation + { + 'dimensions': 3, + 'in_channels': 4, + 'out_channels': 3, + 'channels': (16, 32, 64), + 'strides': (2, 2), + 'num_res_units': 1, + 'act': (Act.LEAKYRELU, {'negative_slope': 0.2}), + }, + torch.randn(16, 4, 32, 64, 48), + (16, 3, 32, 64, 48), +] + +TEST_CASE_6 = [ # 4-channel 3D, batch 16, LeakyReLU activation explicit + { + 'dimensions': 3, + 'in_channels': 4, + 'out_channels': 3, + 'channels': (16, 32, 64), + 'strides': (2, 2), + 'num_res_units': 1, + 'act': (torch.nn.LeakyReLU, {'negative_slope': 0.2}), + }, + torch.randn(16, 4, 32, 64, 48), + (16, 3, 32, 64, 48), +] + +CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6] + class TestUNET(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(CASES) def test_shape(self, input_param, input_data, expected_shape): net = UNet(**input_param) net.eval()