From 64b34ae77928aeba85403684129d16a39c2b8f71 Mon Sep 17 00:00:00 2001 From: wufei2 Date: Sat, 4 May 2024 14:20:44 +0800 Subject: [PATCH] chore: support onnx export --- yolo_world/models/layers/yolo_bricks.py | 78 +++++++++++++++++++------ 1 file changed, 59 insertions(+), 19 deletions(-) diff --git a/yolo_world/models/layers/yolo_bricks.py b/yolo_world/models/layers/yolo_bricks.py index 7ba797cf..8dae7da4 100644 --- a/yolo_world/models/layers/yolo_bricks.py +++ b/yolo_world/models/layers/yolo_bricks.py @@ -11,6 +11,42 @@ from mmyolo.registry import MODELS from mmyolo.models.layers import CSPLayerWithTwoConv +#AdaptiveAvgPool2dCustom and AdaptiveMaxPool2dCustom are compatible when exporting onnx format models +# reference: https://github.com/pytorch/pytorch/issues/42653#issuecomment-1168816422 +class AdaptiveAvgPool2dCustom(nn.Module): + def __init__(self, output_size): + super(AdaptiveAvgPool2dCustom, self).__init__() + self.output_size = torch.tensor(output_size) + + def forward(self, x: torch.Tensor): + # Calculate the stride size required to achieve the desired output size + stride_size = torch.floor(torch.tensor(x.shape[-2:]) / self.output_size).to(torch.int32) + + # Calculate the kernel size based on the stride size and desired output size + kernel_size = torch.tensor(x.shape[-2:]) - (self.output_size - 1) * stride_size + + # Create a AvgPool2d layer with the calculated kernel and stride sizes + avg = nn.AvgPool2d(kernel_size.tolist(), stride=stride_size.tolist()) + + x = avg(x) + return x +class AdaptiveMaxPool2dCustom(nn.Module): + def __init__(self, output_size): + super(AdaptiveMaxPool2dCustom, self).__init__() + self.output_size = torch.tensor(output_size) + + def forward(self, x: torch.Tensor): + # Calculate the stride size required to achieve the desired output size + stride_size = torch.floor(torch.tensor(x.shape[-2:]) / self.output_size).to(torch.int32) + + # Calculate the kernel size based on the stride size and desired output size + kernel_size = torch.tensor(x.shape[-2:]) - (self.output_size - 1) * stride_size + + # Create a MaxPool2d layer with the calculated kernel and stride sizes + max_pool = nn.MaxPool2d(kernel_size.tolist(), stride=stride_size.tolist()) + + x = max_pool(x) + return x @MODELS.register_module() class MaxSigmoidAttnBlock(BaseModule): @@ -31,7 +67,7 @@ def __init__(self, momentum=0.03, eps=0.001), init_cfg: OptMultiConfig = None, - use_einsum: bool = True) -> None: + export_onnx: bool = True) -> None: super().__init__(init_cfg=init_cfg) conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule @@ -40,7 +76,7 @@ def __init__(self, 'out_channels and embed_channels should be divisible by num_heads.' self.num_heads = num_heads self.head_channels = out_channels // num_heads - self.use_einsum = use_einsum + self.export_onnx = export_onnx self.embed_conv = ConvModule( in_channels, @@ -73,8 +109,7 @@ def forward(self, x: Tensor, guide: Tensor) -> Tensor: guide = guide.reshape(B, -1, self.num_heads, self.head_channels) embed = self.embed_conv(x) if self.embed_conv is not None else x embed = embed.reshape(B, self.num_heads, self.head_channels, H, W) - - if self.use_einsum: + if self.export_onnx == False: attn_weight = torch.einsum('bmchw,bnmc->bmhwn', embed, guide) else: batch, m, channel, height, width = embed.shape @@ -116,7 +151,7 @@ def __init__(self, momentum=0.03, eps=0.001), init_cfg: OptMultiConfig = None, - use_einsum: bool = True) -> None: + export_onnx: bool = True) -> None: super().__init__(init_cfg=init_cfg) conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule @@ -125,7 +160,7 @@ def __init__(self, 'out_channels and embed_channels should be divisible by num_heads.' self.num_heads = num_heads self.head_channels = out_channels // num_heads - self.use_einsum = use_einsum + self.export_onnx = export_onnx self.embed_conv = ConvModule( in_channels, @@ -191,7 +226,7 @@ def __init__( norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001), act_cfg: ConfigType = dict(type='SiLU', inplace=True), init_cfg: OptMultiConfig = None, - use_einsum: bool = True) -> None: + export_onnx: bool = True) -> None: super().__init__(in_channels=in_channels, out_channels=out_channels, expand_ratio=expand_ratio, @@ -217,7 +252,7 @@ def __init__( with_scale=with_scale, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - use_einsum=use_einsum) + export_onnx=export_onnx) def forward(self, x: Tensor, guide: Tensor) -> Tensor: """Forward process.""" @@ -247,7 +282,7 @@ def __init__( norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001), act_cfg: ConfigType = dict(type='SiLU', inplace=True), init_cfg: OptMultiConfig = None, - use_einsum: bool = True) -> None: + export_onnx: bool = True) -> None: super().__init__(in_channels=in_channels, out_channels=out_channels, expand_ratio=expand_ratio, @@ -274,7 +309,7 @@ def __init__( with_scale=with_scale, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - use_einsum=use_einsum) + export_onnx=export_onnx) def forward(self, x: Tensor, guide: Tensor) -> Tensor: """Forward process.""" @@ -296,7 +331,7 @@ def __init__(self, num_feats: int = 3, num_heads: int = 8, pool_size: int = 3, - use_einsum: bool = True): + export_onnx: bool = True): super().__init__() self.text_channels = text_channels @@ -305,7 +340,7 @@ def __init__(self, self.num_feats = num_feats self.head_channels = embed_channels // num_heads self.pool_size = pool_size - self.use_einsum = use_einsum + self.export_onnx = export_onnx if with_scale: self.scale = nn.Parameter(torch.tensor([0.]), requires_grad=True) else: @@ -321,11 +356,16 @@ def __init__(self, self.value = nn.Sequential(nn.LayerNorm(embed_channels), Linear(embed_channels, embed_channels)) self.proj = Linear(embed_channels, text_channels) - - self.image_pools = nn.ModuleList([ - nn.AdaptiveMaxPool2d((pool_size, pool_size)) - for _ in range(num_feats) - ]) + if self.export_onnx == False: + self.image_pools = nn.ModuleList([ + nn.AdaptiveMaxPool2d((pool_size, pool_size)) + for _ in range(num_feats) + ]) + else: + self.image_pools = nn.ModuleList([ + AdaptiveMaxPool2dCustom((pool_size, pool_size)) + for _ in range(num_feats) + ]) def forward(self, text_features, image_features): B = image_features[0].shape[0] @@ -345,7 +385,7 @@ def forward(self, text_features, image_features): q = q.reshape(B, -1, self.num_heads, self.head_channels) k = k.reshape(B, -1, self.num_heads, self.head_channels) v = v.reshape(B, -1, self.num_heads, self.head_channels) - if self.use_einsum: + if self.export_onnx == False: attn_weight = torch.einsum('bnmc,bkmc->bmnk', q, k) else: q = q.permute(0, 2, 1, 3) @@ -354,7 +394,7 @@ def forward(self, text_features, image_features): attn_weight = attn_weight / (self.head_channels**0.5) attn_weight = F.softmax(attn_weight, dim=-1) - if self.use_einsum: + if self.export_onnx == False: x = torch.einsum('bmnk,bkmc->bnmc', attn_weight, v) else: v = v.permute(0, 2, 1, 3)