Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Training] This training fails for Gemm node #18344

Closed
elephantpanda opened this issue Nov 8, 2023 · 9 comments
Closed

[Training] This training fails for Gemm node #18344

elephantpanda opened this issue Nov 8, 2023 · 9 comments
Labels
ep:CUDA issues related to the CUDA execution provider training issues related to ONNX Runtime training; typically submitted using template

Comments

@elephantpanda
Copy link

Describe the issue

Trying to create a training onnx from this:
https://www.dropbox.com/scl/fi/uveuro68v8d2epb2awijt/Walker2.onnx?rlkey=xx8oxntdo5dh8c3lhyby3bx9q&dl=0

Using python:

from onnxruntime.training import artifacts
import onnxruntime.training.onnxblock as onnxblock
import onnx
import torch


model = onnx.load("Walker2.onnx")

path_to_output_artifact_directory="output"

requires_grad = []
frozen_params = []

graph = model.graph

for initializer in graph.initializer:
    print(initializer.name +" "+str(initializer.data_type))
    if initializer.data_type==1: ##assume all float32 are trainable
        requires_grad.append(initializer.name)
        print("**trainable**")
    else:
        pass

# Generate the training artifacts


class WeightedAverageLoss(onnxblock.Block):
    def __init__(self):
        self._loss1 = onnxblock.loss.MSELoss()
        self._loss2 = onnxblock.loss.MSELoss()
        self._w1 = onnxblock.blocks.Constant(0.4)
        self._w2 = onnxblock.blocks.Constant(0.6)
        self._add = onnxblock.blocks.Add()
        self._mul = onnxblock.blocks.Mul()

    def build(self, a, b):
        return self._add(
            self._mul(self._w1(), self._loss1(a, target_name="target1")),
            self._mul(self._w2(), self._loss2(b, target_name="target2"))
        )

my_custom_loss = WeightedAverageLoss()


artifacts.generate_artifacts(
    model,
    optimizer=None,#artifacts.OptimType.AdamW,
    loss=my_custom_loss,#artifacts.LossType.MSELoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    do_constant_folding=False,
    opset_version=17,
    training=torch.onnx.TrainingMode.TRAINING
)

Background this is an onnx created by Unity ML-agents and I'm attempting to turn it into a train-on-device example. I deleted some extra nodes that weren't connected to the main tree.

Any idea what's causing this error?

