Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[On-device Training] Yolo custom loss #19464

Open
Marouan-st opened this issue Feb 8, 2024 Discussed in #19390 · 10 comments
Open

[On-device Training] Yolo custom loss #19464

Marouan-st opened this issue Feb 8, 2024 Discussed in #19390 · 10 comments
Labels
training issues related to ONNX Runtime training; typically submitted using template

Comments

@Marouan-st
Copy link

Discussed in #19390

Originally posted by Marouan-st February 2, 2024
Hello,

I would like to implement a custom loss to be able to train on-device a yolov4-tiny model for object detection.

To compute the loss some post-processing must be performed on the output of the model, like computing bboxes iou and sum several losses (class loss + confidence loss + iou loss: cross entropy losses): see https://www.nature.com/articles/s41598-021-02225-y/figures/3

I don't see how to implement all these needed computations in the custom loss, especially how to provide the different losses with the post-processed input, since onnx loss functions takes as input String arguments (input name).

I'm using a yolov4-tiny model compiled from darknet and converted to onnx from a tensorflow implementation of the model.

The Torch implementation of this loss function (for the model i'm using) would look like this (inspired by this yolov4 loss tensorflow implementation):

def compute_loss(pred, conv, label, bboxes, STRIDES=[16, 32], NUM_CLASS=1, IOU_LOSS_THRESH=0.5, i=0):
      conv_shape  = conv.size()
      batch_size  = conv_shape[0]
      output_size = conv_shape[1]
      input_size  = STRIDES[i] * output_size
      conv = torch.reshape(conv, (batch_size, output_size, output_size, 3, 5 + NUM_CLASS))

      conv_raw_conf = conv[:, :, :, :, 4:5]
      conv_raw_prob = conv[:, :, :, :, 5:]

      pred_xywh     = pred[:, :, :, :, 0:4]
      pred_conf     = pred[:, :, :, :, 4:5]

      label_xywh    = label[:, :, :, :, 0:4]
      respond_bbox  = label[:, :, :, :, 4:5]
      label_prob    = label[:, :, :, :, 5:]

      giou = torch.unsqueeze(bbox_giou(pred_xywh, label_xywh), 0) # Here not sure...
      input_size = input_size.to(torch.float32)

      bbox_loss_scale = 2.0 - 1.0 * label_xywh[:, :, :, :, 2:3] * label_xywh[:, :, :, :, 3:4] / (input_size ** 2)
      giou_loss = respond_bbox * bbox_loss_scale * (1- giou)

      iou = bbox_iou(pred_xywh[:, :, :, :, np.newaxis, :], bboxes[:, np.newaxis, np.newaxis, np.newaxis, :, :])
      max_iou = torch.unsqueeze(torch.max(iou), 0)

      respond_bgd = (1.0 - respond_bbox) * (max_iou < IOU_LOSS_THRESH).to(torch.float32)

      conf_focal = torch.pow(respond_bbox - pred_conf, 2)

      conf_loss = conf_focal * (
              respond_bbox * torch.nn.functional.cross_entropy(input=conv_raw_conf, target=respond_bbox)
              +
              respond_bgd * torch.nn.functional.cross_entropy(input=conv_raw_conf, target=respond_bbox)
      )

      prob_loss = respond_bbox * torch.nn.functional.cross_entropy(input=conv_raw_prob, target=label_prob)
      giou_loss = torch.mean(torch.sum(giou_loss))
      conf_loss = torch.mean(torch.sum(conf_loss, axis=[1,2,3,4]))
      prob_loss = torch.mean(torch.sum(prob_loss, axis=[1,2,3,4]))

      return giou_loss + conf_loss + prob_loss

def bbox_iou(bboxes1, bboxes2):
      """
      @param bboxes1: (a, b, ..., 4)
      @param bboxes2: (A, B, ..., 4)
          x:X is 1:n or n:n or n:1
      @return (max(a,A), max(b,B), ...)
      ex) (4,):(3,4) -> (3,)
          (2,1,4):(2,3,4) -> (2,3)
      """
      bboxes1_area = bboxes1[..., 2] * bboxes1[..., 3]
      bboxes2_area = bboxes2[..., 2] * bboxes2[..., 3]

      bboxes1_coor = torch.concat(
          [
              bboxes1[..., :2] - bboxes1[..., 2:] * 0.5,
              bboxes1[..., :2] + bboxes1[..., 2:] * 0.5,
          ],
          axis=-1,
      )
      bboxes2_coor = torch.concat(
          [
              bboxes2[..., :2] - bboxes2[..., 2:] * 0.5,
              bboxes2[..., :2] + bboxes2[..., 2:] * 0.5,
          ],
          axis=-1,
      )

      left_up = torch.maximum(bboxes1_coor[..., :2], bboxes2_coor[..., :2])
      right_down = torch.minimum(bboxes1_coor[..., 2:], bboxes2_coor[..., 2:])

      inter_section = torch.maximum(right_down - left_up, 0.0)
      inter_area = inter_section[..., 0] * inter_section[..., 1]

      union_area = bboxes1_area + bboxes2_area - inter_area

      iou = torch.div(inter_area, union_area)

      return iou

def bbox_giou(bboxes1, bboxes2):
      """
      Generalized IoU
      @param bboxes1: (a, b, ..., 4)
      @param bboxes2: (A, B, ..., 4)
          x:X is 1:n or n:n or n:1
      @return (max(a,A), max(b,B), ...)
      ex) (4,):(3,4) -> (3,)
          (2,1,4):(2,3,4) -> (2,3)
      """
      bboxes1_area = bboxes1[..., 2] * bboxes1[..., 3]
      bboxes2_area = bboxes2[..., 2] * bboxes2[..., 3]
  
      bboxes1_coor = torch.concat(
          [
              bboxes1[..., :2] - bboxes1[..., 2:] * 0.5,
              bboxes1[..., :2] + bboxes1[..., 2:] * 0.5,
          ],
          axis=-1,
      )
      bboxes2_coor = torch.concat(
          [
              bboxes2[..., :2] - bboxes2[..., 2:] * 0.5,
              bboxes2[..., :2] + bboxes2[..., 2:] * 0.5,
          ],
          axis=-1,
      )
  
      left_up = torch.maximum(bboxes1_coor[..., :2], bboxes2_coor[..., :2])
      right_down = torch.minimum(bboxes1_coor[..., 2:], bboxes2_coor[..., 2:])
  
      inter_section = torch.maximum(right_down - left_up, 0.0)
      inter_area = inter_section[..., 0] * inter_section[..., 1]
  
      union_area = bboxes1_area + bboxes2_area - inter_area
  
      iou = torch.div(inter_area, union_area)
  
      enclose_left_up = torch.minimum(bboxes1_coor[..., :2], bboxes2_coor[..., :2])
      enclose_right_down = torch.maximum(
          bboxes1_coor[..., 2:], bboxes2_coor[..., 2:]
      )
  
      enclose_section = enclose_right_down - enclose_left_up
      enclose_area = enclose_section[..., 0] * enclose_section[..., 1]
  
      giou = iou - torch.div(enclose_area - union_area, enclose_area)
  
      return giou

Any suggestions?

Thank you

@OAHLSTM
Copy link

OAHLSTM commented Feb 15, 2024

Hello, I'm working on a similar topic aiming to retrain a fine tuned YoloV8 on the device using Onnxruntime training API, and I'm kinda struggling to define the loss functions as Onnxblocks, @baijumeswani any help here would be appreciated ?

Thank you.

@baijumeswani
Copy link
Contributor

Hi there. I provided some suggestions here: #19464.

The idea being, if the loss is difficult to express in onnxblock, you could try to create an onnx model from pytorch that contains the loss embedded inside it.

class MyPTModelWithLoss:
    def __init__(self):
         ...

    def forward(self, ...):
        p, q, r = compute_logits()
        loss = loss1(p) + loss2(q) + loss3(r)
        return loss

pt_model = MyPTModelWithLoss(...)
torch.onnx.export(pt_model, ...)

onnx_model = onnx.load(<exported_onnx_model_path>)
artifacts.generate_artifacts(onnx_model, requires_grad=[...], frozen_params=[...], loss=None, optimizer=...)

This might become more complex if you already have the onnx model and do not have the access to the pytorch model to add the loss function to. In that case, we can try to support your scenario with onnxblock. So, if this is where you are, please share your loss function, and I'll try to make onnxblock support that scenario.

@sophies927 sophies927 added the training issues related to ONNX Runtime training; typically submitted using template label Feb 22, 2024
@Marouan-st
Copy link
Author

Hi,

I do have access to the YOLOv8n torch model from ultralytics. (ultralytics doc)

I tried to include the loss computation into my model and export it to onnx as follows:

from ultralytics import YOLO
import torch

# Load a model (with pretrained weights)
model = YOLO("yolov8n.pt") 

class YOLOv8nWithLoss(torch.nn.Module):
    def __init__(self, yolov8_model):
        super(YOLOv8nWithLoss, self).__init__()
        self.model = yolov8_model

    def forward(self, batch, targets):
        outputs = self.model.model(batch)
        loss = self.model.model.loss(targets, outputs)
        return loss
    
model_with_loss = YOLOv8nWithLoss(model)
model_with_loss.model.train()

# Export the model to ONNX.
model_name = "yolov8n_with_loss_eval_mode"


# Use opset_version < 18, otherwise ReduceMin error
torch.onnx.export(model_with_loss, (torch.randn(1, 3, 640, 640), torch.Tensor([[0,1,0.85,0.45,0.57,0.98]])),
                  f"training_artifacts/{model_name}.onnx",
                  input_names=["images", "targets"], output_names=["loss"],
                  dynamic_axes={"images": {0: "batch", 2: "height", 3: "width"},
                                "targets": {0: "batch"},
                                "loss": {0: "batch", 2: "anchors"}}, training=torch.onnx.TrainingMode.PRESERVE, opset_version=17)

Loss implementation is available here (cf class v8DetectionLoss)

The export goes well (onnx graph yolov8n_with_loss_train_mode.zip) but with the following warning:

/venv/object_detection/lib/python3.10/site-packages/torch/onnx/utils.py:1686](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/venv/object_detection/lib/python3.10/site-packages/torch/onnx/utils.py:1686): UserWarning: The exported ONNX model failed ONNX shape inference. The model will not be executable by the ONNX Runtime. If this is unintended and you believe there is a bug, please report an issue at https://github.com/pytorch/pytorch/issues. Error reported by strict ONNX shape inference: [ShapeInferenceError] (op_type:Concat, node name: [/Concat_21](https://file+.vscode-resource.vscode-cdn.net/Concat_21)): inputs has inconsistent type tensor(int32) (Triggered internally at ../torch/csrc/jit/serialization/export.cpp:1415.)
  _C._check_onnx_proto(proto)

Then, when I try to generate the artifacts I get the following error:

InferenceError                            Traceback (most recent call last)
Cell In[9], [line 14](vscode-notebook-cell:?execution_count=9&line=14)
      [6](vscode-notebook-cell:?execution_count=9&line=6) frozen_params = [
      [7](vscode-notebook-cell:?execution_count=9&line=7)    param.name
      [8](vscode-notebook-cell:?execution_count=9&line=8)    for param in onnx_model.graph.initializer
      [9](vscode-notebook-cell:?execution_count=9&line=9)    if param.name not in requires_grad
     [10](vscode-notebook-cell:?execution_count=9&line=10) ]
     [13](vscode-notebook-cell:?execution_count=9&line=13) # Generate the training artifacts.
---> [14](vscode-notebook-cell:?execution_count=9&line=14) artifacts.generate_artifacts(
     [15](vscode-notebook-cell:?execution_count=9&line=15)    onnx_model,
     [16](vscode-notebook-cell:?execution_count=9&line=16)    requires_grad=requires_grad,
     [17](vscode-notebook-cell:?execution_count=9&line=17)    frozen_params=frozen_params,
     [18](vscode-notebook-cell:?execution_count=9&line=18)    loss=None,
     [19](vscode-notebook-cell:?execution_count=9&line=19)    optimizer=artifacts.OptimType.AdamW,
     [20](vscode-notebook-cell:?execution_count=9&line=20)    artifact_directory="training_artifacts"
     [21](vscode-notebook-cell:?execution_count=9&line=21) )

