Skip to content

Commit

Permalink
Add mistral to the list of tests models
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Jun 17, 2024
1 parent dc31a6e commit f50e5d6
Show file tree
Hide file tree
Showing 4 changed files with 368 additions and 0 deletions.
27 changes: 27 additions & 0 deletions onnxscript/tools/benchmark/export_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,33 @@ def test_export_model_phi_cpu_eager(self):
out = f.getvalue()
self.assertIn(":repeat_time,", out)

@unittest.skipIf(not has_transformers(), reason="transformers missing")
@unittest.skipIf(torch_older_than("2.4"), reason="fails to export")
@unittest.skipIf(not is_onnxruntime_training(), reason="onnxruntime-training is needed")
def test_export_model_mistral_cpu_dynamo_llama0(self):
args = [

Check warning on line 44 in onnxscript/tools/benchmark/export_model_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/export_model_test.py#L44

Added line #L44 was not covered by tests
"--verbose",
"1",
"--config",
"medium",
"--dtype",
"float32",
"--device",
"cpu",
"--exporter",
"dynamo",
"--optimization",
"rewrite,optimize,inline,llama0",
"--model",
"mistral",
]
f = io.StringIO()

Check warning on line 60 in onnxscript/tools/benchmark/export_model_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/export_model_test.py#L60

Added line #L60 was not covered by tests
with contextlib.redirect_stdout(f):
onnxscript.tools.benchmark.export_model.main(args)

Check warning on line 62 in onnxscript/tools/benchmark/export_model_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/export_model_test.py#L62

Added line #L62 was not covered by tests

out = f.getvalue()
self.assertIn(":repeat_time,", out)

Check warning on line 65 in onnxscript/tools/benchmark/export_model_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/export_model_test.py#L64-L65

Added lines #L64 - L65 were not covered by tests