RuntimeError: C:\a\_work\1\s\orttraining\orttraining\python\orttraining_pybind_state.cc:841 onnxruntime::python::addObjectMethodsForTraining::<lambda_ac677f721119089b105e2d6a6620788a>::operator () [ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. In Node, ("Gemm_13_Grad/Gemm_1", Gemm, "", -1) : ("37_grad": tensor(float),"36": tensor(float),) -> ("action_model._continuous_distribution.mu.weight_grad": tensor(float),) , Error Node (Gemm_13_Grad/Gemm_1) has input size 2 not in range [min=3, max=3].

To reproduce

as above

Urgency

No response

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.15.1

PyTorch Version

1.13.1

Execution Provider

CUDA

Execution Provider Library Version

No response

@elephantpanda elephantpanda added the training issues related to ONNX Runtime training; typically submitted using template label Nov 8, 2023
@github-actions github-actions bot added the ep:CUDA issues related to the CUDA execution provider label Nov 8, 2023
@xadupre
Copy link
Member

xadupre commented Nov 8, 2023

Is it possible to run the following code:

from onnxruntime import InferenceSession
sess = InferenceSession("Walker2.onnx", providers=["CPUExecutionProvider"])

If that fails, the issue probably comes from your model. If that succeeds, the issue probably comes from artifacts.generate_artifacts.

@elephantpanda
Copy link
Author

elephantpanda commented Nov 8, 2023

Yes InferenceSession works perfectly:

import onnxruntime
from onnxruntime import InferenceSession
import numpy as np
import torch
sess = InferenceSession("Walker2.onnx", providers=["CPUExecutionProvider"])


x = onnxruntime.OrtValue.ortvalue_from_numpy(np.array([[1]*243],np.float32))

output = sess.run(["continuous_actions","deterministic_continuous_actions"],{"obs_0":x})

print(output)

This works as expected.

BTW Here is the relevant part of the model:

image

BTW I tried replacing RandomNormal node with identity and still got the same problem. The model was created with a reinforcement learning algorithm.

@elephantpanda
Copy link
Author

elephantpanda commented Nov 8, 2023

I tried simplifying the model using onnx-modifier to just:

image

But still getting the error:

image

Here is the onnx: https://www.dropbox.com/scl/fi/w9t0yeihnlul1ep1ufly8/Walker4.onnx?rlkey=ktcsl4qi7eqoueefxhl3649jb&dl=0

@xadupre
Copy link
Member

xadupre commented Nov 8, 2023

Thanks for the additional information. I'll have a look tomorrow.

@elephantpanda
Copy link
Author

elephantpanda commented Nov 8, 2023

Thankyou! Greatly appreciated.

The Opset is v6. I tried to convert it to opset v7 with onnx.version_converter but now get the error:
"Error No Op registered for Expand with domain_version_7 when trying to make the training artifacts.

Perhaps its just a very old onnx?

(When I create a Linear layer in torch it works fine. So strange this example doesn't work).

@baicenxiao
Copy link

baicenxiao commented Dec 3, 2023

I got the same error when trying to generate training artifacts from a toy onnx network converted from Pytorch:

import torch
import torch.nn as nn
import numpy as np
import onnx
import onnxruntime
from onnxruntime.training import artifacts

class RegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(3, 10)  # First hidden layer
        self.layer2 = nn.Linear(10, 10) # Second hidden layer
        self.logits = nn.Linear(10, 1)  # Output layer

    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = torch.relu(self.layer2(x))
        x = self.logits(x)
        return x

# Instantiate the model
model = RegressionModel()

onnx_model_path = 'regression_model.onnx'
sample_input = torch.randn(1, 3, requires_grad=True)
torch.onnx.export(model, sample_input, onnx_model_path, export_params=True, opset_version=10, 
                  do_constant_folding=True, input_names=['input'], output_names=['output'],
                 dynamic_axes={'input' : {0: 'batch'},    # variable length axes
                                'output' : {0: 'batch'}})

onnx checker and inference work

model = onnx.load("regression_model.onnx")
onnx.checker.check_model(model) # no error here

ort_session = onnxruntime.InferenceSession("regression_model.onnx", providers=["CPUExecutionProvider"])

ort_inputs = {ort_session.get_inputs()[0].name: sample_input.detach().numpy()}

ort_outs = ort_session.run(None, ort_inputs) # works and gives correct result

I see error message RuntimeError: /Users/runner/work/1/s/orttraining/orttraining/python/orttraining_pybind_state.cc:897 auto onnxruntime::python::addObjectMethodsForTraining(py::module &, onnxruntime::python::ExecutionProviderRegistrationFn)::(anonymous class)::operator()(onnxruntime::python::PyGradientGraphBuilderContext *) const [ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. In Node, ("/logits/Gemm_Grad/Gemm_0", Gemm, "", -1) : ("output_grad": tensor(float),"logits.weight": tensor(float),) -> ("/Relu_1_output_0_grad": tensor(float),) , Error Node (/logits/Gemm_Grad/Gemm_0) has input size 2 not in range [min=3, max=3]. when using the code below to generate training artifacts:

path_to_forward_only_onnx_model = 'regression_model.onnx'

# Load the forward-only ONNX model
model = onnx.load(path_to_forward_only_onnx_model)

# Extract model's parameters
all_params = [param.name for param in model.graph.initializer]

trainable_layers = ['layer2', 'output_layer']
requires_grad = [param for param in all_params if any(layer in param for layer in trainable_layers)]
frozen_params = [param for param in all_params if param not in requires_grad]

# Generate the training artifacts
path_to_output_artifact_directory = 'training_artifacts'
artifacts.generate_artifacts(model,
                             requires_grad=requires_grad,
                             frozen_params=frozen_params,
                             loss=onnxblock.loss.MSELoss(),  # Adjust the loss type if needed
                             optimizer=artifacts.OptimType.AdamW,
                             artifact_directory=path_to_output_artifact_directory)

Here is the networks visualization:
Screenshot 2023-12-02 at 4 19 10 PM

My torch version is 2.1.0, onnx version is 1.14.1, onnxruntime version is 1.16.3. I am using Macbook with Apple M1 Pro.
Any idea regarding the solution or work around?

@baicenxiao
Copy link

Just resolved my error by updating opset from 10 to 17:

torch.onnx.export(model, sample_input, onnx_model_path, export_params=True, opset_version=17, 
                  do_constant_folding=True, input_names=['input'], output_names=['output'],
                 dynamic_axes={'input' : {0: 'batch'},    # variable length axes
                                'output' : {0: 'batch'}})

@elephantpanda
Copy link
Author

elephantpanda commented Dec 3, 2023

Just resolved my error by updating opset from 10 to 17:

torch.onnx.export(model, sample_input, onnx_model_path, export_params=True, opset_version=17, 
                  do_constant_folding=True, input_names=['input'], output_names=['output'],
                 dynamic_axes={'input' : {0: 'batch'},    # variable length axes
                                'output' : {0: 'batch'}})

Yes thought that was the problem. I think the onnx I was using was just too old to work. Got to try and recreate it from the python. Be nice if it worked though with op set 10

@baijumeswani
Copy link
Contributor

Closing this issue and marking it as resolved. Please reopen if there are other questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

No branches or pull requests

4 participants