Skip to content

Commit

Permalink
Script embedding_renorm | fix(torchlib) (#1123)
Browse files Browse the repository at this point in the history
Script embedding_renorm by correctly unpacking the outputs of
`op.Unique`. Removed the trace_only function.
  • Loading branch information
justinchuby authored Nov 2, 2023
1 parent 7494498 commit 2ab70ca
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 21 deletions.
22 changes: 4 additions & 18 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3003,28 +3003,14 @@ def aten_embedding_dense_backward(
raise NotImplementedError()


@torch_op("aten::embedding_renorm", trace_only=True)
@torch_op("aten::embedding_renorm")
def aten_embedding_renorm(
weight: TFloat, indices: INT64, max_norm: float, norm_type: float = 2.0
) -> TFloat:
"""embedding_renorm(Tensor weight, Tensor indices, float max_norm, float norm_type) -> Tensor"""

unique_indices = op.Unique(indices)
unique_indices_Y = op.SequenceAt(unique_indices, 0)
# using _onnx private function because op.SrquenceAt(unique_indices, 0) cannot pass module checker
# The error message is:
# onnx.onnx_cpp2py_export.shape_inference.InferenceError:
# [ShapeInferenceError] Shape inference error(s): (op_type:aten_embedding_renorm,
# node name: aten_embedding_renorm_0): [ShapeInferenceError] (op_type:SequenceAt,
# node name: n2): input_sequence typestr: S, has unsupported type: tensor(int64)
return aten_embedding_renorm_onnx(weight, unique_indices_Y, max_norm, norm_type)


@torch_op("aten::embedding_renorm", private=True)
def aten_embedding_renorm_onnx(
weight: TFloat, indices: INT64, max_norm: float, norm_type: float = 2.0
) -> TFloat:
partial_weight = op.Gather(weight, indices)
unique_indices, _, _, _ = op.Unique(indices)
partial_weight = op.Gather(weight, unique_indices)
# partial_weight_norm = sum(|w|^p)^(1/p)
if norm_type == 1.0:
# This is not necessary, but op.ReduceL1 is faster than function list in 'else'
Expand All @@ -3050,7 +3036,7 @@ def aten_embedding_renorm_onnx(
partial_weight_renorm = op.Where(
op.Greater(partial_weight_norm, max_norm), partial_weight_renorm, partial_weight
)
value = op.ScatterND(weight, op.Unsqueeze(indices, [1]), partial_weight_renorm)
value = op.ScatterND(weight, op.Unsqueeze(unique_indices, [1]), partial_weight_renorm)
return value


Expand Down
5 changes: 2 additions & 3 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ class TorchLibOpInfo:
nondeterministic: bool = False
# Whether to compare the shape only for the output[index]
# For example: (1,2) means compare value for output[0] and shape for output[1] and [2]
# We may be able to combine this with the nondeterminstic option
compare_shape_only_for_output: tuple[int] = ()
# We may be able to combine this with the nondeterministic option
compare_shape_only_for_output: tuple[int, ...] = ()
# Whether the function is designed for complex inputs
complex: bool = False
# The acceptable tolerance of the inference result difference between PyTorch and ORT.
Expand Down Expand Up @@ -1073,7 +1073,6 @@ def _where_input_wrangler(
"ops.aten.embedding_renorm",
core_ops.aten_embedding_renorm,
tolerance={torch.float16: (1e-2, 1e-2)},
trace_only=True,
compare_shape_only_for_output=(1, 2, 3),
),
TorchLibOpInfo(
Expand Down

0 comments on commit 2ab70ca

Please sign in to comment.