@unittest.skipIf(not has_transformers(), reason="transformers missing")
def test_export_model_llama_cpu_eager(self):
args = [
Expand Down
13 changes: 13 additions & 0 deletions onnxscript/tools/transformers_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,19 @@ def get_model_and_inputs(
config=config,
)

elif model == "mistral":
import onnxscript.tools.transformers_models.mistral as m_mistral

Check warning on line 118 in onnxscript/tools/transformers_models/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/__init__.py#L118

Added line #L118 was not covered by tests

tmodel, inputs, dynamic_shapes_def = m_mistral.get_mistral_model_from_config(

Check warning on line 120 in onnxscript/tools/transformers_models/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/__init__.py#L120

Added line #L120 was not covered by tests
warmup=warmup,
repeat=repeat,
implementation=implementation,
with_mask=with_mask,
num_hidden_layers=num_hidden_layers,
dynamic_shapes=dynamic_shapes,
config=config,
)

elif model == "phi":
import onnxscript.tools.transformers_models.phi as m_phi

Expand Down
239 changes: 239 additions & 0 deletions onnxscript/tools/transformers_models/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=import-outside-toplevel
from __future__ import annotations

from typing import Any, Sequence

import torch

import onnxscript.tools.transformers_models


def _prepare_config_and_inputs(
batch_size: int,
seq_length: int,
vocab_size: int,
type_sequence_label_size: int = 2,
type_vocab_size: int = 16,
num_labels: int = 3,
num_choices: int = 4,
use_input_mask: bool = False,
use_token_type_ids: bool = False,
use_labels: bool = False,
) -> tuple[Any, ...]:
import torch

Check warning

Code scanning / lintrunner

PYLINT/W0621 Warning

Redefining name 'torch' from outer scope (line 10) (redefined-outer-name)
See redefined-outer-name. To disable, use # pylint: disable=redefined-outer-name

Check warning

Code scanning / lintrunner

PYLINT/W0404 Warning

Reimport 'torch' (imported line 10) (reimported)
See reimported. To disable, use # pylint: disable=reimported

input_ids = onnxscript.tools.transformers_models.ids_tensor(
[batch_size, seq_length], vocab_size
)

input_mask = None
if use_input_mask:
input_mask = torch.tril(torch.ones(batch_size, seq_length))

token_type_ids = None
if use_token_type_ids:
assert type_vocab_size > 0, "type_vocab_size is null"
token_type_ids = onnxscript.tools.transformers_models.ids_tensor(

Check warning on line 40 in onnxscript/tools/transformers_models/mistral.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral.py#L39-L40

Added lines #L39 - L40 were not covered by tests
[batch_size, seq_length], type_vocab_size
)

sequence_labels = None
token_labels = None
choice_labels = None
if use_labels:
assert type_sequence_label_size > 0, "type_sequence_label_size is null"
assert num_labels > 0, "num_labels is null"
assert num_choices > 0, "num_choices is null"
sequence_labels = onnxscript.tools.transformers_models.ids_tensor(

Check warning on line 51 in onnxscript/tools/transformers_models/mistral.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral.py#L48-L51

Added lines #L48 - L51 were not covered by tests
[batch_size], type_sequence_label_size
)
token_labels = onnxscript.tools.transformers_models.ids_tensor(

Check warning on line 54 in onnxscript/tools/transformers_models/mistral.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral.py#L54

Added line #L54 was not covered by tests
[batch_size, seq_length], num_labels
)
choice_labels = onnxscript.tools.transformers_models.ids_tensor(

Check warning on line 57 in onnxscript/tools/transformers_models/mistral.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral.py#L57

Added line #L57 was not covered by tests
[batch_size], num_choices
)

return (
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
)


def get_mistral_model(
input_dims: Sequence[tuple[int, int]] = ((13, 7), (14, 7), (15, 8)),
hidden_size=32,
num_hidden_layers=2,
vocab_size=99,
intermediate_size=16,
max_position_embeddings=512,
num_attention_heads=2,
num_key_value_heads=2,
sliding_window=4096,
_attn_implementation="eager", # needed value to remove graph breaks
with_mask: bool = True,
) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]:
"""
Returns a model.
See `MistralConfig
<https://huggingface.co/docs/transformers/main/en/model_doc/mistral#transformers.MistralConfig>`_.
The parameters are chosen for a unit test configuration.
"""
import torch

Check warning

Code scanning / lintrunner

PYLINT/W0621 Warning

Redefining name 'torch' from outer scope (line 10) (redefined-outer-name)
See redefined-outer-name. To disable, use # pylint: disable=redefined-outer-name

Check warning

Code scanning / lintrunner

PYLINT/W0404 Warning

Reimport 'torch' (imported line 10) (reimported)
See reimported. To disable, use # pylint: disable=reimported
from transformers import MistralConfig
from transformers.models.mistral.modeling_mistral import MistralModel

config = MistralConfig(
num_hidden_layers=num_hidden_layers,
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
max_position_embeddings=max_position_embeddings,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
sliding_window=sliding_window,
)

dynamic_shapes = {0: {0: "batch", 1: "length"}}
if with_mask:
dynamic_shapes.update({1: {0: "batch", 1: "length"}})

if _attn_implementation:
config._attn_implementation = _attn_implementation

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning

Access to a protected member _attn_implementation of a client class (protected-access)
See protected-access. To disable, use # pylint: disable=protected-access

def generate_example_inputs(batch: int, seq: int, vocab_size: int, with_mask: bool):
(
input_ids,
_, # token_type_ids,
input_mask,
_, # sequence_labels,
_, # token_labels,
_, # choice_labels,
) = _prepare_config_and_inputs(
batch_size=batch,
seq_length=seq,
vocab_size=vocab_size,
use_input_mask=with_mask,
)
if with_mask:
return input_ids, input_mask
return (input_ids,)

Check warning on line 128 in onnxscript/tools/transformers_models/mistral.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral.py#L128

Added line #L128 was not covered by tests

if with_mask:

