Skip to content

Commit

Permalink
Merge branch 'main' into custom
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam authored Jul 15, 2024
2 parents ee30f74 + c06e7ab commit 3b4d700
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 44 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ jobs:
- py311-onnx-weekly
- py311-ort-nightly
- py311-experimental-torchlib-tracing
- py311-experimental-torchlib-onnx-ir
- py310
- py39
include:
Expand Down Expand Up @@ -59,9 +58,6 @@ jobs:
- name: py311-experimental-torchlib-tracing
python-version: "3.11"
nox-tag: test-experimental-torchlib-tracing
- name: py311-experimental-torchlib-onnx-ir
python-version: "3.11"
nox-tag: test-experimental-torchlib-onnx-ir
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
Expand Down
21 changes: 0 additions & 21 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,27 +134,6 @@ def test_experimental_torchlib_tracing(session):
)


@nox.session(tags=["test-experimental-torchlib-onnx-ir"])
def test_experimental_torchlib_onnx_ir(session):
"""Test TorchLib using the ONNX IR to build graphs."""
session.install(
*COMMON_TEST_DEPENDENCIES,
PYTORCH,
TORCHVISON,
ONNX,
*ONNX_RUNTIME_NIGHTLY_DEPENDENCIES,
)
session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
session.install(".", "--no-deps")
session.run("pip", "list")
session.run(
"pytest",
"tests/function_libs/torch_lib/ops_test.py",
*session.posargs,
env={"TORCHLIB_EXPERIMENTAL_USE_IR": "1"},
)


