Skip to content

Commit

Permalink
chore: support onnx export
Browse files Browse the repository at this point in the history
  • Loading branch information
wufei2 committed May 4, 2024
1 parent dcae54c commit 64b34ae
Showing 1 changed file with 59 additions and 19 deletions.
78 changes: 59 additions & 19 deletions yolo_world/models/layers/yolo_bricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,42 @@
from mmyolo.registry import MODELS
from mmyolo.models.layers import CSPLayerWithTwoConv

#AdaptiveAvgPool2dCustom and AdaptiveMaxPool2dCustom are compatible when exporting onnx format models
# reference: https://github.com/pytorch/pytorch/issues/42653#issuecomment-1168816422
class AdaptiveAvgPool2dCustom(nn.Module):
def __init__(self, output_size):
super(AdaptiveAvgPool2dCustom, self).__init__()
self.output_size = torch.tensor(output_size)

def forward(self, x: torch.Tensor):
# Calculate the stride size required to achieve the desired output size
stride_size = torch.floor(torch.tensor(x.shape[-2:]) / self.output_size).to(torch.int32)

# Calculate the kernel size based on the stride size and desired output size
kernel_size = torch.tensor(x.shape[-2:]) - (self.output_size - 1) * stride_size

# Create a AvgPool2d layer with the calculated kernel and stride sizes
avg = nn.AvgPool2d(kernel_size.tolist(), stride=stride_size.tolist())

x = avg(x)
return x
class AdaptiveMaxPool2dCustom(nn.Module):
def __init__(self, output_size):
super(AdaptiveMaxPool2dCustom, self).__init__()
self.output_size = torch.tensor(output_size)

def forward(self, x: torch.Tensor):
# Calculate the stride size required to achieve the desired output size
stride_size = torch.floor(torch.tensor(x.shape[-2:]) / self.output_size).to(torch.int32)

# Calculate the kernel size based on the stride size and desired output size
kernel_size = torch.tensor(x.shape[-2:]) - (self.output_size - 1) * stride_size

# Create a MaxPool2d layer with the calculated kernel and stride sizes
max_pool = nn.MaxPool2d(kernel_size.tolist(), stride=stride_size.tolist())

x = max_pool(x)
return x

@MODELS.register_module()
class MaxSigmoidAttnBlock(BaseModule):
Expand All @@ -31,7 +67,7 @@ def __init__(self,
momentum=0.03,
eps=0.001),
init_cfg: OptMultiConfig = None,
use_einsum: bool = True) -> None:
export_onnx: bool = True) -> None:
super().__init__(init_cfg=init_cfg)
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule

Expand All @@ -40,7 +76,7 @@ def __init__(self,
'out_channels and embed_channels should be divisible by num_heads.'
self.num_heads = num_heads
self.head_channels = out_channels // num_heads
self.use_einsum = use_einsum
self.export_onnx = export_onnx

self.embed_conv = ConvModule(
in_channels,
Expand Down Expand Up @@ -73,8 +109,7 @@ def forward(self, x: Tensor, guide: Tensor) -> Tensor:
guide = guide.reshape(B, -1, self.num_heads, self.head_channels)
embed = self.embed_conv(x) if self.embed_conv is not None else x
embed = embed.reshape(B, self.num_heads, self.head_channels, H, W)

if self.use_einsum:
if self.export_onnx == False:
attn_weight = torch.einsum('bmchw,bnmc->bmhwn', embed, guide)
else:
batch, m, channel, height, width = embed.shape
Expand Down Expand Up @@ -116,7 +151,7 @@ def __init__(self,
momentum=0.03,
eps=0.001),
init_cfg: OptMultiConfig = None,
use_einsum: bool = True) -> None:
export_onnx: bool = True) -> None:
super().__init__(init_cfg=init_cfg)
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule

Expand All @@ -125,7 +160,7 @@ def __init__(self,
'out_channels and embed_channels should be divisible by num_heads.'
self.num_heads = num_heads
self.head_channels = out_channels // num_heads
self.use_einsum = use_einsum
self.export_onnx = export_onnx

self.embed_conv = ConvModule(
in_channels,
Expand Down Expand Up @@ -191,7 +226,7 @@ def __init__(
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None,
use_einsum: bool = True) -> None:
export_onnx: bool = True) -> None:
super().__init__(in_channels=in_channels,
out_channels=out_channels,
expand_ratio=expand_ratio,
Expand All @@ -217,7 +252,7 @@ def __init__(
with_scale=with_scale,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
use_einsum=use_einsum)
export_onnx=export_onnx)

def forward(self, x: Tensor, guide: Tensor) -> Tensor:
"""Forward process."""
Expand Down Expand Up @@ -247,7 +282,7 @@ def __init__(
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None,
use_einsum: bool = True) -> None:
export_onnx: bool = True) -> None:
super().__init__(in_channels=in_channels,
out_channels=out_channels,
expand_ratio=expand_ratio,
Expand All @@ -274,7 +309,7 @@ def __init__(
with_scale=with_scale,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
use_einsum=use_einsum)
export_onnx=export_onnx)

def forward(self, x: Tensor, guide: Tensor) -> Tensor:
"""Forward process."""
Expand All @@ -296,7 +331,7 @@ def __init__(self,
num_feats: int = 3,
num_heads: int = 8,
pool_size: int = 3,
use_einsum: bool = True):
export_onnx: bool = True):
super().__init__()

self.text_channels = text_channels
Expand All @@ -305,7 +340,7 @@ def __init__(self,
self.num_feats = num_feats
self.head_channels = embed_channels // num_heads
self.pool_size = pool_size
self.use_einsum = use_einsum
self.export_onnx = export_onnx
if with_scale:
self.scale = nn.Parameter(torch.tensor([0.]), requires_grad=True)
else:
Expand All @@ -321,11 +356,16 @@ def __init__(self,
self.value = nn.Sequential(nn.LayerNorm(embed_channels),
Linear(embed_channels, embed_channels))
self.proj = Linear(embed_channels, text_channels)

self.image_pools = nn.ModuleList([
nn.AdaptiveMaxPool2d((pool_size, pool_size))
for _ in range(num_feats)
])
if self.export_onnx == False:
self.image_pools = nn.ModuleList([
nn.AdaptiveMaxPool2d((pool_size, pool_size))
for _ in range(num_feats)
])
else:
self.image_pools = nn.ModuleList([
AdaptiveMaxPool2dCustom((pool_size, pool_size))
for _ in range(num_feats)
])

def forward(self, text_features, image_features):
B = image_features[0].shape[0]
Expand All @@ -345,7 +385,7 @@ def forward(self, text_features, image_features):
q = q.reshape(B, -1, self.num_heads, self.head_channels)
k = k.reshape(B, -1, self.num_heads, self.head_channels)
v = v.reshape(B, -1, self.num_heads, self.head_channels)
if self.use_einsum:
if self.export_onnx == False:
attn_weight = torch.einsum('bnmc,bkmc->bmnk', q, k)
else:
q = q.permute(0, 2, 1, 3)
Expand All @@ -354,7 +394,7 @@ def forward(self, text_features, image_features):

attn_weight = attn_weight / (self.head_channels**0.5)
attn_weight = F.softmax(attn_weight, dim=-1)
if self.use_einsum:
if self.export_onnx == False:
x = torch.einsum('bmnk,bkmc->bnmc', attn_weight, v)
else:
v = v.permute(0, 2, 1, 3)
Expand Down

0 comments on commit 64b34ae

Please sign in to comment.