-
Notifications
You must be signed in to change notification settings - Fork 635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] TemporalFusionTransformer
- allow mixed precision training
#1518
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## master #1518 +/- ##
=======================================
Coverage 90.19% 90.19%
=======================================
Files 30 30
Lines 4724 4724
=======================================
Hits 4261 4261
Misses 463 463
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
From the API perspective, we need to address the problem that this changes default behaviour, and thus may break or change user code downstream.
What we need to do is to expose the constant in masked_fill
as a parameter in ScaledDotProductAttention
, up until TemporalFusionTransformer
, that means also in InterpretableMultiHeadAttention
.
The default must be left at 1e9 for now (to avoid breaking/changing other peoples' code), but this will allow you to set float("inf")
by using the new parameter.
The docstrings of the modules and components should also get the description of the new parameter. Where classes have no docstring, it would be great if you could add one, but that is not blocking.
I would also appreciate feedback on which value is better as a default - even if we cannot change it right now, we could after a couple releases, giving the users a forewarning that the default will change. |
TemporalFusionTransformer
- allow mixed precision training
Description
This PR modifies the attention mask in the TFT model from 1e-9 to float("inf") to allow Pytorch mixed precision training.
Closes #1325, closes #285