Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Jul 9, 2024
2 parents a42d94e + 60f2d2c commit a28d067
Show file tree
Hide file tree
Showing 22 changed files with 502 additions and 175 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ jobs:
- py311-onnx-weekly
- py311-ort-nightly
- py311-experimental-torchlib-tracing
- py311-experimental-torchlib-onnx-ir
- py310
- py39
include:
Expand Down Expand Up @@ -59,9 +58,6 @@ jobs:
- name: py311-experimental-torchlib-tracing
python-version: "3.11"
nox-tag: test-experimental-torchlib-tracing
- name: py311-experimental-torchlib-onnx-ir
python-version: "3.11"
nox-tag: test-experimental-torchlib-onnx-ir
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
Expand Down
21 changes: 0 additions & 21 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,27 +134,6 @@ def test_experimental_torchlib_tracing(session):
)


@nox.session(tags=["test-experimental-torchlib-onnx-ir"])
def test_experimental_torchlib_onnx_ir(session):
"""Test TorchLib using the ONNX IR to build graphs."""
session.install(
*COMMON_TEST_DEPENDENCIES,
PYTORCH,
TORCHVISON,
ONNX,
*ONNX_RUNTIME_NIGHTLY_DEPENDENCIES,
)
session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
session.install(".", "--no-deps")
session.run("pip", "list")
session.run(
"pytest",
"tests/function_libs/torch_lib/ops_test.py",
*session.posargs,
env={"TORCHLIB_EXPERIMENTAL_USE_IR": "1"},
)