File [~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:154](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:154), in generate_artifacts(model, requires_grad, frozen_params, loss, optimizer, artifact_directory, **extra_options)
    [149](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:149)     custom_op_library = pathlib.Path(custom_op_library)
    [151](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:151) with onnxblock.base(model), onnxblock.custom_op_library(
    [152](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:152)     custom_op_library
    [153](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:153) ) if custom_op_library is not None else contextlib.nullcontext():
--> [154](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:154)     _ = training_block(*[output.name for output in model.graph.output])
    [155](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:155)     training_model, eval_model = training_block.to_model_proto()
    [156](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:156)     model_params = training_block.parameters()

File [~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:188](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:188), in TrainingBlock.__call__(self, *args, **kwargs)
    [184](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:184) self.base = accessor._GLOBAL_ACCESSOR.model
    [186](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:186) logging.debug("Building training block %s", self.__class__.__name__)
--> [188](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:188) output = self.build(*args, **kwargs)
    [190](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:190) model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model)
    [192](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:192) _graph_utils.register_graph_outputs(model, output)

File [~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:124](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:124), in generate_artifacts.<locals>._TrainingBlock.build(self, *inputs_to_loss)
    [121](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:121)     else:
    [122](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:122)         return (loss_output, *tuple(extra_options["additional_output_names"]))
