Skip to content

Commit

Permalink
Merge branch 'main' into rama/optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam authored Jul 9, 2024
2 parents 08592ae + 60f2d2c commit f3719cb
Show file tree
Hide file tree
Showing 38 changed files with 1,308 additions and 802 deletions.
10 changes: 1 addition & 9 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ jobs:
- py311-onnx-weekly
- py311-ort-nightly
- py311-experimental-torchlib-tracing
- py311-experimental-torchlib-onnx-ir
- py310
- py39
- py38
include:
- name: py311
python-version: "3.11"
Expand All @@ -45,9 +43,6 @@ jobs:
- name: py39
python-version: "3.9"
nox-tag: test
- name: py38
python-version: "3.8"
nox-tag: test
- name: py312-torch-nightly
python-version: "3.12"
nox-tag: test-torch-nightly
Expand All @@ -63,9 +58,6 @@ jobs:
- name: py311-experimental-torchlib-tracing
python-version: "3.11"
nox-tag: test-experimental-torchlib-tracing
- name: py311-experimental-torchlib-onnx-ir
python-version: "3.11"
nox-tag: test-experimental-torchlib-onnx-ir
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -105,7 +97,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
transformers: ["4.37.2", "4.41.2"]
transformers: ["4.37.2", "4.41.2", "4.42.3"]
torch: ["release", "nightly"]
python_version: ["3.11"]
nox-tag: ["test-dort"]
Expand Down
3 changes: 3 additions & 0 deletions docs/test/test_documentation_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def do_test_folder(self, folder):
if tested == 0:
raise RuntimeError(f"No example was tested in folder {folder}.")

@unittest.skipIf(
sys.platform != "linux", reason="No need to run the documentation on every OS."
)
def test_documentation_examples(self):
this = os.path.abspath(os.path.dirname(__file__))
onxc = os.path.normpath(os.path.join(this, "..", ".."))
Expand Down
29 changes: 4 additions & 25 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
'numpy==1.26.4; python_version>="3.9"',
"packaging",
"parameterized",
"psutil",
'psutil; sys_platform != "win32"',
"pytest-cov",
"pytest-randomly",
"pytest-subtests",
Expand All @@ -28,13 +28,13 @@
"pyyaml",
"types-PyYAML",
"typing_extensions",
"ml_dtypes",
"ml-dtypes",
)
ONNX = "onnx==1.16"
ONNX_RUNTIME = "onnxruntime==1.17.1"
PYTORCH = "torch==2.2.2"
TORCHVISON = "torchvision==0.17.2"
TRANSFORMERS = "transformers>=4.37.2"
TRANSFORMERS = "transformers==4.37.2"
ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = (
"flatbuffers",
"coloredlogs",
Expand Down Expand Up @@ -134,27 +134,6 @@ def test_experimental_torchlib_tracing(session):
)


@nox.session(tags=["test-experimental-torchlib-onnx-ir"])
def test_experimental_torchlib_onnx_ir(session):
"""Test TorchLib using the ONNX IR to build graphs."""
session.install(
*COMMON_TEST_DEPENDENCIES,
PYTORCH,
TORCHVISON,
ONNX,
*ONNX_RUNTIME_NIGHTLY_DEPENDENCIES,
)
session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
session.install(".", "--no-deps")
session.run("pip", "list")
session.run(
"pytest",
"tests/function_libs/torch_lib/ops_test.py",
*session.posargs,
env={"TORCHLIB_EXPERIMENTAL_USE_IR": "1"},
)


@nox.session(tags=["test-dort"])
def test_dort(session):
"""Test the conversion of a couple of models from transformers."""
Expand All @@ -163,7 +142,7 @@ def test_dort(session):
)
torch_version, transformers_version = session.posargs

if torch_version == "nighly":
if torch_version == "nightly":
session.install(
"--pre",
"torch",
Expand Down
42 changes: 42 additions & 0 deletions onnxscript/_internal/version_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
# Licensed under the MIT License.
"""Version utils for testing."""

from __future__ import annotations

import warnings
from typing import Callable, Sequence

import packaging.version


Expand All @@ -25,6 +30,19 @@ def torch_older_than(version: str) -> bool:
)


