Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check padding density by input of embedding module #19821

Merged
merged 6 commits into from
Apr 10, 2024

Conversation

guyang3532
Copy link
Contributor

@guyang3532 guyang3532 commented Mar 7, 2024

Description

The PaddingElimination optimization is enabled when the density of embedding padding less than 90%. We need to check the density of the embedding padding to decide whether enable the optimization.

Before this pr, we just check the inputs of graph and correlate one with the embedding node by iterate graph from the embedding node back to one graph input.
This is hard to be general because there may be complicated pattern between graph input and embedding node.

This pr check padding density by the direct input of embedding module rather than the input of graph at the first graph execution when exporting onnx graph.
And if the density < 90%, insert a flag PythonOp after the embedding node as:

             Embedding
		  |
            PythonOp (func_name:_FlagPaddingElimination)   (insert if density < 90%)
		  |
            Following graph

When the PaddingElimination is invoked, it check if there is the flag PythonOp(func_name:_FlagPaddingElimination) after the Embedding node and if it is, remove it and do the padding elimination optimization.

@guyang3532 guyang3532 requested a review from pengwa March 7, 2024 12:53
@guyang3532 guyang3532 force-pushed the yangu/check_padding_input branch from fded099 to 372ab79 Compare March 7, 2024 13:12
@pengwa pengwa added the training issues related to ONNX Runtime training; typically submitted using template label Mar 7, 2024
@guyang3532 guyang3532 force-pushed the yangu/check_padding_input branch 2 times, most recently from 6568886 to 3d8614a Compare March 11, 2024 06:19
@guyang3532 guyang3532 force-pushed the yangu/check_padding_input branch 2 times, most recently from f3eb16d to 6027f03 Compare March 11, 2024 08:58
@guyang3532 guyang3532 force-pushed the yangu/check_padding_input branch from 6027f03 to 280f127 Compare April 8, 2024 07:39
@guyang3532 guyang3532 force-pushed the yangu/check_padding_input branch from 280f127 to 0beaed7 Compare April 9, 2024 13:32
pengwa
pengwa previously approved these changes Apr 10, 2024
@pengwa pengwa merged commit 471e969 into main Apr 10, 2024
92 of 94 checks passed
@pengwa pengwa deleted the yangu/check_padding_input branch April 10, 2024 10:45
TedThemistokleous pushed a commit to TedThemistokleous/onnxruntime that referenced this pull request May 7, 2024
### Description
The PaddingElimination optimization is enabled when the density of
embedding padding less than 90%. We need to check the density of the
embedding padding to decide whether enable the optimization.

Before this pr, we just check the inputs of graph and correlate one with
the embedding node by iterate graph from the embedding node back to one
graph input.
This is hard to be general because there may be complicated pattern
between graph input and embedding node.

This pr check padding density by the direct input of embedding module
rather than the input of graph at the first graph execution when
exporting onnx graph.
And if the density < 90%, insert a flag PythonOp after the embedding
node as:
```
             Embedding
		  |
            PythonOp (func_name:_FlagPaddingElimination)   (insert if density < 90%)
		  |
            Following graph
```

When the PaddingElimination is invoked, it check if there is the flag
PythonOp(func_name:_FlagPaddingElimination) after the Embedding node and
if it is, remove it and do the padding elimination optimization.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants