diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 4e918552b..3ff22e1c7 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -99,12 +99,6 @@ jobs: with: name: Error reports (${{ matrix.name }}-${{ matrix.os }}) path: error_reports - - name: Upload IR profiling results - if: matrix.name == 'py311' || matrix.name == 'py311-onnx-weekly' - uses: actions/upload-artifact@v3 - with: - name: IR profiling results - path: tests/ir/serde_test_profiles dort: strategy: diff --git a/docs/api/tools.md b/docs/api/tools.md index 459e6ac54..9f565d613 100644 --- a/docs/api/tools.md +++ b/docs/api/tools.md @@ -10,6 +10,10 @@ .. autofunction:: onnxscript.tools.transformers_models.phi.get_phi_model_from_config ``` +```{eval-rst} +.. autofunction:: onnxscript.tools.transformers_models.phi3.get_phi3_model_from_config +``` + ```{eval-rst} .. autofunction:: onnxscript.tools.transformers_models.llama.get_llama_model_from_config ``` diff --git a/onnxscript/tools/benchmark/export_model_test.py b/onnxscript/tools/benchmark/export_model_test.py index b68550227..6806e3135 100644 --- a/onnxscript/tools/benchmark/export_model_test.py +++ b/onnxscript/tools/benchmark/export_model_test.py @@ -6,12 +6,15 @@ import unittest import onnxscript.tools.benchmark.export_model +import onnxscript.tools.transformers_models.phi3 from onnxscript._internal.version_utils import ( has_transformers, is_onnxruntime_training, torch_older_than, ) +has_phi3 = onnxscript.tools.transformers_models.phi3.has_phi3 + class BenchmarkTest(unittest.TestCase): @unittest.skipIf(not has_transformers(), reason="transformers missing") @@ -140,6 +143,36 @@ def test_export_model_phi_cpu_dynamo_llama0(self): out = f.getvalue() self.assertIn(":repeat_time,", out) + @unittest.skipIf(not has_transformers(), reason="transformers missing") + @unittest.skipIf(torch_older_than("2.4"), reason="Fails to export with torch<2.4") + @unittest.skipIf(not is_onnxruntime_training(), reason="onnxruntime-training is needed") + @unittest.skipIf( + not has_phi3(), reason="transformers is not recent enough to contain the phi3 model" + ) + def test_export_model_phi3_cpu_dynamo_llama0(self): + args = [ + "--verbose", + "1", + "--config", + "medium", + "--dtype", + "float32", + "--device", + "cpu", + "--exporter", + "dynamo", + "--optimization", + "rewrite,optimize,inline,llama0", + "--model", + "phi3", + ] + f = io.StringIO() + with contextlib.redirect_stdout(f): + onnxscript.tools.benchmark.export_model.main(args) + + out = f.getvalue() + self.assertIn(":repeat_time,", out) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/tools/transformers_models/__init__.py b/onnxscript/tools/transformers_models/__init__.py index 1340d544b..ca9a77a3c 100644 --- a/onnxscript/tools/transformers_models/__init__.py +++ b/onnxscript/tools/transformers_models/__init__.py @@ -87,7 +87,7 @@ def get_model_and_inputs( Returns a model and a couple of dummy inputs. Args: - model: model name, 'phi', 'llama', ... + model: model name, 'phi', 'llama', 'phi3', ... config: 'small', 'medium', 'large', ... dynamic_shapes: dynamic or static shapes device: 'cpu' or 'cuda' @@ -127,6 +127,19 @@ def get_model_and_inputs( config=config, ) + elif model == "phi3": + import onnxscript.tools.transformers_models.phi3 as m_phi3 + + tmodel, inputs, dynamic_shapes_def = m_phi3.get_phi3_model_from_config( + warmup=warmup, + repeat=repeat, + implementation=implementation, + with_mask=with_mask, + num_hidden_layers=num_hidden_layers, + dynamic_shapes=dynamic_shapes, + config=config, + ) + else: raise ValueError(f"Model {model!r} is unknown.") diff --git a/onnxscript/tools/transformers_models/phi3.py b/onnxscript/tools/transformers_models/phi3.py new file mode 100644 index 000000000..ad8be3eeb --- /dev/null +++ b/onnxscript/tools/transformers_models/phi3.py @@ -0,0 +1,257 @@ +# Copyright (c) Microsoft Corporation +# Licensed under the MIT License. +# pylint: disable=import-outside-toplevel + +from __future__ import annotations + +from typing import Any, Sequence + +import torch + +import onnxscript.tools.transformers_models + + +def has_phi3() -> bool: + """Tells if package *transformers* contains the phi3 model.""" + try: + from transformers import Phi3Config + + assert Phi3Config + except ImportError: + return False + return True + + +def _prepare_config_and_inputs( + batch_size: int, + seq_length: int, + vocab_size: int, + type_sequence_label_size: int = 2, + type_vocab_size: int = 16, + num_labels: int = 3, + num_choices: int = 4, + use_input_mask: bool = False, + use_token_type_ids: bool = False, + use_labels: bool = False, +) -> tuple[Any, ...]: + input_ids = onnxscript.tools.transformers_models.ids_tensor( + [batch_size, seq_length], vocab_size + ) + + input_mask = None + if use_input_mask: + input_mask = torch.tril(torch.ones(batch_size, seq_length)) + + token_type_ids = None + if use_token_type_ids: + assert type_vocab_size > 0, "type_vocab_size is null" + token_type_ids = onnxscript.tools.transformers_models.ids_tensor( + [batch_size, seq_length], type_vocab_size + ) + + sequence_labels = None + token_labels = None + choice_labels = None + if use_labels: + assert type_sequence_label_size > 0, "type_sequence_label_size is null" + assert num_labels > 0, "num_labels is null" + assert num_choices > 0, "num_choices is null" + sequence_labels = onnxscript.tools.transformers_models.ids_tensor( + [batch_size], type_sequence_label_size + ) + token_labels = onnxscript.tools.transformers_models.ids_tensor( + [batch_size, seq_length], num_labels + ) + choice_labels = onnxscript.tools.transformers_models.ids_tensor( + [batch_size], num_choices + ) + + return ( + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) + + +def get_phi3_model( + input_dims: Sequence[tuple[int, int]] = ((13, 7), (14, 7), (15, 8)), + hidden_size: int = 32, + num_hidden_layers: int = 2, + vocab_size: int = 99, + intermediate_size: int = 16, + max_position_embeddings: int = 512, + num_attention_heads: int = 4, + num_key_value_heads: int = 2, + _attn_implementation: str = "eager", # needed value to remove graph breaks + with_mask: bool = True, +) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]: + """ + Returns a model. + See `PhiConfig + `_. + The parameters are chosen for a unit test configuration from `test_modeling_phi.py + `_. + """ + from transformers import Phi3Config, Phi3Model + + dynamic_shapes = {0: {0: "batch", 1: "length"}} + if with_mask: + dynamic_shapes.update({1: {0: "batch", 1: "length"}}) + + config = Phi3Config( + hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + vocab_size=vocab_size, + intermediate_size=intermediate_size, + max_position_embeddings=max_position_embeddings, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + pad_token_id=min(32000, vocab_size - 1), + ) + if _attn_implementation: + config._attn_implementation = _attn_implementation # pylint: disable=protected-access + + if with_mask: + + class Phi3ModelWrapperNoMask(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.model = Phi3Model(config) + + def forward(self, input_ids, attention_mask): + model_output = self.model(input_ids, attention_mask=attention_mask) + return model_output.to_tuple() + + def generate_example_inputs_no_mask(batch: int, seq: int, vocab_size: int): + ( + input_ids, + _, # token_type_ids, + input_mask, + _, # sequence_labels, + _, # token_labels, + _, # choice_labels, + ) = _prepare_config_and_inputs( + batch_size=batch, + seq_length=seq, + vocab_size=vocab_size, + use_input_mask=True, + ) + return input_ids, input_mask + + example_args_collection = [] + for b, s in input_dims: + example_args_collection.append(generate_example_inputs_no_mask(b, s, vocab_size)) + + return Phi3ModelWrapperNoMask(config), example_args_collection, dynamic_shapes + + # no mask + + class Phi3ModelWrapper(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.model = Phi3Model(config) + + def forward(self, input_ids): + model_output = self.model(input_ids) + return model_output.to_tuple() + + def generate_example_inputs(batch: int, seq: int, vocab_size: int): + ( + input_ids, + *_, + # token_type_ids, + # input_mask, + # sequence_labels, + # token_labels, + # choice_labels, + ) = _prepare_config_and_inputs( + batch_size=batch, + seq_length=seq, + vocab_size=vocab_size, + use_input_mask=True, + ) + return (input_ids,) + + example_args_collection = [] + for b, s in input_dims: + example_args_collection.append(generate_example_inputs(b, s, vocab_size)) + + return Phi3ModelWrapper(config), example_args_collection, dynamic_shapes + + +def get_phi3_model_from_config( + warmup: int = 5, + repeat: int = 10, + config: str = "small", + num_hidden_layers: int = 1, + implementation: str = "eager", + dynamic_shapes: bool = False, + with_mask: bool = True, +) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]: + """ + Returns a model Phi to test or benchmark. + + Args: + warmup: Number of inputs to generate. + repeat: Number of inputs to generate for repeat. + config: small, medium or large + num_hidden_layers: number of hidden layers + implementation: eager or sdpa + with_mask: One or two inputs. + dynamic_shapes: dynamic shapes or not + + Returns: + Model and list of inputs. + """ + if config == "small": + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=32, + num_hidden_layers=num_hidden_layers, + vocab_size=99, + intermediate_size=16, + max_position_embeddings=512, + num_attention_heads=4, + num_key_value_heads=2, + _attn_implementation=implementation, + with_mask=with_mask, + ) + elif config == "medium": + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=1024, + num_hidden_layers=num_hidden_layers, + vocab_size=1024, + intermediate_size=1024, + num_attention_heads=4, + num_key_value_heads=4, + max_position_embeddings=1024, + _attn_implementation=implementation, + with_mask=with_mask, + ) + elif config in ("large", "default"): + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=2048, + num_hidden_layers=num_hidden_layers, + vocab_size=51200, + intermediate_size=8192, + num_attention_heads=32, + num_key_value_heads=None, + max_position_embeddings=2048, + _attn_implementation=implementation, + with_mask=with_mask, + ) + else: + raise ValueError(f"Unexpected configuration {config!r}.") + + return get_phi3_model(**conf_dict) # type: ignore[arg-type] diff --git a/onnxscript/tools/transformers_models/phi3_test.py b/onnxscript/tools/transformers_models/phi3_test.py new file mode 100644 index 000000000..62bb6faf8 --- /dev/null +++ b/onnxscript/tools/transformers_models/phi3_test.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pylint: disable=not-callable + +import copy +import sys +import unittest + +import numpy as np +import onnxruntime +import torch + +import onnxscript.optimizer +import onnxscript.rewriter +import onnxscript.tools.training_helper +import onnxscript.tools.transformers_models +import onnxscript.tools.transformers_models.phi3 +from onnxscript._internal.version_utils import has_transformers, torch_older_than + +has_phi3 = onnxscript.tools.transformers_models.phi3.has_phi3 + + +class TestExportPhi3(unittest.TestCase): + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(not has_phi3(), reason="transformers is not recent enough") + @unittest.skipIf(torch_older_than("2.4"), reason="fails to export") + def test_phi3_export_cpu(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.phi3.get_phi3_model() + ) + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().numpy(), results[0], atol=1e-5) + + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(not has_phi3(), reason="transformers is not recent enough") + def test_phi3_export_cuda(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.phi3.get_phi3_model() + ) + input_tensors_cpu = input_tensors_many[0] + model = model.to("cuda") + input_tensors = [i.to("cuda") for i in input_tensors_cpu] + expected = model(*input_tensors) + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.detach().cpu().numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CUDAExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().cpu().numpy(), results[0], atol=1e-5) + + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(not has_phi3(), reason="transformers is not recent enough") + @unittest.skipIf( + True, + reason="You are not running the flash-attention implementation, expect numerical differences.", + ) + def test_phi3_dort_static(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.phi3.get_phi3_model() + ) + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + + local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) + + compiled_model = torch.compile( + copy.deepcopy(model), + backend=local_aot_ort, + dynamic=False, + fullgraph=True, + ) + + results = compiled_model(*input_tensors) + torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5) + + expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors) + gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) + torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main(verbosity=2)