Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove internal enforce for IO binding inputs (#18266)
### Description This PR removes an internal `ORT_ENFORCE` when binding `torch.tensor` inputs using IO binding for end-to-end scripts. ### Motivation and Context In merged exports of PyTorch models to ONNX, each past key and past value in the past KV cache has an input shape of `(batch_size, num_heads, past_sequence_length, head_size)`. In the first pass through the model to process the prompt, `past_sequence_length = 0`. Therefore, each of these inputs is of shape `(batch_size, num_heads, 0, head_size)`. In subsequent passes, `past_sequence_length > 0`. When binding a `torch.tensor` of shape `(batch_size, num_heads, 0, head_size)` with `io_binding.bind_input`, the tensor's `data_ptr()` must be passed. For a `torch.tensor` of this shape, its `data_ptr()` returns 0. Because it returns 0, the existing `ORT_ENFORCE` is therefore false and an error is raised. By removing the internal `ORT_ENFORCE`, no error is raised and the model runs successfully. LLaMA-2 Example: Input Name | Input Size | Device | Device ID | Torch Dtype | data_ptr() ------------- | ----------- | ------- | ----------- | ------------- | ----------- input_ids | torch.Size([1, 11]) | cuda | 7 | torch.int64 | 140639561842688 attention_mask | torch.Size([1, 11]) | cuda | 7 | torch.int64 | 140639561843200 position_ids | torch.Size([1, 11]) | cuda | 7 | torch.int64 | 140639561844224 past_key_values.0.key | torch.Size([1, 32, 0, 128]) | cuda | 7 | torch.float32 | 0 past_key_values.0.value | torch.Size([1, 32, 0, 128]) | cuda | 7 | torch.float32 | 0 ... | ... | ... | ... | ... | ...
- Loading branch information