diff --git a/configs/yolov10/hyp.scratch.yaml b/configs/yolov10/hyp.scratch.yaml new file mode 100644 index 00000000..eb5883ee --- /dev/null +++ b/configs/yolov10/hyp.scratch.yaml @@ -0,0 +1,60 @@ +optimizer: + optimizer: momentum + lr_init: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3) + momentum: 0.937 # SGD momentum/Adam beta1 + nesterov: True # update gradients with NAG(Nesterov Accelerated Gradient) algorithm + loss_scale: 1.0 # loss scale for optimizer + warmup_epochs: 3 # warmup epochs (fractions ok) + warmup_momentum: 0.8 # warmup initial momentum + warmup_bias_lr: 0.1 # warmup initial bias lr + min_warmup_step: 1000 # minimum warmup step + group_param: yolov8 # group param strategy + gp_weight_decay: 0.0005 # group param weight decay 5e-4 + start_factor: 1.0 + end_factor: 0.01 + +loss: + name: YOLOv10Loss + box: 7.5 # box loss gain + cls: 0.5 # cls loss gain + dfl: 1.5 # dfl loss gain + reg_max: 16 + +data: + num_parallel_workers: 4 + + # multi-stage data augment + train_transforms: { + stage_epochs: [ 490, 10 ], + trans_list: [ + [ + {func_name: mosaic, prob: 1.0}, + {func_name: resample_segments}, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.5, shear: 0.0}, + {func_name: albumentations}, + {func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4}, + {func_name: fliplr, prob: 0.5}, + {func_name: label_norm, xyxy2xywh_: True}, + {func_name: label_pad, padding_size: 160, padding_value: -1}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ], + [ + {func_name: letterbox, scaleup: True}, + {func_name: resample_segments}, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.5, shear: 0.0}, + {func_name: albumentations}, + {func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4}, + {func_name: fliplr, prob: 0.5}, + {func_name: label_norm, xyxy2xywh_: True}, + {func_name: label_pad, padding_size: 160, padding_value: -1}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ]] + } + + test_transforms: [ + {func_name: letterbox, scaleup: False, only_image: True}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ] diff --git a/configs/yolov10/yolov10n.yaml b/configs/yolov10/yolov10n.yaml new file mode 100644 index 00000000..3c08470c --- /dev/null +++ b/configs/yolov10/yolov10n.yaml @@ -0,0 +1,61 @@ +__BASE__: [ + '../coco.yaml', + './hyp.scratch.yaml', +] + +epochs: 500 # total train epochs +per_batch_size: 32 # 32 * 8 = 256 +img_size: 640 +iou_thres: 0.7 +overflow_still_update: False +ms_loss_scaler: dynamic +ms_loss_scaler_value: 65536.0 +clip_grad: True +anchor_base: False +opencv_threads_num: 0 # opencv: disable threading optimizations + +network: + model_name: yolov10 + nc: 80 # number of classes + reg_max: 16 + + depth_multiple: 0.33 # model depth multiple + width_multiple: 0.25 # layer channel multiple + max_channels: 1024 + stride: [8, 16, 32] + + # YOLOv10.0n backbone + backbone: + # [from, repeats, module, args] + - [-1, 1, ConvNormAct, [64, 3, 2]] # 0-P1/2 + - [-1, 1, ConvNormAct, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, ConvNormAct, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + + # YOLOv10.0n head + head: + - [-1, 1, Upsample, [None, 2, 'nearest']] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 13 + + - [-1, 1, Upsample, [None, 2, 'nearest']] + - [[-1, 4], 1, Concat, [1] ] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, ConvNormAct, [256, 3, 2]] + - [[ -1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, YOLOv10Head, [nc, reg_max, stride]] # Detect(P3, P4, P5) + \ No newline at end of file diff --git a/mindyolo/data/albumentations.py b/mindyolo/data/albumentations.py index 1ab3307f..9456cf87 100644 --- a/mindyolo/data/albumentations.py +++ b/mindyolo/data/albumentations.py @@ -56,7 +56,7 @@ def __call__(self, sample, p=1.0, **kwargs): sample['img'] = new['image'] sample['bboxes'] = np.array(new['bboxes']) - sample['cls'] = np.array(new['class_labels']) + sample['cls'] = np.array(new['class_labels']).reshape(-1, 1) sample['bbox_format'] = "xywhn" return sample diff --git a/mindyolo/models/__init__.py b/mindyolo/models/__init__.py index b6eb53fe..4621cb1a 100644 --- a/mindyolo/models/__init__.py +++ b/mindyolo/models/__init__.py @@ -1,10 +1,11 @@ from . import (heads, initializer, layers, losses, model_factory, yolov3, - yolov4, yolov5, yolov7, yolov8) + yolov4, yolov5, yolov7, yolov8, yolov10) __all__ = [] __all__.extend(heads.__all__) __all__.extend(layers.__all__) __all__.extend(losses.__all__) +__all__.extend(yolov10.__all__) __all__.extend(yolov8.__all__) __all__.extend(yolov7.__all__) __all__.extend(yolov5.__all__) @@ -25,4 +26,5 @@ from .yolov5 import * from .yolov7 import * from .yolov8 import * +from .yolov10 import * from .yolox import * diff --git a/mindyolo/models/heads/__init__.py b/mindyolo/models/heads/__init__.py index 593e3df8..9c69bc99 100644 --- a/mindyolo/models/heads/__init__.py +++ b/mindyolo/models/heads/__init__.py @@ -5,7 +5,7 @@ from .yolov7_head import * from .yolov8_head import * from .yolox_head import * - +from .yolov10_head import * __all__ = [ "YOLOv3Head", @@ -13,5 +13,6 @@ "YOLOv5Head", "YOLOv7Head", "YOLOv7AuxHead", "YOLOv8Head", "YOLOv8SegHead", - "YOLOXHead" + "YOLOXHead", + "YOLOv10Head" ] diff --git a/mindyolo/models/heads/yolov10_head.py b/mindyolo/models/heads/yolov10_head.py new file mode 100644 index 00000000..ba0276ec --- /dev/null +++ b/mindyolo/models/heads/yolov10_head.py @@ -0,0 +1,192 @@ +import math +import numpy as np +from copy import deepcopy + +import mindspore as ms +import mindspore.numpy as mnp +from mindspore import Parameter, Tensor, nn, ops + +from ..layers import DFL, ConvNormAct, Identity +from ..layers.utils import meshgrid + +class YOLOv10Head(nn.Cell): + # YOLOv10 Detect head for detection models + def __init__(self, nc=80, reg_max=16, stride=(), ch=(), sync_bn=False): # detection layer + super().__init__() + # self.dynamic = False # force grid reconstruction + + assert isinstance(stride, (tuple, list)) and len(stride) > 0 + assert isinstance(ch, (tuple, list)) and len(ch) > 0 + + self.nc = nc # number of classes + self.nl = len(ch) # number of detection layers + self.reg_max = reg_max # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x) + self.no = nc + self.reg_max * 4 # number of outputs per anchor + self.stride = Parameter(Tensor(stride, ms.int32), requires_grad=False) + self.max_det = 300 # max_det + self.end2end = True + self.export = False + + c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels + self.cv2 = nn.CellList( + [ + nn.SequentialCell( + [ + ConvNormAct(x, c2, 3, sync_bn=sync_bn), + ConvNormAct(c2, c2, 3, sync_bn=sync_bn), + nn.Conv2d(c2, 4 * self.reg_max, 1, has_bias=True), + ] + ) + for x in ch + ] + ) + self.cv3 = nn.CellList( + [ + nn.SequentialCell( + [ + nn.SequentialCell( + [ + ConvNormAct(x, x, 3, g=x), + ConvNormAct(x, c3, 1) + ] + ), + nn.SequentialCell([ + ConvNormAct(c3, c3, 3, g=c3), + ConvNormAct(c3, c3, 1) + ] + ), + nn.Conv2d(c3, self.nc, 1, has_bias=True) + ] + ) + for i, x in enumerate(ch) + ] + ) + self.dfl = DFL(self.reg_max) if self.reg_max > 1 else Identity() + + self.one2one_cv2 = deepcopy(self.cv2) + self.one2one_cv3 = deepcopy(self.cv3) + + def construct(self, x): + """Concatenates and returns predicted bounding boxes and class probabilities.""" + if self.end2end: + return self.construct_end2end(x) + + x = () + for i in range(self.nl): + x += (ops.concat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1),) + if self.training: # Training path + return x + y= self._inference(x) + return y if self.export else (y, x) + + def construct_end2end(self, x): + """ + Performs forward pass of the YOLOv10Head module. + + Args: + x (tensor): Input tensor. + + Returns: + (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections. + If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately. + """ + x_detach = [ops.stop_gradient(xi) for xi in x] + one2one = () + for i in range(self.nl): + one2one += (ops.concat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1),) + one2many = () + for i in range(self.nl): + one2many += (ops.concat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1), ) + if self.training: # Training path + return (one2many, one2one) + y = self._inference(one2one) + y = self.postprocess(ops.transpose(y, (0, 2, 1)), self.max_det, self.nc) + return y if self.export else (y, (one2many, one2one)) + + def _inference(self, x): + # Inference path + shape = x[0].shape # BCHW + _anchors, _strides = self.make_anchors(x, self.stride, 0.5) + _anchors, _strides = _anchors.swapaxes(0, 1), _strides.swapaxes(0, 1) + + _x = () + for i in range(len(x)): + _x += (x[i].view(shape[0], self.no, -1),) + _x = ops.concat(_x, 2) + box, cls = _x[:, : self.reg_max * 4, :], _x[:, self.reg_max * 4 : self.reg_max * 4 + self.nc, :] + # box, cls = ops.concat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1) + dbox = self.dist2bbox(self.dfl(box), ops.expand_dims(_anchors, 0), xywh=not self.end2end, axis=1) * _strides + + y = None + return ops.concat((dbox, ops.Sigmoid()(cls)), 1) + + @staticmethod + def make_anchors(feats, strides, grid_cell_offset=0.5): + """Generate anchors from features.""" + anchor_points, stride_tensor = (), () + dtype = feats[0].dtype + for i, stride in enumerate(strides): + _, _, h, w = feats[i].shape + sx = mnp.arange(w, dtype=dtype) + grid_cell_offset # shift x + sy = mnp.arange(h, dtype=dtype) + grid_cell_offset # shift y + # FIXME: Not supported on a specific model of machine + sy, sx = meshgrid((sy, sx), indexing="ij") + anchor_points += (ops.stack((sx, sy), -1).view(-1, 2),) + stride_tensor += (ops.ones((h * w, 1), dtype) * stride,) + return ops.concat(anchor_points), ops.concat(stride_tensor) + + @staticmethod + def dist2bbox(distance, anchor_points, xywh=True, axis=-1): + """Transform distance(ltrb) to box(xywh or xyxy).""" + lt, rb = ops.split(distance, split_size_or_sections=2, axis=axis) + x1y1 = anchor_points - lt + x2y2 = anchor_points + rb + if xywh: + c_xy = (x1y1 + x2y2) / 2 + wh = x2y2 - x1y1 + return ops.concat((c_xy, wh), axis) # xywh bbox + return ops.concat((x1y1, x2y2), axis) # xyxy bbox + + @staticmethod + def postprocess(preds, max_det, nc=80): + """ + Post-processes YOLO model predictions. + + Args: + preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension + format [x, y, w, h, class_probs]. + max_det (int): Maximum detections per image. + nc (int, optional): Number of classes. Default: 80. + + Returns: + (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last + dimension format [x, y, w, h, max_class_prob, class_index]. + """ + batch_size, _, _ = preds.shape # i.e. shape(16,8400,84) + boxes, scores = preds.split([4, nc], axis=-1) + max_scores = ops.amax(scores, axis=-1) + max_scores, index = ops.topk(max_scores, max_det, dim=-1) + index = ops.expand_dims(index, -1) + boxes = ops.gather_elements(boxes, dim=1, index=ops.tile(index, (1, 1, boxes.shape[-1]))) + scores = ops.gather_elements(scores, dim=1, index=ops.tile(index,(1, 1, nc))) + + scores, index = ops.topk(ops.flatten(scores, start_dim=1), max_det, dim=-1) + i = ops.arange(batch_size)[..., None] # batch indices + return ops.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], -1) + + def initialize_biases(self): + # Initialize Detect() biases, WARNING: requires stride availability + m = self + for a, b, s in zip(m.cv2, m.cv3, m.stride): # from + s = s.asnumpy() + a[-1].bias = ops.assign(a[-1].bias, Tensor(np.ones(a[-1].bias.shape), ms.float32)) + b_np = b[-1].bias.data.asnumpy() + b_np[: m.nc] = math.log(5 / m.nc / (640 / int(s)) ** 2) + b[-1].bias = ops.assign(b[-1].bias, Tensor(b_np, ms.float32)) + if self.end2end: + for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from + s = s.asnumpy() + a[-1].bias = ops.assign(a[-1].bias, Tensor(np.ones(a[-1].bias.shape), ms.float32)) + b_np = b[-1].bias.data.asnumpy() + b_np[: m.nc] = math.log(5 / m.nc / (640 / int(s)) ** 2) + b[-1].bias = ops.assign(b[-1].bias, Tensor(b_np, ms.float32)) \ No newline at end of file diff --git a/mindyolo/models/layers/__init__.py b/mindyolo/models/layers/__init__.py index a86b9485..9060d49f 100644 --- a/mindyolo/models/layers/__init__.py +++ b/mindyolo/models/layers/__init__.py @@ -34,4 +34,7 @@ "SPPF", "Upsample", "Residualblock", + "SCDown", + "PSA", + "C2fCIB", ] diff --git a/mindyolo/models/layers/bottleneck.py b/mindyolo/models/layers/bottleneck.py index 4bd9bb81..cf3993df 100644 --- a/mindyolo/models/layers/bottleneck.py +++ b/mindyolo/models/layers/bottleneck.py @@ -67,14 +67,14 @@ def __init__( self, c1, c2, n=1, shortcut=False, g=1, e=0.5, momentum=0.97, eps=1e-3, sync_bn=False ): # ch_in, ch_out, number, shortcut, groups, expansion super().__init__() - _c = int(c2 * e) # hidden channels - self.cv1 = ConvNormAct(c1, 2 * _c, 1, 1, momentum=momentum, eps=eps, sync_bn=sync_bn) + self.c = int(c2 * e) # hidden channels + self.cv1 = ConvNormAct(c1, 2 * self.c, 1, 1, momentum=momentum, eps=eps, sync_bn=sync_bn) self.cv2 = ConvNormAct( - (2 + n) * _c, c2, 1, momentum=momentum, eps=eps, sync_bn=sync_bn + (2 + n) * self.c, c2, 1, momentum=momentum, eps=eps, sync_bn=sync_bn ) # optional act=FReLU(c2) self.m = nn.CellList( [ - Bottleneck(_c, _c, shortcut, k=(3, 3), g=(1, g), e=1.0, momentum=momentum, eps=eps, sync_bn=sync_bn) + Bottleneck(self.c, self.c, shortcut, k=(3, 3), g=(1, g), e=1.0, momentum=momentum, eps=eps, sync_bn=sync_bn) for _ in range(n) ] ) @@ -136,3 +136,110 @@ def construct(self, x): c5 = self.conv3(c4) return c5 + +class SCDown(nn.Cell): + def __init__(self, c1, c2, k, s): + super().__init__() + self.cv1 = ConvNormAct(c1, c2, k=1, s=1) + self.cv2 = ConvNormAct(c2, c2, k=k, s=s, g=c2, act=False) + + def construct(self, x): + return self.cv2(self.cv1(x)) + +class Attention(nn.Cell): + def __init__(self, dim, num_heads=8, attn_ratio=0.5): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.key_dim = int(self.head_dim * attn_ratio) + self.scale = self.key_dim ** -0.5 + nh_kd = self.key_dim * num_heads + h = dim + nh_kd * 2 + self.qkv = ConvNormAct(c1=dim, c2=h, k=1, act=False) + self.proj = ConvNormAct(c1=dim, c2=dim, k=1, act=False) + self.pe = ConvNormAct(c1=dim, c2=dim, k=3, s=1, g=dim, act=False) + + def construct(self, x): + B, C, H, W = x.shape + N = H * W + qkv = self.qkv(x) + q, k, v = qkv.view(B, self.num_heads, self.key_dim*2 + self.head_dim, N).split([self.key_dim, self.key_dim, self.head_dim], axis=2) + + attn = ( + (ops.transpose(q, (0, 1, 3, 2)) @ k) * self.scale + ) + attn = ops.softmax(attn) + x = (v @ ops.transpose(attn, (0, 1, 3, 2))).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W)) + x = self.proj(x) + return x + +class PSA(nn.Cell): + def __init__(self, c1, c2, e=0.5): + super().__init__() + assert(c1 == c2) + self.c = int(c1 * e) + self.cv1 = ConvNormAct(c1, 2 * self.c, 1, 1) + self.cv2 = ConvNormAct(2 * self.c, c1, 1) + + self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64) + self.ffn = nn.SequentialCell( + [ + ConvNormAct(self.c, self.c*2, 1), + ConvNormAct(self.c*2, self.c, 1, act=False) + ] + ) + + def construct(self, x): + a, b = self.cv1(x).split((self.c, self.c), axis=1) + b = b + self.attn(b) + b = b + self.ffn(b) + return self.cv2(ops.concat((a, b), 1)) + +class RepVGGDW(nn.Cell): + def __init__(self, ed): + super().__init__() + self.conv = ConvNormAct(ed, ed, k=7, s=1, p=3, g=ed, act=False) + self.conv1 = ConvNormAct(ed, ed, k=3, s=1, p=1, g=ed, act=False) + self.dim = ed + self.act = nn.SiLU() + + def construct(self, x): + return self.act(self.conv(x) + self.conv1(x)) + +class CIB(nn.Cell): + # Standard bottleneck + def __init__( + self, c1, c2, shortcut=True, e=0.5, lk=False, act=True, momentum=0.97, eps=1e-3, sync_bn=False + ): # ch_in, ch_out, shortcut, kernels, groups, expand + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = nn.SequentialCell( + [ + ConvNormAct(c1, c1, 3, g=c1, act=act, momentum=momentum, eps=eps, sync_bn=sync_bn), + ConvNormAct(c1, 2 * c_, 1, act=act, momentum=momentum, eps=eps, sync_bn=sync_bn), + ConvNormAct(2 * c_, 2 * c_, 3, g=2 * c_, act=act, momentum=momentum, eps=eps, sync_bn=sync_bn) if not lk else RepVGGDW(2 * c_), + ConvNormAct(2 * c_, c2, 1, act=act, momentum=momentum, eps=eps, sync_bn=sync_bn), + ConvNormAct(c2, c2, 3, g=c2, act=act, momentum=momentum, eps=eps, sync_bn=sync_bn), + ] + ) + self.add = shortcut and c1 == c2 + + def construct(self, x): + if self.add: + out = x + self.cv1(x) + else: + out = self.cv1(x) + return out + +class C2fCIB(C2f): + # CSP Bottleneck with 2 convolutions + def __init__( + self, c1, c2, n=1, shortcut=False, lk=False, g=1, e=0.5, momentum=0.97, eps=1e-3, sync_bn=False + ): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__(c1, c2, n, shortcut, g, e, momentum, eps, sync_bn) + self.m = nn.CellList( + [ + CIB(self.c, self.c, shortcut, e=1.0, lk=lk, momentum=momentum, eps=eps, sync_bn=sync_bn) + for _ in range(n) + ] + ) \ No newline at end of file diff --git a/mindyolo/models/layers/conv.py b/mindyolo/models/layers/conv.py index ff801ec3..d95fbd94 100644 --- a/mindyolo/models/layers/conv.py +++ b/mindyolo/models/layers/conv.py @@ -49,7 +49,7 @@ def __init__( self.bn = nn.SyncBatchNorm(c2, momentum=momentum, eps=eps) else: self.bn = nn.BatchNorm2d(c2, momentum=momentum, eps=eps) - self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Cell) else Identity) + self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Cell) else Identity()) def construct(self, x): return self.act(self.bn(self.conv(x))) @@ -96,7 +96,7 @@ def __init__(self, c1, c2, k=3, s=1, p=None, g=1, act=True, momentum=0.97, eps=1 padding_11 = autopad(k, p) - k // 2 - self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Cell) else Identity) + self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Cell) else Identity()) if sync_bn: BatchNorm = nn.SyncBatchNorm diff --git a/mindyolo/models/losses/__init__.py b/mindyolo/models/losses/__init__.py index b54930a1..53877459 100644 --- a/mindyolo/models/losses/__init__.py +++ b/mindyolo/models/losses/__init__.py @@ -1,11 +1,12 @@ from . import (loss_factory, yolov3_loss, yolov4_loss, yolov5_loss, - yolov7_loss, yolov8_loss) + yolov7_loss, yolov8_loss, yolov10_loss) from .loss_factory import * from .yolov3_loss import * from .yolov4_loss import * from .yolov5_loss import * from .yolov7_loss import * from .yolov8_loss import * +from .yolov10_loss import * from .yolox_loss import * __all__ = [] @@ -14,4 +15,5 @@ __all__.extend(yolov5_loss.__all__) __all__.extend(yolov7_loss.__all__) __all__.extend(yolov8_loss.__all__) +__all__.extend(yolov10_loss.__all__) __all__.extend(loss_factory.__all__) diff --git a/mindyolo/models/losses/yolov10_loss.py b/mindyolo/models/losses/yolov10_loss.py new file mode 100644 index 00000000..cb154d4d --- /dev/null +++ b/mindyolo/models/losses/yolov10_loss.py @@ -0,0 +1,22 @@ +from mindspore import nn, ops + +from .yolov8_loss import YOLOv8Loss +from mindyolo.models.registry import register_model + +__all__ = ["YOLOv10Loss"] + +@register_model +class YOLOv10Loss(nn.Cell): + def __init__(self, box, cls, dfl, stride, nc, reg_max=16, tal_topk=10, **kwargs): + super().__init__() + self.one2many = YOLOv8Loss(box, cls, dfl, stride, nc, reg_max, tal_topk=10) + self.one2one = YOLOv8Loss(box, cls, dfl, stride, nc, reg_max, tal_topk=1) + # branch name returned by lossitem for print + self.loss_item_name = ["loss", "lbox", "lcls", "dfl"] + + def construct(self, preds, batch, imgs): + one2many = preds[0] + loss_one2many = self.one2many(one2many, batch, imgs) + one2one = preds[1] + loss_one2one = self.one2one(one2one, batch, imgs) + return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1] \ No newline at end of file diff --git a/mindyolo/models/losses/yolov8_loss.py b/mindyolo/models/losses/yolov8_loss.py index 6e85efe6..4e79abf9 100644 --- a/mindyolo/models/losses/yolov8_loss.py +++ b/mindyolo/models/losses/yolov8_loss.py @@ -14,7 +14,7 @@ @register_model class YOLOv8Loss(nn.Cell): - def __init__(self, box, cls, dfl, stride, nc, reg_max=16, **kwargs): + def __init__(self, box, cls, dfl, stride, nc, reg_max=16, tal_topk=10, **kwargs): super(YOLOv8Loss, self).__init__() self.bce = nn.BCEWithLogitsLoss(reduction="none") @@ -27,7 +27,7 @@ def __init__(self, box, cls, dfl, stride, nc, reg_max=16, **kwargs): self.reg_max = reg_max self.use_dfl = reg_max > 1 - self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) + self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0) self.bbox_loss = BboxLoss(reg_max, use_dfl=self.use_dfl) self.proj = mnp.arange(reg_max) diff --git a/mindyolo/models/model_factory.py b/mindyolo/models/model_factory.py index 5d7c5751..61424c8a 100644 --- a/mindyolo/models/model_factory.py +++ b/mindyolo/models/model_factory.py @@ -150,6 +150,9 @@ def parse_model(d, ch, nc, sync_bn=False): # model_dict, input_channels(3) DWConvNormAct, DWBottleneck, DWC3, + SCDown, + PSA, + C2fCIB, ): c1, c2 = ch[f], args[0] if max_channels: @@ -172,7 +175,7 @@ def parse_model(d, ch, nc, sync_bn=False): # model_dict, input_channels(3) DWC3, ): kwargs["sync_bn"] = sync_bn - if m in (DownC, SPPCSPC, C3, C2f, DWC3): + if m in (DownC, SPPCSPC, C3, C2f, DWC3, C2fCIB): args.insert(2, n) # number of repeats n = 1 elif m in (nn.BatchNorm2d, nn.SyncBatchNorm): @@ -185,7 +188,7 @@ def parse_model(d, ch, nc, sync_bn=False): # model_dict, input_channels(3) args.append([ch[x] for x in f]) if isinstance(args[1], int): # number of anchors args[1] = [list(range(args[1] * 2))] * len(f) - elif m in (YOLOv8Head, YOLOv8SegHead, YOLOXHead): # head of anchor free + elif m in (YOLOv10Head, YOLOv8Head, YOLOv8SegHead, YOLOXHead): # head of anchor free args.append([ch[x] for x in f]) if m in (YOLOv8SegHead,): args[3] = math.ceil(min(args[3], max_channels) * gw / 8) * 8 diff --git a/mindyolo/models/yolov10.py b/mindyolo/models/yolov10.py new file mode 100644 index 00000000..10c1b3f1 --- /dev/null +++ b/mindyolo/models/yolov10.py @@ -0,0 +1,49 @@ +import numpy as np + +import mindspore as ms +from mindspore import Tensor, nn + +from mindyolo.models.heads.yolov10_head import YOLOv10Head +from mindyolo.models.model_factory import build_model_from_cfg +from mindyolo.models.registry import register_model + +__all__ = ["YOLOv10", "yolov10"] + + +def _cfg(url="", **kwargs): + return {"url": url, **kwargs} + + +default_cfgs = {"yolov10": _cfg(url="")} + + +class YOLOv10(nn.Cell): + def __init__(self, cfg, in_channels=3, num_classes=None, sync_bn=False): + super(YOLOv10, self).__init__() + self.cfg = cfg + self.stride = Tensor(np.array(cfg.stride), ms.int32) + self.stride_max = int(max(self.cfg.stride)) + ch, nc = in_channels, num_classes + + self.nc = nc # override yaml value + self.model = build_model_from_cfg(model_cfg=cfg, in_channels=ch, num_classes=nc, sync_bn=sync_bn) + self.names = [str(i) for i in range(nc)] # default names + + self.initialize_weights() + + def construct(self, x): + return self.model(x) + + def initialize_weights(self): + # reset parameter for Detect Head + m = self.model.model[-1] + if isinstance(m, YOLOv10Head): + m.initialize_biases() + m.dfl.initialize_conv_weight() + + +@register_model +def yolov10(cfg, in_channels=3, num_classes=None, **kwargs) -> YOLOv10: + """Get yolov10 model.""" + model = YOLOv10(cfg=cfg, in_channels=in_channels, num_classes=num_classes, **kwargs) + return model