From 726e1758740a19cd3350ef73a2a14a9ee6fc71c5 Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Fri, 3 Nov 2023 16:12:32 -0700 Subject: [PATCH] Remove internal enforce for IO binding inputs (#18266) ### Description This PR removes an internal `ORT_ENFORCE` when binding `torch.tensor` inputs using IO binding for end-to-end scripts. ### Motivation and Context In merged exports of PyTorch models to ONNX, each past key and past value in the past KV cache has an input shape of `(batch_size, num_heads, past_sequence_length, head_size)`. In the first pass through the model to process the prompt, `past_sequence_length = 0`. Therefore, each of these inputs is of shape `(batch_size, num_heads, 0, head_size)`. In subsequent passes, `past_sequence_length > 0`. When binding a `torch.tensor` of shape `(batch_size, num_heads, 0, head_size)` with `io_binding.bind_input`, the tensor's `data_ptr()` must be passed. For a `torch.tensor` of this shape, its `data_ptr()` returns 0. Because it returns 0, the existing `ORT_ENFORCE` is therefore false and an error is raised. By removing the internal `ORT_ENFORCE`, no error is raised and the model runs successfully. LLaMA-2 Example: Input Name | Input Size | Device | Device ID | Torch Dtype | data_ptr() ------------- | ----------- | ------- | ----------- | ------------- | ----------- input_ids | torch.Size([1, 11]) | cuda | 7 | torch.int64 | 140639561842688 attention_mask | torch.Size([1, 11]) | cuda | 7 | torch.int64 | 140639561843200 position_ids | torch.Size([1, 11]) | cuda | 7 | torch.int64 | 140639561844224 past_key_values.0.key | torch.Size([1, 32, 0, 128]) | cuda | 7 | torch.float32 | 0 past_key_values.0.value | torch.Size([1, 32, 0, 128]) | cuda | 7 | torch.float32 | 0 ... | ... | ... | ... | ... | ... --- onnxruntime/python/onnxruntime_pybind_iobinding.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/python/onnxruntime_pybind_iobinding.cc b/onnxruntime/python/onnxruntime_pybind_iobinding.cc index 7638a12bb820c..59d5a77bfbea3 100644 --- a/onnxruntime/python/onnxruntime_pybind_iobinding.cc +++ b/onnxruntime/python/onnxruntime_pybind_iobinding.cc @@ -60,8 +60,6 @@ void addIoBindingMethods(pybind11::module& m) { }) // This binds input as a Tensor that wraps memory pointer along with the OrtMemoryInfo .def("bind_input", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, py::object& element_type, const std::vector& shape, int64_t data_ptr) -> void { - ORT_ENFORCE(data_ptr != 0, "Pointer to data memory is not valid"); - PyArray_Descr* dtype; if (!PyArray_DescrConverter(element_type.ptr(), &dtype)) { throw std::runtime_error("Not a valid numpy type");