Skip to content

Commit

Permalink
Merge branch 'main' into rama/multi-pattern-input-reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam authored Jul 16, 2024
2 parents 8f5878e + 581e998 commit 268dc3b
Show file tree
Hide file tree
Showing 19 changed files with 514 additions and 177 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ jobs:
- py311-onnx-weekly
- py311-ort-nightly
- py311-experimental-torchlib-tracing
- py311-experimental-torchlib-onnx-ir
- py310
- py39
include:
Expand Down Expand Up @@ -59,9 +58,6 @@ jobs:
- name: py311-experimental-torchlib-tracing
python-version: "3.11"
nox-tag: test-experimental-torchlib-tracing
- name: py311-experimental-torchlib-onnx-ir
python-version: "3.11"
nox-tag: test-experimental-torchlib-onnx-ir
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
Expand Down
21 changes: 0 additions & 21 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,27 +134,6 @@ def test_experimental_torchlib_tracing(session):
)


@nox.session(tags=["test-experimental-torchlib-onnx-ir"])
def test_experimental_torchlib_onnx_ir(session):
"""Test TorchLib using the ONNX IR to build graphs."""
session.install(
*COMMON_TEST_DEPENDENCIES,
PYTORCH,
TORCHVISON,
ONNX,
*ONNX_RUNTIME_NIGHTLY_DEPENDENCIES,
)
session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
session.install(".", "--no-deps")
session.run("pip", "list")
session.run(
"pytest",
"tests/function_libs/torch_lib/ops_test.py",
*session.posargs,
env={"TORCHLIB_EXPERIMENTAL_USE_IR": "1"},
)


