Skip to content

Commit

Permalink
Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACING on | tes…
Browse files Browse the repository at this point in the history
…t(torchlib) (#1180)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #1178
* #1177
* #1176
* __->__ #1180

### Changes

- Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACING on
- Test with Python 3.11 as well
- Fixes #1061
- Fix aten::any.dims and aten::all.dims
  • Loading branch information
justinchuby authored Nov 23, 2023
1 parent 77ef131 commit a89a2a9
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 33 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ jobs:
- py310-torch-nightly
- py310-onnx-weekly
- py310-ort-nightly
- py311-ort-nightly
- py310-experimental-torchlib-tracing
include:
- name: py310
python-version: "3.10"
Expand All @@ -50,6 +52,12 @@ jobs:
- name: py310-ort-nightly
python-version: "3.10"
nox-tag: test-ort-nightly
- name: py311-ort-nightly
python-version: "3.11"
nox-tag: test-ort-nightly
- name: py310-experimental-torchlib-tracing
python-version: "3.10"
nox-tag: test-experimental-torchlib-tracing
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
Expand Down
19 changes: 18 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

COMMON_TEST_DEPENDENCIES = (
"jinja2",
"numpy==1.23.5",
"numpy==1.24.4",
"typing_extensions",
"beartype!=0.16.0",
"types-PyYAML",
Expand Down Expand Up @@ -95,3 +95,20 @@ def test_ort_nightly(session):
session.install(".", "--no-deps")
session.run("pip", "list")
session.run("pytest", "onnxscript", *session.posargs)


@nox.session(tags=["test-experimental-torchlib-tracing"])
def test_experimental_torchlib_tracing(session):
"""Test TorchLib with the experimental TORCHLIB_EXPERIMENTAL_PREFER_TRACING flag on."""
session.install(
*COMMON_TEST_DEPENDENCIES, PYTORCH, ONNX, *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES
)
session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
session.install(".", "--no-deps")
session.run("pip", "list")
session.run(
"pytest",
"onnxscript/tests/function_libs/torch_lib/ops_test.py",
*session.posargs,
env={"TORCHLIB_EXPERIMENTAL_PREFER_TRACING": "1"},
)
17 changes: 10 additions & 7 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,9 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
if not dim:
return aten_all_dims_no_dim(self, keepdim)
for d in dim:
self = aten_all_dim(self, d, keepdim)
self = aten_all_dim(self, d, keepdim=True)
if not keepdim:
self = op.Squeeze(self, list(dim))
return self


Expand Down Expand Up @@ -488,7 +490,9 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
if not dim:
return aten_any_dims_no_dim(self, keepdim)
for d in dim:
self = aten_any_dim(self, d, keepdim)
self = aten_any_dim(self, d, keepdim=True)
if not keepdim:
self = op.Squeeze(self, list(dim))
return self


Expand Down Expand Up @@ -7339,17 +7343,16 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"))
def aten_softmax(
self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype
) -> TFloatOrBFloat16:
@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"), trace_only=True)
def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrBFloat16:
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = IsScalar(self)
if self_is_scalar:
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
result = op.Softmax(self, axis=dim)
result = op.Cast(result, to=dtype)
if dtype != -1:
result = op.Cast(result, to=dtype)
if self_is_scalar:
# Convert to scalar when input is scalar
result = op.Squeeze(result)
Expand Down
8 changes: 4 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from typing import Optional, Sequence

from onnxscript import FLOAT
from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TFloatOrBFloat16
Expand Down Expand Up @@ -212,17 +211,18 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::log_softmax", "aten::special_log_softmax"))
@torch_op(("aten::log_softmax", "aten::special_log_softmax"), trace_only=True)
def aten_special_log_softmax(
self: TFloatOrBFloat16, dim: int, dtype: int = FLOAT.dtype
self: TFloatOrBFloat16, dim: int, dtype: int = -1
) -> TFloatOrBFloat16:
"""special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = IsScalar(self)
if self_is_scalar:
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
result = op.LogSoftmax(self, axis=dim)
result = op.Cast(result, to=dtype)
if dtype != -1:
result = op.Cast(result, to=dtype)
if self_is_scalar: # squeeze to scalar due to input is scalar
result = op.Squeeze(result)
return result
Expand Down
4 changes: 4 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def add_decorate_info(
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
for decorate_meta in skip_or_xfails:
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
if opinfo is None and not decorate_meta.enabled_if:
# If the OpInfo doesn't exist and it is not enabled, we skip the OpInfo
# because it could be an OpInfo that is in torch-nightly but not older versions.
continue
assert (
opinfo is not None
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
Expand Down
122 changes: 101 additions & 21 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,21 +471,41 @@ def _where_input_wrangler(
),
TorchLibOpInfo("ops.aten._log_softmax", core_ops.aten__log_softmax),
TorchLibOpInfo(
"ops.aten._log_softmax_half", core_ops.aten__log_softmax_half, trace_only=True
).xfail(
"ops.aten._log_softmax_half",
core_ops.aten__log_softmax_half,
trace_only=True,
tolerance={torch.float16: (1e-3, 1e-3)},
)
.xfail(
reason="PyTorch does not implement _log_softmax for float16 on CPU",
dtypes=(torch.float16,),
enabled_if=version_utils.torch_older_than("2.2"),
)
.xfail(
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
),
TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax, trace_only=True),
TorchLibOpInfo(
"ops.aten._softmax_half", core_ops.aten__softmax_half, trace_only=True
).xfail(
TorchLibOpInfo("ops.aten._softmax_half", core_ops.aten__softmax_half, trace_only=True)
.xfail(
reason="PyTorch does not implement _softmax for float16 on CPU",
dtypes=(torch.float16,),
enabled_if=version_utils.torch_older_than("2.2"),
)
.xfail(
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
),
TorchLibOpInfo("all_dim", core_ops.aten_all_dim).skip(
matcher=lambda sample: not (len(sample.kwargs) > 0)
or isinstance(sample.kwargs.get("dim"), tuple),
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer",
),
TorchLibOpInfo("all_dim", core_ops.aten_all_dim).xfail(
matcher=lambda sample: not (len(sample.kwargs) > 0),
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design",
TorchLibOpInfo("all_dims", core_ops.aten_all_dims, trace_only=True).skip(
matcher=lambda sample: not isinstance(sample.kwargs.get("dim"), tuple),
reason="this overload requires dim to be a tuple",
),
TorchLibOpInfo("allclose", core_ops.aten_allclose),
TorchLibOpInfo(
Expand All @@ -501,7 +521,11 @@ def _where_input_wrangler(
TorchLibOpInfo("acosh", core_ops.aten_acosh),
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True, trace_only=True),
TorchLibOpInfo("addbmm", core_ops.aten_addbmm, tolerance={torch.float32: (2e-5, 2e-5)}),
TorchLibOpInfo(
"addbmm",
core_ops.aten_addbmm,
tolerance={torch.float32: (2e-5, 2e-5), torch.float16: (2e-1, 2e-2)},
),
TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv),
TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}),
TorchLibOpInfo("addmm", core_ops.aten_addmm)
Expand All @@ -522,7 +546,7 @@ def _where_input_wrangler(
dtypes=(torch.int16, torch.int32, torch.int64),
reason="ONNX Runtime does not support int inputs to Gemm",
),
TorchLibOpInfo("addmv", core_ops.aten_addmv),
TorchLibOpInfo("addmv", core_ops.aten_addmv, tolerance={torch.float16: (1e-3, 1e-2)}),
TorchLibOpInfo(
"addr",
core_ops.aten_addr,
Expand Down Expand Up @@ -557,8 +581,13 @@ def _where_input_wrangler(
"any_dim",
core_ops.aten_any_dim,
).skip(
matcher=lambda sample: not (len(sample.kwargs) > 0),
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design",
matcher=lambda sample: not (len(sample.kwargs) > 0)
or isinstance(sample.kwargs.get("dim"), tuple),
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer",
),
TorchLibOpInfo("any_dims", core_ops.aten_any_dims, trace_only=True).skip(
matcher=lambda sample: not isinstance(sample.kwargs.get("dim"), tuple),
reason="this overload requires dim to be a tuple",
),
TorchLibOpInfo("asin", core_ops.aten_asin),
TorchLibOpInfo("asinh", core_ops.aten_asinh),
Expand Down Expand Up @@ -640,7 +669,7 @@ def _where_input_wrangler(
"https://github.com/microsoft/onnxscript/issues/1007"
),
),
TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm),
TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm, tolerance={torch.float16: (1e-3, 1e-2)}),
TorchLibOpInfo("bernoulli", core_ops.aten_bernoulli, nondeterministic=True),
TorchLibOpInfo(
# This string is a unique ID. In extra_opinfo.py, we
Expand Down Expand Up @@ -845,6 +874,12 @@ def _where_input_wrangler(
dtypes=(torch.int64, torch.int32),
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
)
.xfail(
variant_name="tensor_overload",
dtypes=(torch.int64, torch.int32, torch.float16),
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
enabled_if=not version_utils.torch_older_than("2.2"),
)
.xfail(
dtypes=(torch.float16,),
reason="op 'Range' doesn't support float16.",
Expand All @@ -861,17 +896,35 @@ def _where_input_wrangler(
TorchLibOpInfo(
"log_softmax",
special_ops.aten_special_log_softmax,
trace_only=True,
tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (4e-4, 6e-3)},
).xfail(
)
.xfail(
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
)
.xfail(
variant_name="with_dtype",
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
)
.skip(
matcher=lambda sample: len(sample.input.shape) == 0,
reason="fixme: LogSoftMax does not support empty tensor as input",
)
.skip(
variant_name="with_dtype",
matcher=lambda sample: len(sample.input.shape) == 0,
reason="fixme: LogSoftMax does not support empty tensor as input",
),
TorchLibOpInfo("log2", core_ops.aten_log2),
TorchLibOpInfo("logaddexp", core_ops.aten_logaddexp),
TorchLibOpInfo("logaddexp2", core_ops.aten_logaddexp2),
TorchLibOpInfo("logcumsumexp", core_ops.aten_logcumsumexp),
TorchLibOpInfo(
"logcumsumexp", core_ops.aten_logcumsumexp, tolerance={torch.float16: (1e-2, 1e-1)}
),
TorchLibOpInfo("logdet", core_ops.aten_logdet),
TorchLibOpInfo("logsumexp", core_ops.aten_logsumexp),
TorchLibOpInfo("lt", core_ops.aten_lt),
Expand All @@ -884,7 +937,7 @@ def _where_input_wrangler(
"matmul",
core_ops.aten_matmul,
# Windows requires a more relaxed tolerance
tolerance={torch.float32: (2e-5, 2e-5)},
tolerance={torch.float32: (2e-5, 2e-5), torch.float16: (2e-3, 2e-2)},
).skip(
matcher=lambda sample: torch.numel(sample.input) == 0,
reason="values of matmul of [m, 0] and [0, n] matrices are undefined",
Expand Down Expand Up @@ -1339,12 +1392,28 @@ def _where_input_wrangler(
TorchLibOpInfo(
"softmax",
core_ops.aten_softmax,
trace_only=True,
tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (3e-4, 4e-4)},
).xfail(
)
.xfail(
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
)
.xfail(
variant_name="with_dtype",
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
)
.skip(
matcher=lambda sample: len(sample.input.shape) == 0,
reason="fixme: SoftMax does not support empty tensor as input",
)
.skip(
variant_name="with_dtype",
matcher=lambda sample: len(sample.input.shape) == 0,
reason="fixme: SoftMax does not support empty tensor as input",
),
TorchLibOpInfo("nn.functional.softplus", nn_ops.aten_softplus).xfail(
dtypes=(torch.float16,),
Expand Down Expand Up @@ -1700,7 +1769,12 @@ def _where_input_wrangler(
variant_name="empty_strides",
reason="fixme: 'shape' do not match: torch.Size([2, 3, 4, 3]) != torch.Size([2, 3, 4, 2]). https://github.com/microsoft/onnxscript/issues/975",
),
TorchLibOpInfo("native_batch_norm", core_ops.aten_native_batch_norm, trace_only=True),
TorchLibOpInfo(
"native_batch_norm",
core_ops.aten_native_batch_norm,
trace_only=True,
tolerance={torch.float16: (9e-3, 7e-4)},
),
TorchLibOpInfo(
"ops.aten._native_batch_norm_legit", core_ops.aten_native_batch_norm, trace_only=True
),
Expand All @@ -1719,9 +1793,11 @@ def _where_input_wrangler(
"ops.aten.native_group_norm",
core_ops.aten_native_group_norm,
trace_only=True,
tolerance={torch.float16: (1e-2, 7e-3)},
).xfail(
dtypes=(torch.float16,),
reason="fixme: 'GroupNormKernelImpl' not implemented for 'Half' in nightly and weekly",
enabled_if=version_utils.torch_older_than("2.2"),
),
TorchLibOpInfo(
"native_layer_norm",
Expand Down Expand Up @@ -1809,7 +1885,11 @@ def _where_input_wrangler(
matcher=lambda sample: len(sample.args) != 1,
reason="this overload is implemented for bias=None",
),
TorchLibOpInfo("nn.functional.linear_bias", nn_ops.aten_linear_bias).skip(
TorchLibOpInfo(
"nn.functional.linear_bias",
nn_ops.aten_linear_bias,
tolerance={torch.float16: (2e-1, 4e-4)},
).skip(
# input: input, args: weight, bias; so len(args) == 2 means bias is provided
matcher=lambda sample: len(sample.args) != 2,
reason="this overload is implemented for bias!=None",
Expand Down Expand Up @@ -2059,8 +2139,8 @@ def _where_input_wrangler(
TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like, trace_only=True),
)

ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims"))
ops_test_common.duplicate_opinfo(OPS_DB, "arange", ("arange_start", "arange_start_step"))
ops_test_common.duplicate_opinfo(OPS_DB, "argmax", ("argmax_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "argmin", ("argmin_dim",))
Expand Down

0 comments on commit a89a2a9

Please sign in to comment.