diff --git a/simulai/models/__init__.py b/simulai/models/__init__.py index 28399c78..d1605ba8 100644 --- a/simulai/models/__init__.py +++ b/simulai/models/__init__.py @@ -26,6 +26,7 @@ ImprovedDeepONet, ImprovedDenseNetwork, Transformer, + UNet, MetaModel, ModelMaker, MultiNetwork, diff --git a/simulai/models/_pytorch_models/__init__.py b/simulai/models/_pytorch_models/__init__.py index 9795c858..4ec59745 100644 --- a/simulai/models/_pytorch_models/__init__.py +++ b/simulai/models/_pytorch_models/__init__.py @@ -7,4 +7,5 @@ ) from ._deeponet import DeepONet, FlexibleDeepONet, ImprovedDeepONet, ResDeepONet from ._transformer import Transformer +from ._unet import UNet from ._miscellaneous import ImprovedDenseNetwork, MetaModel, ModelMaker, MultiNetwork, MoEPool, SplitPool diff --git a/simulai/models/_pytorch_models/_unet.py b/simulai/models/_pytorch_models/_unet.py new file mode 100644 index 00000000..d6244b4d --- /dev/null +++ b/simulai/models/_pytorch_models/_unet.py @@ -0,0 +1,184 @@ +import copy +import numpy as np +import torch +from typing import Union, List, Tuple, Optional + +from simulai.templates import NetworkTemplate, as_tensor, channels_dim +from simulai.regression import DenseNetwork, SLFNN, ConvolutionalNetwork + +# A CNN UNet encoder or decodeder is no more than a curved CNN +# in which intermediary outputs and inputs are also stored. +class CNNUnetEncoder(ConvolutionalNetwork): + name = "convunetencoder" + engine = "torch" + + def __init__( + self, + layers: list = None, + activations: list = None, + pre_layer: Optional[torch.nn.Module] = None, + case: str = "2d", + last_activation: str = "identity", + transpose: bool = False, + flatten: bool = False, + intermediary_outputs_indices: List[int] = None, + name: str = None, + ) -> None: + + super(CNNUnetEncoder, self).__init__(layers=layers, + activations=activations, + pre_layer=pre_layer, + case=case, + last_activation=last_activation, + transpose=transpose, + flatten=flatten, + name=name, + ) + + self.intermediary_outputs_indices = intermediary_outputs_indices + + self.pipeline = torch.nn.Sequential(*[layer_j for layer_j in self.list_of_layers + if not isinstance(layer_j, torch.nn.Identity)]) + @as_tensor + @channels_dim + def forward( + self, input_data: Union[torch.Tensor, np.ndarray] = None + ) -> torch.Tensor: + + intermediary_outputs = list() + + for j in self.intermediary_outputs_indices: + intermediary_outputs.append(self.pipeline[:j](input_data)) + + main_output = self.pipeline(input_data) + + return main_output, intermediary_outputs + +class CNNUnetDecoder(ConvolutionalNetwork): + name = "convunetdecoder" + engine = "torch" + + def __init__( + self, + layers: list = None, + activations: list = None, + pre_layer: Optional[torch.nn.Module] = None, + case: str = "2d", + last_activation: str = "identity", + transpose: bool = False, + flatten: bool = False, + intermediary_inputs_indices: List[int] = None, + name: str = None, + channels_last=False, + ) -> None: + + super(CNNUnetDecoder, self).__init__(layers=layers, + activations=activations, + pre_layer=pre_layer, + case=case, + last_activation=last_activation, + transpose=transpose, + flatten=flatten, + name=name, + ) + + self.intermediary_inputs_indices = intermediary_inputs_indices + + if channels_last: + self.concat_axis = -1 + else: + self.concat_axis = 1 + + self.list_of_layers = [layer_j for layer_j in self.list_of_layers + if not isinstance(layer_j, torch.nn.Identity)] + self.pipeline = torch.nn.Sequential(*self.list_of_layers) + + #@as_tensor + #@channels_dim + def forward( + self, input_data: Union[torch.Tensor, np.ndarray] = None, + intermediary_encoder_outputs:List[torch.Tensor] = None, + ) -> torch.Tensor: + + current_input = input_data + intermediary_encoder_outputs = intermediary_encoder_outputs[::-1] + + for j, layer_j in enumerate(self.list_of_layers): + + if j in self.intermediary_inputs_indices: + i = self.intermediary_inputs_indices.index(j) + + input_j = torch.cat([current_input, intermediary_encoder_outputs[i]], dim=self.concat_axis) + else: + input_j = current_input + + output_j = layer_j(input_j) + current_input = output_j + + return current_input + +class UNet(NetworkTemplate): + + + def __init__(self, layers_config:dict=None, + intermediary_outputs_indices:List=None, + intermediary_inputs_indices:List=None, + encoder_extra_args:dict=dict(), + decoder_extra_args:dict=dict()) -> None: + + super(UNet, self).__init__() + + self.layers_config = layers_config + self.intermediary_outputs_indices = intermediary_outputs_indices + self.intermediary_inputs_indices = intermediary_inputs_indices + + self.layers_config_encoder = self.layers_config["encoder"] + self.layers_config_decoder = self.layers_config["decoder"] + + self.encoder_activations = self.layers_config["encoder_activations"] + self.decoder_activations = self.layers_config["decoder_activations"] + + self.encoder_horizontal_outputs = dict() + + # Configuring the encoder + encoder_type = self.layers_config_encoder.get("type") + layers_config_encoder = self.layers_config_encoder.get("architecture") + + if encoder_type == "cnn": + self.encoder = CNNUnetEncoder(layers=self.layers_config_encoder["architecture"], + activations=self.encoder_activations, + intermediary_outputs_indices=self.intermediary_outputs_indices, + case="2d", name="encoder", + **encoder_extra_args) + else: + raise Exception(f"Option {encoder_type} is not available.") + + # Configuring the decoder + decoder_type = self.layers_config_decoder.get("type") + layers_config_encoder = self.layers_config_encoder.get("architecture") + + if encoder_type == "cnn": + self.decoder = CNNUnetDecoder(layers=self.layers_config_decoder["architecture"], + activations=self.decoder_activations, + intermediary_inputs_indices=self.intermediary_inputs_indices, + case="2d", name="decoder", + **decoder_extra_args) + else: + raise Exception(f"Option {encoder_type} is not available.") + + self.add_module("encoder", self.encoder) + self.add_module("decoder", self.decoder) + + @as_tensor + def forward(self, input_data: Union[torch.Tensor, np.ndarray] = None + ) -> torch.Tensor: + + encoder_main_output, encoder_intermediary_outputs = self.encoder(input_data=input_data) + output = self.decoder(input_data = encoder_main_output, + intermediary_encoder_outputs=encoder_intermediary_outputs) + + return output + + def summary(self): + + print(self) diff --git a/simulai/templates/_pytorch_network.py b/simulai/templates/_pytorch_network.py index f7868675..3c92600b 100644 --- a/simulai/templates/_pytorch_network.py +++ b/simulai/templates/_pytorch_network.py @@ -159,9 +159,9 @@ def _setup_activations( ): activations_list = list() for activation_name in activation: - activation_op = self._get_operation(operation=activation_name) + activation_op = self._get_operation(operation=activation_name, is_activation=True) - activations_list.append(activation_op()) + activations_list.append(activation_op)##activation_op()) return activations_list, activation diff --git a/tests/network/test_unet.py b/tests/network/test_unet.py new file mode 100644 index 00000000..df5955f8 --- /dev/null +++ b/tests/network/test_unet.py @@ -0,0 +1,307 @@ +# (C) Copyright IBM Corp. 2019, 2020, 2021, 2022. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest import TestCase + +import numpy as np +from tests.config import configure_dtype +torch = configure_dtype() + +from utils import configure_device + +from simulai import ARRAY_DTYPE +from simulai.file import SPFile +from simulai.optimization import Optimizer + +DEVICE = configure_device() + + +def generate_data( + n_samples: int = None, + image_size: tuple = None, + n_inputs: int = None, + n_outputs: int = None, +) -> (torch.Tensor, torch.Tensor): + + input_data = np.random.rand(n_samples, n_inputs, *image_size) + output_data = np.random.rand(n_samples, n_outputs, *image_size) + + return torch.from_numpy(input_data.astype(ARRAY_DTYPE)), torch.from_numpy( + output_data.astype(ARRAY_DTYPE) + ) + +# Model template +def model_2d(): + from simulai.models import UNet + + # Configuring model + n_inputs = 3 + n_outputs = 1 + n_ch_0 = 2 + + layers = { + "encoder": { + "type": "cnn", + "architecture" :[ + + { + "in_channels": n_inputs, + "out_channels": n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + }, + + { + "in_channels": n_ch_0, + "out_channels": n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + "after_conv": {"type": "maxpool2d", "kernel_size": 2, "stride": 2}, + }, + + { + "in_channels": n_ch_0, + "out_channels": 2*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + }, + + { + "in_channels": 2*n_ch_0, + "out_channels": 2*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + "after_conv": {"type": "maxpool2d", "kernel_size": 2, "stride": 2}, + }, + + { + "in_channels": 2*n_ch_0, + "out_channels": 4*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + }, + + { + "in_channels": 4*n_ch_0, + "out_channels": 4*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + "after_conv": {"type": "maxpool2d", "kernel_size": 2, "stride": 2}, + }, + + { + "in_channels": 4*n_ch_0, + "out_channels": 8*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + }, + + { + "in_channels": 8*n_ch_0, + "out_channels": 8*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + "after_conv": {"type": "maxpool2d", "kernel_size": 2, "stride": 2}, + }, + + { + "in_channels": 8*n_ch_0, + "out_channels": 16*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + }, + + ] + }, + "decoder": { + "type": "cnn", + "architecture" :[ + + { + "in_channels": 16*n_ch_0, + "out_channels": 16*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + + }, + + { + "in_channels": 16*n_ch_0, + "out_channels": 8*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + "before_conv": {"type": "upsample", "scale_factor": 2, "mode": "bicubic"}, + }, + + { + "in_channels": 16*n_ch_0, + "out_channels": 8*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + + }, + + { + "in_channels": 8*n_ch_0, + "out_channels": 8*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + + }, + + { + "in_channels": 8*n_ch_0, + "out_channels": 4*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + "before_conv": {"type": "upsample", "scale_factor": 2, "mode": "bicubic"}, + }, + + { + "in_channels": 8*n_ch_0, + "out_channels": 4*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + + }, + + { + "in_channels": 4*n_ch_0, + "out_channels": 4*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + + }, + + { + "in_channels": 4*n_ch_0, + "out_channels": 2*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + "before_conv": {"type": "upsample", "scale_factor": 2, "mode": "bicubic"}, + }, + + { + "in_channels": 4*n_ch_0, + "out_channels": 2*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + }, + + { + "in_channels": 2*n_ch_0, + "out_channels": 2*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + }, + + { + "in_channels": 2*n_ch_0, + "out_channels": 1*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + "before_conv": {"type": "upsample", "scale_factor": 2, "mode": "bicubic"}, + }, + + { + "in_channels": 2*n_ch_0, + "out_channels": 1*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + }, + + { + "in_channels": 1*n_ch_0, + "out_channels": 1*n_ch_0, + "kernel_size": 3, + "stride": 1, + "padding": 1, + }, + + { + "in_channels": 1*n_ch_0, + "out_channels": n_outputs, + "kernel_size": 3, + "stride": 1, + "padding": 1, + }, + + ] + }, + "encoder_activations": ["relu", "relu", "relu", "relu", "relu", + "relu", "relu", "relu", "relu", "relu", + "relu" + ], + + "decoder_activations": ["relu", "identity", "relu", "relu", "identity", + "relu", "relu", "identity", "relu", "relu", + "identity", "relu", "relu", "identity", + ], + } + + unet = UNet(layers_config=layers, + intermediary_outputs_indices=[4, 9, 14, 19], + intermediary_inputs_indices=[4, 10, 16, 22], + ) + + return unet + +class TestConvNet2D(TestCase): + def setUp(self) -> None: + pass + + def test_convnet_2d_n_parameters(self): + convnet = model_2d() + + assert type(convnet.n_parameters) == int + + def test_convnet_2d_eval(self): + input_data, output_data = generate_data( + n_samples=100, image_size=(16, 16), n_inputs=3, n_outputs=1 + ) + + unet = model_2d() + unet.summary() + + estimated_output_data = unet.eval(input_data=input_data) + + assert estimated_output_data.shape == output_data.shape, ( + "The output of eval is not correct." + f" Expected {output_data.shape}," + f" but received {estimated_output_data.shape}." + ) + +