-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[On-device Training] Yolo custom loss #19464
Comments
Hello, I'm working on a similar topic aiming to retrain a fine tuned YoloV8 on the device using Onnxruntime training API, and I'm kinda struggling to define the loss functions as Onnxblocks, @baijumeswani any help here would be appreciated ? Thank you. |
Hi there. I provided some suggestions here: #19464. The idea being, if the loss is difficult to express in onnxblock, you could try to create an onnx model from pytorch that contains the loss embedded inside it. class MyPTModelWithLoss:
def __init__(self):
...
def forward(self, ...):
p, q, r = compute_logits()
loss = loss1(p) + loss2(q) + loss3(r)
return loss
pt_model = MyPTModelWithLoss(...)
torch.onnx.export(pt_model, ...)
onnx_model = onnx.load(<exported_onnx_model_path>)
artifacts.generate_artifacts(onnx_model, requires_grad=[...], frozen_params=[...], loss=None, optimizer=...) This might become more complex if you already have the onnx model and do not have the access to the pytorch model to add the loss function to. In that case, we can try to support your scenario with onnxblock. So, if this is where you are, please share your loss function, and I'll try to make onnxblock support that scenario. |
Hi, I do have access to the YOLOv8n torch model from ultralytics. (ultralytics doc) I tried to include the loss computation into my model and export it to onnx as follows: from ultralytics import YOLO
import torch
# Load a model (with pretrained weights)
model = YOLO("yolov8n.pt")
class YOLOv8nWithLoss(torch.nn.Module):
def __init__(self, yolov8_model):
super(YOLOv8nWithLoss, self).__init__()
self.model = yolov8_model
def forward(self, batch, targets):
outputs = self.model.model(batch)
loss = self.model.model.loss(targets, outputs)
return loss
model_with_loss = YOLOv8nWithLoss(model)
model_with_loss.model.train()
# Export the model to ONNX.
model_name = "yolov8n_with_loss_eval_mode"
# Use opset_version < 18, otherwise ReduceMin error
torch.onnx.export(model_with_loss, (torch.randn(1, 3, 640, 640), torch.Tensor([[0,1,0.85,0.45,0.57,0.98]])),
f"training_artifacts/{model_name}.onnx",
input_names=["images", "targets"], output_names=["loss"],
dynamic_axes={"images": {0: "batch", 2: "height", 3: "width"},
"targets": {0: "batch"},
"loss": {0: "batch", 2: "anchors"}}, training=torch.onnx.TrainingMode.PRESERVE, opset_version=17) Loss implementation is available here (cf class v8DetectionLoss) The export goes well (onnx graph yolov8n_with_loss_train_mode.zip) but with the following warning:
Then, when I try to generate the artifacts I get the following error:
@baijumeswani Any idea of what could cause this error? Thank you |
Looking at your model after doing shape inferencing on it, I see the concat node like so: The concat node is trying to concat tensors of different types (int64 and int32) and this will fail with onnxruntime. All the types being concatenated need to be the same. You can try to add a cast node to cast the int32 to int64 and see how that goes. |
I found the int32 tensor and changed its type to int64 and I don't get the ShapeInference error anymore, thanks. Now I get another error when trying to generate the artifacts: RuntimeError Traceback (most recent call last)
Cell In[16], [line 15](vscode-notebook-cell:?execution_count=16&line=15)
[7](vscode-notebook-cell:?execution_count=16&line=7) frozen_params = [
[8](vscode-notebook-cell:?execution_count=16&line=8) param.name
[9](vscode-notebook-cell:?execution_count=16&line=9) for param in onnx_model.graph.initializer
[10](vscode-notebook-cell:?execution_count=16&line=10) if param.name not in requires_grad
[11](vscode-notebook-cell:?execution_count=16&line=11) ]
[14](vscode-notebook-cell:?execution_count=16&line=14) # Generate the training artifacts.
---> [15](vscode-notebook-cell:?execution_count=16&line=15) artifacts.generate_artifacts(
[16](vscode-notebook-cell:?execution_count=16&line=16) onnx_model,
[17](vscode-notebook-cell:?execution_count=16&line=17) requires_grad=requires_grad,
[18](vscode-notebook-cell:?execution_count=16&line=18) frozen_params=frozen_params,
[19](vscode-notebook-cell:?execution_count=16&line=19) loss=None,
[20](vscode-notebook-cell:?execution_count=16&line=20) optimizer=artifacts.OptimType.AdamW,
[21](vscode-notebook-cell:?execution_count=16&line=21) artifact_directory="training_artifacts"
[22](vscode-notebook-cell:?execution_count=16&line=22) )
File [~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:154](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:154), in generate_artifacts(model, requires_grad, frozen_params, loss, optimizer, artifact_directory, **extra_options)
[149](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:149) custom_op_library = pathlib.Path(custom_op_library)
[151](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:151) with onnxblock.base(model), onnxblock.custom_op_library(
[152](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:152) custom_op_library
[153](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:153) ) if custom_op_library is not None else contextlib.nullcontext():
--> [154](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:154) _ = training_block(*[output.name for output in model.graph.output])
[155](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:155) training_model, eval_model = training_block.to_model_proto()
[156](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:156) model_params = training_block.parameters()
File [~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:204](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:204), in TrainingBlock.__call__(self, *args, **kwargs)
[196](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:196) self._parameters = _training_graph_utils.get_model_parameters(model, self._requires_grad, self._frozen_params)
[198](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:198) # Build the gradient graph. The gradient graph building is composed of the following steps:
[199](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:199) # - Move all model parameters to model inputs.
[200](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:200) # - Run orttraining graph transformers on the model.
[201](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:201) # - Add the gradient graph to the optimized model.
[202](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:202) # The order of model inputs after gradient graph building is: user inputs, model parameters as inputs
[203](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:203) # The order of the model outputs is: user outputs, model parameter gradients (in the order of parameter inputs)
--> [204](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:204) self._training_model, self._eval_model = _training_graph_utils.build_gradient_graph(
[205](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:205) model, self._requires_grad, self._frozen_params, output, accessor._GLOBAL_CUSTOM_OP_LIBRARY
[206](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:206) )
[208](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:208) logging.debug("Adding gradient accumulation nodes for training block %s", self.__class__.__name__)
[210](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:210) _training_graph_utils.build_gradient_accumulation_graph(self._training_model, self._requires_grad)
File [~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:127](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:127), in build_gradient_graph(model, requires_grad, frozen_params, output_names, custom_op_library)
[124](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:124) if custom_op_library is not None:
[125](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:125) options.register_custom_ops_library(os.fspath(custom_op_library))
--> [127](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:127) optimized_model = onnx.load_from_string(get_optimized_model(model.SerializeToString(), requires_grad, options))
[129](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:129) # Assumption is that the first graph output is the loss output
[130](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:130) gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names[0], options)
RuntimeError: [/local/home/user/tools/onnxruntime/orttraining/orttraining/python/orttraining_pybind_state.cc:1010](https://file+.vscode-resource.vscode-cdn.net/local/home/user/tools/onnxruntime/orttraining/orttraining/python/orttraining_pybind_state.cc:1010) onnxruntime::python::addObjectMethodsForTraining(pybind11::module&)::<lambda(const pybind11::bytes&, const std::unordered_set<std::__cxx11::basic_string<char> >&, onnxruntime::python::PySessionOptions*)> [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (/Reshape_52_output_0_log_prob). Here's the corrected graph yolov8n_with_loss_train_mode_cast.zip and how i generate the artifacts: # Load the onnx model.
model_name = "yolov8n_with_loss_train_mode_cast"
onnx_model = onnx.load(f"{model_name}.onnx")
requires_grad = ["model.model.model.22.cv3.2.2.weight", "model.model.model.22.cv3.2.2.bias"]
frozen_params = [
param.name
for param in onnx_model.graph.initializer
if param.name not in requires_grad
]
# Generate the training artifacts.
artifacts.generate_artifacts(
onnx_model,
requires_grad=requires_grad,
frozen_params=frozen_params,
loss=None,
optimizer=artifacts.OptimType.AdamW,
artifact_directory="training_artifacts"
) |
^ Seems like a bug. I will add a pull-request to address this issue. |
Hello @baijumeswani, any update regarding this bug? |
I addressed the issue you highlighted here: #20016 However, there is still another problem that is that the model has a ReduceMax node. ORT training does not have a gradient kernel for the ReduceMax node defined yet. And so the gradient graph building fails. |
Ok, thank you for your support. |
Hello @baijumeswani, 1- The model I provided contains a forward graph + loss computation, so i'm wondering if there is any way to build the gradient graph only for the forward part of the model? 2- In that case, how are we supposed to feed the loss function to the loss argument of the generate_artifacts function? Thank you |
Discussed in #19390
Originally posted by Marouan-st February 2, 2024
Hello,
I would like to implement a custom loss to be able to train on-device a yolov4-tiny model for object detection.
To compute the loss some post-processing must be performed on the output of the model, like computing bboxes iou and sum several losses (class loss + confidence loss + iou loss: cross entropy losses): see https://www.nature.com/articles/s41598-021-02225-y/figures/3
I don't see how to implement all these needed computations in the custom loss, especially how to provide the different losses with the post-processed input, since onnx loss functions takes as input String arguments (input name).
I'm using a yolov4-tiny model compiled from darknet and converted to onnx from a tensorflow implementation of the model.
The Torch implementation of this loss function (for the model i'm using) would look like this (inspired by this yolov4 loss tensorflow implementation):
Any suggestions?
Thank you
The text was updated successfully, but these errors were encountered: