-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use tilized operators for Mistral AI speedup on inference #3812
Comments
The tensors in the mistral model are converted to the TILE layout to use the tilized operators and the commit is updated for the same. |
Further performance improvements for tiles... Our goal is to only have tensors in tile layout from input to output. Starting in
Put Then in
Update how RMSNorm works so that is just applies Next up is
then lets skip all padding and reshaping until rotary embedding. Then after |
The commit has been revised based on the provided suggestions. In In Corresponding PR: #4029 |
The commit for the mistral model is updated by optimizing the rotary_embed method.
In the optimized version of the mistral model, the conversion of freqs_cis (PyTorch complex tensor) to tt_lib complex tensor will happen
Note: #transformer_block = 32(Number of transformer_block) Corresponding PR: #4029 |
@vigneshkeerthivasanx you can also attempt to move this code to on device since we have tt_lib.tensor.arange,
this should make the tensors on device and hold them there saving some time for moving the data back/forth |
@muthutt , the commit is updated by incorporating the above-mentioned comments. After modifying the method Even though the variation in the pcc looks insignificant, this modification has greatly affected the result of the gs-demo.
Note: I was not able to use tt_lib's power directly, so I have used the following logic: I have used tt_lib.tensor.exp, tt_lib.tensor.mul ad tt_lib.tensor.log accordigly Corresponding PR: #4029 |
lets keep changes that preserve a sensible output
thanks for trying ou the proposal
…On Mon, Dec 11, 2023 at 8:00 AM Sudharsan ***@***.***> wrote:
@muthutt <https://github.com/muthutt> , the commit is updated by
incorporating the above-mentioned comments.
After modifying the method precompute_freqs_cis, there is a very slight
degradation in the pcc.
Previously the pcc was 0.9928005641025528
After modifying the method precompute_freqs_cis, the pcc is
0.9921967026489718.
Even though the variation in the pcc looks insignificant, this
modification has greatly affected the result of the gs-demo.
Input Prompt: ['A man is sitting on a roof ']
Output Prompt before modifying the `precompute_freqs_cis`: 'A man is sitting on a roof 100 meters above the ground.\n\nA man is sitting on'
Output Prompt after modification: 'A man is sitting on a roof 100 feet of a a.10.10.1'
Note: I was not able to use tt_lib's power directly, so I have used the
following logic:
pow(exponent, base) = exp(exponent * log(base))
I have used tt_lib.tensor.exp, tt_lib.tensor.mul ad tt_lib.tensor.log
accordigly
Corresponding PR: #4029
<#4029>
—
Reply to this email directly, view it on GitHub
<#3812 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/BAGOCNFILMAO5Z2LGEXU67TYI4UYXAVCNFSM6AAAAAA7NFMUVWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNJQGM3TEMBWG4>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Reference: https://github.com/tenstorrent-metal/tt-metal/tree/a3740def58c3b8672b7e3279261506ae70b97810/models/demos/resnet/tt
/metalResnetBlock50.py
The text was updated successfully, but these errors were encountered: