Skip to content

Commit

Permalink
Create token type ids when not provided (#2081)
Browse files Browse the repository at this point in the history
* create token type ids when needed

* add test
  • Loading branch information
echarlaix authored Oct 29, 2024
1 parent 2e637be commit 4a39ae0
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
19 changes: 18 additions & 1 deletion optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,6 @@ def _prepare_onnx_inputs(
self, use_torch: bool, **inputs: Union[torch.Tensor, np.ndarray]
) -> Dict[str, np.ndarray]:
onnx_inputs = {}

# converts pytorch inputs into numpy inputs for onnx
for input_name in self.input_names.keys():
onnx_inputs[input_name] = inputs.pop(input_name)
Expand Down Expand Up @@ -1086,6 +1085,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down Expand Up @@ -1241,6 +1243,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down Expand Up @@ -1330,6 +1335,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down Expand Up @@ -1437,6 +1445,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down Expand Up @@ -1527,6 +1538,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down Expand Up @@ -1610,6 +1624,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down
12 changes: 12 additions & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2192,6 +2192,18 @@ def test_compare_to_io_binding(self, model_arch):

gc.collect()

def test_default_token_type_ids(self):
model_id = MODEL_NAMES["bert"]
model = ORTModelForFeatureExtraction.from_pretrained(model_id, export=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer("this is a simple input", return_tensors="np")
self.assertTrue("token_type_ids" in model.input_names)
token_type_ids = tokens.pop("token_type_ids")
outs = model(token_type_ids=token_type_ids, **tokens)
outs_without_token_type_ids = model(**tokens)
self.assertTrue(np.allclose(outs.last_hidden_state, outs_without_token_type_ids.last_hidden_state))
gc.collect()


class ORTModelForMultipleChoiceIntegrationTest(ORTModelTestMixin):
# Multiple Choice tests are conducted on different models due to mismatch size in model's classifier
Expand Down

0 comments on commit 4a39ae0

Please sign in to comment.