Skip to content

Commit

Permalink
Remove internal enforce for IO binding inputs (#18266)
Browse files Browse the repository at this point in the history
### Description
This PR removes an internal `ORT_ENFORCE` when binding `torch.tensor`
inputs using IO binding for end-to-end scripts.



### Motivation and Context
In merged exports of PyTorch models to ONNX, each past key and past
value in the past KV cache has an input shape of `(batch_size,
num_heads, past_sequence_length, head_size)`. In the first pass through
the model to process the prompt, `past_sequence_length = 0`. Therefore,
each of these inputs is of shape `(batch_size, num_heads, 0,
head_size)`. In subsequent passes, `past_sequence_length > 0`.

When binding a `torch.tensor` of shape `(batch_size, num_heads, 0,
head_size)` with `io_binding.bind_input`, the tensor's `data_ptr()` must
be passed. For a `torch.tensor` of this shape, its `data_ptr()` returns
0. Because it returns 0, the existing `ORT_ENFORCE` is therefore false
and an error is raised. By removing the internal `ORT_ENFORCE`, no error
is raised and the model runs successfully.

LLaMA-2 Example:
Input Name | Input Size | Device | Device ID | Torch Dtype | data_ptr()
------------- | ----------- | ------- | ----------- | ------------- |
-----------
input_ids | torch.Size([1, 11]) | cuda | 7 | torch.int64 |
140639561842688
attention_mask | torch.Size([1, 11]) | cuda | 7 | torch.int64 |
140639561843200
position_ids | torch.Size([1, 11]) | cuda | 7 | torch.int64 |
140639561844224
past_key_values.0.key | torch.Size([1, 32, 0, 128]) | cuda | 7 |
torch.float32 | 0
past_key_values.0.value | torch.Size([1, 32, 0, 128]) | cuda | 7 |
torch.float32 | 0
... | ... | ... | ... | ... | ...
  • Loading branch information
kunal-vaishnavi authored and tianleiwu committed Nov 6, 2023
1 parent ad7cecb commit 726e175
Showing 1 changed file with 0 additions and 2 deletions.
2 changes: 0 additions & 2 deletions onnxruntime/python/onnxruntime_pybind_iobinding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ void addIoBindingMethods(pybind11::module& m) {
})
// This binds input as a Tensor that wraps memory pointer along with the OrtMemoryInfo
.def("bind_input", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, py::object& element_type, const std::vector<int64_t>& shape, int64_t data_ptr) -> void {
ORT_ENFORCE(data_ptr != 0, "Pointer to data memory is not valid");

PyArray_Descr* dtype;
if (!PyArray_DescrConverter(element_type.ptr(), &dtype)) {
throw std::runtime_error("Not a valid numpy type");
Expand Down

0 comments on commit 726e175

Please sign in to comment.