Skip to content

Commit

Permalink
Add embedding_renorm code | feat(torchlib) (#1098)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaowuhu authored Oct 24, 2023
1 parent 2a610c4 commit 0035390
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 0 deletions.
51 changes: 51 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3000,6 +3000,57 @@ def aten_embedding_dense_backward(
raise NotImplementedError()


@torch_op("aten::embedding_renorm", trace_only=True)
def aten_embedding_renorm(
weight: TFloat, indices: INT64, max_norm: float, norm_type: float = 2.0
) -> TFloat:
"""embedding_renorm(Tensor weight, Tensor indices, float max_norm, float norm_type) -> Tensor"""

unique_indices = op.Unique(indices)
unique_indices_Y = op.SequenceAt(unique_indices, 0)
# using _onnx private function because op.SrquenceAt(unique_indices, 0) cannot pass module checker
# The error message is:
# onnx.onnx_cpp2py_export.shape_inference.InferenceError:
# [ShapeInferenceError] Shape inference error(s): (op_type:aten_embedding_renorm,
# node name: aten_embedding_renorm_0): [ShapeInferenceError] (op_type:SequenceAt,
# node name: n2): input_sequence typestr: S, has unsupported type: tensor(int64)
return aten_embedding_renorm_onnx(weight, unique_indices_Y, max_norm, norm_type)


@torch_op("aten::embedding_renorm", private=True)
def aten_embedding_renorm_onnx(
weight: TFloat, indices: INT64, max_norm: float, norm_type: float = 2.0
) -> TFloat:
partial_weight = op.Gather(weight, indices)
# partial_weight_norm = sum(|w|^p)^(1/p)
if norm_type == 1.0:
# This is not necessary, but op.ReduceL1 is faster than function list in 'else'
partial_weight_norm = op.ReduceL1(partial_weight, axes=[1], keepdims=True)
elif norm_type == 2.0:
# This is not necessary, but op.ReduceL2 is faster than function list in 'else'
partial_weight_norm = op.ReduceL2(partial_weight, axes=[1], keepdims=True)
else:
# Abs -> Pow -> ReduceSum -> Pow -> Pow
partial_weight_abs = op.Abs(partial_weight)
partial_weight_pow = op.Pow(partial_weight_abs, op.Constant(value_float=norm_type))
partial_weight_norm = op.ReduceSum(partial_weight_pow, axes=[1], keepdims=True)
pow_value = op.CastLike(1.0 / norm_type, weight)
partial_weight_norm = op.Pow(partial_weight_norm, pow_value)

max_norm = op.CastLike(op.Constant(value_float=max_norm), weight)
# This is to avoid weight is zero
err = op.CastLike(op.Constant(value_float=1e-7), weight)
partial_weight_norm_ = op.Add(partial_weight_norm, err)
scales = op.Div(max_norm, partial_weight_norm_)
partial_weight_renorm = op.Mul(partial_weight, scales)
# Set values to renormed values where weight_norm > max_norm, but keep the original values where weight_norm <= max_norm
partial_weight_renorm = op.Where(
op.Greater(partial_weight_norm, max_norm), partial_weight_renorm, partial_weight
)
value = op.ScatterND(weight, op.Unsqueeze(indices, [1]), partial_weight_renorm)
return value


def aten_embedding_sparse_backward(
grad: TensorType,
indices: TensorType,
Expand Down
37 changes: 37 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,36 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra
yield opinfo_core.SampleInput(t, kwargs={"p": p})


def sample_inputs_embedding_renorm(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs

def make_input(shape):
return common_methods_invocations.make_tensor(
shape, device=device, dtype=dtype, requires_grad=requires_grad
)

def make_long_input(shape, *, low, high, noncontiguous=False):
return common_methods_invocations.make_tensor(
shape,
device=device,
dtype=torch.long,
low=low,
high=high,
noncontiguous=noncontiguous,
)

for max_norm in (0.5, 1.0, 5.0):
for norm_type in (0.8, 1.0, 2.0, 2.5):
idx = make_long_input((6,), low=0, high=S)
weights = make_input((S, S)) * 2
yield common_methods_invocations.SampleInput(
weights,
args=(idx,),
kwargs={"max_norm": max_norm, "norm_type": norm_type},
)


def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs
Expand Down Expand Up @@ -1240,6 +1270,13 @@ def sample_inputs_scaled_dot_product_flash_attention(
sample_inputs_func=sample_inputs_embedding_bag_padding_idx,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.embedding_renorm",
aten_name="embedding_renorm",
dtypes=common_dtype.floating_types_and_half(),
sample_inputs_func=sample_inputs_embedding_renorm,
supports_out=False,
),
opinfo_core.OpInfo(
"nn.functional.conv3d",
aten_name="conv3d",
Expand Down
7 changes: 7 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,13 @@ def _where_input_wrangler(
tolerance={torch.float16: (1e-2, 1e-2)},
compare_shape_only_for_output=(1, 2, 3),
),
TorchLibOpInfo(
"ops.aten.embedding_renorm",
core_ops.aten_embedding_renorm,
tolerance={torch.float16: (1e-2, 1e-2)},
trace_only=True,
compare_shape_only_for_output=(1, 2, 3),
),
TorchLibOpInfo(
"nn.functional.embedding",
core_ops.aten_embedding,
Expand Down

0 comments on commit 0035390

Please sign in to comment.