Skip to content

Commit

Permalink
Mark more functions as traceable | feat(torchlib) (#1300)
Browse files Browse the repository at this point in the history
- Mark functions like `aten::t` to be traceable to eliminate more
conditional branches and casts during tracing.
`TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1`
- Fix handling of Sequence inputs in IsScalar during trace mode.
- Fix tests
  • Loading branch information
justinchuby authored Mar 14, 2024
1 parent 0de0d30 commit d681dbd
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 28 deletions.
4 changes: 4 additions & 0 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ def eval_function( # type: ignore[override]
else:
# Fall to call add_function_call
pass
elif isinstance(args[0], Sequence):
return False
else:
# Python constants are scalars
return True
Expand All @@ -363,6 +365,8 @@ def eval_function( # type: ignore[override]
else:
# Fall to call add_function_call
pass
elif isinstance(args[0], Sequence):
return False
else:
# Python constants are scalars
return 0
Expand Down
21 changes: 12 additions & 9 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ def aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = Fals
return result


@torch_op("aten::argmin")
@torch_op("aten::argmin", traceable=True)
def aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
"""argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""

Expand All @@ -734,7 +734,7 @@ def aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
return result


@torch_op("aten::argmin")
@torch_op("aten::argmin", traceable=True)
def aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64:
"""argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""

Expand Down Expand Up @@ -3238,7 +3238,7 @@ def aten_exp(self: TFloat) -> TFloat:
return op.Exp(self)


@torch_op("aten::exp2")
@torch_op("aten::exp2", traceable=True)
def aten_exp2(self: TFloat) -> TFloat:
"""exp2(Tensor self) -> Tensor"""

Expand All @@ -3257,7 +3257,7 @@ def aten_expand(self: TTensor, size: TInt) -> TTensor:
return op.Expand(self, size)


@torch_op("aten::expand_as")
@torch_op("aten::expand_as", traceable=True)
def aten_expand_as(self: TTensor, other: TTensor) -> TTensor:
"""expand_as(Tensor(a) self, Tensor other) -> Tensor(a)"""

Expand Down Expand Up @@ -4906,7 +4906,10 @@ def aten_margin_ranking_loss(
raise NotImplementedError()


@torch_op(("aten::masked_fill", "aten::masked_fill.Scalar", "aten::masked_fill.Tensor"))
@torch_op(
("aten::masked_fill", "aten::masked_fill.Scalar", "aten::masked_fill.Tensor"),
traceable=True,
)
def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor:
"""masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor"""
# NOTE: Do not attempt to cast `mask` to BOOL because mask should not take any other types.
Expand Down Expand Up @@ -5037,7 +5040,7 @@ def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
else:
if IsScalar(dim):
dim = op.Unsqueeze(dim, axes=0)
result = op.ReduceMean(self, axes=dim, keepdims=keepdim)
result = op.ReduceMean(self, dim, keepdims=keepdim)
return result


Expand Down Expand Up @@ -7482,7 +7485,7 @@ def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrB
return result


@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"))
@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"), traceable=True)
def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""

Expand Down Expand Up @@ -7887,7 +7890,7 @@ def aten_symeig(
raise NotImplementedError()


@torch_op("aten::t")
@torch_op("aten::t", traceable=True)
def aten_t(self: TTensor) -> TTensor:
"""t(Tensor(a) self) -> Tensor(a)"""

Expand Down Expand Up @@ -8063,7 +8066,7 @@ def aten_to_sparse_csr(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::topk")
@torch_op("aten::topk", traceable=True)
def aten_topk(
self: TReal, k: INT64, dim: int = -1, largest: bool = True, sorted: bool = True
) -> Tuple[TReal, INT64]:
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
private=True,
complex=True,
traceable=True,
)
def _fftn_onnx_normalization(
self,
Expand All @@ -35,7 +36,7 @@ def _fftn_onnx_normalization(
) -> TFloat:
# Obtain the total_sample_count (n) for normalization
self_shape = op.Shape(self)
total_sample_count = op.ReduceProd(self_shape[dims], keepdims=0)
total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0)
total_sample_count = op.CastLike(total_sample_count, transformed)

# Normalize the result
Expand Down
2 changes: 2 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def _aten_linalg_vector_norm_no_dim_onnx(self: TFloat, ord: float, keepdim: bool

self = op.Abs(self)
ord = op.Cast(ord, to=FLOAT.dtype) # Must be FLOAT, due to op.IsInf() needs FLOAT
# TODO(justinchuby): Evaluate IsInf in trace mode
if op.IsInf(ord, detect_negative=0, detect_positive=1):
result = op.ReduceMax(self, keepdims=keepdim)
elif op.IsInf(ord, detect_negative=1, detect_positive=0):
Expand Down Expand Up @@ -373,6 +374,7 @@ def _aten_linalg_vector_norm_onnx(
dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
self = op.Abs(self)
ord = op.Cast(ord, to=FLOAT.dtype) # Must be FLOAT, due to op.IsInf() needs FLOAT
# TODO(justinchuby): Evaluate IsInf in trace mode
if op.IsInf(ord, detect_negative=0, detect_positive=1):
result = op.ReduceMax(self, dim, keepdims=keepdim)
elif op.IsInf(ord, detect_negative=1, detect_positive=0):
Expand Down
31 changes: 16 additions & 15 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
TFloatUnlessFloat32 = TypeVar("TFloatUnlessFloat32", bound=Union[BFLOAT16, FLOAT16, DOUBLE])


@torch_op("aten::aten_adaptive_avg_pool1d")
@torch_op("aten::aten_adaptive_avg_pool1d", traceable=True)
def aten_adaptive_avg_pool1d(self: TFloat, output_size: INT64[1]) -> TFloat:
"""adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor"""

Expand All @@ -58,7 +58,7 @@ def aten_adaptive_avg_pool1d(self: TFloat, output_size: INT64[1]) -> TFloat:
return result


@torch_op("aten::aten_adaptive_avg_pool2d")
@torch_op("aten::aten_adaptive_avg_pool2d", traceable=True)
def aten_adaptive_avg_pool2d(self: TFloat, output_size: INT64[2]) -> TFloat:
"""adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor"""

Expand All @@ -76,7 +76,7 @@ def aten_adaptive_avg_pool2d(self: TFloat, output_size: INT64[2]) -> TFloat:
return result


@torch_op("aten::aten_adaptive_avg_pool3d")
@torch_op("aten::aten_adaptive_avg_pool3d", traceable=True)
def aten_adaptive_avg_pool3d(self: TFloat, output_size: INT64[3]) -> TFloat:
"""adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor"""

Expand Down Expand Up @@ -350,7 +350,7 @@ def aten_celu(self: FLOAT, alpha: float = 1.0) -> FLOAT:
return op.Celu(self, alpha=alpha) # op.Celu only support float32


@torch_op("aten::celu")
@torch_op("aten::celu", traceable=True)
def aten_celu_type_promoted(
self: TFloatUnlessFloat32, alpha: float = 1.0
) -> TFloatUnlessFloat32:
Expand Down Expand Up @@ -409,7 +409,7 @@ def aten_conv_depthwise3d(
raise NotImplementedError()


@torch_op("aten::cross_entropy_loss")
@torch_op("aten::cross_entropy_loss", traceable=True)
def aten_cross_entropy_loss(
self: TFloatOrBFloat16,
target: IntType,
Expand Down Expand Up @@ -871,7 +871,7 @@ def aten_max_pool2d(
return _aten_max_pool_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode, 3)


@torch_op("internal::max_pool", private=True)
@torch_op("internal::max_pool", private=True, traceable=True)
def _aten_max_pool_onnx(
self: TFloatOrUInt8,
kernel_shape: Sequence[int],
Expand Down Expand Up @@ -1003,7 +1003,7 @@ def aten_max_pool3d_with_indices(
)


@torch_op("internal::max_pool_with_indices", private=True)
@torch_op("internal::max_pool_with_indices", private=True, traceable=True)
def _aten_max_pool_with_indices_onnx(
self: TFloatOrUInt8,
kernel_size: Sequence[int],
Expand Down Expand Up @@ -1159,7 +1159,7 @@ def aten_mkldnn_reorder_conv3d_weight(
raise NotImplementedError()


@torch_op("aten::mse_loss")
@torch_op("aten::mse_loss", traceable=True)
def aten_mse_loss(self: TReal, target: TReal, reduction: int = 1) -> TReal:
"""mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor"""
# FIXME: When reduction=0, the shape(result) will be different than other case
Expand Down Expand Up @@ -1235,7 +1235,7 @@ def aten_multilabel_margin_loss_forward(
raise NotImplementedError()


@torch_op("aten::nll_loss")
@torch_op("aten::nll_loss", traceable=True)
def aten_nll_loss(
self: TFloat,
target: INT64,
Expand All @@ -1248,7 +1248,7 @@ def aten_nll_loss(
if self_rank_is_1: # self rank should be at least 2
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))

rank_target = op.Size(op.Shape(target))
rank_target = Rank(target)
if rank_target == 0: # target rank should be at least 1
target = op.Unsqueeze(target, op.Constant(value_ints=[0]))

Expand All @@ -1271,7 +1271,7 @@ def aten_nll_loss(
return result


@torch_op("aten::nll_loss")
@torch_op("aten::nll_loss", traceable=True)
def aten_nll_loss_weight(
self: TFloat,
target: INT64,
Expand All @@ -1282,10 +1282,11 @@ def aten_nll_loss_weight(
"""nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor"""

self_rank_is_1 = Rank(self) == 1
if self_rank_is_1: # self rank should be at least 2
if self_rank_is_1:
# self rank should be at least 2
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))

rank_target = op.Size(op.Shape(target))
rank_target = Rank(target)
if rank_target == 0: # target rank should be at least 1
target = op.Unsqueeze(target, op.Constant(value_ints=[0]))

Expand Down Expand Up @@ -1490,7 +1491,7 @@ def aten_relu(self: TReal) -> TReal:
return op.Relu(self)


@torch_op("aten::relu6")
@torch_op("aten::relu6", traceable=True)
def aten_relu6(self: TReal) -> TReal:
"""relu6(Tensor self) -> Tensor"""

Expand Down Expand Up @@ -1778,7 +1779,7 @@ def aten__scaled_dot_product_flash_attention(
)


@torch_op("aten::_scaled_dot_product_efficient_attention", private=True)
@torch_op("aten::_scaled_dot_product_efficient_attention", private=True, traceable=True)
def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs(
query: TFloat,
compute_log_sumexp: bool,
Expand Down
23 changes: 21 additions & 2 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def _where_input_wrangler(
enabled_if=version_utils.torch_older_than("2.2"),
)
.xfail(
enabled_if=version_utils.onnxruntime_older_than("1.17"),
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
Expand All @@ -493,6 +494,7 @@ def _where_input_wrangler(
enabled_if=version_utils.torch_older_than("2.2"),
)
.xfail(
enabled_if=version_utils.onnxruntime_older_than("1.17"),
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
Expand Down Expand Up @@ -544,6 +546,13 @@ def _where_input_wrangler(
"decomposed",
dtypes=(torch.int16, torch.int32, torch.int64),
reason="ONNX Runtime does not support int inputs to Gemm",
)
.xfail(
"decomposed",
matcher=lambda sample: torch.numel(sample.input) == 0
or torch.numel(sample.args[0]) == 0
or torch.numel(sample.args[1]) == 0,
reason="ONNX Runtime does not support zero sized inputs",
),
TorchLibOpInfo("addmv", core_ops.aten_addmv, tolerance={torch.float16: (1e-3, 1e-2)}),
TorchLibOpInfo(
Expand Down Expand Up @@ -1545,6 +1554,7 @@ def _where_input_wrangler(
"t",
core_ops.aten_t,
).xfail(
enabled_if=not _flags.EXPERIMENTAL_PREFER_TRACING,
reason="fixme: ORT Graph attribute inferencing failed on rank-1 input. https://github.com/onnx/onnx/issues/4986",
test_class_name="TestOutputConsistencyFullGraph",
),
Expand Down Expand Up @@ -1895,7 +1905,11 @@ def _where_input_wrangler(
"native_layer_norm",
core_ops.aten_native_layer_norm,
trace_only=True,
tolerance={torch.float32: (3.7e-5, 1.8e-4)},
tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (1e-1, 7e-4)},
).skip(
dtypes=(torch.float16,),
device_type="cpu",
reason="native_layer_norm outputs different dtypes on CPU and CUDA. Our implematation is based on that for CUDA",
),
TorchLibOpInfo(
"nn.functional.avg_pool1d",
Expand Down Expand Up @@ -2091,9 +2105,14 @@ def _where_input_wrangler(
# Output[0] is OK, but other outputs just have the same shape with zero values
nondeterministic=True,
compare_shape_only_for_output=(1, 2, 3, 4, 5, 6, 7, 8),
).skip(
)
.skip(
enabled_if=version_utils.torch_older_than("2.1"),
reason="The operator is not supported in older version.",
)
.skip(
device_type="cpu",
reason="_scaled_dot_product_flash_attention only supports CUDA",
),
TorchLibOpInfo(
"ops.aten._scaled_dot_product_efficient_attention",
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
setuptools>=61.0.0
numpy
onnx-weekly>=1.16.0.dev20231204
onnxruntime>=1.15.1
onnxruntime>=1.17.0
typing_extensions

# Docs site
Expand Down

0 comments on commit d681dbd

Please sign in to comment.