Skip to content

Commit

Permalink
81 layer factory (Project-MONAI#127)
Browse files Browse the repository at this point in the history
* Adding new factory concept.
* Update to networks
* [DLMED] add introduction of layer factory to highlights (Project-MONAI#217)

* adds factories doc to docs/ (Project-MONAI#218)

Co-authored-by: Nic Ma <[email protected]>
Co-authored-by: Wenqi Li <[email protected]>
  • Loading branch information
3 people authored Mar 27, 2020
1 parent 5ae87f8 commit e657db0
Show file tree
Hide file tree
Showing 8 changed files with 347 additions and 103 deletions.
11 changes: 10 additions & 1 deletion docs/source/highlights.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,16 @@ torch.backends.cudnn.benchmark = False
There are domain-specific loss functions in the medical research area which are different from the generic computer vision ones. As an important module of MONAI, these loss functions are implemented in PyTorch, such as Dice loss and generalized Dice loss.

## Network architectures
Some deep neural network architectures have shown to be particularly effective for medical imaging analysis tasks. MONAI implements reference networks with the aims of both flexibility and code readability.
Some deep neural network architectures have shown to be particularly effective for medical imaging analysis tasks. MONAI implements reference networks with the aims of both flexibility and code readability.
In order to leverage the common network layers and blocks, MONAI provides several predefined layers and blocks which are compatible with 1D, 2D and 3D networks. Users can easily integrate the layer factories in their own networks.
For example:
```py
# add Convolution layer to the network which is compatible with different spatial dimensions.
dimension = 3
name = Conv.CONVTRANS
conv_type = Conv[name, dimension]
add_module('conv1', conv_type(in_channels, out_channels, kernel_size=1, bias=False))
```

## Evaluation
To run model inferences and evaluate the model quality, MONAI provides reference implementations for the relevant widely-used approaches. Currently, several popular evaluation metrics and inference patterns are included:
Expand Down
28 changes: 8 additions & 20 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,14 @@ Blocks
Layers
------

`get_conv_type`
~~~~~~~~~~~~~~~
.. automethod:: monai.networks.layers.factories.get_conv_type

`get_dropout_type`
~~~~~~~~~~~~~~~~~~
.. automethod:: monai.networks.layers.factories.get_dropout_type

`get_normalize_type`
~~~~~~~~~~~~~~~~~~~~
.. automethod:: monai.networks.layers.factories.get_normalize_type

`get_maxpooling_type`
~~~~~~~~~~~~~~~~~~~~~
.. automethod:: monai.networks.layers.factories.get_maxpooling_type

`get_avgpooling_type`
~~~~~~~~~~~~~~~~~~~~~
.. automethod:: monai.networks.layers.factories.get_avgpooling_type

`Factories`
~~~~~~~~~~~
.. automodule:: monai.networks.layers.factories
.. currentmodule:: monai.networks.layers.factories

`LayerFactory`
##############
.. autoclass:: LayerFactory

.. automodule:: monai.networks.layers.simplelayers
.. currentmodule:: monai.networks.layers.simplelayers
Expand Down
55 changes: 38 additions & 17 deletions monai/networks/blocks/convolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,43 @@
import numpy as np
import torch.nn as nn

from monai.networks.layers.factories import get_conv_type, get_dropout_type, get_normalize_type
from monai.networks.layers.factories import Dropout, Norm, Act, Conv, split_args
from monai.networks.layers.convutils import same_padding


class Convolution(nn.Sequential):
"""
Constructs a convolution with optional dropout, normalization, and activation layers.
"""

def __init__(self, dimensions, in_channels, out_channels, strides=1, kernel_size=3, instance_norm=True, dropout=0,
dilation=1, bias=True, conv_only=False, is_transposed=False):
def __init__(self, dimensions, in_channels, out_channels, strides=1, kernel_size=3, act=Act.PRELU,
norm=Norm.INSTANCE, dropout=None, dilation=1, bias=True, conv_only=False, is_transposed=False):
super().__init__()
self.dimensions = dimensions
self.in_channels = in_channels
self.out_channels = out_channels
self.is_transposed = is_transposed

padding = same_padding(kernel_size, dilation)
normalize_type = get_normalize_type(dimensions, instance_norm)
conv_type = get_conv_type(dimensions, is_transposed)
drop_type = get_dropout_type(dimensions)
conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, dimensions]

# define the normalisation type and the arguments to the constructor
norm_name, norm_args = split_args(norm)
norm_type = Norm[norm_name, dimensions]

# define the activation type and the arguments to the constructor
act_name, act_args = split_args(act)
act_type = Act[act_name]

if dropout:
# if dropout was specified simply as a p value, use default name and make a keyword map with the value
if isinstance(dropout, (int, float)):
drop_name = Dropout.DROPOUT
drop_args = {"p": dropout}
else:
drop_name, drop_args = split_args(dropout)

drop_type = Dropout[drop_name, dimensions]

if is_transposed:
conv = conv_type(in_channels, out_channels, kernel_size, strides, padding, strides - 1, 1, bias, dilation)
Expand All @@ -39,17 +58,16 @@ def __init__(self, dimensions, in_channels, out_channels, strides=1, kernel_size
self.add_module("conv", conv)

if not conv_only:
self.add_module("norm", normalize_type(out_channels))
if dropout > 0: # omitting Dropout2d appears faster than relying on it short-circuiting when dropout==0
self.add_module("dropout", drop_type(dropout))
self.add_module("norm", norm_type(out_channels, **norm_args))
if dropout:
self.add_module("dropout", drop_type(**drop_args))

self.add_module("prelu", nn.modules.PReLU())
self.add_module("act", act_type(**act_args))


class ResidualUnit(nn.Module):

def __init__(self, dimensions, in_channels, out_channels, strides=1, kernel_size=3, subunits=2, instance_norm=True,
dropout=0, dilation=1, bias=True, last_conv_only=False):
def __init__(self, dimensions, in_channels, out_channels, strides=1, kernel_size=3, subunits=2,
act=Act.PRELU, norm=Norm.INSTANCE, dropout=None, dilation=1, bias=True, last_conv_only=False):
super().__init__()
self.dimensions = dimensions
self.in_channels = in_channels
Expand All @@ -64,10 +82,13 @@ def __init__(self, dimensions, in_channels, out_channels, strides=1, kernel_size

for su in range(subunits):
conv_only = last_conv_only and su == (subunits - 1)
unit = Convolution(dimensions, schannels, out_channels, sstrides, kernel_size, instance_norm, dropout,
dilation, bias, conv_only)
unit = Convolution(dimensions, schannels, out_channels, sstrides,
kernel_size, act, norm, dropout, dilation, bias, conv_only)

self.conv.add_module("unit%i" % su, unit)
schannels = out_channels # after first loop set channels and strides to what they should be for subsequent units

# after first loop set channels and strides to what they should be for subsequent units
schannels = out_channels
sstrides = 1

# apply convolution to input to change number of output channels and size to match that coming from self.conv
Expand All @@ -79,7 +100,7 @@ def __init__(self, dimensions, in_channels, out_channels, strides=1, kernel_size
rkernel_size = 1
rpadding = 0

conv_type = get_conv_type(dimensions, False)
conv_type = Conv[Conv.CONV, dimensions]
self.residual = conv_type(in_channels, out_channels, rkernel_size, strides, rpadding, bias=bias)

def forward(self, x):
Expand Down
217 changes: 194 additions & 23 deletions monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,46 +10,217 @@
# limitations under the License.

"""
handles spatial 1D, 2D, 3D network components with a factory pattern.
Defines factories for creating layers in generic, extensible, and dimensionally independent ways. A separate factory
object is created for each type of layer, and factory functions keyed to names are added to these objects. Whenever
a layer is requested the factory name and any necessary arguments are passed to the factory object. The return value
is typically a type but can be any callable producing a layer object.
The factory objects contain functions keyed to names converted to upper case, these names can be referred to as members
of the factory so that they can function as constant identifiers. eg. instance normalisation is named `Norm.INSTANCE`.
For example, to get a transpose convolution layer the name is needed and then a dimension argument is provided which is
passed to the factory function:
.. code-block:: python
dimension = 3
name = Conv.CONVTRANS
conv = Conv[name, dimension]
This allows the `dimension` value to be set in the constructor, for example so that the dimensionality of a network is
parameterizable. Not all factories require arguments after the name, the caller must be aware which are required.
Defining new factories involves creating the object then associating it with factory functions:
.. code-block:: python
fact = LayerFactory()
@fact.factory_function('test')
def make_something(x, y):
# do something with x and y to choose which layer type to return
return SomeLayerType
...
# request object from factory TEST with 1 and 2 as values for x and y
layer = fact[fact.TEST, 1, 2]
Typically the caller of a factory would know what arguments to pass (ie. the dimensionality of the requested type) but
can be parameterized with the factory name and the arguments to pass to the created type at instantiation time:
.. code-block:: python
def use_factory(fact_args):
fact_name, type_args = split_args
layer_type = fact[fact_name, 1, 2]
return layer_type(**type_args)
...
kw_args = {'arg0':0, 'arg1':True}
layer = use_factory( (fact.TEST, kwargs) )
"""

from torch import nn as nn
from typing import Callable

import torch.nn as nn


class LayerFactory:
"""
Factory object for creating layers, this uses given factory functions to actually produce the types or constructing
callables. These functions are referred to by name and can be added at any time.
"""

def __init__(self):
self.factories = {}

@property
def names(self):
"""
Produces all factory names.
"""

return tuple(self.factories)

def add_factory_callable(self, name, func):
"""
Add the factory function to this object under the given name.
"""

self.factories[name.upper()] = func

def factory_function(self, name):
"""
Decorator for adding a factory function with the given name.
"""

def _add(func):
self.add_factory_callable(name, func)
return func

return _add

def get_constructor(self, factory_name, *args):
"""
Get the constructor for the given factory name and arguments.
"""

if not isinstance(factory_name, str):
raise ValueError("Factories must be selected by name")

fact = self.factories[factory_name.upper()]
return fact(*args)

def __getitem__(self, args):
"""
Get the given name or name/arguments pair. If `args` is a callable it is assumed to be the constructor
itself and is returned, otherwise it should be the factory name or a pair containing the name and arguments.
"""

# `args[0]` is actually a type or constructor
if callable(args):
return args

# `args` is a factory name or a name with arguments
if isinstance(args, str):
name_obj, args = args, ()
else:
name_obj, *args = args

def get_conv_type(dim, is_transpose):
if is_transpose:
types = [nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d]
return self.get_constructor(name_obj, *args)

def __getattr__(self, key):
"""
If `key` is a factory name, return it, otherwise behave as inherited. This allows referring to factory names
as if they were constants, eg. `Fact.FOO` for a factory Fact with factory function foo.
"""

if key in self.factories:
return key

return super().__getattr__(key)


def split_args(args):
"""
Split arguments in a way to be suitable for using with the factory types. If `args` is a name it's interpreted
"""

if isinstance(args, str):
return args, {}
else:
types = [nn.Conv1d, nn.Conv2d, nn.Conv3d]
name_obj, args = args

return types[dim - 1]
if not isinstance(name_obj, (str, Callable)) or not isinstance(args, dict):
msg = "Layer specifiers must be single strings or pairs of the form (name/object-types, argument dict)"
raise ValueError(msg)

return name_obj, args


def get_dropout_type(dim):
# Define factories for these layer types

Dropout = LayerFactory()
Norm = LayerFactory()
Act = LayerFactory()
Conv = LayerFactory()
Pool = LayerFactory()


@Dropout.factory_function("dropout")
def dropout_factory(dim):
types = [nn.Dropout, nn.Dropout2d, nn.Dropout3d]
return types[dim - 1]


def get_normalize_type(dim, is_instance):
if is_instance:
types = [nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d]
else:
types = [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]
@Norm.factory_function("instance")
def instance_factory(dim):
types = [nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d]
return types[dim - 1]


@Norm.factory_function("batch")
def batch_factory(dim):
types = [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]
return types[dim - 1]


def get_maxpooling_type(dim, is_adaptive):
if is_adaptive:
types = [nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d]
else:
types = [nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]
Act.add_factory_callable("relu", lambda: nn.modules.ReLU)
Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU)
Act.add_factory_callable("prelu", lambda: nn.modules.PReLU)


@Conv.factory_function("conv")
def conv_factory(dim):
types = [nn.Conv1d, nn.Conv2d, nn.Conv3d]
return types[dim - 1]


def get_avgpooling_type(dim, is_adaptive):
if is_adaptive:
types = [nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d]
else:
types = [nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d]
@Conv.factory_function("convtrans")
def convtrans_factory(dim):
types = [nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d]
return types[dim - 1]


@Pool.factory_function("max")
def maxpooling_factory(dim):
types = [nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]
return types[dim - 1]


@Pool.factory_function("adaptivemax")
def adaptive_maxpooling_factory(dim):
types = [nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d]
return types[dim - 1]


@Pool.factory_function("avg")
def avgpooling_factory(dim):
types = [nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d]
return types[dim - 1]


@Pool.factory_function("adaptiveavg")
def adaptive_avgpooling_factory(dim):
types = [nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d]
return types[dim - 1]
Loading

0 comments on commit e657db0

Please sign in to comment.