diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e3484523a..47dfbeb10 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3000,6 +3000,57 @@ def aten_embedding_dense_backward( raise NotImplementedError() +@torch_op("aten::embedding_renorm", trace_only=True) +def aten_embedding_renorm( + weight: TFloat, indices: INT64, max_norm: float, norm_type: float = 2.0 +) -> TFloat: + """embedding_renorm(Tensor weight, Tensor indices, float max_norm, float norm_type) -> Tensor""" + + unique_indices = op.Unique(indices) + unique_indices_Y = op.SequenceAt(unique_indices, 0) + # using _onnx private function because op.SrquenceAt(unique_indices, 0) cannot pass module checker + # The error message is: + # onnx.onnx_cpp2py_export.shape_inference.InferenceError: + # [ShapeInferenceError] Shape inference error(s): (op_type:aten_embedding_renorm, + # node name: aten_embedding_renorm_0): [ShapeInferenceError] (op_type:SequenceAt, + # node name: n2): input_sequence typestr: S, has unsupported type: tensor(int64) + return aten_embedding_renorm_onnx(weight, unique_indices_Y, max_norm, norm_type) + + +@torch_op("aten::embedding_renorm", private=True) +def aten_embedding_renorm_onnx( + weight: TFloat, indices: INT64, max_norm: float, norm_type: float = 2.0 +) -> TFloat: + partial_weight = op.Gather(weight, indices) + # partial_weight_norm = sum(|w|^p)^(1/p) + if norm_type == 1.0: + # This is not necessary, but op.ReduceL1 is faster than function list in 'else' + partial_weight_norm = op.ReduceL1(partial_weight, axes=[1], keepdims=True) + elif norm_type == 2.0: + # This is not necessary, but op.ReduceL2 is faster than function list in 'else' + partial_weight_norm = op.ReduceL2(partial_weight, axes=[1], keepdims=True) + else: + # Abs -> Pow -> ReduceSum -> Pow -> Pow + partial_weight_abs = op.Abs(partial_weight) + partial_weight_pow = op.Pow(partial_weight_abs, op.Constant(value_float=norm_type)) + partial_weight_norm = op.ReduceSum(partial_weight_pow, axes=[1], keepdims=True) + pow_value = op.CastLike(1.0 / norm_type, weight) + partial_weight_norm = op.Pow(partial_weight_norm, pow_value) + + max_norm = op.CastLike(op.Constant(value_float=max_norm), weight) + # This is to avoid weight is zero + err = op.CastLike(op.Constant(value_float=1e-7), weight) + partial_weight_norm_ = op.Add(partial_weight_norm, err) + scales = op.Div(max_norm, partial_weight_norm_) + partial_weight_renorm = op.Mul(partial_weight, scales) + # Set values to renormed values where weight_norm > max_norm, but keep the original values where weight_norm <= max_norm + partial_weight_renorm = op.Where( + op.Greater(partial_weight_norm, max_norm), partial_weight_renorm, partial_weight + ) + value = op.ScatterND(weight, op.Unsqueeze(indices, [1]), partial_weight_renorm) + return value + + def aten_embedding_sparse_backward( grad: TensorType, indices: TensorType, diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 066e87f45..352413b7c 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -818,6 +818,36 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra yield opinfo_core.SampleInput(t, kwargs={"p": p}) +def sample_inputs_embedding_renorm(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + def make_input(shape): + return common_methods_invocations.make_tensor( + shape, device=device, dtype=dtype, requires_grad=requires_grad + ) + + def make_long_input(shape, *, low, high, noncontiguous=False): + return common_methods_invocations.make_tensor( + shape, + device=device, + dtype=torch.long, + low=low, + high=high, + noncontiguous=noncontiguous, + ) + + for max_norm in (0.5, 1.0, 5.0): + for norm_type in (0.8, 1.0, 2.0, 2.5): + idx = make_long_input((6,), low=0, high=S) + weights = make_input((S, S)) * 2 + yield common_methods_invocations.SampleInput( + weights, + args=(idx,), + kwargs={"max_norm": max_norm, "norm_type": norm_type}, + ) + + def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs @@ -1240,6 +1270,13 @@ def sample_inputs_scaled_dot_product_flash_attention( sample_inputs_func=sample_inputs_embedding_bag_padding_idx, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.embedding_renorm", + aten_name="embedding_renorm", + dtypes=common_dtype.floating_types_and_half(), + sample_inputs_func=sample_inputs_embedding_renorm, + supports_out=False, + ), opinfo_core.OpInfo( "nn.functional.conv3d", aten_name="conv3d", diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index c23729ee5..7c6a64b49 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1044,6 +1044,13 @@ def _where_input_wrangler( tolerance={torch.float16: (1e-2, 1e-2)}, compare_shape_only_for_output=(1, 2, 3), ), + TorchLibOpInfo( + "ops.aten.embedding_renorm", + core_ops.aten_embedding_renorm, + tolerance={torch.float16: (1e-2, 1e-2)}, + trace_only=True, + compare_shape_only_for_output=(1, 2, 3), + ), TorchLibOpInfo( "nn.functional.embedding", core_ops.aten_embedding,