Skip to content

Commit

Permalink
Add llama model to benchmark (#1593)
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre authored Jun 14, 2024
1 parent caf22fa commit a1164f3
Show file tree
Hide file tree
Showing 17 changed files with 426 additions and 61 deletions.
4 changes: 4 additions & 0 deletions docs/api/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@
```{eval-rst}
.. autofunction:: onnxscript.tools.transformers_models.phi.get_phi_model_config
```

```{eval-rst}
.. autofunction:: onnxscript.tools.transformers_models.llama.get_llama_model_config
```
7 changes: 5 additions & 2 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"expecttest==0.1.6",
"hypothesis",
'numpy==1.24.4; python_version<"3.9"',
'numpy==1.26.0; python_version>="3.9"',
'numpy==1.26.4; python_version>="3.9"',
"packaging",
"parameterized",
"pyinstrument",
Expand All @@ -34,6 +34,7 @@
ONNX_RUNTIME = "onnxruntime==1.17.1"
PYTORCH = "torch==2.2.2"
TORCHVISON = "torchvision==0.17.2"
TRANSFORMERS = "transformers>=4.37.2"
ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = (
"flatbuffers",
"coloredlogs",
Expand All @@ -60,6 +61,7 @@ def test(session):
TORCHVISON,
ONNX,
ONNX_RUNTIME,
TRANSFORMERS,
)
session.install(".", "--no-deps")
session.run("pip", "list")
Expand All @@ -73,6 +75,7 @@ def test_torch_nightly(session):
session.install(
*COMMON_TEST_DEPENDENCIES,
ONNX_RUNTIME,
TRANSFORMERS,
)
session.install("-r", "requirements/ci/requirements-onnx-weekly.txt")
session.install("-r", "requirements/ci/requirements-pytorch-nightly.txt")
Expand All @@ -85,7 +88,7 @@ def test_torch_nightly(session):
@nox.session(tags=["test-onnx-weekly"])
def test_onnx_weekly(session):
"""Test with ONNX weekly (preview) build."""
session.install(*COMMON_TEST_DEPENDENCIES, ONNX_RUNTIME, PYTORCH, TORCHVISON)
session.install(*COMMON_TEST_DEPENDENCIES, ONNX_RUNTIME, PYTORCH, TORCHVISON, TRANSFORMERS)
session.install("-r", "requirements/ci/requirements-onnx-weekly.txt")
session.install(".", "--no-deps")
session.run("pip", "list")
Expand Down
31 changes: 31 additions & 0 deletions onnxscript/_internal/version_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@ def torch_older_than(version: str) -> bool:
)


def is_onnxruntime_training() -> bool:
"""Returns True if the onnxruntime is onnxruntime-training."""
try:
from onnxruntime import training # pylint: disable=import-outside-toplevel

assert training
except ImportError:
# onnxruntime not training
return False

try:
from onnxruntime.capi.onnxruntime_pybind11_state import ( # pylint: disable=import-outside-toplevel
OrtValueVector,
)
except ImportError:
return False

return hasattr(OrtValueVector, "push_back_batch")


def onnxruntime_older_than(version: str) -> bool:
"""Returns True if the onnxruntime version is older than the given version."""
import onnxruntime # pylint: disable=import-outside-toplevel
Expand All @@ -43,3 +63,14 @@ def numpy_older_than(version: str) -> bool:
packaging.version.parse(numpy.__version__).release
< packaging.version.parse(version).release
)


def has_transformers():
"""Tells if transformers is installed."""
try:
import transformers # pylint: disable=import-outside-toplevel

