Skip to content

Commit

Permalink
Add Phi3 to the list of tests models (#1624)
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
xadupre and justinchuby authored Jun 18, 2024
1 parent 677ba7f commit 7f7fd74
Show file tree
Hide file tree
Showing 6 changed files with 406 additions and 7 deletions.
6 changes: 0 additions & 6 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,6 @@ jobs:
with:
name: Error reports (${{ matrix.name }}-${{ matrix.os }})
path: error_reports
- name: Upload IR profiling results
if: matrix.name == 'py311' || matrix.name == 'py311-onnx-weekly'
uses: actions/upload-artifact@v3
with:
name: IR profiling results
path: tests/ir/serde_test_profiles

dort:
strategy:
Expand Down
4 changes: 4 additions & 0 deletions docs/api/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
.. autofunction:: onnxscript.tools.transformers_models.phi.get_phi_model_from_config
```

```{eval-rst}
.. autofunction:: onnxscript.tools.transformers_models.phi3.get_phi3_model_from_config
```

```{eval-rst}
.. autofunction:: onnxscript.tools.transformers_models.llama.get_llama_model_from_config
```
33 changes: 33 additions & 0 deletions onnxscript/tools/benchmark/export_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
import unittest

import onnxscript.tools.benchmark.export_model
import onnxscript.tools.transformers_models.phi3
from onnxscript._internal.version_utils import (
has_transformers,
is_onnxruntime_training,
torch_older_than,
)

has_phi3 = onnxscript.tools.transformers_models.phi3.has_phi3


class BenchmarkTest(unittest.TestCase):
@unittest.skipIf(not has_transformers(), reason="transformers missing")
Expand Down Expand Up @@ -140,6 +143,36 @@ def test_export_model_phi_cpu_dynamo_llama0(self):
out = f.getvalue()
self.assertIn(":repeat_time,", out)

@unittest.skipIf(not has_transformers(), reason="transformers missing")
@unittest.skipIf(torch_older_than("2.4"), reason="Fails to export with torch<2.4")
@unittest.skipIf(not is_onnxruntime_training(), reason="onnxruntime-training is needed")
@unittest.skipIf(
not has_phi3(), reason="transformers is not recent enough to contain the phi3 model"
)
def test_export_model_phi3_cpu_dynamo_llama0(self):
args = [
"--verbose",
"1",
"--config",
"medium",
"--dtype",
"float32",
"--device",
"cpu",
"--exporter",
"dynamo",
"--optimization",
"rewrite,optimize,inline,llama0",
"--model",
"phi3",
]
f = io.StringIO()
with contextlib.redirect_stdout(f):
onnxscript.tools.benchmark.export_model.main(args)

out = f.getvalue()
self.assertIn(":repeat_time,", out)


if __name__ == "__main__":
unittest.main(verbosity=2)
15 changes: 14 additions & 1 deletion onnxscript/tools/transformers_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_model_and_inputs(
Returns a model and a couple of dummy inputs.
Args:
model: model name, 'phi', 'llama', ...
model: model name, 'phi', 'llama', 'phi3', ...
config: 'small', 'medium', 'large', ...
dynamic_shapes: dynamic or static shapes
device: 'cpu' or 'cuda'
Expand Down Expand Up @@ -127,6 +127,19 @@ def get_model_and_inputs(
config=config,
)

elif model == "phi3":
import onnxscript.tools.transformers_models.phi3 as m_phi3

tmodel, inputs, dynamic_shapes_def = m_phi3.get_phi3_model_from_config(
warmup=warmup,
repeat=repeat,
implementation=implementation,
with_mask=with_mask,
num_hidden_layers=num_hidden_layers,
dynamic_shapes=dynamic_shapes,
config=config,
)

else:
raise ValueError(f"Model {model!r} is unknown.")

Expand Down
257 changes: 257 additions & 0 deletions onnxscript/tools/transformers_models/phi3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Copyright (c) Microsoft Corporation
# Licensed under the MIT License.
# pylint: disable=import-outside-toplevel

from __future__ import annotations

from typing import Any, Sequence

import torch

import onnxscript.tools.transformers_models


def has_phi3() -> bool:
"""Tells if package *transformers* contains the phi3 model."""
try:
from transformers import Phi3Config

assert Phi3Config
except ImportError:
return False
return True


def _prepare_config_and_inputs(
batch_size: int,
seq_length: int,
vocab_size: int,
type_sequence_label_size: int = 2,
type_vocab_size: int = 16,
num_labels: int = 3,
num_choices: int = 4,
use_input_mask: bool = False,
use_token_type_ids: bool = False,
use_labels: bool = False,
) -> tuple[Any, ...]:
input_ids = onnxscript.tools.transformers_models.ids_tensor(
[batch_size, seq_length], vocab_size
)

input_mask = None
if use_input_mask:
input_mask = torch.tril(torch.ones(batch_size, seq_length))

token_type_ids = None
if use_token_type_ids:
assert type_vocab_size > 0, "type_vocab_size is null"
token_type_ids = onnxscript.tools.transformers_models.ids_tensor(
[batch_size, seq_length], type_vocab_size
)

sequence_labels = None
token_labels = None
choice_labels = None
if use_labels:
assert type_sequence_label_size > 0, "type_sequence_label_size is null"
assert num_labels > 0, "num_labels is null"
assert num_choices > 0, "num_choices is null"
sequence_labels = onnxscript.tools.transformers_models.ids_tensor(
[batch_size], type_sequence_label_size
)
token_labels = onnxscript.tools.transformers_models.ids_tensor(
[batch_size, seq_length], num_labels
)
choice_labels = onnxscript.tools.transformers_models.ids_tensor(
[batch_size], num_choices
)

return (
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
)


def get_phi3_model(
input_dims: Sequence[tuple[int, int]] = ((13, 7), (14, 7), (15, 8)),
hidden_size: int = 32,
num_hidden_layers: int = 2,
vocab_size: int = 99,
intermediate_size: int = 16,
max_position_embeddings: int = 512,
num_attention_heads: int = 4,
num_key_value_heads: int = 2,
_attn_implementation: str = "eager", # needed value to remove graph breaks
with_mask: bool = True,
) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]:
"""
Returns a model.
See `PhiConfig
<https://huggingface.co/docs/transformers/main/en/model_doc/phi#transformers.PhiConfig>`_.
The parameters are chosen for a unit test configuration from `test_modeling_phi.py
<https://github.com/huggingface/transformers/blob/main/tests/models/phi/test_modeling_phi.py>`_.
"""
from transformers import Phi3Config, Phi3Model

dynamic_shapes = {0: {0: "batch", 1: "length"}}
if with_mask:
dynamic_shapes.update({1: {0: "batch", 1: "length"}})

config = Phi3Config(
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
vocab_size=vocab_size,
intermediate_size=intermediate_size,
max_position_embeddings=max_position_embeddings,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
pad_token_id=min(32000, vocab_size - 1),
)
if _attn_implementation:
config._attn_implementation = _attn_implementation # pylint: disable=protected-access

if with_mask:

class Phi3ModelWrapperNoMask(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.model = Phi3Model(config)

def forward(self, input_ids, attention_mask):
model_output = self.model(input_ids, attention_mask=attention_mask)
return model_output.to_tuple()

def generate_example_inputs_no_mask(batch: int, seq: int, vocab_size: int):
(
input_ids,
_, # token_type_ids,
input_mask,
_, # sequence_labels,
_, # token_labels,
_, # choice_labels,
) = _prepare_config_and_inputs(
batch_size=batch,
seq_length=seq,
vocab_size=vocab_size,
use_input_mask=True,
)
return input_ids, input_mask

example_args_collection = []
for b, s in input_dims:
example_args_collection.append(generate_example_inputs_no_mask(b, s, vocab_size))

return Phi3ModelWrapperNoMask(config), example_args_collection, dynamic_shapes

# no mask

class Phi3ModelWrapper(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.model = Phi3Model(config)

def forward(self, input_ids):
model_output = self.model(input_ids)
return model_output.to_tuple()

def generate_example_inputs(batch: int, seq: int, vocab_size: int):
(
input_ids,
*_,
# token_type_ids,
# input_mask,
# sequence_labels,
# token_labels,
# choice_labels,
) = _prepare_config_and_inputs(
batch_size=batch,
seq_length=seq,
vocab_size=vocab_size,
use_input_mask=True,
)
return (input_ids,)

example_args_collection = []
for b, s in input_dims:
example_args_collection.append(generate_example_inputs(b, s, vocab_size))

return Phi3ModelWrapper(config), example_args_collection, dynamic_shapes


def get_phi3_model_from_config(
warmup: int = 5,
repeat: int = 10,
config: str = "small",
num_hidden_layers: int = 1,
implementation: str = "eager",
dynamic_shapes: bool = False,
with_mask: bool = True,
) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]:
"""
Returns a model Phi to test or benchmark.
Args:
warmup: Number of inputs to generate.
repeat: Number of inputs to generate for repeat.
config: small, medium or large
num_hidden_layers: number of hidden layers
implementation: eager or sdpa
with_mask: One or two inputs.
dynamic_shapes: dynamic shapes or not
Returns:
Model and list of inputs.
"""
if config == "small":
conf_dict = dict(
input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm(
dynamic_shapes, warmup, repeat
),
hidden_size=32,
num_hidden_layers=num_hidden_layers,
vocab_size=99,
intermediate_size=16,
max_position_embeddings=512,
num_attention_heads=4,
num_key_value_heads=2,
_attn_implementation=implementation,
with_mask=with_mask,
)
elif config == "medium":
conf_dict = dict(
input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm(
dynamic_shapes, warmup, repeat
),
hidden_size=1024,
num_hidden_layers=num_hidden_layers,
vocab_size=1024,
intermediate_size=1024,
num_attention_heads=4,
num_key_value_heads=4,
max_position_embeddings=1024,
_attn_implementation=implementation,
with_mask=with_mask,
)
elif config in ("large", "default"):
conf_dict = dict(
input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm(
dynamic_shapes, warmup, repeat
),
hidden_size=2048,
num_hidden_layers=num_hidden_layers,
vocab_size=51200,
intermediate_size=8192,
num_attention_heads=32,
num_key_value_heads=None,
max_position_embeddings=2048,
_attn_implementation=implementation,
with_mask=with_mask,
)
else:
raise ValueError(f"Unexpected configuration {config!r}.")

return get_phi3_model(**conf_dict) # type: ignore[arg-type]
Loading

0 comments on commit 7f7fd74

Please sign in to comment.