Skip to content

Commit

Permalink
Update: Reconstruct modules.py, and update package path.
Browse files Browse the repository at this point in the history
  • Loading branch information
chairc committed Dec 5, 2023
1 parent 387545f commit b2caa37
Show file tree
Hide file tree
Showing 11 changed files with 408 additions and 332 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@ We named this project IDDM: Industrial Defect Diffusion Model. It aims to reprod
**Repository Structure**

```yaml
Industrial Defect Diffusion Model
├── datasets
│ └── dataset_demo
│ ├── class_1
│ ├── class_2
│ └── class_3
├── model
│ ├── modules
│ │ ├── activation.py
│ │ ├── attention.py
│ │ ├── block.py
│ │ ├── conv.py
│ │ ├── ema.py
│ │ └── module.py
│ ├── networks
│ │ ├── base.py
Expand Down
6 changes: 6 additions & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@
**本仓库整体结构**

```yaml
Industrial Defect Diffusion Model
├── datasets
│ └── dataset_demo
│ ├── class_1
│ ├── class_2
│ └── class_3
├── model
│ ├── modules
│ │ ├── activation.py
│ │ ├── attention.py
│ │ ├── block.py
│ │ ├── conv.py
│ │ ├── ema.py
│ │ └── module.py
│ ├── networks
│ │ ├── base.py
Expand Down
36 changes: 36 additions & 0 deletions model/modules/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@Date : 2023/12/5 10:19
@Author : chairc
@Site : https://github.com/chairc
"""
import logging
import coloredlogs
import torch.nn as nn

logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")


def get_activation_function(name="silu", inplace=False):
"""
Get activation function
:param name: Activation function name
:param inplace: can optionally do the operation in-place
:return Activation function
"""
if name == "relu":
act = nn.ReLU(inplace=inplace)
elif name == "relu6":
act = nn.ReLU6(inplace=inplace)
elif name == "silu":
act = nn.SiLU(inplace=inplace)
elif name == "lrelu":
act = nn.LeakyReLU(0.1, inplace=inplace)
elif name == "gelu":
act = nn.GELU()
else:
logger.warning(msg=f"Unsupported activation function type: {name}")
act = nn.SiLU(inplace=inplace)
return act
53 changes: 53 additions & 0 deletions model/modules/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@Date : 2023/12/5 10:19
@Author : chairc
@Site : https://github.com/chairc
"""
import torch.nn as nn
from model.modules.activation import get_activation_function


class SelfAttention(nn.Module):
"""
SelfAttention block
"""

def __init__(self, channels, size, act="silu"):
"""
Initialize the self-attention block
:param channels: Channels
:param size: Size
:param act: Activation function
"""
super(SelfAttention, self).__init__()
self.channels = channels
self.size = size
# batch_first is not supported in pytorch 1.8.
# If you want to support upgrading to 1.9 and above, or use the following code to transpose
self.mha = nn.MultiheadAttention(embed_dim=channels, num_heads=4, batch_first=True)
self.ln = nn.LayerNorm(normalized_shape=[channels])
self.ff_self = nn.Sequential(
nn.LayerNorm(normalized_shape=[channels]),
nn.Linear(in_features=channels, out_features=channels),
get_activation_function(name=act),
nn.Linear(in_features=channels, out_features=channels),
)