@nox.session(tags=["test-dort"])
def test_dort(session):
"""Test the conversion of a couple of models from transformers."""
Expand Down
1 change: 1 addition & 0 deletions onnxscript/function_libs/torch_lib/_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ def _load_boolean_flag(
EXPERIMENTAL_USE_IR: bool = _load_boolean_flag(
"TORCHLIB_EXPERIMENTAL_USE_IR",
this_will="use the ONNX IR instead of the PyTorch Graph for graph building",
deprecated=True,
)
69 changes: 46 additions & 23 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4652,7 +4652,7 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::le.Tensor", "aten::less_equal.Tensor", "_operator::le"))
@torch_op(("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"))
def aten_le(self: TReal, other: TReal) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -5986,16 +5986,12 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType:
raise NotImplementedError()


@torch_op("aten::native_dropout")
@torch_op("aten::native_dropout", trace_only=True)
def aten_native_dropout(
input: TFloatOrBFloat16, p: float, train: bool = True
) -> Tuple[TFloatOrBFloat16, BOOL]:
"""native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)"""

# Python bool attributes need to be explicitly converted to BOOL
# because the underlying attribute type is int
# TODO(#872): Allow ONNX Script to handle this conversion
train = op.Cast(train, to=BOOL.dtype)
result, mask = op.Dropout(input, p, train)
return result, mask

Expand Down Expand Up @@ -6393,25 +6389,18 @@ def aten_ones_like(
device: str = "",
pin_memory: bool = False,
) -> TTensor:
"""ones_like.
"""ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
Note: dtype is an onnx enum. Users should convert torch dtype to onnx dtype
before calling this function.
"""
# ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor

# NOTE: trace_only because both if branches need to be the same type, but we have
# a cast in the if branch.
if dtype is None:
dtype = -1

if dtype == -1:
one = op.CastLike(1, self)
else:
one = op.Cast(1, to=dtype)
return _aten_ones_like_onnx(self, one)


@torch_op("aten::ones_like", private=True)
def _aten_ones_like_onnx(self: TTensor, one) -> TTensor:
shape = op.Shape(self)
return op.Expand(one, shape)

Expand Down Expand Up @@ -6562,17 +6551,33 @@ def aten_positive(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar", "_operator::pow"))
@torch_op(
(
"aten::pow.Scalar",
"aten::pow.Tensor_Tensor",
"aten::pow.Tensor_Scalar",
"_operator::pow",
)
)
def aten_pow(self: TReal, exponent: TTensor) -> TReal:
"""pow(Tensor self, Tensor exponent) -> Tensor"""

return op.Pow(self, exponent)


def aten_prelu(self: TensorType, weight: TensorType) -> TensorType:
@torch_op(("aten::prelu", "aten::_prelu_kernel"), trace_only=True)
def aten_prelu(self: TReal, weight: TReal) -> TReal:
"""prelu(Tensor self, Tensor weight) -> Tensor"""

raise NotImplementedError()
zero = op.CastLike(0, self)
rank = len(self.shape)
if rank == 0:
# e.g. self: [], weight: [1]
weight = op.Squeeze(weight)
elif rank >= 2:
# e.g. self: [5,10,5], weight: [10]
weight = op.Reshape(weight, [1, -1] + [1] * (rank - 2))
return op.Add(op.Max(self, zero), op.Mul(weight, op.Min(self, zero)))


def aten_prelu_backward(
Expand All @@ -6583,10 +6588,12 @@ def aten_prelu_backward(
raise NotImplementedError()


def aten_prod(self: TensorType, dtype: Optional[int] = None) -> TensorType:
@torch_op(("aten::prod.dim_int"), trace_only=True)
def aten_prod(self: TReal, dim: int, keepdim: bool = False) -> TReal:
"""prod(Tensor self, *, ScalarType? dtype=None) -> Tensor"""

raise NotImplementedError()
# Todo: add test for this function later
return op.ReduceProd(self, axes=[dim], keepdims=keepdim)


def aten_promote_types(type1: int, type2: int) -> int:
Expand Down Expand Up @@ -7369,6 +7376,19 @@ def aten_scalar_tensor_sym_number(
return common_ops.cast_to(s, dtype=dtype)


@torch_op("aten::scatter.value", trace_only=True)
def aten_scatter(
self: TReal,
dim: int, # we have to use int here because ScatterElements() will use this attribute
index: TInt,
src: TReal,
) -> TReal:
"""scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"""

update = op.Expand(src, op.Shape(index))
return op.ScatterElements(self, index, update, axis=dim)


@torch_op("aten::scatter_add")
def aten_scatter_add(
self: TReal,
Expand Down Expand Up @@ -8377,10 +8397,11 @@ def aten_trunc(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
return op.Where(is_negative, op.Neg(integer_parts), integer_parts)


def aten_type_as(self: TensorType, other: TensorType) -> TensorType:
@torch_op("aten::type_as", traceable=True)
def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2:
"""type_as(Tensor self, Tensor other) -> Tensor"""

raise NotImplementedError()
return op.CastLike(self, other)


@torch_op("aten::unbind.int")
Expand Down Expand Up @@ -8861,6 +8882,8 @@ def aten_zeros_like(self: TTensor, dtype: int = -1) -> TTensor:

# NOTE: trace_only because both if branches need to be the same type, but we have
# a cast in the if branch.
if dtype is None:
dtype = -1

if dtype == -1:
zero = op.CastLike(0, self)
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_lib/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def aten_linalg_cross(self: TensorType, other: TensorType, dim: int = -1) -> Ten
raise NotImplementedError()


@torch_op(("aten::linalg_det", "aten::det"))
@torch_op(("aten::_linalg_det", "aten::linalg_det", "aten::det"))
def aten_linalg_det(A: TFloat) -> TFloat:
"""linalg_det(Tensor A) -> Tensor"""

Expand Down
10 changes: 7 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,12 +632,15 @@ def aten_hardtanh(self: TReal, min_val: float = -1.0, max_val: float = 1.0) -> T
return op.Clip(self, min_val, max_val)


@torch_op("aten::hardtanh_backward", trace_only=True)
def aten_hardtanh_backward(
grad_output: TensorType, self: TensorType, min_val: float, max_val: float
) -> TensorType:
"""hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor"""

raise NotImplementedError()
max_mask = op.Where(op.Greater(self, max_val), 0.0, 1.0)
min_mask = op.Where(op.Less(self, min_val), 0.0, 1.0)
return op.Mul(op.Mul(grad_output, max_mask), min_mask)


def aten_huber_loss(
Expand Down Expand Up @@ -2046,10 +2049,11 @@ def aten_sigmoid_backward(grad_output: TensorType, output: TensorType) -> Tensor
raise NotImplementedError()


def aten_silu(self: TensorType) -> TensorType:
@torch_op("aten::silu", traceable=True)
def aten_silu(self: TFloat) -> TFloat:
"""silu(Tensor self) -> Tensor"""

raise NotImplementedError()
return op.Mul(self, op.Sigmoid(self))


def aten_silu_backward(grad_output: TensorType, self: TensorType) -> TensorType:
Expand Down
26 changes: 26 additions & 0 deletions onnxscript/ir/_convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,29 @@ def tensor(
doc_string=name,
)
return tensor_


def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
"""Return a dictionary mapping names to values in the graph.
The mapping does not include values from subgraphs.
Args:
graph: The graph to extract the mapping from.
Returns:
A dictionary mapping names to values.
"""
values = {}
values.update(graph.initializers)
# The names of the values can be None or "", which we need to exclude
for input in graph.inputs:
if not input.name:
continue
values[input.name] = input
for node in graph:
for value in node.outputs:
if not value.name:
continue
values[value.name] = value
return values
2 changes: 1 addition & 1 deletion onnxscript/ir/_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ class TypeProtocol(Protocol):
elem_type: TypeProtocol | _enums.DataType
dtype: _enums.DataType

def __eq__(self, __value: object) -> bool: ...
def __eq__(self, value: object, /) -> bool: ...


@typing.runtime_checkable
Expand Down
6 changes: 6 additions & 0 deletions onnxscript/tools/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@
from onnxscript.tools.benchmark.benchmark_helpers import (
common_export,
get_parsed_args,
make_configs,
make_dataframe_from_benchmark_data,
multi_run,
run_inference,
run_onnx_inference,
)

__all__ = [
"get_parsed_args",
"common_export",
"make_configs",
"multi_run",
"make_dataframe_from_benchmark_data",
"run_inference",
"run_onnx_inference",
]
Loading

0 comments on commit 268dc3b

Please sign in to comment.