Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Training] [ShapeInferenceError] Dimension could not be inferred: incompatible shapes #21327

Closed
srijanie03 opened this issue Jul 11, 2024 · 3 comments
Assignees
Labels
training issues related to ONNX Runtime training; typically submitted using template

Comments

@srijanie03
Copy link

srijanie03 commented Jul 11, 2024

Describe the issue

I am trying to get ONNX training graph for Lllama2_7b model. I can get the forward graph, no problem. But the issue occurs when I use generate artifacts.

image

I don't get this error when I run a signle layer Transformer block (attention+MLP) with similar dimensions. What is causing the issue here?

Additionally, the loss function is throwing an error too : expected 2 but got 66 arguments. Please explain. Thank you!

To reproduce

base_model = llama2_7b()


batch = torch.tensor([[    1,  7569,  7225, 16229,   366],
        [    1,  7569,  2462,  8640,   263]])

model_outputs = base_model(batch)
if isinstance(model_outputs, torch.Tensor):
    model_outputs = [model_outputs]

input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}


f = io.BytesIO()
torch.onnx.export(
    base_model,
    batch,
    "torchtune_llama2.onnx",
    input_names=input_names,
    output_names=output_names,
    opset_version=14,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    dynamic_axes=dynamic_axes,
    export_params=True,
    keep_initializers_as_inputs=False,
)
requires_grad = [name for name, param in base_model.named_parameters() if param.requires_grad]

frozen_params = [name for name, param in base_model.named_parameters() if not param.requires_grad]

artifacts.generate_artifacts(
    "torchtune_llama2.onnx",
    #optimizer=artifacts.OptimType.AdamW,
    #loss=artifacts.LossType.CrossEntropyLoss, 
    #loss=artifacts.LossType.MSELoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    artifact_directory="llama2",
    additional_output_names=["output"])

Urgency

Urgent

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

1.19.0

PyTorch Version

2.4.0

Execution Provider

CUDA

Execution Provider Library Version

CUDA 12.4

@srijanie03 srijanie03 added the training issues related to ONNX Runtime training; typically submitted using template label Jul 11, 2024
@github-actions github-actions bot added the ep:CUDA issues related to the CUDA execution provider label Jul 11, 2024
@sophies927 sophies927 removed the ep:CUDA issues related to the CUDA execution provider label Jul 11, 2024
@carzh carzh self-assigned this Jul 15, 2024
@github-actions github-actions bot added the ep:CUDA issues related to the CUDA execution provider label Jul 22, 2024
@carzh carzh removed the ep:CUDA issues related to the CUDA execution provider label Jul 26, 2024
@carzh
Copy link
Contributor

carzh commented Jul 26, 2024

Hi, do you have the full script for exporting & generating artifacts? Or could you provide the forward graph ONNX file?

Generally, we see these errors when the generated forward graph is incorrect. Especially for the loss function, which expects a certain number of graph outputs. For LLM's especially, we see that unless the base Torch model that is being exported is in training mode, then usually it uses a key-value cache (an inference-only optimization that adds inputs and outputs to the graph).

The Torch model passed to torch.onnx.export (the base_model) must be in training mode (ie, you should be able to train with it), and the input and output names passed to the export function should correlate with the input names and output names of the Torch model.

If you have a working PyTorch training script for Llama2_7b, you can use that to determine the correct input names and output names, and what inputs you need to pass in for it to be in training mode.

@srijanie03
Copy link
Author

Hi @carzh. Thanks a lot for the comment. Yes, I exactly did that and was able to resolve the issue. There was a mismatch with the input dimensions which was generating the error.

@carzh
Copy link
Contributor

carzh commented Jul 30, 2024

Glad to hear it! Closing as resolved -- feel free to reopen if you run into further issues.

@carzh carzh closed this as completed Jul 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

No branches or pull requests

3 participants