@nox.session(tags=["test-dort"])
def test_dort(session):
"""Test the conversion of a couple of models from transformers."""
Expand Down
1 change: 1 addition & 0 deletions onnxscript/function_libs/torch_lib/_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ def _load_boolean_flag(
EXPERIMENTAL_USE_IR: bool = _load_boolean_flag(
"TORCHLIB_EXPERIMENTAL_USE_IR",
this_will="use the ONNX IR instead of the PyTorch Graph for graph building",
deprecated=True,
)
1 change: 0 additions & 1 deletion onnxscript/function_libs/torch_lib/backward_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# pylint: disable=not-callable, unbalanced-tuple-unpacking

import copy
import os
import sys
import unittest

Expand Down
45 changes: 31 additions & 14 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4652,7 +4652,7 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::le.Tensor", "aten::less_equal.Tensor", "_operator::le"))
@torch_op(("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"))
def aten_le(self: TReal, other: TReal) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -6393,25 +6393,18 @@ def aten_ones_like(
device: str = "",
pin_memory: bool = False,
) -> TTensor:
"""ones_like.
"""ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
Note: dtype is an onnx enum. Users should convert torch dtype to onnx dtype
before calling this function.
"""
# ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor

# NOTE: trace_only because both if branches need to be the same type, but we have
# a cast in the if branch.
if dtype is None:
dtype = -1

if dtype == -1:
one = op.CastLike(1, self)
else:
one = op.Cast(1, to=dtype)
return _aten_ones_like_onnx(self, one)


@torch_op("aten::ones_like", private=True)
def _aten_ones_like_onnx(self: TTensor, one) -> TTensor:
shape = op.Shape(self)
return op.Expand(one, shape)

Expand Down Expand Up @@ -6562,7 +6555,14 @@ def aten_positive(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar", "_operator::pow"))
@torch_op(
(
"aten::pow.Scalar",
"aten::pow.Tensor_Tensor",
"aten::pow.Tensor_Scalar",
"_operator::pow",
)
)
def aten_pow(self: TReal, exponent: TTensor) -> TReal:
"""pow(Tensor self, Tensor exponent) -> Tensor"""

Expand All @@ -6583,10 +6583,12 @@ def aten_prelu_backward(
raise NotImplementedError()


def aten_prod(self: TensorType, dtype: Optional[int] = None) -> TensorType:
@torch_op(("aten::prod.dim_int"), trace_only=True)
def aten_prod(self: TReal, dim: int, keepdim: bool = False) -> TReal:
"""prod(Tensor self, *, ScalarType? dtype=None) -> Tensor"""

raise NotImplementedError()
# Todo: add test for this function later
return op.ReduceProd(self, axes=[dim], keepdims=keepdim)


def aten_promote_types(type1: int, type2: int) -> int:
Expand Down Expand Up @@ -7369,6 +7371,19 @@ def aten_scalar_tensor_sym_number(
return common_ops.cast_to(s, dtype=dtype)


@torch_op("aten::scatter.value", trace_only=True)
def aten_scatter(
self: TReal,
dim: int, # we have to use int here because ScatterElements() will use this attribute
index: TInt,
src: TReal,
) -> TReal:
"""scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"""

update = op.Expand(src, op.Shape(index))
return op.ScatterElements(self, index, update, axis=dim)


@torch_op("aten::scatter_add")
def aten_scatter_add(
self: TReal,
Expand Down Expand Up @@ -8861,6 +8876,8 @@ def aten_zeros_like(self: TTensor, dtype: int = -1) -> TTensor:

# NOTE: trace_only because both if branches need to be the same type, but we have
# a cast in the if branch.
if dtype is None:
dtype = -1

if dtype == -1:
zero = op.CastLike(0, self)
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_lib/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def aten_linalg_cross(self: TensorType, other: TensorType, dim: int = -1) -> Ten
raise NotImplementedError()


@torch_op(("aten::linalg_det", "aten::det"))
@torch_op(("aten::_linalg_det", "aten::linalg_det", "aten::det"))
def aten_linalg_det(A: TFloat) -> TFloat:
"""linalg_det(Tensor A) -> Tensor"""

Expand Down
10 changes: 7 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,12 +632,15 @@ def aten_hardtanh(self: TReal, min_val: float = -1.0, max_val: float = 1.0) -> T
return op.Clip(self, min_val, max_val)


@torch_op("aten::hardtanh_backward", trace_only=True)
def aten_hardtanh_backward(
grad_output: TensorType, self: TensorType, min_val: float, max_val: float
) -> TensorType:
"""hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor"""

raise NotImplementedError()
max_mask = op.Where(op.Greater(self, max_val), 0.0, 1.0)
min_mask = op.Where(op.Less(self, min_val), 0.0, 1.0)
return op.Mul(op.Mul(grad_output, max_mask), min_mask)


def aten_huber_loss(
Expand Down Expand Up @@ -2046,10 +2049,11 @@ def aten_sigmoid_backward(grad_output: TensorType, output: TensorType) -> Tensor
raise NotImplementedError()


def aten_silu(self: TensorType) -> TensorType:
@torch_op("aten::silu", traceable=True)
def aten_silu(self: TFloat) -> TFloat:
"""silu(Tensor self) -> Tensor"""

raise NotImplementedError()
return op.Mul(self, op.Sigmoid(self))


def aten_silu_backward(grad_output: TensorType, self: TensorType) -> TensorType:
Expand Down
26 changes: 26 additions & 0 deletions onnxscript/ir/_convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,29 @@ def tensor(
doc_string=name,
)
return tensor_


def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
"""Return a dictionary mapping names to values in the graph.
The mapping does not include values from subgraphs.
Args:
graph: The graph to extract the mapping from.
Returns:
A dictionary mapping names to values.
"""
values = {}
values.update(graph.initializers)
# The names of the values can be None or "", which we need to exclude
for input in graph.inputs:
if not input.name:
continue
values[input.name] = input
for node in graph:
for value in node.outputs:
if not value.name:
continue
values[value.name] = value
return values
2 changes: 1 addition & 1 deletion onnxscript/ir/_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ class TypeProtocol(Protocol):
elem_type: TypeProtocol | _enums.DataType
dtype: _enums.DataType

def __eq__(self, __value: object) -> bool: ...
def __eq__(self, value: object, /) -> bool: ...


@typing.runtime_checkable
Expand Down
6 changes: 6 additions & 0 deletions onnxscript/tools/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@
from onnxscript.tools.benchmark.benchmark_helpers import (
common_export,
get_parsed_args,
make_configs,
make_dataframe_from_benchmark_data,
multi_run,
run_inference,
run_onnx_inference,
)

__all__ = [
"get_parsed_args",
"common_export",
"make_configs",
"multi_run",
"make_dataframe_from_benchmark_data",
"run_inference",
"run_onnx_inference",
]
86 changes: 85 additions & 1 deletion onnxscript/tools/benchmark/benchmark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import argparse
import itertools
import multiprocessing
import os
import platform
Expand Down Expand Up @@ -195,6 +196,52 @@ def run_benchmark(
return data


def measure_discrepancies(
expected: list[tuple[Any, ...]],
outputs: list[tuple[Any, ...]],
) -> tuple[float, float]:
"""
Computes the discrepancies.
Args:
expected: list of outputs coming from a torch model
outputs: list of outputs coming from an onnx model
Returns:
max absolute errors, max relative errors
"""

def _flatten(outputs):
flat = []
for tensor in outputs:
if isinstance(tensor, tuple):
flat.extend(_flatten(tensor))
else:
flat.append(tensor)
return tuple(flat)

abs_errs = []
rel_errs = []
for torch_outputs_mixed_types, onnx_outputs in zip(expected, outputs):
torch_outputs = _flatten(torch_outputs_mixed_types)
assert len(torch_outputs) == len(
onnx_outputs
), f"Length mismatch {len(torch_outputs)} != {len(onnx_outputs)}"
for torch_tensor, onnx_tensor in zip(torch_outputs, onnx_outputs):
assert (
torch_tensor.dtype == onnx_tensor.dtype
), f"Type mismatch {torch_tensor.dtype} != {onnx_tensor.dtype}"
assert (
torch_tensor.shape == onnx_tensor.shape
), f"Type mismatch {torch_tensor.shape} != {onnx_tensor.shape}"
diff = torch_tensor - onnx_tensor
abs_err = float(diff.abs().max())
rel_err = float((diff.abs() / torch_tensor).max())
abs_errs.append(abs_err)
rel_errs.append(rel_err)
return max(abs_errs), max(rel_errs)


def common_export(
model: Any,
inputs: Sequence[Any],
Expand Down Expand Up @@ -620,6 +667,7 @@ def run_onnx_inference(
repeat: int = 5,
verbose: int = 0,
ort_optimize: bool = True,
torch_model: Any | None = None,
) -> dict[str, Any]:
"""
Runs multiple times the same inference with onnxruntime.
Expand All @@ -631,6 +679,7 @@ def run_onnx_inference(
repeat: number of iterations to repeat
verbose: verbosity
ort_optimize: enable, disable onnxruntime optimizations
torch_model: if not empty, measure the discrepancies
Returns:
statistcs
Expand Down Expand Up @@ -667,16 +716,26 @@ def run_onnx_inference(
print(f"[run_inference] created session in {end}")
print(f"[run_inference] start {warmup} warmup iterations")

if torch_model:
expected = [
torch_model(*example_inputs[i % len(example_inputs)]) for i in range(warmup)
]

got = []
iterations = []
begin = time.perf_counter()
for i in range(warmup):
t0 = time.perf_counter()
wrapped_session.run_dlpack(*example_inputs[i % len(example_inputs)])
got.append(wrapped_session.run_dlpack(*example_inputs[i % len(example_inputs)]))
iterations.append(time.perf_counter() - t0)
end = time.perf_counter() - begin
stats["warmup"] = warmup
stats["warmup_time"] = end / warmup
stats["warmup_iter"] = iterations
if torch_model:
abs_err, rel_err = measure_discrepancies(expected, got)
stats["discrepancies_abs"] = abs_err
stats["discrepancies_rel"] = rel_err

if verbose:
print(f"[run_inference] warmup done in {time.perf_counter() - begin}")
Expand All @@ -697,3 +756,28 @@ def run_onnx_inference(
print(f"[run_inference] measure done in {time.perf_counter() - begin}")

return stats


def multi_run(kwargs: dict[str, Any]) -> bool:
"""Checks if multiple values were sent for one argument."""
return any(isinstance(v, str) and "," in v for v in kwargs.values())


def make_configs(kwargs: dict[str, Any]) -> list[dict[str, Any]]:
"""Creates all the configurations based on the command line arguments."""
print(kwargs)
args = []
for k, v in kwargs.items():
if isinstance(v, str):
args.append([(k, s) for s in v.split(",")])
else:
args.append([(k, v)])
configs = list(itertools.product(*args))
return [dict(c) for c in configs]


def make_dataframe_from_benchmark_data(data: list[dict]) -> Any:
"""Creates a dataframe from the received data."""
import pandas

return pandas.DataFrame(data)
Loading

0 comments on commit a28d067

Please sign in to comment.