diff --git a/onnxscript/tools/benchmark/export_model_test.py b/onnxscript/tools/benchmark/export_model_test.py index c8a2dc229..dc124cb99 100644 --- a/onnxscript/tools/benchmark/export_model_test.py +++ b/onnxscript/tools/benchmark/export_model_test.py @@ -37,6 +37,33 @@ def test_export_model_phi_cpu_eager(self): out = f.getvalue() self.assertIn(":repeat_time,", out) + @unittest.skipIf(not has_transformers(), reason="transformers missing") + @unittest.skipIf(torch_older_than("2.4"), reason="fails to export") + @unittest.skipIf(not is_onnxruntime_training(), reason="onnxruntime-training is needed") + def test_export_model_mistral_cpu_dynamo_llama0(self): + args = [ + "--verbose", + "1", + "--config", + "medium", + "--dtype", + "float32", + "--device", + "cpu", + "--exporter", + "dynamo", + "--optimization", + "rewrite,optimize,inline,llama0", + "--model", + "mistral", + ] + f = io.StringIO() + with contextlib.redirect_stdout(f): + onnxscript.tools.benchmark.export_model.main(args) + + out = f.getvalue() + self.assertIn(":repeat_time,", out) + @unittest.skipIf(not has_transformers(), reason="transformers missing") def test_export_model_llama_cpu_eager(self): args = [ diff --git a/onnxscript/tools/transformers_models/__init__.py b/onnxscript/tools/transformers_models/__init__.py index 1340d544b..330131605 100644 --- a/onnxscript/tools/transformers_models/__init__.py +++ b/onnxscript/tools/transformers_models/__init__.py @@ -114,6 +114,19 @@ def get_model_and_inputs( config=config, ) + elif model == "mistral": + import onnxscript.tools.transformers_models.mistral as m_mistral + + tmodel, inputs, dynamic_shapes_def = m_mistral.get_mistral_model_from_config( + warmup=warmup, + repeat=repeat, + implementation=implementation, + with_mask=with_mask, + num_hidden_layers=num_hidden_layers, + dynamic_shapes=dynamic_shapes, + config=config, + ) + elif model == "phi": import onnxscript.tools.transformers_models.phi as m_phi diff --git a/onnxscript/tools/transformers_models/mistral.py b/onnxscript/tools/transformers_models/mistral.py new file mode 100644 index 000000000..9b4992794 --- /dev/null +++ b/onnxscript/tools/transformers_models/mistral.py @@ -0,0 +1,239 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# pylint: disable=import-outside-toplevel +from __future__ import annotations + +from typing import Any, Sequence + +import torch + +import onnxscript.tools.transformers_models + + +def _prepare_config_and_inputs( + batch_size: int, + seq_length: int, + vocab_size: int, + type_sequence_label_size: int = 2, + type_vocab_size: int = 16, + num_labels: int = 3, + num_choices: int = 4, + use_input_mask: bool = False, + use_token_type_ids: bool = False, + use_labels: bool = False, +) -> tuple[Any, ...]: + import torch + + input_ids = onnxscript.tools.transformers_models.ids_tensor( + [batch_size, seq_length], vocab_size + ) + + input_mask = None + if use_input_mask: + input_mask = torch.tril(torch.ones(batch_size, seq_length)) + + token_type_ids = None + if use_token_type_ids: + assert type_vocab_size > 0, "type_vocab_size is null" + token_type_ids = onnxscript.tools.transformers_models.ids_tensor( + [batch_size, seq_length], type_vocab_size + ) + + sequence_labels = None + token_labels = None + choice_labels = None + if use_labels: + assert type_sequence_label_size > 0, "type_sequence_label_size is null" + assert num_labels > 0, "num_labels is null" + assert num_choices > 0, "num_choices is null" + sequence_labels = onnxscript.tools.transformers_models.ids_tensor( + [batch_size], type_sequence_label_size + ) + token_labels = onnxscript.tools.transformers_models.ids_tensor( + [batch_size, seq_length], num_labels + ) + choice_labels = onnxscript.tools.transformers_models.ids_tensor( + [batch_size], num_choices + ) + + return ( + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) + + +def get_mistral_model( + input_dims: Sequence[tuple[int, int]] = ((13, 7), (14, 7), (15, 8)), + hidden_size=32, + num_hidden_layers=2, + vocab_size=99, + intermediate_size=16, + max_position_embeddings=512, + num_attention_heads=2, + num_key_value_heads=2, + sliding_window=4096, + _attn_implementation="eager", # needed value to remove graph breaks + with_mask: bool = True, +) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]: + """ + Returns a model. + See `MistralConfig + `_. + The parameters are chosen for a unit test configuration. + """ + import torch + from transformers import MistralConfig + from transformers.models.mistral.modeling_mistral import MistralModel + + config = MistralConfig( + num_hidden_layers=num_hidden_layers, + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + max_position_embeddings=max_position_embeddings, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + sliding_window=sliding_window, + ) + + dynamic_shapes = {0: {0: "batch", 1: "length"}} + if with_mask: + dynamic_shapes.update({1: {0: "batch", 1: "length"}}) + + if _attn_implementation: + config._attn_implementation = _attn_implementation + + def generate_example_inputs(batch: int, seq: int, vocab_size: int, with_mask: bool): + ( + input_ids, + _, # token_type_ids, + input_mask, + _, # sequence_labels, + _, # token_labels, + _, # choice_labels, + ) = _prepare_config_and_inputs( + batch_size=batch, + seq_length=seq, + vocab_size=vocab_size, + use_input_mask=with_mask, + ) + if with_mask: + return input_ids, input_mask + return (input_ids,) + + if with_mask: + + class MistralModelWrapperWithMask(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.model = MistralModel(config) + + def forward(self, input_ids, attention_mask): + model_output = self.model(input_ids, attention_mask=attention_mask) + return model_output.to_tuple() + + example_args_collection = [] + for b, s in input_dims: + example_args_collection.append( + generate_example_inputs(b, s, vocab_size, with_mask) + ) + + return MistralModelWrapperWithMask(config), example_args_collection, dynamic_shapes + + class MistralModelWrapper(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.model = MistralModel(config) + + def forward(self, input_ids): + model_output = self.model(input_ids) + return model_output.to_tuple() + + example_args_collection = [] + for b, s in input_dims: + example_args_collection.append(generate_example_inputs(b, s, vocab_size, with_mask)) + + return MistralModelWrapper(config), example_args_collection, dynamic_shapes + + +def get_mistral_model_from_config( + warmup: int = 5, + repeat: int = 10, + config: str = "small", + num_hidden_layers: int = 1, + implementation: str = "eager", + dynamic_shapes: bool = False, + with_mask: bool = True, +) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]: + """ + Returns a model Phi to test or benchmark. + + Args: + warmup: Number of inputs to generate. + repeat: Number of inputs to generate for repeat. + config: small, medium or large + num_hidden_layers: number of hidden layers + implementation: eager or sdpa + with_mask: One or two inputs. + dynamic_shapes: dynamic shapes or not + + Returns: + Model and list of inputs. + """ + if config == "small": + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=32, + num_hidden_layers=num_hidden_layers, + vocab_size=99, + intermediate_size=16, + max_position_embeddings=512, + num_attention_heads=4, + num_key_value_heads=2, + _attn_implementation=implementation, + with_mask=with_mask, + ) + elif config == "medium": + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=1024, + num_hidden_layers=num_hidden_layers, + vocab_size=1024, + intermediate_size=1024, + num_attention_heads=4, + num_key_value_heads=4, + max_position_embeddings=1024, + sliding_window=4096, + _attn_implementation=implementation, + with_mask=with_mask, + ) + elif config in ("large", "default"): + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=4096, + num_hidden_layers=num_hidden_layers, + vocab_size=32000, + intermediate_size=14336, + num_attention_heads=32, + num_key_value_heads=8, + max_position_embeddings=131072, + sliding_window=4096, + _attn_implementation=implementation, + with_mask=with_mask, + ) + else: + raise ValueError(f"Unexpected configuration {config!r}.") + + return get_mistral_model(**conf_dict) # type: ignore[arg-type] diff --git a/onnxscript/tools/transformers_models/mistral_test.py b/onnxscript/tools/transformers_models/mistral_test.py new file mode 100644 index 000000000..a919f3c69 --- /dev/null +++ b/onnxscript/tools/transformers_models/mistral_test.py @@ -0,0 +1,89 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pylint: disable=not-callable + +import copy +import sys +import unittest + +import numpy as np +import onnxruntime +import torch + +import onnxscript.optimizer +import onnxscript.rewriter +import onnxscript.tools.training_helper +import onnxscript.tools.transformers_models +import onnxscript.tools.transformers_models.mistral +from onnxscript._internal.version_utils import has_transformers, torch_older_than + + +class TestExportPhi(unittest.TestCase): + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(torch_older_than("2.4"), reason="fails to export") + def test_phi_export_cpu(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.mistral.get_mistral_model() + ) + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().numpy(), results[0], atol=1e-5) + + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + def test_phi_export_cuda(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.mistral.get_mistral_model() + ) + input_tensors_cpu = input_tensors_many[0] + model = model.to("cuda") + input_tensors = [i.to("cuda") for i in input_tensors_cpu] + expected = model(*input_tensors) + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.detach().cpu().numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CUDAExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().cpu().numpy(), results[0], atol=1e-5) + + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + def test_phi_dort_static(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.mistral.get_mistral_model() + ) + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + + local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) + + compiled_model = torch.compile( + copy.deepcopy(model), + backend=local_aot_ort, + dynamic=False, + fullgraph=True, + ) + + results = compiled_model(*input_tensors) + torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5) + + expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors) + gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) + torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main(verbosity=2)