@nox.session(tags=["test-dort"])
def test_dort(session):
"""Test the conversion of a couple of models from transformers."""
Expand Down
1 change: 1 addition & 0 deletions onnxscript/function_libs/torch_lib/_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ def _load_boolean_flag(
EXPERIMENTAL_USE_IR: bool = _load_boolean_flag(
"TORCHLIB_EXPERIMENTAL_USE_IR",
this_will="use the ONNX IR instead of the PyTorch Graph for graph building",
deprecated=True,
)
54 changes: 41 additions & 13 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4652,7 +4652,7 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::le.Tensor", "aten::less_equal.Tensor", "_operator::le"))
@torch_op(("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"))
def aten_le(self: TReal, other: TReal) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -5986,16 +5986,12 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType:
raise NotImplementedError()


@torch_op("aten::native_dropout")
@torch_op("aten::native_dropout", trace_only=True)
def aten_native_dropout(
input: TFloatOrBFloat16, p: float, train: bool = True
) -> Tuple[TFloatOrBFloat16, BOOL]:
"""native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)"""

# Python bool attributes need to be explicitly converted to BOOL
# because the underlying attribute type is int
# TODO(#872): Allow ONNX Script to handle this conversion
train = op.Cast(train, to=BOOL.dtype)
result, mask = op.Dropout(input, p, train)
return result, mask

Expand Down Expand Up @@ -6555,17 +6551,33 @@ def aten_positive(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar", "_operator::pow"))
@torch_op(
(
"aten::pow.Scalar",
"aten::pow.Tensor_Tensor",
"aten::pow.Tensor_Scalar",
"_operator::pow",
)
)
def aten_pow(self: TReal, exponent: TTensor) -> TReal:
"""pow(Tensor self, Tensor exponent) -> Tensor"""

return op.Pow(self, exponent)


def aten_prelu(self: TensorType, weight: TensorType) -> TensorType:
@torch_op(("aten::prelu", "aten::_prelu_kernel"), trace_only=True)
def aten_prelu(self: TReal, weight: TReal) -> TReal:
"""prelu(Tensor self, Tensor weight) -> Tensor"""

raise NotImplementedError()
zero = op.CastLike(0, self)
rank = len(self.shape)
if rank == 0:
# e.g. self: [], weight: [1]
weight = op.Squeeze(weight)
elif rank >= 2:
# e.g. self: [5,10,5], weight: [10]
weight = op.Reshape(weight, [1, -1] + [1] * (rank - 2))
return op.Add(op.Max(self, zero), op.Mul(weight, op.Min(self, zero)))


def aten_prelu_backward(
Expand All @@ -6576,10 +6588,12 @@ def aten_prelu_backward(
raise NotImplementedError()


def aten_prod(self: TensorType, dtype: Optional[int] = None) -> TensorType:
@torch_op(("aten::prod.dim_int"), trace_only=True)
def aten_prod(self: TReal, dim: int, keepdim: bool = False) -> TReal:
"""prod(Tensor self, *, ScalarType? dtype=None) -> Tensor"""

raise NotImplementedError()
# Todo: add test for this function later
return op.ReduceProd(self, axes=[dim], keepdims=keepdim)


def aten_promote_types(type1: int, type2: int) -> int:
Expand Down Expand Up @@ -7362,6 +7376,19 @@ def aten_scalar_tensor_sym_number(
return common_ops.cast_to(s, dtype=dtype)


@torch_op("aten::scatter.value", trace_only=True)
def aten_scatter(
self: TReal,
dim: int, # we have to use int here because ScatterElements() will use this attribute
index: TInt,
src: TReal,
) -> TReal:
"""scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"""

update = op.Expand(src, op.Shape(index))
return op.ScatterElements(self, index, update, axis=dim)


@torch_op("aten::scatter_add")
def aten_scatter_add(
self: TReal,
Expand Down Expand Up @@ -8370,10 +8397,11 @@ def aten_trunc(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
return op.Where(is_negative, op.Neg(integer_parts), integer_parts)


def aten_type_as(self: TensorType, other: TensorType) -> TensorType:
@torch_op("aten::type_as", traceable=True)
def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2:
"""type_as(Tensor self, Tensor other) -> Tensor"""

raise NotImplementedError()
return op.CastLike(self, other)


@torch_op("aten::unbind.int")
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_lib/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def aten_linalg_cross(self: TensorType, other: TensorType, dim: int = -1) -> Ten
raise NotImplementedError()


@torch_op(("aten::linalg_det", "aten::det"))
@torch_op(("aten::_linalg_det", "aten::linalg_det", "aten::det"))
def aten_linalg_det(A: TFloat) -> TFloat:
"""linalg_det(Tensor A) -> Tensor"""

Expand Down
5 changes: 4 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,12 +632,15 @@ def aten_hardtanh(self: TReal, min_val: float = -1.0, max_val: float = 1.0) -> T
return op.Clip(self, min_val, max_val)


@torch_op("aten::hardtanh_backward", trace_only=True)
def aten_hardtanh_backward(
grad_output: TensorType, self: TensorType, min_val: float, max_val: float
) -> TensorType:
"""hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor"""

raise NotImplementedError()
max_mask = op.Where(op.Greater(self, max_val), 0.0, 1.0)
min_mask = op.Where(op.Less(self, min_val), 0.0, 1.0)
return op.Mul(op.Mul(grad_output, max_mask), min_mask)


def aten_huber_loss(
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/ir/_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ class TypeProtocol(Protocol):
elem_type: TypeProtocol | _enums.DataType
dtype: _enums.DataType

def __eq__(self, __value: object) -> bool: ...
def __eq__(self, value: object, /) -> bool: ...


@typing.runtime_checkable
Expand Down
4 changes: 2 additions & 2 deletions requirements/lintrunner/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# This file is auto updated by dependabot
lintrunner-adapters>=0.8.0
# RUFF, RUFF-FIX
ruff==0.4.7
ruff==0.5.1
# MYPY
mypy==1.10.0
mypy==1.10.1
types-PyYAML==6.0.12.11
# PYLINT
pylint==2.17.6
Expand Down
3 changes: 2 additions & 1 deletion tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import onnxscript
import onnxscript.evaluator
from onnxscript import ir
from onnxscript.function_libs.torch_lib import graph_building
from tests.function_libs.torch_lib import error_reproduction

Expand Down Expand Up @@ -538,7 +539,7 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
onnx.checker.check_model(onnx_model, full_check=True)
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
raise AssertionError(
f"ONNX model is invalid. Model:\n{onnx.printer.to_text(onnx_model)}"
f"ONNX model is invalid. Model:\n{ir.serde.deserialize_model(onnx_model)}"
) from e

try:
Expand Down
1 change: 1 addition & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,7 @@ def _where_input_wrangler(
),
TorchLibOpInfo("polar", core_ops.aten_polar),
TorchLibOpInfo("pow", core_ops.aten_pow),
TorchLibOpInfo("nn.functional.prelu", core_ops.aten_prelu),
TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand, nondeterministic=True),
TorchLibOpInfo("ops.aten.rand_like", core_ops.aten_rand_like, nondeterministic=True),
TorchLibOpInfo("ops.aten.randint", core_ops.aten_randint, nondeterministic=True),
Expand Down

0 comments on commit 3b4d700

Please sign in to comment.