assert transformers
return True # noqa
except ImportError:
return False
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3669,7 +3669,7 @@ def aten_ger(self: TensorType, vec2: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("_operator::getitem")
@torch_op(("_operator::getitem", "aten::getitem"))
def aten_getitem(self: Sequence[TTensor], i: INT64) -> TTensor:
return op.SequenceAt(self, i)

Expand Down
3 changes: 2 additions & 1 deletion onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def rewrite(
if function_rewrite_rules:
for rule_cls in function_rewrite_rules:
count, model_ir = rule_cls().apply_to_model(model_ir)
print(f"Applied {count} of onnxruntime specific function rewrite rules.")
if count > 0:
print(f"Applied {count} of rewrite rules.")
if pattern_rewrite_rules:
if not isinstance(pattern_rewrite_rules, RewriteRuleSet):
# Create a pattern rule-set using provided rules
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/rewriter/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def rewrite(
if function_rules:
for rule_cls in function_rules:
count, model = rule_cls().apply_to_model(model)
print(f"Applied {count} of onnxruntime specific function rewrite rules.")
if count > 0:
print(f"Applied {count} of onnxruntime specific function rewrite rules.")
if pattern_rules:
count = pattern.RewriteRuleSet(pattern_rules).apply_to_model(model)
print(f"Applied {count} of onnxruntime specific pattern rewrite rules.")
Expand Down
3 changes: 3 additions & 0 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,9 @@ def node(self, index: int) -> NodePattern:
def num_nodes(self) -> int:
return len(self._nodes)

def __len__(self) -> int:
return self.num_nodes()

@property
def inputs(self) -> Sequence[ValuePattern]:
return self._inputs
Expand Down
13 changes: 11 additions & 2 deletions onnxscript/tools/benchmark/benchmark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def apply_rule_sets(
if rule_set_name == "llama0":
rule_set = rules.llama_p0_rule_set()
else:
raise AssertionError(f"Unexpected rule_set name {rule_set_name!r}")
raise ValueError(f"Unexpected rule_set name {rule_set_name!r}")

begin = time.perf_counter()
rule_set.apply_to_model(ir_model)
Expand Down Expand Up @@ -380,6 +380,8 @@ def optimize_model_proto(
if verbose:
print(f"[optimize_model_proto] start {value}")

n_nodes = len(model_proto.graph.node)
n_functions = len(model_proto.functions)
begin = time.perf_counter()
if value == "optimize":
model_proto = onnxscript.optimizer.optimize(
Expand All @@ -405,10 +407,17 @@ def optimize_model_proto(
)

end = time.perf_counter() - begin
delta = len(model_proto.graph.node) - n_nodes
deltaf = len(model_proto.functions) - n_functions
if stats:
stats[f"opt_{value}_time"] = end
stats[f"opt_{value}_dnodes"] = delta
stats[f"opt_{value}_dfunctions"] = deltaf
if verbose:
print(f"[optimize_model_proto] {value} done in {end}")
print(
f"[optimize_model_proto] {value} done in {end} "
f"with +/- {delta} nodes, +/- {deltaf} functions"
)

return model_proto

Expand Down
6 changes: 5 additions & 1 deletion onnxscript/tools/benchmark/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@ def main(args=None):
This script can be used to quickly evaluate the improvment made by a pattern optimization
for a particular model.
Example::
Example with a large phi model::
python -m onnxscript.tools.benchmark.export_model --model phi --device cuda --config large --num_hidden_layers=6 --dtype=float32 --dynamic=0 --verbose=1 --exporter=dynamo
Example with a medium llama model::
python -m onnxscript.tools.benchmark.export_model --model llama --device cuda --config large --num_hidden_layers=1 --dtype=float32 --dynamic=0 --verbose=1 --exporter=dynamo
"""
),
repeat=(10, "number of inferences to measure"),
Expand Down
43 changes: 41 additions & 2 deletions onnxscript/tools/benchmark/export_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import unittest

import onnxscript.tools.benchmark.export_model
from onnxscript.tools.transformers_models import has_transformers
from onnxscript._internal.version_utils import (
has_transformers,
is_onnxruntime_training,
torch_older_than,
)


class BenchmarkTest(unittest.TestCase):
Expand All @@ -23,6 +27,8 @@ def test_export_model_phi_cpu_eager(self):
"cpu",
"--exporter",
"eager",
"--model",
"phi",
]
f = io.StringIO()
with contextlib.redirect_stdout(f):
Expand All @@ -32,6 +38,30 @@ def test_export_model_phi_cpu_eager(self):
self.assertIn(":repeat_time,", out)

@unittest.skipIf(not has_transformers(), reason="transformers missing")
def test_export_model_llama_cpu_eager(self):
args = [
"--verbose",
"1",
"--config",
"medium",
"--dtype",
"float32",
"--device",
"cpu",
"--exporter",
"eager",
"--model",
"llama",
]
f = io.StringIO()
with contextlib.redirect_stdout(f):
onnxscript.tools.benchmark.export_model.main(args)

out = f.getvalue()
self.assertIn(":repeat_time,", out)

@unittest.skipIf(not has_transformers(), reason="transformers missing")
@unittest.skipIf(not is_onnxruntime_training(), reason="onnxruntime-training is needed")
def test_export_model_phi_cpu_dynamo(self):
args = [
"--verbose",
Expand All @@ -44,6 +74,8 @@ def test_export_model_phi_cpu_dynamo(self):
"cpu",
"--exporter",
"dynamo",
"--model",
"phi",
]
f = io.StringIO()
with contextlib.redirect_stdout(f):
Expand All @@ -53,6 +85,7 @@ def test_export_model_phi_cpu_dynamo(self):
self.assertIn(":repeat_time,", out)

@unittest.skipIf(not has_transformers(), reason="transformers missing")
@unittest.skipIf(not is_onnxruntime_training(), reason="onnxruntime-training is needed")
def test_export_model_phi_cpu_script(self):
args = [
"--verbose",
Expand All @@ -65,6 +98,8 @@ def test_export_model_phi_cpu_script(self):
"cpu",
"--exporter",
"script",
"--model",
"phi",
]
f = io.StringIO()
with contextlib.redirect_stdout(f):
Expand All @@ -74,6 +109,8 @@ def test_export_model_phi_cpu_script(self):
self.assertIn(":repeat_time,", out)

@unittest.skipIf(not has_transformers(), reason="transformers missing")
@unittest.skipIf(torch_older_than("2.4"), reason="fails to export")
@unittest.skipIf(not is_onnxruntime_training(), reason="onnxruntime-training is needed")
def test_export_model_phi_cpu_dynamo_llama0(self):
args = [
"--verbose",
Expand All @@ -87,7 +124,9 @@ def test_export_model_phi_cpu_dynamo_llama0(self):
"--exporter",
"dynamo",
"--optimization",
"llama0",
"rewrite,optimize,inline,llama0",
"--model",
"phi",
]
f = io.StringIO()
with contextlib.redirect_stdout(f):
Expand Down
60 changes: 43 additions & 17 deletions onnxscript/tools/transformers_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,31 @@
import random
from typing import Any, Sequence

import onnx
import onnx.inliner
import torch

import onnxscript.optimizer
import onnxscript.rewriter

def has_transformers():
"""Tells if transformers is installed."""
try:
import transformers

assert transformers
return True # noqa
except ImportError:
return False
def export_to_onnx(model: Any, *args: Sequence[Any], optimize: bool = True) -> onnx.ModelProto:
"""
Export a model to ONNX.
If optimize is True, it calls *onnxscript.optimizer.optimize*,
*onnxscript.rewriter.rewriter*, *onnx.inliner.inline_local_functions*.
"""
prog = torch.onnx.dynamo_export(model, *args)
model_proto = prog.model_proto
if optimize:
model_proto = onnxscript.optimizer.optimize(
model_proto,
num_iterations=2,
onnx_shape_inference=True,
)
model_proto = onnxscript.rewriter.rewrite(model_proto)
model_proto = onnx.inliner.inline_local_functions(model_proto)
return model_proto


def ids_tensor(
Expand Down Expand Up @@ -78,20 +91,33 @@ def get_model_and_inputs(
config: 'small', 'medium', 'large', ...
dynamic_shapes: dynamic or static shapes
device: 'cpu' or 'cuda'
num_hidden_layers: number of hidden layers
with_mask: one input or two inputs
num_hidden_layers: Number of hidden layers.
with_mask: One input or two inputs.
implementation: eager or sdpa
warmup: number of inputs to generate
repeat: number of inputs to generate for repeat
dtype: if specified, cast the model and the inputs into this type
warmup: Number of inputs to generate.
repeat: Number of inputs to generate for repeat.
dtype: If specified, cast the model and the inputs into this type.
Returns:
model and list of inputs
"""
if model == "phi":
import onnxscript.tools.transformers_models.phi as m
if model == "llama":
import onnxscript.tools.transformers_models.llama as m_llama

tmodel, inputs, dynamic_shapes_def = m_llama.get_llama_model_from_config(
warmup=warmup,
repeat=repeat,
implementation=implementation,
with_mask=with_mask,
num_hidden_layers=num_hidden_layers,
dynamic_shapes=dynamic_shapes,
config=config,
)

elif model == "phi":
import onnxscript.tools.transformers_models.phi as m_phi

tmodel, inputs, dynamic_shapes_def = m.get_phi_model_config(
tmodel, inputs, dynamic_shapes_def = m_phi.get_phi_model_from_config(
warmup=warmup,
repeat=repeat,
implementation=implementation,
Expand All @@ -102,7 +128,7 @@ def get_model_and_inputs(
)

else:
raise AssertionError(f"Model {model!r} is unknown.")
raise ValueError(f"Model {model!r} is unknown.")

if dtype is not None:
dt = getattr(torch, dtype)
Expand Down
Loading

0 comments on commit a1164f3

Please sign in to comment.