-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add embedding_renorm code | feat(torchlib) #1098
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
c4d6b2d
add embedding_renorm code
xiaowuhu 72245a9
Merge branch 'main' into xiaowu/AddOp(embedding_renorm)
xiaowuhu 8b9398b
Merge branch 'main' into xiaowu/AddOp(embedding_renorm)
xiaowuhu 200db0d
Merge branch 'xiaowu/AddOp(embedding_renorm)' of https://github.com/m…
xiaowuhu 1fe55e5
fix lint
xiaowuhu 15cc5df
Update core.py
xiaowuhu 49fbcc6
Merge branch 'main' into xiaowu/AddOp(embedding_renorm)
xiaowuhu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3000,6 +3000,57 @@ def aten_embedding_dense_backward( | |
raise NotImplementedError() | ||
|
||
|
||
@torch_op("aten::embedding_renorm", trace_only=True) | ||
def aten_embedding_renorm( | ||
weight: TFloat, indices: INT64, max_norm: float, norm_type: float = 2.0 | ||
) -> TFloat: | ||
"""embedding_renorm(Tensor weight, Tensor indices, float max_norm, float norm_type) -> Tensor""" | ||
|
||
unique_indices = op.Unique(indices) | ||
unique_indices_Y = op.SequenceAt(unique_indices, 0) | ||
# using _onnx private function because op.SrquenceAt(unique_indices, 0) cannot pass module checker | ||
# The error message is: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we report this to ONNX? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Filed an issue: onnx/onnx#5698 |
||
# onnx.onnx_cpp2py_export.shape_inference.InferenceError: | ||
# [ShapeInferenceError] Shape inference error(s): (op_type:aten_embedding_renorm, | ||
# node name: aten_embedding_renorm_0): [ShapeInferenceError] (op_type:SequenceAt, | ||
# node name: n2): input_sequence typestr: S, has unsupported type: tensor(int64) | ||
return aten_embedding_renorm_onnx(weight, unique_indices_Y, max_norm, norm_type) | ||
|
||
|
||
@torch_op("aten::embedding_renorm", private=True) | ||
def aten_embedding_renorm_onnx( | ||
weight: TFloat, indices: INT64, max_norm: float, norm_type: float = 2.0 | ||
justinchuby marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> TFloat: | ||
partial_weight = op.Gather(weight, indices) | ||
# partial_weight_norm = sum(|w|^p)^(1/p) | ||
if norm_type == 1.0: | ||
# This is not necessary, but op.ReduceL1 is faster than function list in 'else' | ||
partial_weight_norm = op.ReduceL1(partial_weight, axes=[1], keepdims=True) | ||
elif norm_type == 2.0: | ||
# This is not necessary, but op.ReduceL2 is faster than function list in 'else' | ||
partial_weight_norm = op.ReduceL2(partial_weight, axes=[1], keepdims=True) | ||
else: | ||
# Abs -> Pow -> ReduceSum -> Pow -> Pow | ||
partial_weight_abs = op.Abs(partial_weight) | ||
partial_weight_pow = op.Pow(partial_weight_abs, op.Constant(value_float=norm_type)) | ||
partial_weight_norm = op.ReduceSum(partial_weight_pow, axes=[1], keepdims=True) | ||
pow_value = op.CastLike(1.0 / norm_type, weight) | ||
partial_weight_norm = op.Pow(partial_weight_norm, pow_value) | ||
|
||
max_norm = op.CastLike(op.Constant(value_float=max_norm), weight) | ||
# This is to avoid weight is zero | ||
err = op.CastLike(op.Constant(value_float=1e-7), weight) | ||
partial_weight_norm_ = op.Add(partial_weight_norm, err) | ||
scales = op.Div(max_norm, partial_weight_norm_) | ||
partial_weight_renorm = op.Mul(partial_weight, scales) | ||
# Set values to renormed values where weight_norm > max_norm, but keep the original values where weight_norm <= max_norm | ||
partial_weight_renorm = op.Where( | ||
op.Greater(partial_weight_norm, max_norm), partial_weight_renorm, partial_weight | ||
) | ||
value = op.ScatterND(weight, op.Unsqueeze(indices, [1]), partial_weight_renorm) | ||
return value | ||
|
||
|
||
def aten_embedding_sparse_backward( | ||
grad: TensorType, | ||
indices: TensorType, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use slice here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nop. It is an unequal length list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in #1123. The output is not a Sequence. That’s why the checker complained