Skip to content

Commit

Permalink
Implement the experimental evaluator for folding branches and castlik…
Browse files Browse the repository at this point in the history
…es | feat(torchlib) (#1178)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ #1178


As an effort described in
#1095, this PR

- Implements the experimental evaluator for folding branches and
castlikes so that they are eagerly evaluated when possible.
- Updates implementation for `addr` for it to be traceable.
- Conditionally enabled previously xfailed tests.

Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI.

E.g. clamp_min now becomes

```
<
   ir_version: 8,
   opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1],
   producer_name: "pytorch",
   producer_version: "2.2.0"
>
main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) 
   <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8>
{
   _val_2 = Size (input_0)
   _val_3 = Shape <start: int = 0> (input_1)
   _val_4 = Size (_val_3)
   _val_5 = Constant <value: tensor = int64 {0}> ()
   _val_6 = Equal (_val_2, _val_5)
   _val_7 = Constant <value: tensor = int64 {0}> ()
   _val_8 = Equal (_val_4, _val_7)
   _val_9 = Max (input_0, input_1)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
   tmp = Shape (input)
   return_val = Size (tmp)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
   tmp = Shape (input)
   tmp_0 = Size (tmp)
   tmp_1 = Constant <value_int: int = 0> ()
   return_val = Equal (tmp_0, tmp_1)
}
```
  • Loading branch information
justinchuby authored Nov 29, 2023
1 parent 744cabd commit 5ba7efa
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 1 deletion.
57 changes: 57 additions & 0 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,29 @@ def graph(self) -> TorchScriptGraph:
return self._graph

def eval(self, schema, inputs, attributes):
if _flags.EXPERIMENTAL_PREFER_TRACING:
if schema.name == "CastLike":
assert len(inputs) == 2
# Skip CastLike if the input and output types are the same
src_input = inputs[0]
target_input = inputs[1]
dtypes_available = (
isinstance(src_input, TorchScriptTensor)
and isinstance(target_input, TorchScriptTensor)
and src_input.dtype is not None
and target_input.dtype is not None
)
if dtypes_available:
if src_input.dtype == target_input.dtype:
# Same type. No cast needed
return src_input
else:
# Create a Cast node
return self._graph.add_op_call(
onnx.defs.get_schema("Cast"),
(src_input,),
{"to": target_input.onnx_dtype},
)
return self._graph.add_op_call(schema, inputs, attributes)

@runtime_typing.checked
Expand All @@ -303,6 +326,40 @@ def eval_function( # type: ignore[override]
args: Sequence[ValidArgumentType],
kwargs: Mapping[str, ValidArgumentType],
):
if _flags.EXPERIMENTAL_PREFER_TRACING:
# Special cases for handling IsScalar and Rank
if function.name == "IsScalar":
if len(args) != 1:
raise TypeError(
f"Expected 1 positional argument for function '{function}', got {len(args)}."
)
if isinstance(args[0], TorchScriptTensor):
if args[0].rank is not None:
return args[0].rank == 0
else:
# Fall to call add_function_call
pass
else:
# Python constants are scalars
return True
if function.name == "Rank":
if len(args) != 1:
raise TypeError(
f"Expected 1 positional argument for function '{function}', got {len(args)}."
)
if isinstance(args[0], TorchScriptTensor):
if args[0].rank is not None:
return args[0].rank
else:
# Fall to call add_function_call
pass
else:
# Python constants are scalars
return 0
elif function.experimental_traceable:
# Trace the function call instead of adding the function as a node
return function.function(*args, **kwargs)

# args/kwargs are TorchScriptTensor/python built-in based
param_schemas = function.param_schemas()
(
Expand Down
2 changes: 2 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,11 @@ def aten_addr(
# https://github.com/pytorch/pytorch/blob/51664489ba6f6b2343bbec9af9ca99185e2a5dbc/aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp#L53-L54
# When beta == 0, values in self should be ignored,
# nans and infs in self should not propagate.
alpha = op.CastLike(alpha, outer)
if beta == 0.0:
result = op.Mul(alpha, outer)
else:
beta = op.CastLike(beta, outer)
result = op.Add(op.Mul(beta, self), op.Mul(alpha, outer))

return result
Expand Down
5 changes: 4 additions & 1 deletion onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from typing_extensions import Self

from onnxscript._internal import version_utils
from onnxscript.function_libs.torch_lib import _flags
from onnxscript.function_libs.torch_lib.ops import core as core_ops
from onnxscript.function_libs.torch_lib.ops import fft as fft_ops
from onnxscript.function_libs.torch_lib.ops import linalg as linalg_ops
Expand Down Expand Up @@ -556,7 +557,7 @@ def _where_input_wrangler(
TorchLibOpInfo(
"addr",
core_ops.aten_addr,
tolerance={torch.float16: (1e-3, 3e-3)},
tolerance={torch.float16: (3e-3, 4e-3)},
),
TorchLibOpInfo(
"amax",
Expand Down Expand Up @@ -990,6 +991,7 @@ def _where_input_wrangler(
variant_name="reduction_with_dim",
reason="fixme: ORT Graph attribute inferencing failed https://github.com/onnx/onnx/issues/4986",
test_class_name="TestOutputConsistencyFullGraph",
enabled_if=not _flags.EXPERIMENTAL_PREFER_TRACING,
)
.xfail(
matcher=lambda sample: len(sample.args) == 0
Expand Down Expand Up @@ -1758,6 +1760,7 @@ def _where_input_wrangler(
variant_name="reduction_with_dim",
reason="fixme: ORT Graph attribute inferencing failed https://github.com/onnx/onnx/issues/4986",
test_class_name="TestOutputConsistencyFullGraph",
enabled_if=not _flags.EXPERIMENTAL_PREFER_TRACING,
)
.xfail(
matcher=lambda sample: len(sample.args) == 0
Expand Down

0 comments on commit 5ba7efa

Please sign in to comment.