def forward(self, x):
"""
SelfAttention forward
:param x: Input
:return: attention_value
"""
# First perform the shape transformation, and then use 'swapaxes' to exchange the first
# second dimensions of the new tensor
x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
x_ln = self.ln(x)
# batch_first is not supported in pytorch 1.8.
# If you want to support upgrading to 1.9 and above, or use the following code to transpose
attention_value, _ = self.mha(x_ln, x_ln, x_ln)
attention_value = attention_value + x
attention_value = self.ff_self(attention_value) + attention_value
return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)
131 changes: 131 additions & 0 deletions model/modules/block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@Date : 2023/12/5 10:21
@Author : chairc
@Site : https://github.com/chairc
"""
import torch
import torch.nn as nn

from model.modules.conv import BaseConv, DoubleConv
from model.modules.module import CSPLayer


class DownBlock(nn.Module):
"""
Downsample block
"""

def __init__(self, in_channels, out_channels, emb_channels=256, act="silu"):
"""
Initialize the downsample block
:param in_channels: Input channels
:param out_channels: Output channels
:param emb_channels: Embed channels
:param act: Activation function
"""
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True, act=act),
DoubleConv(in_channels=in_channels, out_channels=out_channels, act=act),
)

self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(in_features=emb_channels, out_features=out_channels),
)

def forward(self, x, time):
"""
DownBlock forward
:param x: Input
:param time: Time
:return: x + emb
"""
x = self.maxpool_conv(x)
emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb


class UpBlock(nn.Module):
"""
Upsample Block
"""

def __init__(self, in_channels, out_channels, emb_channels=256, act="silu"):
"""
Initialize the upsample block
:param in_channels: Input channels
:param out_channels: Output channels
:param emb_channels: Embed channels
:param act: Activation function
"""
super().__init__()

self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.conv = nn.Sequential(
DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True, act=act),
DoubleConv(in_channels=in_channels, out_channels=out_channels, mid_channels=in_channels // 2, act=act),
)

self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(in_features=emb_channels, out_features=out_channels),
)

def forward(self, x, skip_x, time):
"""
UpBlock forward
:param x: Input
:param skip_x: Merged input
:param time: Time
:return: x + emb
"""
x = self.up(x)
x = torch.cat([skip_x, x], dim=1)
x = self.conv(x)
emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb


class CSPDarkDownBlock(nn.Module):
def __init__(self, in_channels, out_channels, emb_channels=256, n=1, act="silu"):
super().__init__()
self.conv_csp = nn.Sequential(
BaseConv(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, act=act),
CSPLayer(in_channels=out_channels, out_channels=out_channels, n=n, act=act)
)

self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(in_features=emb_channels, out_features=out_channels),
)

def forward(self, x, time):
x = self.conv_csp(x)
emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb


class CSPDarkUpBlock(nn.Module):

def __init__(self, in_channels, out_channels, emb_channels=256, n=1, act="silu"):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode="nearest")
self.conv = BaseConv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, act=act)
self.csp = CSPLayer(in_channels=in_channels, out_channels=out_channels, n=n, shortcut=False, act=act)

self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(in_features=emb_channels, out_features=out_channels),
)

def forward(self, x, skip_x, time):
x = self.conv(x)
x = self.up(x)
x = torch.cat([skip_x, x], dim=1)
x = self.conv(x)
emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb
98 changes: 98 additions & 0 deletions model/modules/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@Date : 2023/12/5 10:22
@Author : chairc
@Site : https://github.com/chairc
"""
import logging
import coloredlogs

import torch.nn as nn
import torch.nn.functional as F

from model.modules.activation import get_activation_function

logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")


class DoubleConv(nn.Module):
"""
Double convolution
"""

def __init__(self, in_channels, out_channels, mid_channels=None, residual=False, act="silu"):
"""
Initialize the double convolution block
:param in_channels: Input channels
:param out_channels: Output channels
:param mid_channels: Middle channels
:param residual: Whether residual
:param act: Activation function
"""
super().__init__()
self.residual = residual
if not mid_channels:
mid_channels = out_channels
self.act = act
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(num_groups=1, num_channels=mid_channels),
get_activation_function(name=self.act),
nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(num_groups=1, num_channels=out_channels),
)

def forward(self, x):
"""
DoubleConv forward
:param x: Input
:return: Residual or non-residual results
"""
if self.residual:
out = x + self.double_conv(x)
if self.act == "relu":
return F.relu(out)
elif self.act == "relu6":
return F.relu6(out)
elif self.act == "silu":
return F.silu(out)
elif self.act == "lrelu":
return F.leaky_relu(out)
elif self.act == "gelu":
return F.gelu(out)
else:
logger.warning(msg=f"Unsupported activation function type: {self.act}")
return F.silu(out)
else:
return self.double_conv(x)


class BaseConv(nn.Module):
"""
Base convolution
Conv2d -> BatchNorm -> Activation function block
"""

def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False, act="silu"):
"""
Initialize the Base convolution
:param in_channels: Input channels
:param out_channels: Output channels
:param kernel_size: Kernel size
:param stride: Stride
:param groups: Groups
:param bias: Bias
:param act: Activation function
"""
super().__init__()
# Same padding
pad = (kernel_size - 1) // 2
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=pad, groups=groups, bias=bias)
self.gn = nn.GroupNorm(num_groups=1, num_channels=out_channels)
self.act = get_activation_function(name=act, inplace=True)

def forward(self, x):
return self.act(self.gn(self.conv(x)))
Loading

0 comments on commit b2caa37

Please sign in to comment.