From 4a39ae0b1de05601c8a33f4a13c244bdd016db24 Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Tue, 29 Oct 2024 10:20:31 +0100 Subject: [PATCH] Create token type ids when not provided (#2081) * create token type ids when needed * add test --- optimum/onnxruntime/modeling_ort.py | 19 ++++++++++++++++++- tests/onnxruntime/test_modeling.py | 12 ++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index ce1d68536a..8e5a814b68 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -931,7 +931,6 @@ def _prepare_onnx_inputs( self, use_torch: bool, **inputs: Union[torch.Tensor, np.ndarray] ) -> Dict[str, np.ndarray]: onnx_inputs = {} - # converts pytorch inputs into numpy inputs for onnx for input_name in self.input_names.keys(): onnx_inputs[input_name] = inputs.pop(input_name) @@ -1086,6 +1085,9 @@ def forward( use_torch = isinstance(input_ids, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if token_type_ids is None and "token_type_ids" in self.input_names: + token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) + if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( input_ids, @@ -1241,6 +1243,9 @@ def forward( use_torch = isinstance(input_ids, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if token_type_ids is None and "token_type_ids" in self.input_names: + token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) + if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( input_ids, @@ -1330,6 +1335,9 @@ def forward( use_torch = isinstance(input_ids, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if token_type_ids is None and "token_type_ids" in self.input_names: + token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) + if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( input_ids, @@ -1437,6 +1445,9 @@ def forward( use_torch = isinstance(input_ids, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if token_type_ids is None and "token_type_ids" in self.input_names: + token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) + if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( input_ids, @@ -1527,6 +1538,9 @@ def forward( use_torch = isinstance(input_ids, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if token_type_ids is None and "token_type_ids" in self.input_names: + token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) + if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( input_ids, @@ -1610,6 +1624,9 @@ def forward( use_torch = isinstance(input_ids, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if token_type_ids is None and "token_type_ids" in self.input_names: + token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) + if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( input_ids, diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 33243da278..da450b8e31 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -2192,6 +2192,18 @@ def test_compare_to_io_binding(self, model_arch): gc.collect() + def test_default_token_type_ids(self): + model_id = MODEL_NAMES["bert"] + model = ORTModelForFeatureExtraction.from_pretrained(model_id, export=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer("this is a simple input", return_tensors="np") + self.assertTrue("token_type_ids" in model.input_names) + token_type_ids = tokens.pop("token_type_ids") + outs = model(token_type_ids=token_type_ids, **tokens) + outs_without_token_type_ids = model(**tokens) + self.assertTrue(np.allclose(outs.last_hidden_state, outs_without_token_type_ids.last_hidden_state)) + gc.collect() + class ORTModelForMultipleChoiceIntegrationTest(ORTModelTestMixin): # Multiple Choice tests are conducted on different models due to mismatch size in model's classifier