Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove internal enforce for IO binding inputs #18266

Conversation

kunal-vaishnavi
Copy link
Contributor

@kunal-vaishnavi kunal-vaishnavi commented Nov 3, 2023

Description

This PR removes an internal ORT_ENFORCE when binding torch.tensor inputs using IO binding for end-to-end scripts.

Motivation and Context

In merged exports of PyTorch models to ONNX, each past key and past value in the past KV cache has an input shape of (batch_size, num_heads, past_sequence_length, head_size). In the first pass through the model to process the prompt, past_sequence_length = 0. Therefore, each of these inputs is of shape (batch_size, num_heads, 0, head_size). In subsequent passes, past_sequence_length > 0.

When binding a torch.tensor of shape (batch_size, num_heads, 0, head_size) with io_binding.bind_input, the tensor's data_ptr() must be passed. For a torch.tensor of this shape, its data_ptr() returns 0. Because it returns 0, the existing ORT_ENFORCE is therefore false and an error is raised. By removing the internal ORT_ENFORCE, no error is raised and the model runs successfully.

LLaMA-2 Example:

Input Name Input Size Device Device ID Torch Dtype data_ptr()
input_ids torch.Size([1, 11]) cuda 7 torch.int64 140639561842688
attention_mask torch.Size([1, 11]) cuda 7 torch.int64 140639561843200
position_ids torch.Size([1, 11]) cuda 7 torch.int64 140639561844224
past_key_values.0.key torch.Size([1, 32, 0, 128]) cuda 7 torch.float32 0
past_key_values.0.value torch.Size([1, 32, 0, 128]) cuda 7 torch.float32 0
... ... ... ... ... ...

@kunal-vaishnavi kunal-vaishnavi merged commit 08eaa1c into microsoft:main Nov 3, 2023
84 of 86 checks passed
tianleiwu pushed a commit that referenced this pull request Nov 7, 2023
### Description
This PR removes an internal `ORT_ENFORCE` when binding `torch.tensor`
inputs using IO binding for end-to-end scripts.



### Motivation and Context
In merged exports of PyTorch models to ONNX, each past key and past
value in the past KV cache has an input shape of `(batch_size,
num_heads, past_sequence_length, head_size)`. In the first pass through
the model to process the prompt, `past_sequence_length = 0`. Therefore,
each of these inputs is of shape `(batch_size, num_heads, 0,
head_size)`. In subsequent passes, `past_sequence_length > 0`.

When binding a `torch.tensor` of shape `(batch_size, num_heads, 0,
head_size)` with `io_binding.bind_input`, the tensor's `data_ptr()` must
be passed. For a `torch.tensor` of this shape, its `data_ptr()` returns
0. Because it returns 0, the existing `ORT_ENFORCE` is therefore false
and an error is raised. By removing the internal `ORT_ENFORCE`, no error
is raised and the model runs successfully.

LLaMA-2 Example:
Input Name | Input Size | Device | Device ID | Torch Dtype | data_ptr()
------------- | ----------- | ------- | ----------- | ------------- |
-----------
input_ids | torch.Size([1, 11]) | cuda | 7 | torch.int64 |
140639561842688
attention_mask | torch.Size([1, 11]) | cuda | 7 | torch.int64 |
140639561843200
position_ids | torch.Size([1, 11]) | cuda | 7 | torch.int64 |
140639561844224
past_key_values.0.key | torch.Size([1, 32, 0, 128]) | cuda | 7 |
torch.float32 | 0
past_key_values.0.value | torch.Size([1, 32, 0, 128]) | cuda | 7 |
torch.float32 | 0
... | ... | ... | ... | ... | ...
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
### Description
This PR removes an internal `ORT_ENFORCE` when binding `torch.tensor`
inputs using IO binding for end-to-end scripts.



### Motivation and Context
In merged exports of PyTorch models to ONNX, each past key and past
value in the past KV cache has an input shape of `(batch_size,
num_heads, past_sequence_length, head_size)`. In the first pass through
the model to process the prompt, `past_sequence_length = 0`. Therefore,
each of these inputs is of shape `(batch_size, num_heads, 0,
head_size)`. In subsequent passes, `past_sequence_length > 0`.

When binding a `torch.tensor` of shape `(batch_size, num_heads, 0,
head_size)` with `io_binding.bind_input`, the tensor's `data_ptr()` must
be passed. For a `torch.tensor` of this shape, its `data_ptr()` returns
0. Because it returns 0, the existing `ORT_ENFORCE` is therefore false
and an error is raised. By removing the internal `ORT_ENFORCE`, no error
is raised and the model runs successfully.

LLaMA-2 Example:
Input Name | Input Size | Device | Device ID | Torch Dtype | data_ptr()
------------- | ----------- | ------- | ----------- | ------------- |
-----------
input_ids | torch.Size([1, 11]) | cuda | 7 | torch.int64 |
140639561842688
attention_mask | torch.Size([1, 11]) | cuda | 7 | torch.int64 |
140639561843200
position_ids | torch.Size([1, 11]) | cuda | 7 | torch.int64 |
140639561844224
past_key_values.0.key | torch.Size([1, 32, 0, 128]) | cuda | 7 |
torch.float32 | 0
past_key_values.0.value | torch.Size([1, 32, 0, 128]) | cuda | 7 |
torch.float32 | 0
... | ... | ... | ... | ... | ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants