Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix transformer layer detection for recompute #20106

Merged
merged 10 commits into from
Mar 29, 2024
Merged

Conversation

pengwa
Copy link
Contributor

@pengwa pengwa commented Mar 27, 2024

Fix transformer layer detection for recompute

Originally logic miss detecting the layer boudary node in Mistral model. This PR simplifies the searching, by using more strong pattern's match, to make sure it is flexible enough to cover different transformer variants.

Also add a UT.

Add a warning when user enable layerwise recompute but no layer boudary nodes are found.

@pengwa pengwa added the training issues related to ONNX Runtime training; typically submitted using template label Mar 27, 2024
@pengwa pengwa merged commit 2092beb into main Mar 29, 2024
95 checks passed
@pengwa pengwa deleted the pengwa/priority_tuning branch March 29, 2024 09:44
TedThemistokleous pushed a commit to TedThemistokleous/onnxruntime that referenced this pull request May 7, 2024
### Fix transformer layer detection for recompute

Originally logic miss detecting the layer boudary node in Mistral model.
This PR simplifies the searching, by using more strong pattern's match,
to make sure it is flexible enough to cover different transformer
variants.

Also add a UT.

Add a warning when user enable layerwise recompute but no layer boudary
nodes are found.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants