-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
Signed-off-by: Xavier Dupre <[email protected]>
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
# pylint: disable=import-outside-toplevel | ||
from __future__ import annotations | ||
|
||
from typing import Any, Sequence | ||
|
||
import torch | ||
|
||
import onnxscript.tools.transformers_models | ||
|
||
|
||
def _prepare_config_and_inputs( | ||
batch_size: int, | ||
seq_length: int, | ||
vocab_size: int, | ||
type_sequence_label_size: int = 2, | ||
type_vocab_size: int = 16, | ||
num_labels: int = 3, | ||
num_choices: int = 4, | ||
use_input_mask: bool = False, | ||
use_token_type_ids: bool = False, | ||
use_labels: bool = False, | ||
) -> tuple[Any, ...]: | ||
import torch | ||
Check warning Code scanning / lintrunner PYLINT/W0621 Warning
Redefining name 'torch' from outer scope (line 10) (redefined-outer-name)
See redefined-outer-name. To disable, use # pylint: disable=redefined-outer-name Check warning Code scanning / lintrunner PYLINT/W0404 Warning
Reimport 'torch' (imported line 10) (reimported)
See reimported. To disable, use # pylint: disable=reimported |
||
|
||
input_ids = onnxscript.tools.transformers_models.ids_tensor( | ||
[batch_size, seq_length], vocab_size | ||
) | ||
|
||
input_mask = None | ||
if use_input_mask: | ||
input_mask = torch.tril(torch.ones(batch_size, seq_length)) | ||
|
||
token_type_ids = None | ||
if use_token_type_ids: | ||
assert type_vocab_size > 0, "type_vocab_size is null" | ||
token_type_ids = onnxscript.tools.transformers_models.ids_tensor( | ||
[batch_size, seq_length], type_vocab_size | ||
) | ||
|
||
sequence_labels = None | ||
token_labels = None | ||
choice_labels = None | ||
if use_labels: | ||
assert type_sequence_label_size > 0, "type_sequence_label_size is null" | ||
assert num_labels > 0, "num_labels is null" | ||
assert num_choices > 0, "num_choices is null" | ||
sequence_labels = onnxscript.tools.transformers_models.ids_tensor( | ||
[batch_size], type_sequence_label_size | ||
) | ||
token_labels = onnxscript.tools.transformers_models.ids_tensor( | ||
[batch_size, seq_length], num_labels | ||
) | ||
choice_labels = onnxscript.tools.transformers_models.ids_tensor( | ||
[batch_size], num_choices | ||
) | ||
|
||
return ( | ||
input_ids, | ||
token_type_ids, | ||
input_mask, | ||
sequence_labels, | ||
token_labels, | ||
choice_labels, | ||
) | ||
|
||
|
||
def get_mistral_model( | ||
input_dims: Sequence[tuple[int, int]] = ((13, 7), (14, 7), (15, 8)), | ||
hidden_size=32, | ||
num_hidden_layers=2, | ||
vocab_size=99, | ||
intermediate_size=16, | ||
max_position_embeddings=512, | ||
num_attention_heads=2, | ||
num_key_value_heads=2, | ||
sliding_window=4096, | ||
_attn_implementation="eager", # needed value to remove graph breaks | ||
with_mask: bool = True, | ||
) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]: | ||
""" | ||
Returns a model. | ||
See `MistralConfig | ||
<https://huggingface.co/docs/transformers/main/en/model_doc/mistral#transformers.MistralConfig>`_. | ||
The parameters are chosen for a unit test configuration. | ||
""" | ||
import torch | ||
Check warning Code scanning / lintrunner PYLINT/W0621 Warning
Redefining name 'torch' from outer scope (line 10) (redefined-outer-name)
See redefined-outer-name. To disable, use # pylint: disable=redefined-outer-name Check warning Code scanning / lintrunner PYLINT/W0404 Warning
Reimport 'torch' (imported line 10) (reimported)
See reimported. To disable, use # pylint: disable=reimported |
||
from transformers import MistralConfig | ||
from transformers.models.mistral.modeling_mistral import MistralModel | ||
|
||
config = MistralConfig( | ||
num_hidden_layers=num_hidden_layers, | ||
vocab_size=vocab_size, | ||
hidden_size=hidden_size, | ||
intermediate_size=intermediate_size, | ||
max_position_embeddings=max_position_embeddings, | ||
num_attention_heads=num_attention_heads, | ||
num_key_value_heads=num_key_value_heads, | ||
sliding_window=sliding_window, | ||
) | ||
|
||
dynamic_shapes = {0: {0: "batch", 1: "length"}} | ||
if with_mask: | ||
dynamic_shapes.update({1: {0: "batch", 1: "length"}}) | ||
|
||
if _attn_implementation: | ||
config._attn_implementation = _attn_implementation | ||
Check warning Code scanning / lintrunner PYLINT/W0212 Warning
Access to a protected member _attn_implementation of a client class (protected-access)
See protected-access. To disable, use # pylint: disable=protected-access |
||
|
||
def generate_example_inputs(batch: int, seq: int, vocab_size: int, with_mask: bool): | ||
( | ||
input_ids, | ||
_, # token_type_ids, | ||
input_mask, | ||
_, # sequence_labels, | ||
_, # token_labels, | ||
_, # choice_labels, | ||
) = _prepare_config_and_inputs( | ||
batch_size=batch, | ||
seq_length=seq, | ||
vocab_size=vocab_size, | ||
use_input_mask=with_mask, | ||
) | ||
if with_mask: | ||
return input_ids, input_mask | ||
return (input_ids,) | ||
|
||
if with_mask: | ||
|
||
class MistralModelWrapperWithMask(torch.nn.Module): | ||
def __init__(self, config): | ||
super().__init__() | ||
self.model = MistralModel(config) | ||
|
||
def forward(self, input_ids, attention_mask): | ||
model_output = self.model(input_ids, attention_mask=attention_mask) | ||
return model_output.to_tuple() | ||
|
||
example_args_collection = [] | ||
for b, s in input_dims: | ||
example_args_collection.append( | ||
generate_example_inputs(b, s, vocab_size, with_mask) | ||
) | ||
|
||
return MistralModelWrapperWithMask(config), example_args_collection, dynamic_shapes | ||
|
||
class MistralModelWrapper(torch.nn.Module): | ||
def __init__(self, config): | ||
super().__init__() | ||
self.model = MistralModel(config) | ||
|
||
def forward(self, input_ids): | ||
model_output = self.model(input_ids) | ||
return model_output.to_tuple() | ||
|
||
example_args_collection = [] | ||
for b, s in input_dims: | ||
example_args_collection.append(generate_example_inputs(b, s, vocab_size, with_mask)) | ||
|
||
return MistralModelWrapper(config), example_args_collection, dynamic_shapes | ||
|
||
|
||
def get_mistral_model_from_config( | ||
warmup: int = 5, | ||
repeat: int = 10, | ||
config: str = "small", | ||
num_hidden_layers: int = 1, | ||
implementation: str = "eager", | ||
dynamic_shapes: bool = False, | ||
with_mask: bool = True, | ||
) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]: | ||
""" | ||
Returns a model Phi to test or benchmark. | ||
Args: | ||
warmup: Number of inputs to generate. | ||
repeat: Number of inputs to generate for repeat. | ||
config: small, medium or large | ||
num_hidden_layers: number of hidden layers | ||
implementation: eager or sdpa | ||
with_mask: One or two inputs. | ||
dynamic_shapes: dynamic shapes or not | ||
Returns: | ||
Model and list of inputs. | ||
""" | ||
if config == "small": | ||
conf_dict = dict( | ||
input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( | ||
dynamic_shapes, warmup, repeat | ||
), | ||
hidden_size=32, | ||
num_hidden_layers=num_hidden_layers, | ||
vocab_size=99, | ||
intermediate_size=16, | ||
max_position_embeddings=512, | ||
num_attention_heads=4, | ||
num_key_value_heads=2, | ||
_attn_implementation=implementation, | ||
with_mask=with_mask, | ||
) | ||
elif config == "medium": | ||
conf_dict = dict( | ||
input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( | ||
dynamic_shapes, warmup, repeat | ||
), | ||
hidden_size=1024, | ||
num_hidden_layers=num_hidden_layers, | ||
vocab_size=1024, | ||
intermediate_size=1024, | ||
num_attention_heads=4, | ||
num_key_value_heads=4, | ||
max_position_embeddings=1024, | ||
sliding_window=4096, | ||
_attn_implementation=implementation, | ||
with_mask=with_mask, | ||
) | ||
elif config in ("large", "default"): | ||
conf_dict = dict( | ||
input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( | ||
dynamic_shapes, warmup, repeat | ||
), | ||
hidden_size=4096, | ||
num_hidden_layers=num_hidden_layers, | ||
vocab_size=32000, | ||
intermediate_size=14336, | ||
num_attention_heads=32, | ||
num_key_value_heads=8, | ||
max_position_embeddings=131072, | ||
sliding_window=4096, | ||
_attn_implementation=implementation, | ||
with_mask=with_mask, | ||
) | ||
else: | ||
raise ValueError(f"Unexpected configuration {config!r}.") | ||
|
||
return get_mistral_model(**conf_dict) # type: ignore[arg-type] | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
# pylint: disable=not-callable | ||
|
||
import copy | ||
import sys | ||
import unittest | ||
|
||
import numpy as np | ||
import onnxruntime | ||
import torch | ||
|
||
import onnxscript.optimizer | ||
import onnxscript.rewriter | ||
import onnxscript.tools.training_helper | ||
import onnxscript.tools.transformers_models | ||
import onnxscript.tools.transformers_models.mistral | ||
from onnxscript._internal.version_utils import has_transformers, torch_older_than | ||
|
||
|
||
class TestExportPhi(unittest.TestCase): | ||
@unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") | ||
@unittest.skipIf(not has_transformers(), reason="transformers is missing") | ||
@unittest.skipIf(torch_older_than("2.4"), reason="fails to export") | ||
def test_phi_export_cpu(self): | ||
model, input_tensors_many, _ = ( | ||
onnxscript.tools.transformers_models.mistral.get_mistral_model() | ||
) | ||
input_tensors = input_tensors_many[0] | ||
expected = model(*input_tensors) | ||
proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) | ||
names = [i.name for i in proto.graph.input] | ||
np_input_tensors = [x.numpy() for x in input_tensors] | ||
feeds = dict(zip(names, np_input_tensors)) | ||
sess = onnxruntime.InferenceSession( | ||
proto.SerializeToString(), providers=["CPUExecutionProvider"] | ||
) | ||
results = sess.run(None, feeds) | ||
np.testing.assert_allclose(expected[0].detach().numpy(), results[0], atol=1e-5) | ||
|
||
@unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") | ||
@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") | ||
@unittest.skipIf(not has_transformers(), reason="transformers is missing") | ||
def test_phi_export_cuda(self): | ||
model, input_tensors_many, _ = ( | ||
onnxscript.tools.transformers_models.mistral.get_mistral_model() | ||
) | ||
input_tensors_cpu = input_tensors_many[0] | ||
model = model.to("cuda") | ||
input_tensors = [i.to("cuda") for i in input_tensors_cpu] | ||
expected = model(*input_tensors) | ||
proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) | ||
names = [i.name for i in proto.graph.input] | ||
np_input_tensors = [x.detach().cpu().numpy() for x in input_tensors] | ||
feeds = dict(zip(names, np_input_tensors)) | ||
sess = onnxruntime.InferenceSession( | ||
proto.SerializeToString(), providers=["CUDAExecutionProvider"] | ||
) | ||
results = sess.run(None, feeds) | ||
np.testing.assert_allclose(expected[0].detach().cpu().numpy(), results[0], atol=1e-5) | ||
|
||
@unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") | ||
@unittest.skipIf(not has_transformers(), reason="transformers is missing") | ||
def test_phi_dort_static(self): | ||
model, input_tensors_many, _ = ( | ||
onnxscript.tools.transformers_models.mistral.get_mistral_model() | ||
) | ||
input_tensors = input_tensors_many[0] | ||
expected = model(*input_tensors) | ||
|
||
local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) | ||
|
||
compiled_model = torch.compile( | ||
copy.deepcopy(model), | ||
backend=local_aot_ort, | ||
dynamic=False, | ||
fullgraph=True, | ||
) | ||
|
||
results = compiled_model(*input_tensors) | ||
torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5) | ||
|
||
expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors) | ||
gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) | ||
torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main(verbosity=2) | ||