Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add embedding_renorm code | feat(torchlib) #1098

Merged
merged 7 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3000,6 +3000,57 @@ def aten_embedding_dense_backward(
raise NotImplementedError()


@torch_op("aten::embedding_renorm", trace_only=True)
def aten_embedding_renorm(
weight: TFloat, indices: INT64, max_norm: float, norm_type: float = 2.0
) -> TFloat:
"""embedding_renorm(Tensor weight, Tensor indices, float max_norm, float norm_type) -> Tensor"""

unique_indices = op.Unique(indices)
unique_indices_Y = op.SequenceAt(unique_indices, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use slice here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nop. It is an unequal length list.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #1123. The output is not a Sequence. That’s why the checker complained

# using _onnx private function because op.SrquenceAt(unique_indices, 0) cannot pass module checker
# The error message is:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we report this to ONNX?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed an issue: onnx/onnx#5698

# onnx.onnx_cpp2py_export.shape_inference.InferenceError:
# [ShapeInferenceError] Shape inference error(s): (op_type:aten_embedding_renorm,
# node name: aten_embedding_renorm_0): [ShapeInferenceError] (op_type:SequenceAt,
# node name: n2): input_sequence typestr: S, has unsupported type: tensor(int64)
return aten_embedding_renorm_onnx(weight, unique_indices_Y, max_norm, norm_type)


@torch_op("aten::embedding_renorm", private=True)
def aten_embedding_renorm_onnx(
weight: TFloat, indices: INT64, max_norm: float, norm_type: float = 2.0
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
) -> TFloat:
partial_weight = op.Gather(weight, indices)
# partial_weight_norm = sum(|w|^p)^(1/p)
if norm_type == 1.0:
# This is not necessary, but op.ReduceL1 is faster than function list in 'else'
partial_weight_norm = op.ReduceL1(partial_weight, axes=[1], keepdims=True)
elif norm_type == 2.0:
# This is not necessary, but op.ReduceL2 is faster than function list in 'else'
partial_weight_norm = op.ReduceL2(partial_weight, axes=[1], keepdims=True)
else:
# Abs -> Pow -> ReduceSum -> Pow -> Pow
partial_weight_abs = op.Abs(partial_weight)
partial_weight_pow = op.Pow(partial_weight_abs, op.Constant(value_float=norm_type))
partial_weight_norm = op.ReduceSum(partial_weight_pow, axes=[1], keepdims=True)
pow_value = op.CastLike(1.0 / norm_type, weight)
partial_weight_norm = op.Pow(partial_weight_norm, pow_value)

max_norm = op.CastLike(op.Constant(value_float=max_norm), weight)
# This is to avoid weight is zero
err = op.CastLike(op.Constant(value_float=1e-7), weight)
partial_weight_norm_ = op.Add(partial_weight_norm, err)
scales = op.Div(max_norm, partial_weight_norm_)
partial_weight_renorm = op.Mul(partial_weight, scales)
# Set values to renormed values where weight_norm > max_norm, but keep the original values where weight_norm <= max_norm
partial_weight_renorm = op.Where(
op.Greater(partial_weight_norm, max_norm), partial_weight_renorm, partial_weight
)
value = op.ScatterND(weight, op.Unsqueeze(indices, [1]), partial_weight_renorm)
return value


def aten_embedding_sparse_backward(
grad: TensorType,
indices: TensorType,
Expand Down
37 changes: 37 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,36 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra
yield opinfo_core.SampleInput(t, kwargs={"p": p})


def sample_inputs_embedding_renorm(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs

def make_input(shape):
return common_methods_invocations.make_tensor(
shape, device=device, dtype=dtype, requires_grad=requires_grad
)

def make_long_input(shape, *, low, high, noncontiguous=False):
return common_methods_invocations.make_tensor(
shape,
device=device,
dtype=torch.long,
low=low,
high=high,
noncontiguous=noncontiguous,
)

for max_norm in (0.5, 1.0, 5.0):
for norm_type in (0.8, 1.0, 2.0, 2.5):
idx = make_long_input((6,), low=0, high=S)
weights = make_input((S, S)) * 2
yield common_methods_invocations.SampleInput(
weights,
args=(idx,),
kwargs={"max_norm": max_norm, "norm_type": norm_type},
)


def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs
Expand Down Expand Up @@ -1240,6 +1270,13 @@ def sample_inputs_scaled_dot_product_flash_attention(
sample_inputs_func=sample_inputs_embedding_bag_padding_idx,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.embedding_renorm",
aten_name="embedding_renorm",
dtypes=common_dtype.floating_types_and_half(),
sample_inputs_func=sample_inputs_embedding_renorm,
supports_out=False,
),
opinfo_core.OpInfo(
"nn.functional.conv3d",
aten_name="conv3d",
Expand Down
7 changes: 7 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,13 @@ def _where_input_wrangler(
tolerance={torch.float16: (1e-2, 1e-2)},
compare_shape_only_for_output=(1, 2, 3),
),
TorchLibOpInfo(
"ops.aten.embedding_renorm",
core_ops.aten_embedding_renorm,
tolerance={torch.float16: (1e-2, 1e-2)},
trace_only=True,
compare_shape_only_for_output=(1, 2, 3),
),
TorchLibOpInfo(
"nn.functional.embedding",
core_ops.aten_embedding,
Expand Down