Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Onnx Model run failed in a loop #21213

Closed
inisis opened this issue Jun 30, 2024 · 2 comments
Closed

Onnx Model run failed in a loop #21213

inisis opened this issue Jun 30, 2024 · 2 comments

Comments

@inisis
Copy link
Contributor

inisis commented Jun 30, 2024

Describe the issue

For large language model, it's a common practice to loop it, and after each loop the kvcache gets longer, however, Shape mismatch attempting to re-use buffer was raised

2024-06-30 01:50:54.266368919 [E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running Concat node. Name:'/block/attn/Concat_9' Status Message: /onnxruntime_src/onnxruntime/core/framework/op_kernel.cc:83 virtual OrtValue* onnxruntime::OpKernelContext::OutputMLValue(int, const onnxruntime::TensorShape&) status.IsOK() was false. Shape mismatch attempting to re-use buffer. {1,1,16,128} != {1,2,16,128}. Validate usage of dim_value (values should be > 0) and dim_param (all values with the same string should equate to the same size) in shapes in the model.

To reproduce

model is provided

import numpy as np
import onnxruntime as rt

seq_len = 1

hidden_states = np.random.rand(seq_len, 1, 2048).astype(np.float32)
attention_mask = np.array(np.tril(np.ones([1, 1, seq_len, seq_len], dtype=np.bool_)).tolist())
position_ids = np.arange(seq_len, dtype=np.int64)
past_key_values = np.zeros([2, 1, 0, 16, 128], dtype=np.float32)

def onnx_runner(model_path):
    sess_options = rt.InferenceSession(model_path)
    return sess_options

runner = onnx_runner('model.onnx')

def forward(runner, *args):
    input_names = [input.name for input in runner.get_inputs()]
    output_names = [output.name for output in runner.get_outputs()]    
    kwargs = {param: arg for param, arg in zip(input_names, args)}
    output = runner.run(output_names, kwargs)

    return output

for i in range(2):
    print(hidden_states.shape, attention_mask.shape, position_ids.shape, past_key_values.shape)
    token_id, past_key_values = forward(runner, hidden_states, attention_mask, position_ids, past_key_values)

Urgency

No response

Platform

Linux

OS Version

ubuntu 2004

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.18.0

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

@inisis
Copy link
Contributor Author

inisis commented Jun 30, 2024

The problems seems to be with the past_key_values, it only allocates a static shape of buffer, and next iteration, it tries to reuse it, but found shape mismatch and raise it.

image

@tianleiwu
Copy link
Contributor

The input/output shape is not correct:
image

The correct shape is like the following (batch_size=1, hidden_size=2048, num_heads=16, head_dim=128 for this model):

input_embds: (batch_size, seq_len, hidden_size)
past_key_value:  (2, batch_size, past_seq_len, num_heads, head_dim)
attention_mask: (batch_size, 1, seq_len, past_seq_len + seq_len)
hideen_state: (batch_size, seq_len, hidden_size)
presents:  (2, batch_size, past_seq_len + seq_len, num_heads, head_dim)

@inisis inisis closed this as completed Jun 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants