From 2ab70ca103adc7ac743f3d0cf6fd7a3daca60394 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 Nov 2023 12:04:08 -0700 Subject: [PATCH] Script embedding_renorm | fix(torchlib) (#1123) Script embedding_renorm by correctly unpacking the outputs of `op.Unique`. Removed the trace_only function. --- .../function_libs/torch_lib/ops/core.py | 22 ++++--------------- .../function_libs/torch_lib/ops_test_data.py | 5 ++--- 2 files changed, 6 insertions(+), 21 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index fe72facb9..a3abd04d3 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3003,28 +3003,14 @@ def aten_embedding_dense_backward( raise NotImplementedError() -@torch_op("aten::embedding_renorm", trace_only=True) +@torch_op("aten::embedding_renorm") def aten_embedding_renorm( weight: TFloat, indices: INT64, max_norm: float, norm_type: float = 2.0 ) -> TFloat: """embedding_renorm(Tensor weight, Tensor indices, float max_norm, float norm_type) -> Tensor""" - unique_indices = op.Unique(indices) - unique_indices_Y = op.SequenceAt(unique_indices, 0) - # using _onnx private function because op.SrquenceAt(unique_indices, 0) cannot pass module checker - # The error message is: - # onnx.onnx_cpp2py_export.shape_inference.InferenceError: - # [ShapeInferenceError] Shape inference error(s): (op_type:aten_embedding_renorm, - # node name: aten_embedding_renorm_0): [ShapeInferenceError] (op_type:SequenceAt, - # node name: n2): input_sequence typestr: S, has unsupported type: tensor(int64) - return aten_embedding_renorm_onnx(weight, unique_indices_Y, max_norm, norm_type) - - -@torch_op("aten::embedding_renorm", private=True) -def aten_embedding_renorm_onnx( - weight: TFloat, indices: INT64, max_norm: float, norm_type: float = 2.0 -) -> TFloat: - partial_weight = op.Gather(weight, indices) + unique_indices, _, _, _ = op.Unique(indices) + partial_weight = op.Gather(weight, unique_indices) # partial_weight_norm = sum(|w|^p)^(1/p) if norm_type == 1.0: # This is not necessary, but op.ReduceL1 is faster than function list in 'else' @@ -3050,7 +3036,7 @@ def aten_embedding_renorm_onnx( partial_weight_renorm = op.Where( op.Greater(partial_weight_norm, max_norm), partial_weight_renorm, partial_weight ) - value = op.ScatterND(weight, op.Unsqueeze(indices, [1]), partial_weight_renorm) + value = op.ScatterND(weight, op.Unsqueeze(unique_indices, [1]), partial_weight_renorm) return value diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 976b9a150..7af316b36 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -80,8 +80,8 @@ class TorchLibOpInfo: nondeterministic: bool = False # Whether to compare the shape only for the output[index] # For example: (1,2) means compare value for output[0] and shape for output[1] and [2] - # We may be able to combine this with the nondeterminstic option - compare_shape_only_for_output: tuple[int] = () + # We may be able to combine this with the nondeterministic option + compare_shape_only_for_output: tuple[int, ...] = () # Whether the function is designed for complex inputs complex: bool = False # The acceptable tolerance of the inference result difference between PyTorch and ORT. @@ -1073,7 +1073,6 @@ def _where_input_wrangler( "ops.aten.embedding_renorm", core_ops.aten_embedding_renorm, tolerance={torch.float16: (1e-2, 1e-2)}, - trace_only=True, compare_shape_only_for_output=(1, 2, 3), ), TorchLibOpInfo(