Remove internal enforce for IO binding inputs #18266
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR removes an internal
ORT_ENFORCE
when bindingtorch.tensor
inputs using IO binding for end-to-end scripts.Motivation and Context
In merged exports of PyTorch models to ONNX, each past key and past value in the past KV cache has an input shape of
(batch_size, num_heads, past_sequence_length, head_size)
. In the first pass through the model to process the prompt,past_sequence_length = 0
. Therefore, each of these inputs is of shape(batch_size, num_heads, 0, head_size)
. In subsequent passes,past_sequence_length > 0
.When binding a
torch.tensor
of shape(batch_size, num_heads, 0, head_size)
withio_binding.bind_input
, the tensor'sdata_ptr()
must be passed. For atorch.tensor
of this shape, itsdata_ptr()
returns 0. Because it returns 0, the existingORT_ENFORCE
is therefore false and an error is raised. By removing the internalORT_ENFORCE
, no error is raised and the model runs successfully.LLaMA-2 Example: