-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from chairc/dev
Reconstruct modules.py, and update package path; Rewrite checkpoint storage and loading functions, add checkpoint files
- Loading branch information
Showing
15 changed files
with
592 additions
and
419 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#!/usr/bin/env python | ||
# -*- coding:utf-8 -*- | ||
""" | ||
@Date : 2023/12/5 10:19 | ||
@Author : chairc | ||
@Site : https://github.com/chairc | ||
""" | ||
import logging | ||
import coloredlogs | ||
import torch.nn as nn | ||
|
||
logger = logging.getLogger(__name__) | ||
coloredlogs.install(level="INFO") | ||
|
||
|
||
def get_activation_function(name="silu", inplace=False): | ||
""" | ||
Get activation function | ||
:param name: Activation function name | ||
:param inplace: can optionally do the operation in-place | ||
:return Activation function | ||
""" | ||
if name == "relu": | ||
act = nn.ReLU(inplace=inplace) | ||
elif name == "relu6": | ||
act = nn.ReLU6(inplace=inplace) | ||
elif name == "silu": | ||
act = nn.SiLU(inplace=inplace) | ||
elif name == "lrelu": | ||
act = nn.LeakyReLU(0.1, inplace=inplace) | ||
elif name == "gelu": | ||
act = nn.GELU() | ||
else: | ||
logger.warning(msg=f"Unsupported activation function type: {name}") | ||
act = nn.SiLU(inplace=inplace) | ||
return act |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#!/usr/bin/env python | ||
# -*- coding:utf-8 -*- | ||
""" | ||
@Date : 2023/12/5 10:19 | ||
@Author : chairc | ||
@Site : https://github.com/chairc | ||
""" | ||
import torch.nn as nn | ||
from model.modules.activation import get_activation_function | ||
|
||
|
||
class SelfAttention(nn.Module): | ||
""" | ||
SelfAttention block | ||
""" | ||
|
||
def __init__(self, channels, size, act="silu"): | ||
""" | ||
Initialize the self-attention block | ||
:param channels: Channels | ||
:param size: Size | ||
:param act: Activation function | ||
""" | ||
super(SelfAttention, self).__init__() | ||
self.channels = channels | ||
self.size = size | ||
# batch_first is not supported in pytorch 1.8. | ||
# If you want to support upgrading to 1.9 and above, or use the following code to transpose | ||
self.mha = nn.MultiheadAttention(embed_dim=channels, num_heads=4, batch_first=True) | ||
self.ln = nn.LayerNorm(normalized_shape=[channels]) | ||
self.ff_self = nn.Sequential( | ||
nn.LayerNorm(normalized_shape=[channels]), | ||
nn.Linear(in_features=channels, out_features=channels), | ||
get_activation_function(name=act), | ||
nn.Linear(in_features=channels, out_features=channels), | ||
) | ||
|
||
def forward(self, x): | ||
""" | ||
SelfAttention forward | ||
:param x: Input | ||
:return: attention_value | ||
""" | ||
# First perform the shape transformation, and then use 'swapaxes' to exchange the first | ||
# second dimensions of the new tensor | ||
x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2) | ||
x_ln = self.ln(x) | ||
# batch_first is not supported in pytorch 1.8. | ||
# If you want to support upgrading to 1.9 and above, or use the following code to transpose | ||
attention_value, _ = self.mha(x_ln, x_ln, x_ln) | ||
attention_value = attention_value + x | ||
attention_value = self.ff_self(attention_value) + attention_value | ||
return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
#!/usr/bin/env python | ||
# -*- coding:utf-8 -*- | ||
""" | ||
@Date : 2023/12/5 10:21 | ||
@Author : chairc | ||
@Site : https://github.com/chairc | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
|
||
from model.modules.conv import BaseConv, DoubleConv | ||
from model.modules.module import CSPLayer | ||
|
||
|
||
class DownBlock(nn.Module): | ||
""" | ||
Downsample block | ||
""" | ||
|
||
def __init__(self, in_channels, out_channels, emb_channels=256, act="silu"): | ||
""" | ||
Initialize the downsample block | ||
:param in_channels: Input channels | ||
:param out_channels: Output channels | ||
:param emb_channels: Embed channels | ||
:param act: Activation function | ||
""" | ||
super().__init__() | ||
self.maxpool_conv = nn.Sequential( | ||
nn.MaxPool2d(2), | ||
DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True, act=act), | ||
DoubleConv(in_channels=in_channels, out_channels=out_channels, act=act), | ||
) | ||
|
||
self.emb_layer = nn.Sequential( | ||
nn.SiLU(), | ||
nn.Linear(in_features=emb_channels, out_features=out_channels), | ||
) | ||
|
||
def forward(self, x, time): | ||
""" | ||
DownBlock forward | ||
:param x: Input | ||
:param time: Time | ||
:return: x + emb | ||
""" | ||
x = self.maxpool_conv(x) | ||
emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) | ||
return x + emb | ||
|
||
|
||
class UpBlock(nn.Module): | ||
""" | ||
Upsample Block | ||
""" | ||
|
||
def __init__(self, in_channels, out_channels, emb_channels=256, act="silu"): | ||
""" | ||
Initialize the upsample block | ||
:param in_channels: Input channels | ||
:param out_channels: Output channels | ||
:param emb_channels: Embed channels | ||
:param act: Activation function | ||
""" | ||
super().__init__() | ||
|
||
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) | ||
self.conv = nn.Sequential( | ||
DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True, act=act), | ||
DoubleConv(in_channels=in_channels, out_channels=out_channels, mid_channels=in_channels // 2, act=act), | ||
) | ||
|
||
self.emb_layer = nn.Sequential( | ||
nn.SiLU(), | ||
nn.Linear(in_features=emb_channels, out_features=out_channels), | ||
) | ||
|
||
def forward(self, x, skip_x, time): | ||
""" | ||
UpBlock forward | ||
:param x: Input | ||
:param skip_x: Merged input | ||
:param time: Time | ||
:return: x + emb | ||
""" | ||
x = self.up(x) | ||
x = torch.cat([skip_x, x], dim=1) | ||
x = self.conv(x) | ||
emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) | ||
return x + emb | ||
|
||
|
||
class CSPDarkDownBlock(nn.Module): | ||
def __init__(self, in_channels, out_channels, emb_channels=256, n=1, act="silu"): | ||
super().__init__() | ||
self.conv_csp = nn.Sequential( | ||
BaseConv(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, act=act), | ||
CSPLayer(in_channels=out_channels, out_channels=out_channels, n=n, act=act) | ||
) | ||
|
||
self.emb_layer = nn.Sequential( | ||
nn.SiLU(), | ||
nn.Linear(in_features=emb_channels, out_features=out_channels), | ||
) | ||
|
||
def forward(self, x, time): | ||
x = self.conv_csp(x) | ||
emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) | ||
return x + emb | ||
|
||
|
||
class CSPDarkUpBlock(nn.Module): | ||
|
||
def __init__(self, in_channels, out_channels, emb_channels=256, n=1, act="silu"): | ||
super().__init__() | ||
self.up = nn.Upsample(scale_factor=2, mode="nearest") | ||
self.conv = BaseConv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, act=act) | ||
self.csp = CSPLayer(in_channels=in_channels, out_channels=out_channels, n=n, shortcut=False, act=act) | ||
|
||
self.emb_layer = nn.Sequential( | ||
nn.SiLU(), | ||
nn.Linear(in_features=emb_channels, out_features=out_channels), | ||
) | ||
|
||
def forward(self, x, skip_x, time): | ||
x = self.conv(x) | ||
x = self.up(x) | ||
x = torch.cat([skip_x, x], dim=1) | ||
x = self.conv(x) | ||
emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) | ||
return x + emb |
Oops, something went wrong.