class MistralModelWrapperWithMask(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.model = MistralModel(config)

def forward(self, input_ids, attention_mask):
model_output = self.model(input_ids, attention_mask=attention_mask)
return model_output.to_tuple()

example_args_collection = []
for b, s in input_dims:
example_args_collection.append(
generate_example_inputs(b, s, vocab_size, with_mask)
)

return MistralModelWrapperWithMask(config), example_args_collection, dynamic_shapes

class MistralModelWrapper(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.model = MistralModel(config)

Check warning on line 152 in onnxscript/tools/transformers_models/mistral.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral.py#L149-L152

Added lines #L149 - L152 were not covered by tests

def forward(self, input_ids):
model_output = self.model(input_ids)
return model_output.to_tuple()

Check warning on line 156 in onnxscript/tools/transformers_models/mistral.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral.py#L154-L156

Added lines #L154 - L156 were not covered by tests

example_args_collection = []

Check warning on line 158 in onnxscript/tools/transformers_models/mistral.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral.py#L158

Added line #L158 was not covered by tests
for b, s in input_dims:
example_args_collection.append(generate_example_inputs(b, s, vocab_size, with_mask))

Check warning on line 160 in onnxscript/tools/transformers_models/mistral.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral.py#L160

Added line #L160 was not covered by tests

return MistralModelWrapper(config), example_args_collection, dynamic_shapes

Check warning on line 162 in onnxscript/tools/transformers_models/mistral.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral.py#L162

Added line #L162 was not covered by tests


def get_mistral_model_from_config(
warmup: int = 5,
repeat: int = 10,
config: str = "small",
num_hidden_layers: int = 1,
implementation: str = "eager",
dynamic_shapes: bool = False,
with_mask: bool = True,
) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]:
"""
Returns a model Phi to test or benchmark.
Args:
warmup: Number of inputs to generate.
repeat: Number of inputs to generate for repeat.
config: small, medium or large
num_hidden_layers: number of hidden layers
implementation: eager or sdpa
with_mask: One or two inputs.
dynamic_shapes: dynamic shapes or not
Returns:
Model and list of inputs.
"""
if config == "small":
conf_dict = dict(

Check warning on line 190 in onnxscript/tools/transformers_models/mistral.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral.py#L190

Added line #L190 was not covered by tests
input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm(
dynamic_shapes, warmup, repeat
),
hidden_size=32,
num_hidden_layers=num_hidden_layers,
vocab_size=99,
intermediate_size=16,
max_position_embeddings=512,
num_attention_heads=4,
num_key_value_heads=2,
_attn_implementation=implementation,
with_mask=with_mask,
)
elif config == "medium":
conf_dict = dict(

Check warning on line 205 in onnxscript/tools/transformers_models/mistral.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral.py#L205

Added line #L205 was not covered by tests
input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm(
dynamic_shapes, warmup, repeat
),
hidden_size=1024,
num_hidden_layers=num_hidden_layers,
vocab_size=1024,
intermediate_size=1024,
num_attention_heads=4,
num_key_value_heads=4,
max_position_embeddings=1024,
sliding_window=4096,
_attn_implementation=implementation,
with_mask=with_mask,
)
elif config in ("large", "default"):
conf_dict = dict(

Check warning on line 221 in onnxscript/tools/transformers_models/mistral.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral.py#L221

Added line #L221 was not covered by tests
input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm(
dynamic_shapes, warmup, repeat
),
hidden_size=4096,
num_hidden_layers=num_hidden_layers,
vocab_size=32000,
intermediate_size=14336,
num_attention_heads=32,
num_key_value_heads=8,
max_position_embeddings=131072,
sliding_window=4096,
_attn_implementation=implementation,
with_mask=with_mask,
)
else:
raise ValueError(f"Unexpected configuration {config!r}.")

Check warning on line 237 in onnxscript/tools/transformers_models/mistral.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral.py#L237

Added line #L237 was not covered by tests

return get_mistral_model(**conf_dict) # type: ignore[arg-type]

Check warning on line 239 in onnxscript/tools/transformers_models/mistral.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral.py#L239

Added line #L239 was not covered by tests
89 changes: 89 additions & 0 deletions onnxscript/tools/transformers_models/mistral_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=not-callable

import copy
import sys
import unittest

import numpy as np
import onnxruntime
import torch

import onnxscript.optimizer
import onnxscript.rewriter
import onnxscript.tools.training_helper
import onnxscript.tools.transformers_models
import onnxscript.tools.transformers_models.mistral
from onnxscript._internal.version_utils import has_transformers, torch_older_than


class TestExportPhi(unittest.TestCase):
@unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows")
@unittest.skipIf(not has_transformers(), reason="transformers is missing")
@unittest.skipIf(torch_older_than("2.4"), reason="fails to export")
def test_phi_export_cpu(self):
model, input_tensors_many, _ = (
onnxscript.tools.transformers_models.mistral.get_mistral_model()
)
input_tensors = input_tensors_many[0]
expected = model(*input_tensors)
proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors)
names = [i.name for i in proto.graph.input]
np_input_tensors = [x.numpy() for x in input_tensors]
feeds = dict(zip(names, np_input_tensors))
sess = onnxruntime.InferenceSession(
proto.SerializeToString(), providers=["CPUExecutionProvider"]
)
results = sess.run(None, feeds)
np.testing.assert_allclose(expected[0].detach().numpy(), results[0], atol=1e-5)

@unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows")
@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available")
@unittest.skipIf(not has_transformers(), reason="transformers is missing")
def test_phi_export_cuda(self):
model, input_tensors_many, _ = (

Check warning on line 45 in onnxscript/tools/transformers_models/mistral_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral_test.py#L45

Added line #L45 was not covered by tests
onnxscript.tools.transformers_models.mistral.get_mistral_model()
)
input_tensors_cpu = input_tensors_many[0]
model = model.to("cuda")

Check warning on line 49 in onnxscript/tools/transformers_models/mistral_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral_test.py#L48-L49

Added lines #L48 - L49 were not covered by tests
input_tensors = [i.to("cuda") for i in input_tensors_cpu]
expected = model(*input_tensors)
proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors)

Check warning on line 52 in onnxscript/tools/transformers_models/mistral_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral_test.py#L51-L52

Added lines #L51 - L52 were not covered by tests
names = [i.name for i in proto.graph.input]
np_input_tensors = [x.detach().cpu().numpy() for x in input_tensors]
feeds = dict(zip(names, np_input_tensors))
sess = onnxruntime.InferenceSession(

Check warning on line 56 in onnxscript/tools/transformers_models/mistral_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral_test.py#L55-L56

Added lines #L55 - L56 were not covered by tests
proto.SerializeToString(), providers=["CUDAExecutionProvider"]
)
results = sess.run(None, feeds)
np.testing.assert_allclose(expected[0].detach().cpu().numpy(), results[0], atol=1e-5)

Check warning on line 60 in onnxscript/tools/transformers_models/mistral_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral_test.py#L59-L60

Added lines #L59 - L60 were not covered by tests

@unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows")
@unittest.skipIf(not has_transformers(), reason="transformers is missing")
def test_phi_dort_static(self):
model, input_tensors_many, _ = (
onnxscript.tools.transformers_models.mistral.get_mistral_model()
)
input_tensors = input_tensors_many[0]
expected = model(*input_tensors)

local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False)

compiled_model = torch.compile(
copy.deepcopy(model),
backend=local_aot_ort,
dynamic=False,
fullgraph=True,
)

results = compiled_model(*input_tensors)
torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5)

expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors)
gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors)
torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
unittest.main(verbosity=2)

Check warning on line 89 in onnxscript/tools/transformers_models/mistral_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral_test.py#L89

Added line #L89 was not covered by tests

0 comments on commit f50e5d6

Please sign in to comment.