From b2caa377b27be1b49f13696a351a07d8c08fce39 Mon Sep 17 00:00:00 2001 From: ChairC <974833488@qq.com> Date: Tue, 5 Dec 2023 23:24:16 +0800 Subject: [PATCH] Update: Reconstruct modules.py, and update package path. --- README.md | 6 + README_zh.md | 6 + model/modules/activation.py | 36 ++++ model/modules/attention.py | 53 ++++++ model/modules/block.py | 131 ++++++++++++++ model/modules/conv.py | 98 ++++++++++ model/modules/ema.py | 69 +++++++ model/modules/module.py | 331 +--------------------------------- model/networks/cspdarkunet.py | 4 +- model/networks/unet.py | 4 +- tools/train.py | 2 +- 11 files changed, 408 insertions(+), 332 deletions(-) create mode 100644 model/modules/activation.py create mode 100644 model/modules/attention.py create mode 100644 model/modules/block.py create mode 100644 model/modules/conv.py create mode 100644 model/modules/ema.py diff --git a/README.md b/README.md index c988304..81284a6 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ We named this project IDDM: Industrial Defect Diffusion Model. It aims to reprod **Repository Structure** ```yaml +Industrial Defect Diffusion Model ├── datasets │ └── dataset_demo │ ├── class_1 @@ -18,6 +19,11 @@ We named this project IDDM: Industrial Defect Diffusion Model. It aims to reprod │ └── class_3 ├── model │ ├── modules +│ │ ├── activation.py +│ │ ├── attention.py +│ │ ├── block.py +│ │ ├── conv.py +│ │ ├── ema.py │ │ └── module.py │ ├── networks │ │ ├── base.py diff --git a/README_zh.md b/README_zh.md index c96da69..c479ea3 100644 --- a/README_zh.md +++ b/README_zh.md @@ -11,6 +11,7 @@ **本仓库整体结构** ```yaml +Industrial Defect Diffusion Model ├── datasets │ └── dataset_demo │ ├── class_1 @@ -18,6 +19,11 @@ │ └── class_3 ├── model │ ├── modules +│ │ ├── activation.py +│ │ ├── attention.py +│ │ ├── block.py +│ │ ├── conv.py +│ │ ├── ema.py │ │ └── module.py │ ├── networks │ │ ├── base.py diff --git a/model/modules/activation.py b/model/modules/activation.py new file mode 100644 index 0000000..f844394 --- /dev/null +++ b/model/modules/activation.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" + @Date : 2023/12/5 10:19 + @Author : chairc + @Site : https://github.com/chairc +""" +import logging +import coloredlogs +import torch.nn as nn + +logger = logging.getLogger(__name__) +coloredlogs.install(level="INFO") + + +def get_activation_function(name="silu", inplace=False): + """ + Get activation function + :param name: Activation function name + :param inplace: can optionally do the operation in-place + :return Activation function + """ + if name == "relu": + act = nn.ReLU(inplace=inplace) + elif name == "relu6": + act = nn.ReLU6(inplace=inplace) + elif name == "silu": + act = nn.SiLU(inplace=inplace) + elif name == "lrelu": + act = nn.LeakyReLU(0.1, inplace=inplace) + elif name == "gelu": + act = nn.GELU() + else: + logger.warning(msg=f"Unsupported activation function type: {name}") + act = nn.SiLU(inplace=inplace) + return act diff --git a/model/modules/attention.py b/model/modules/attention.py new file mode 100644 index 0000000..9d4e230 --- /dev/null +++ b/model/modules/attention.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" + @Date : 2023/12/5 10:19 + @Author : chairc + @Site : https://github.com/chairc +""" +import torch.nn as nn +from model.modules.activation import get_activation_function + + +class SelfAttention(nn.Module): + """ + SelfAttention block + """ + + def __init__(self, channels, size, act="silu"): + """ + Initialize the self-attention block + :param channels: Channels + :param size: Size + :param act: Activation function + """ + super(SelfAttention, self).__init__() + self.channels = channels + self.size = size + # batch_first is not supported in pytorch 1.8. + # If you want to support upgrading to 1.9 and above, or use the following code to transpose + self.mha = nn.MultiheadAttention(embed_dim=channels, num_heads=4, batch_first=True) + self.ln = nn.LayerNorm(normalized_shape=[channels]) + self.ff_self = nn.Sequential( + nn.LayerNorm(normalized_shape=[channels]), + nn.Linear(in_features=channels, out_features=channels), + get_activation_function(name=act), + nn.Linear(in_features=channels, out_features=channels), + ) + + def forward(self, x): + """ + SelfAttention forward + :param x: Input + :return: attention_value + """ + # First perform the shape transformation, and then use 'swapaxes' to exchange the first + # second dimensions of the new tensor + x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2) + x_ln = self.ln(x) + # batch_first is not supported in pytorch 1.8. + # If you want to support upgrading to 1.9 and above, or use the following code to transpose + attention_value, _ = self.mha(x_ln, x_ln, x_ln) + attention_value = attention_value + x + attention_value = self.ff_self(attention_value) + attention_value + return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size) diff --git a/model/modules/block.py b/model/modules/block.py new file mode 100644 index 0000000..1b1ad84 --- /dev/null +++ b/model/modules/block.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" + @Date : 2023/12/5 10:21 + @Author : chairc + @Site : https://github.com/chairc +""" +import torch +import torch.nn as nn + +from model.modules.conv import BaseConv, DoubleConv +from model.modules.module import CSPLayer + + +class DownBlock(nn.Module): + """ + Downsample block + """ + + def __init__(self, in_channels, out_channels, emb_channels=256, act="silu"): + """ + Initialize the downsample block + :param in_channels: Input channels + :param out_channels: Output channels + :param emb_channels: Embed channels + :param act: Activation function + """ + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), + DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True, act=act), + DoubleConv(in_channels=in_channels, out_channels=out_channels, act=act), + ) + + self.emb_layer = nn.Sequential( + nn.SiLU(), + nn.Linear(in_features=emb_channels, out_features=out_channels), + ) + + def forward(self, x, time): + """ + DownBlock forward + :param x: Input + :param time: Time + :return: x + emb + """ + x = self.maxpool_conv(x) + emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) + return x + emb + + +class UpBlock(nn.Module): + """ + Upsample Block + """ + + def __init__(self, in_channels, out_channels, emb_channels=256, act="silu"): + """ + Initialize the upsample block + :param in_channels: Input channels + :param out_channels: Output channels + :param emb_channels: Embed channels + :param act: Activation function + """ + super().__init__() + + self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + self.conv = nn.Sequential( + DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True, act=act), + DoubleConv(in_channels=in_channels, out_channels=out_channels, mid_channels=in_channels // 2, act=act), + ) + + self.emb_layer = nn.Sequential( + nn.SiLU(), + nn.Linear(in_features=emb_channels, out_features=out_channels), + ) + + def forward(self, x, skip_x, time): + """ + UpBlock forward + :param x: Input + :param skip_x: Merged input + :param time: Time + :return: x + emb + """ + x = self.up(x) + x = torch.cat([skip_x, x], dim=1) + x = self.conv(x) + emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) + return x + emb + + +class CSPDarkDownBlock(nn.Module): + def __init__(self, in_channels, out_channels, emb_channels=256, n=1, act="silu"): + super().__init__() + self.conv_csp = nn.Sequential( + BaseConv(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, act=act), + CSPLayer(in_channels=out_channels, out_channels=out_channels, n=n, act=act) + ) + + self.emb_layer = nn.Sequential( + nn.SiLU(), + nn.Linear(in_features=emb_channels, out_features=out_channels), + ) + + def forward(self, x, time): + x = self.conv_csp(x) + emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) + return x + emb + + +class CSPDarkUpBlock(nn.Module): + + def __init__(self, in_channels, out_channels, emb_channels=256, n=1, act="silu"): + super().__init__() + self.up = nn.Upsample(scale_factor=2, mode="nearest") + self.conv = BaseConv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, act=act) + self.csp = CSPLayer(in_channels=in_channels, out_channels=out_channels, n=n, shortcut=False, act=act) + + self.emb_layer = nn.Sequential( + nn.SiLU(), + nn.Linear(in_features=emb_channels, out_features=out_channels), + ) + + def forward(self, x, skip_x, time): + x = self.conv(x) + x = self.up(x) + x = torch.cat([skip_x, x], dim=1) + x = self.conv(x) + emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) + return x + emb diff --git a/model/modules/conv.py b/model/modules/conv.py new file mode 100644 index 0000000..6114bf0 --- /dev/null +++ b/model/modules/conv.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" + @Date : 2023/12/5 10:22 + @Author : chairc + @Site : https://github.com/chairc +""" +import logging +import coloredlogs + +import torch.nn as nn +import torch.nn.functional as F + +from model.modules.activation import get_activation_function + +logger = logging.getLogger(__name__) +coloredlogs.install(level="INFO") + + +class DoubleConv(nn.Module): + """ + Double convolution + """ + + def __init__(self, in_channels, out_channels, mid_channels=None, residual=False, act="silu"): + """ + Initialize the double convolution block + :param in_channels: Input channels + :param out_channels: Output channels + :param mid_channels: Middle channels + :param residual: Whether residual + :param act: Activation function + """ + super().__init__() + self.residual = residual + if not mid_channels: + mid_channels = out_channels + self.act = act + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, padding=1, bias=False), + nn.GroupNorm(num_groups=1, num_channels=mid_channels), + get_activation_function(name=self.act), + nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False), + nn.GroupNorm(num_groups=1, num_channels=out_channels), + ) + + def forward(self, x): + """ + DoubleConv forward + :param x: Input + :return: Residual or non-residual results + """ + if self.residual: + out = x + self.double_conv(x) + if self.act == "relu": + return F.relu(out) + elif self.act == "relu6": + return F.relu6(out) + elif self.act == "silu": + return F.silu(out) + elif self.act == "lrelu": + return F.leaky_relu(out) + elif self.act == "gelu": + return F.gelu(out) + else: + logger.warning(msg=f"Unsupported activation function type: {self.act}") + return F.silu(out) + else: + return self.double_conv(x) + + +class BaseConv(nn.Module): + """ + Base convolution + Conv2d -> BatchNorm -> Activation function block + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False, act="silu"): + """ + Initialize the Base convolution + :param in_channels: Input channels + :param out_channels: Output channels + :param kernel_size: Kernel size + :param stride: Stride + :param groups: Groups + :param bias: Bias + :param act: Activation function + """ + super().__init__() + # Same padding + pad = (kernel_size - 1) // 2 + self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=pad, groups=groups, bias=bias) + self.gn = nn.GroupNorm(num_groups=1, num_channels=out_channels) + self.act = get_activation_function(name=act, inplace=True) + + def forward(self, x): + return self.act(self.gn(self.conv(x))) diff --git a/model/modules/ema.py b/model/modules/ema.py new file mode 100644 index 0000000..f04fdba --- /dev/null +++ b/model/modules/ema.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" + @Date : 2023/12/5 10:18 + @Author : chairc + @Site : https://github.com/chairc +""" + + +class EMA: + """ + Exponential Moving Average + """ + + def __init__(self, beta): + """ + Initialize EMA + :param beta: β + """ + super().__init__() + self.beta = beta + self.step = 0 + + def update_model_average(self, ema_model, current_model): + """ + Update model average + :param ema_model: EMA model + :param current_model: Current model + :return: None + """ + for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()): + old_weight, up_weight = ema_params, current_params.data + ema_params.data = self.update_average(old_weight, up_weight) + + def update_average(self, old_weight, new_weight): + """ + Update average + :param old_weight: Old weight + :param new_weight: New weight + :return: new_weight or old_weight * self.beta + (1 - self.beta) * new_weight + """ + if old_weight is None: + return new_weight + return old_weight * self.beta + (1 - self.beta) * new_weight + + def step_ema(self, ema_model, model, step_start_ema=2000): + """ + EMA step + :param ema_model: EMA model + :param model: Original model + :param step_start_ema: Start EMA step + :return: None + """ + if self.step < step_start_ema: + self.reset_parameters(ema_model, model) + self.step += 1 + return + self.update_model_average(ema_model, model) + self.step += 1 + + @staticmethod + def reset_parameters(ema_model, model): + """ + Reset parameters + :param ema_model: EMA model + :param model: Original model + :return: None + """ + ema_model.load_state_dict(model.state_dict()) diff --git a/model/modules/module.py b/model/modules/module.py index b073518..024b2b7 100644 --- a/model/modules/module.py +++ b/model/modules/module.py @@ -10,299 +10,13 @@ import torch import torch.nn as nn -import torch.nn.functional as F + +from model.modules.conv import BaseConv logger = logging.getLogger(__name__) coloredlogs.install(level="INFO") -def get_activation_function(name="silu", inplace=False): - """ - Get activation function - :param name: Activation function name - :param inplace: can optionally do the operation in-place - :return Activation function - """ - if name == "relu": - act = nn.ReLU(inplace=inplace) - elif name == "relu6": - act = nn.ReLU6(inplace=inplace) - elif name == "silu": - act = nn.SiLU(inplace=inplace) - elif name == "lrelu": - act = nn.LeakyReLU(0.1, inplace=inplace) - elif name == "gelu": - act = nn.GELU() - else: - logger.warning(msg=f"Unsupported activation function type: {name}") - act = nn.SiLU(inplace=inplace) - return act - - -class EMA: - """ - Exponential Moving Average - """ - - def __init__(self, beta): - """ - Initialize EMA - :param beta: β - """ - super().__init__() - self.beta = beta - self.step = 0 - - def update_model_average(self, ema_model, current_model): - """ - Update model average - :param ema_model: EMA model - :param current_model: Current model - :return: None - """ - for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()): - old_weight, up_weight = ema_params, current_params.data - ema_params.data = self.update_average(old_weight, up_weight) - - def update_average(self, old_weight, new_weight): - """ - Update average - :param old_weight: Old weight - :param new_weight: New weight - :return: new_weight or old_weight * self.beta + (1 - self.beta) * new_weight - """ - if old_weight is None: - return new_weight - return old_weight * self.beta + (1 - self.beta) * new_weight - - def step_ema(self, ema_model, model, step_start_ema=2000): - """ - EMA step - :param ema_model: EMA model - :param model: Original model - :param step_start_ema: Start EMA step - :return: None - """ - if self.step < step_start_ema: - self.reset_parameters(ema_model, model) - self.step += 1 - return - self.update_model_average(ema_model, model) - self.step += 1 - - def reset_parameters(self, ema_model, model): - """ - Reset parameters - :param ema_model: EMA model - :param model: Original model - :return: None - """ - ema_model.load_state_dict(model.state_dict()) - - -class SelfAttention(nn.Module): - """ - SelfAttention block - """ - - def __init__(self, channels, size, act="silu"): - """ - Initialize the self-attention block - :param channels: Channels - :param size: Size - :param act: Activation function - """ - super(SelfAttention, self).__init__() - self.channels = channels - self.size = size - # batch_first is not supported in pytorch 1.8. - # If you want to support upgrading to 1.9 and above, or use the following code to transpose - self.mha = nn.MultiheadAttention(embed_dim=channels, num_heads=4, batch_first=True) - self.ln = nn.LayerNorm(normalized_shape=[channels]) - self.ff_self = nn.Sequential( - nn.LayerNorm(normalized_shape=[channels]), - nn.Linear(in_features=channels, out_features=channels), - get_activation_function(name=act), - nn.Linear(in_features=channels, out_features=channels), - ) - - def forward(self, x): - """ - SelfAttention forward - :param x: Input - :return: attention_value - """ - # First perform the shape transformation, and then use 'swapaxes' to exchange the first - # second dimensions of the new tensor - x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2) - x_ln = self.ln(x) - # batch_first is not supported in pytorch 1.8. - # If you want to support upgrading to 1.9 and above, or use the following code to transpose - attention_value, _ = self.mha(x_ln, x_ln, x_ln) - attention_value = attention_value + x - attention_value = self.ff_self(attention_value) + attention_value - return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size) - - -class DoubleConv(nn.Module): - """ - Double convolution - """ - - def __init__(self, in_channels, out_channels, mid_channels=None, residual=False, act="silu"): - """ - Initialize the double convolution block - :param in_channels: Input channels - :param out_channels: Output channels - :param mid_channels: Middle channels - :param residual: Whether residual - :param act: Activation function - """ - super().__init__() - self.residual = residual - if not mid_channels: - mid_channels = out_channels - self.act = act - self.double_conv = nn.Sequential( - nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, padding=1, bias=False), - nn.GroupNorm(num_groups=1, num_channels=mid_channels), - get_activation_function(name=self.act), - nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False), - nn.GroupNorm(num_groups=1, num_channels=out_channels), - ) - - def forward(self, x): - """ - DoubleConv forward - :param x: Input - :return: Residual or non-residual results - """ - if self.residual: - out = x + self.double_conv(x) - if self.act == "relu": - return F.relu(out) - elif self.act == "relu6": - return F.relu6(out) - elif self.act == "silu": - return F.silu(out) - elif self.act == "lrelu": - return F.leaky_relu(out) - elif self.act == "gelu": - return F.gelu(out) - else: - logger.warning(msg=f"Unsupported activation function type: {self.act}") - return F.silu(out) - else: - return self.double_conv(x) - - -class DownBlock(nn.Module): - """ - Downsample block - """ - - def __init__(self, in_channels, out_channels, emb_channels=256, act="silu"): - """ - Initialize the downsample block - :param in_channels: Input channels - :param out_channels: Output channels - :param emb_channels: Embed channels - :param act: Activation function - """ - super().__init__() - self.maxpool_conv = nn.Sequential( - nn.MaxPool2d(2), - DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True, act=act), - DoubleConv(in_channels=in_channels, out_channels=out_channels, act=act), - ) - - self.emb_layer = nn.Sequential( - nn.SiLU(), - nn.Linear(in_features=emb_channels, out_features=out_channels), - ) - - def forward(self, x, time): - """ - DownBlock forward - :param x: Input - :param time: Time - :return: x + emb - """ - x = self.maxpool_conv(x) - emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) - return x + emb - - -class UpBlock(nn.Module): - """ - Upsample Block - """ - - def __init__(self, in_channels, out_channels, emb_channels=256, act="silu"): - """ - Initialize the upsample block - :param in_channels: Input channels - :param out_channels: Output channels - :param emb_channels: Embed channels - :param act: Activation function - """ - super().__init__() - - self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) - self.conv = nn.Sequential( - DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True, act=act), - DoubleConv(in_channels=in_channels, out_channels=out_channels, mid_channels=in_channels // 2, act=act), - ) - - self.emb_layer = nn.Sequential( - nn.SiLU(), - nn.Linear(in_features=emb_channels, out_features=out_channels), - ) - - def forward(self, x, skip_x, time): - """ - UpBlock forward - :param x: Input - :param skip_x: Merged input - :param time: Time - :return: x + emb - """ - x = self.up(x) - x = torch.cat([skip_x, x], dim=1) - x = self.conv(x) - emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) - return x + emb - - -class BaseConv(nn.Module): - """ - Base convolution - Conv2d -> BatchNorm -> Activation function block - """ - - def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False, act="silu"): - """ - Initialize the Base convolution - :param in_channels: Input channels - :param out_channels: Output channels - :param kernel_size: Kernel size - :param stride: Stride - :param groups: Groups - :param bias: Bias - :param act: Activation function - """ - super().__init__() - # Same padding - pad = (kernel_size - 1) // 2 - self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, - stride=stride, padding=pad, groups=groups, bias=bias) - self.gn = nn.GroupNorm(num_groups=1, num_channels=out_channels) - self.act = get_activation_function(name=act, inplace=True) - - def forward(self, x): - return self.act(self.gn(self.conv(x))) - - class Bottleneck(nn.Module): """ Standard bottleneck @@ -394,44 +108,3 @@ def forward(self, x): x_1 = self.m(x_1) x = torch.cat([x_1, x_2], dim=1) return self.conv3(x) - - -class CSPDarkDownBlock(nn.Module): - def __init__(self, in_channels, out_channels, emb_channels=256, n=1, act="silu"): - super().__init__() - self.conv_csp = nn.Sequential( - BaseConv(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, act=act), - CSPLayer(in_channels=out_channels, out_channels=out_channels, n=n, act=act) - ) - - self.emb_layer = nn.Sequential( - nn.SiLU(), - nn.Linear(in_features=emb_channels, out_features=out_channels), - ) - - def forward(self, x, time): - x = self.conv_csp(x) - emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) - return x + emb - - -class CSPDarkUpBlock(nn.Module): - - def __init__(self, in_channels, out_channels, emb_channels=256, n=1, act="silu"): - super().__init__() - self.up = nn.Upsample(scale_factor=2, mode="nearest") - self.conv = BaseConv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, act=act) - self.csp = CSPLayer(in_channels=in_channels, out_channels=out_channels, n=n, shortcut=False, act=act) - - self.emb_layer = nn.Sequential( - nn.SiLU(), - nn.Linear(in_features=emb_channels, out_features=out_channels), - ) - - def forward(self, x, skip_x, time): - x = self.conv(x) - x = self.up(x) - x = torch.cat([skip_x, x], dim=1) - x = self.conv(x) - emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) - return x + emb diff --git a/model/networks/cspdarkunet.py b/model/networks/cspdarkunet.py index ec9b7cf..21f052e 100644 --- a/model/networks/cspdarkunet.py +++ b/model/networks/cspdarkunet.py @@ -9,7 +9,9 @@ import torch.nn as nn from model.networks.base import BaseNet -from model.modules.module import SelfAttention, CSPDarkUpBlock, CSPDarkDownBlock, BaseConv +from model.modules.attention import SelfAttention +from model.modules.block import CSPDarkDownBlock,CSPDarkUpBlock +from model.modules.conv import BaseConv class CSPDarkUnet(BaseNet): diff --git a/model/networks/unet.py b/model/networks/unet.py index e0189b0..9bb99c8 100644 --- a/model/networks/unet.py +++ b/model/networks/unet.py @@ -9,7 +9,9 @@ import torch.nn as nn from model.networks.base import BaseNet -from model.modules.module import UpBlock, DownBlock, DoubleConv, SelfAttention +from model.modules.attention import SelfAttention +from model.modules.block import DownBlock, UpBlock +from model.modules.conv import DoubleConv class UNet(BaseNet): diff --git a/tools/train.py b/tools/train.py index 34c7abe..ec61aea 100644 --- a/tools/train.py +++ b/tools/train.py @@ -22,7 +22,7 @@ from tqdm import tqdm sys.path.append(os.path.dirname(sys.path[0])) -from model.modules.module import EMA +from model.modules.ema import EMA from utils.initializer import device_initializer, seed_initializer, network_initializer, optimizer_initializer, \ sample_initializer, lr_initializer, fp16_initializer from utils.utils import plot_images, save_images, get_dataset, setup_logging, save_train_logging