Skip to content

Commit

Permalink
Add CI to check with onnxruntime-traning and transformers (#1609)
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre authored Jun 18, 2024
1 parent d620466 commit 677ba7f
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 15 deletions.
26 changes: 26 additions & 0 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,32 @@ jobs:
name: IR profiling results
path: tests/ir/serde_test_profiles

dort:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
transformers: ["4.37.2", "4.41.2"]
torch: ["release", "nightly"]
python_version: ["3.11"]
nox-tag: ["test-dort"]
name:
- dort
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python_version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
- name: Install nox
run: python -m pip install nox
- name: Pull Test Data
run: git lfs pull
- run: |
nox -t ${{ matrix.nox-tag }} --forcecolor -- ${{ matrix.torch }} ${{ matrix.transformers }}
name: Run tests
build_docs:
strategy:
fail-fast: false
Expand Down
30 changes: 29 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
'numpy==1.26.4; python_version>="3.9"',
"packaging",
"parameterized",
"pyinstrument",
"pytest-cov",
"pytest-randomly",
"pytest-subtests",
Expand Down Expand Up @@ -153,3 +152,32 @@ def test_experimental_torchlib_onnx_ir(session):
*session.posargs,
env={"TORCHLIB_EXPERIMENTAL_USE_IR": "1"},
)


@nox.session(tags=["test-dort"])
def test_dort(session):
"""Test the conversion of a couple of models from transformers."""
session.install(
*COMMON_TEST_DEPENDENCIES,
)
torch_version, transformers_version = session.posargs

if torch_version == "nighly":
session.install(
"--pre",
"torch",
"torchvision",
"torchaudio",
"--index-url",
"https://download.pytorch.org/whl/nightly/cpu",
)
else:
session.install("torch", "torchvision", "torchaudio")

session.install("torch", "torchvision", "torchaudio")
session.install(f"transformers=={transformers_version}")
session.install("onnxruntime-training==1.17.1")

session.run("pip", "list")
session.run("pytest", "onnxscript")
session.run("pytest", "tests")
5 changes: 5 additions & 0 deletions onnxscript/tools/benchmark/export_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ def test_export_model_llama_cpu_eager(self):

@unittest.skipIf(not has_transformers(), reason="transformers missing")
@unittest.skipIf(not is_onnxruntime_training(), reason="onnxruntime-training is needed")
@unittest.skipIf(
torch_older_than("2.4"),
reason="TypeError: _functionalize_sync(): "
"argument 't' (position 1) must be Tensor, not NoneType",
)
def test_export_model_phi_cpu_dynamo(self):
args = [
"--verbose",
Expand Down
1 change: 0 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ beartype!=0.16.0
expecttest==0.1.6
hypothesis
parameterized
pyinstrument
pytest-cov
pytest-randomly
pytest-subtests
Expand Down
3 changes: 2 additions & 1 deletion tests/common/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import onnx
import onnxruntime
import torch

from onnxscript import optimizer
from onnxscript._legacy_ir import visitor
Expand All @@ -29,7 +30,7 @@ def skip_if_no_cuda(reason: str):
def skip_dec(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if not onnxruntime.get_device() == "GPU":
if not torch.cuda.is_available() or not onnxruntime.get_device() == "GPU":
raise unittest.SkipTest(f"GPU is not available. {reason}")
return func(self, *args, **kwargs)

Expand Down
13 changes: 1 addition & 12 deletions tests/ir/serde_roundtrip_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=import-outside-toplevel
from __future__ import annotations

import pathlib
Expand All @@ -8,7 +9,6 @@
import onnx
import onnx.backend.test
import parameterized
import pyinstrument

import onnxscript.testing
from onnxscript import ir
Expand All @@ -25,12 +25,6 @@


class SerdeTest(unittest.TestCase):
def setUp(self) -> None:
self.profiler = pyinstrument.Profiler()

def tearDown(self) -> None:
self.profiler.reset()

@parameterized.parameterized.expand(test_args)
def test_serialization_deserialization_produces_same_model(
self, _: str, model_path: pathlib.Path
Expand All @@ -41,13 +35,8 @@ def test_serialization_deserialization_produces_same_model(
onnx.checker.check_model(model)

# Profile the serialization and deserialization process
self.profiler.start()
ir_model = ir.serde.deserialize_model(model)
serialized = ir.serde.serialize_model(ir_model)
self.profiler.stop()
profile_path = pathlib.Path(__file__).parent / "serde_test_profiles"
profile_path.mkdir(exist_ok=True)
self.profiler.write_html(profile_path / f"{self.id().split('.')[-1]}.html")

onnxscript.testing.assert_onnx_proto_equal(serialized, model)
onnx.checker.check_model(serialized)
Expand Down

0 comments on commit 677ba7f

Please sign in to comment.