diff --git a/yolo_world/models/detectors/yolo_world.py b/yolo_world/models/detectors/yolo_world.py index 6d4b0910..ba8d7439 100644 --- a/yolo_world/models/detectors/yolo_world.py +++ b/yolo_world/models/detectors/yolo_world.py @@ -79,6 +79,9 @@ def extract_feat( if batch_data_samples is None: texts = self.texts txt_feats = self.text_feats + batch_size=batch_inputs.shape[0] + texts = texts * batch_size + txt_feats = txt_feats.repeat(batch_size, 1, 1) elif isinstance(batch_data_samples, dict) and 'texts' in batch_data_samples: texts = batch_data_samples['texts'] diff --git a/yolo_world/models/layers/yolo_bricks.py b/yolo_world/models/layers/yolo_bricks.py index 0c39131c..0379fad8 100644 --- a/yolo_world/models/layers/yolo_bricks.py +++ b/yolo_world/models/layers/yolo_bricks.py @@ -11,6 +11,42 @@ from mmyolo.registry import MODELS from mmyolo.models.layers import CSPLayerWithTwoConv +#AdaptiveAvgPool2dCustom and AdaptiveMaxPool2dCustom are compatible when exporting onnx format models +# reference: https://github.com/pytorch/pytorch/issues/42653#issuecomment-1168816422 +class AdaptiveAvgPool2dCustom(nn.Module): + def __init__(self, output_size): + super(AdaptiveAvgPool2dCustom, self).__init__() + self.output_size = torch.tensor(output_size) + + def forward(self, x: torch.Tensor): + # Calculate the stride size required to achieve the desired output size + stride_size = torch.floor(torch.tensor(x.shape[-2:]) / self.output_size).to(torch.int32) + + # Calculate the kernel size based on the stride size and desired output size + kernel_size = torch.tensor(x.shape[-2:]) - (self.output_size - 1) * stride_size + + # Create a AvgPool2d layer with the calculated kernel and stride sizes + avg = nn.AvgPool2d(kernel_size.tolist(), stride=stride_size.tolist()) + + x = avg(x) + return x +class AdaptiveMaxPool2dCustom(nn.Module): + def __init__(self, output_size): + super(AdaptiveMaxPool2dCustom, self).__init__() + self.output_size = torch.tensor(output_size) + + def forward(self, x: torch.Tensor): + # Calculate the stride size required to achieve the desired output size + stride_size = torch.floor(torch.tensor(x.shape[-2:]) / self.output_size).to(torch.int32) + + # Calculate the kernel size based on the stride size and desired output size + kernel_size = torch.tensor(x.shape[-2:]) - (self.output_size - 1) * stride_size + + # Create a MaxPool2d layer with the calculated kernel and stride sizes + max_pool = nn.MaxPool2d(kernel_size.tolist(), stride=stride_size.tolist()) + + x = max_pool(x) + return x @MODELS.register_module() class MaxSigmoidAttnBlock(BaseModule): @@ -31,7 +67,7 @@ def __init__(self, momentum=0.03, eps=0.001), init_cfg: OptMultiConfig = None, - use_einsum: bool = True) -> None: + export_onnx: bool = True) -> None: super().__init__(init_cfg=init_cfg) conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule @@ -40,7 +76,7 @@ def __init__(self, 'out_channels and embed_channels should be divisible by num_heads.' self.num_heads = num_heads self.head_channels = out_channels // num_heads - self.use_einsum = use_einsum + self.export_onnx = export_onnx self.embed_conv = ConvModule( in_channels, @@ -73,8 +109,7 @@ def forward(self, x: Tensor, guide: Tensor) -> Tensor: guide = guide.reshape(B, -1, self.num_heads, self.head_channels) embed = self.embed_conv(x) if self.embed_conv is not None else x embed = embed.reshape(B, self.num_heads, self.head_channels, H, W) - - if self.use_einsum: + if self.export_onnx == False: attn_weight = torch.einsum('bmchw,bnmc->bmhwn', embed, guide) else: batch, m, channel, height, width = embed.shape @@ -116,7 +151,7 @@ def __init__(self, momentum=0.03, eps=0.001), init_cfg: OptMultiConfig = None, - use_einsum: bool = True) -> None: + export_onnx: bool = True) -> None: super().__init__(init_cfg=init_cfg) conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule @@ -125,7 +160,7 @@ def __init__(self, 'out_channels and embed_channels should be divisible by num_heads.' self.num_heads = num_heads self.head_channels = out_channels // num_heads - self.use_einsum = use_einsum + self.export_onnx = export_onnx self.embed_conv = ConvModule( in_channels, @@ -272,7 +307,7 @@ def __init__( norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001), act_cfg: ConfigType = dict(type='SiLU', inplace=True), init_cfg: OptMultiConfig = None, - use_einsum: bool = True) -> None: + export_onnx: bool = True) -> None: super().__init__(in_channels=in_channels, out_channels=out_channels, expand_ratio=expand_ratio, @@ -298,7 +333,7 @@ def __init__( with_scale=with_scale, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - use_einsum=use_einsum) + export_onnx=export_onnx) def forward(self, x: Tensor, guide: Tensor) -> Tensor: """Forward process.""" @@ -328,7 +363,7 @@ def __init__( norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001), act_cfg: ConfigType = dict(type='SiLU', inplace=True), init_cfg: OptMultiConfig = None, - use_einsum: bool = True) -> None: + export_onnx: bool = True) -> None: super().__init__(in_channels=in_channels, out_channels=out_channels, expand_ratio=expand_ratio, @@ -412,7 +447,7 @@ def __init__( with_scale=with_scale, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - use_einsum=use_einsum) + export_onnx=export_onnx) def forward(self, x: Tensor, guide: Tensor) -> Tensor: """Forward process.""" @@ -434,7 +469,7 @@ def __init__(self, num_feats: int = 3, num_heads: int = 8, pool_size: int = 3, - use_einsum: bool = True): + export_onnx: bool = True): super().__init__() self.text_channels = text_channels @@ -443,7 +478,7 @@ def __init__(self, self.num_feats = num_feats self.head_channels = embed_channels // num_heads self.pool_size = pool_size - self.use_einsum = use_einsum + self.export_onnx = export_onnx if with_scale: self.scale = nn.Parameter(torch.tensor([0.]), requires_grad=True) else: @@ -459,11 +494,16 @@ def __init__(self, self.value = nn.Sequential(nn.LayerNorm(embed_channels), Linear(embed_channels, embed_channels)) self.proj = Linear(embed_channels, text_channels) - - self.image_pools = nn.ModuleList([ - nn.AdaptiveMaxPool2d((pool_size, pool_size)) - for _ in range(num_feats) - ]) + if self.export_onnx == False: + self.image_pools = nn.ModuleList([ + nn.AdaptiveMaxPool2d((pool_size, pool_size)) + for _ in range(num_feats) + ]) + else: + self.image_pools = nn.ModuleList([ + AdaptiveMaxPool2dCustom((pool_size, pool_size)) + for _ in range(num_feats) + ]) def forward(self, text_features, image_features): B = image_features[0].shape[0] @@ -483,7 +523,7 @@ def forward(self, text_features, image_features): q = q.reshape(B, -1, self.num_heads, self.head_channels) k = k.reshape(B, -1, self.num_heads, self.head_channels) v = v.reshape(B, -1, self.num_heads, self.head_channels) - if self.use_einsum: + if self.export_onnx == False: attn_weight = torch.einsum('bnmc,bkmc->bmnk', q, k) else: q = q.permute(0, 2, 1, 3) @@ -492,7 +532,7 @@ def forward(self, text_features, image_features): attn_weight = attn_weight / (self.head_channels**0.5) attn_weight = F.softmax(attn_weight, dim=-1) - if self.use_einsum: + if self.export_onnx == False: x = torch.einsum('bmnk,bkmc->bnmc', attn_weight, v) else: v = v.permute(0, 2, 1, 3)