def transformers_older_than(version: str) -> bool | None:
"""Returns True if the transformers version is older than the given version."""
try:
import transformers # pylint: disable=import-outside-toplevel
except ImportError:
return None

return (
packaging.version.parse(transformers.__version__).release
< packaging.version.parse(version).release
)


def is_onnxruntime_training() -> bool:
"""Returns True if the onnxruntime is onnxruntime-training."""
try:
Expand Down Expand Up @@ -74,3 +92,27 @@ def has_transformers():
return True # noqa
except ImportError:
return False


def ignore_warnings(warns: Warning | Sequence[Warning]) -> Callable: # type: ignore[arg-type]
"""Catches warnings.
Args:
warns: warnings to ignore
Returns:
decorated function
"""

def wrapper(fct):
if warns is None:
raise AssertionError(f"warns cannot be None for '{fct}'.")

def call_f(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", warns) # type: ignore[arg-type]
return fct(self)

return call_f

return wrapper
41 changes: 26 additions & 15 deletions onnxscript/backend/onnx_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

import dataclasses
import importlib
import os
import pathlib
import re
import sys
import unittest
from typing import Pattern

Expand Down Expand Up @@ -89,6 +91,17 @@ def skip(pattern: str | Pattern, reason: str, *, condition: bool = True):
skip(r"^test_ai_onnx_ml_label_encoder", "ONNX Runtime does not support Opset 21 at 1.17"),
)

if sys.platform == "win32":
SKIP_TESTS = (
*SKIP_TESTS,
skip(r"^test_gemm_beta", "cannot import module, import_module does not work"),
skip(
r"^test_averagepool_2d_default",
"cannot import module, import_module does not work",
),
skip("^test_bitwise_not_3d", "cannot import module, import_module does not work"),
)


def load_function(obj):
return ort.InferenceSession(obj.SerializeToString(), providers=("CPUExecutionProvider",))
Expand All @@ -106,16 +119,24 @@ def run_function(obj, *inputs):
def extract_functions(name: str, content: str, test_folder: pathlib.Path):
if not test_folder.exists():
test_folder.mkdir(exist_ok=True, parents=True)
init = test_folder / "__init__.py"
init.touch(exist_ok=True)
file = test_folder / f"{name}.py"
file.write_text(content, encoding="utf-8")
init = str(test_folder / "__init__.py")
with open(init, "w", encoding="utf-8") as f:
f.write("\n")
filename = str(test_folder / f"{name}.py")
with open(filename, "w", encoding="utf-8") as f:
f.write(content + "\n")
assert os.path.exists(
filename
), f"{filename!r} ({os.path.abspath(filename)!r} does not exist."
import_name = f"tests.{test_folder.parts[-1]}.{name}"
try:
mod = importlib.import_module(import_name)
except (SyntaxError, ImportError) as e:
raise AssertionError(
f"Unable to import {import_name!r} (file: {file!r})\n----\n{content}"
f"Unable to import {import_name!r} (e={e}) (file: {filename!r}, "
f"absolute path: {os.path.abspath(filename)!r}, "
f"current folder: {os.getcwd()}"
f"\n---- CONTENT --\n{content}"
) from e
functions = {
k: v for k, v in mod.__dict__.items() if isinstance(v, onnxscript.OnnxFunction)
Expand Down Expand Up @@ -265,16 +286,6 @@ def _load_function(_):
return session

def _run_function(obj, *inputs):
print(" run ONNX")
for i, inp in enumerate(inputs):
if inp is None:
print(f" input {i}: None")
else:
print(
f" input {i}: "
f"dtype={inp.dtype!r} shape={inp.shape!r}"
f"{inp.ravel().tolist()!r}"
)
try:
return run_function(obj, *inputs)
except Exception as e:
Expand Down
1 change: 1 addition & 0 deletions onnxscript/function_libs/torch_lib/_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ def _load_boolean_flag(
EXPERIMENTAL_USE_IR: bool = _load_boolean_flag(
"TORCHLIB_EXPERIMENTAL_USE_IR",
this_will="use the ONNX IR instead of the PyTorch Graph for graph building",
deprecated=True,
)
Loading

0 comments on commit f3719cb

Please sign in to comment.