--> [124](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:124) return self._loss(*inputs_to_loss)

File [~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/blocks.py:50](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/blocks.py:50), in Block.__call__(self, *args, **kwargs)
     [46](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/blocks.py:46) logging.debug("Building block: %s", self.__class__.__name__)
     [48](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/blocks.py:48) output = self.build(*args, **kwargs)
---> [50](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/blocks.py:50) onnx.checker.check_model(self.base, True)
     [52](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/blocks.py:52) return output

File [~/venv/object_detection/lib/python3.10/site-packages/onnx/checker.py:148](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnx/checker.py:148), in check_model(model, full_check, skip_opset_compatibility_check)
    [144](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnx/checker.py:144) if sys.getsizeof(protobuf_string) > MAXIMUM_PROTOBUF:
    [145](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnx/checker.py:145)     raise ValueError(
    [146](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnx/checker.py:146)         "This protobuf of onnx model is too large (>2GB). Call check_model with model path instead."
    [147](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnx/checker.py:147)     )
--> [148](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnx/checker.py:148) C.check_model(protobuf_string, full_check, skip_opset_compatibility_check)

InferenceError: [ShapeInferenceError] (op_type:Concat, node name: [/Concat_21](https://file+.vscode-resource.vscode-cdn.net/Concat_21)): inputs has inconsistent type tensor(int32)

@baijumeswani Any idea of what could cause this error?

Thank you

@baijumeswani
Copy link
Contributor

Looking at your model after doing shape inferencing on it, I see the concat node like so:

image

The concat node is trying to concat tensors of different types (int64 and int32) and this will fail with onnxruntime. All the types being concatenated need to be the same. You can try to add a cast node to cast the int32 to int64 and see how that goes.

@Marouan-st
Copy link
Author

Marouan-st commented Mar 15, 2024

I found the int32 tensor and changed its type to int64 and I don't get the ShapeInference error anymore, thanks.

Now I get another error when trying to generate the artifacts:

RuntimeError                              Traceback (most recent call last)
Cell In[16], [line 15](vscode-notebook-cell:?execution_count=16&line=15)
      [7](vscode-notebook-cell:?execution_count=16&line=7) frozen_params = [
      [8](vscode-notebook-cell:?execution_count=16&line=8)    param.name
      [9](vscode-notebook-cell:?execution_count=16&line=9)    for param in onnx_model.graph.initializer
     [10](vscode-notebook-cell:?execution_count=16&line=10)    if param.name not in requires_grad
     [11](vscode-notebook-cell:?execution_count=16&line=11) ]
     [14](vscode-notebook-cell:?execution_count=16&line=14) # Generate the training artifacts.
---> [15](vscode-notebook-cell:?execution_count=16&line=15) artifacts.generate_artifacts(
     [16](vscode-notebook-cell:?execution_count=16&line=16)    onnx_model,
     [17](vscode-notebook-cell:?execution_count=16&line=17)    requires_grad=requires_grad,
     [18](vscode-notebook-cell:?execution_count=16&line=18)    frozen_params=frozen_params,
     [19](vscode-notebook-cell:?execution_count=16&line=19)    loss=None,
     [20](vscode-notebook-cell:?execution_count=16&line=20)    optimizer=artifacts.OptimType.AdamW,
     [21](vscode-notebook-cell:?execution_count=16&line=21)    artifact_directory="training_artifacts"
     [22](vscode-notebook-cell:?execution_count=16&line=22) )

File [~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:154](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:154), in generate_artifacts(model, requires_grad, frozen_params, loss, optimizer, artifact_directory, **extra_options)
    [149](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:149)     custom_op_library = pathlib.Path(custom_op_library)
    [151](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:151) with onnxblock.base(model), onnxblock.custom_op_library(
    [152](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:152)     custom_op_library
    [153](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:153) ) if custom_op_library is not None else contextlib.nullcontext():
--> [154](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:154)     _ = training_block(*[output.name for output in model.graph.output])
    [155](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:155)     training_model, eval_model = training_block.to_model_proto()
    [156](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:156)     model_params = training_block.parameters()

File [~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:204](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:204), in TrainingBlock.__call__(self, *args, **kwargs)
    [196](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:196) self._parameters = _training_graph_utils.get_model_parameters(model, self._requires_grad, self._frozen_params)
    [198](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:198) # Build the gradient graph. The gradient graph building is composed of the following steps:
    [199](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:199) #   - Move all model parameters to model inputs.
    [200](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:200) #   - Run orttraining graph transformers on the model.
    [201](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:201) #   - Add the gradient graph to the optimized model.
    [202](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:202) # The order of model inputs after gradient graph building is: user inputs, model parameters as inputs
    [203](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:203) # The order of the model outputs is: user outputs, model parameter gradients (in the order of parameter inputs)
--> [204](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:204) self._training_model, self._eval_model = _training_graph_utils.build_gradient_graph(
    [205](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:205)     model, self._requires_grad, self._frozen_params, output, accessor._GLOBAL_CUSTOM_OP_LIBRARY
    [206](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:206) )
    [208](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:208) logging.debug("Adding gradient accumulation nodes for training block %s", self.__class__.__name__)
    [210](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:210) _training_graph_utils.build_gradient_accumulation_graph(self._training_model, self._requires_grad)

File [~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:127](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:127), in build_gradient_graph(model, requires_grad, frozen_params, output_names, custom_op_library)
    [124](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:124) if custom_op_library is not None:
    [125](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:125)     options.register_custom_ops_library(os.fspath(custom_op_library))
--> [127](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:127) optimized_model = onnx.load_from_string(get_optimized_model(model.SerializeToString(), requires_grad, options))
    [129](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:129) # Assumption is that the first graph output is the loss output
    [130](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:130) gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names[0], options)

RuntimeError: [/local/home/user/tools/onnxruntime/orttraining/orttraining/python/orttraining_pybind_state.cc:1010](https://file+.vscode-resource.vscode-cdn.net/local/home/user/tools/onnxruntime/orttraining/orttraining/python/orttraining_pybind_state.cc:1010) onnxruntime::python::addObjectMethodsForTraining(pybind11::module&)::<lambda(const pybind11::bytes&, const std::unordered_set<std::__cxx11::basic_string<char> >&, onnxruntime::python::PySessionOptions*)> [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (/Reshape_52_output_0_log_prob).

Here's the corrected graph yolov8n_with_loss_train_mode_cast.zip and how i generate the artifacts:

# Load the onnx model.
model_name = "yolov8n_with_loss_train_mode_cast"
onnx_model = onnx.load(f"{model_name}.onnx")

requires_grad = ["model.model.model.22.cv3.2.2.weight", "model.model.model.22.cv3.2.2.bias"]
frozen_params = [
   param.name
   for param in onnx_model.graph.initializer
   if param.name not in requires_grad
]


# Generate the training artifacts.
artifacts.generate_artifacts(
   onnx_model,
   requires_grad=requires_grad,
   frozen_params=frozen_params,
   loss=None,
   optimizer=artifacts.OptimType.AdamW,
   artifact_directory="training_artifacts"
)

@baijumeswani
Copy link
Contributor

^ Seems like a bug. I will add a pull-request to address this issue.

@Marouan-st Marouan-st changed the title [On-device Training] Yolov4 custom loss [On-device Training] Yolo custom loss Mar 19, 2024
@Marouan-st
Copy link
Author

^ Seems like a bug. I will add a pull-request to address this issue.

Hello @baijumeswani, any update regarding this bug?

@baijumeswani
Copy link
Contributor

I addressed the issue you highlighted here: #20016

However, there is still another problem that is that the model has a ReduceMax node. ORT training does not have a gradient kernel for the ReduceMax node defined yet. And so the gradient graph building fails.

@Marouan-st
Copy link
Author

Marouan-st commented Mar 22, 2024

Ok, thank you for your support.
Do you know how I could replace these ReduceMax nodes by supported operations for training?
Also, there are ReduceMin nodes in the graph, is there a gradient kernel for these nodes?

@Marouan-st
Copy link
Author

Hello @baijumeswani,
ReduceMax and ReduceMin operations are only used in the loss computation, so the gradient is not really required for these operations.
I'm following your suggested approach based on creating an onnx model from pytorch that contains the loss embedded inside it. I have two questions:

1- The model I provided contains a forward graph + loss computation, so i'm wondering if there is any way to build the gradient graph only for the forward part of the model?

2- In that case, how are we supposed to feed the loss function to the loss argument of the generate_artifacts function?

Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

No branches or pull requests

4 participants