Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] TemporalFusionTransformer - allow mixed precision training #1518

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Marcrb2
Copy link

@Marcrb2 Marcrb2 commented Feb 19, 2024

Description

This PR modifies the attention mask in the TFT model from 1e-9 to float("inf") to allow Pytorch mixed precision training.

Closes #1325, closes #285

@codecov-commenter
Copy link

codecov-commenter commented Feb 19, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (b3fcf86) 90.19% compared to head (d93c1e0) 90.19%.
Report is 8 commits behind head on master.

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #1518   +/-   ##
=======================================
  Coverage   90.19%   90.19%           
=======================================
  Files          30       30           
  Lines        4724     4724           
=======================================
  Hits         4261     4261           
  Misses        463      463           
Flag Coverage Δ
cpu 90.19% <100.00%> (ø)
pytest 90.19% <100.00%> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@fkiraly fkiraly added the enhancement New feature or request label Sep 16, 2024
Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

From the API perspective, we need to address the problem that this changes default behaviour, and thus may break or change user code downstream.

What we need to do is to expose the constant in masked_fill as a parameter in ScaledDotProductAttention, up until TemporalFusionTransformer, that means also in InterpretableMultiHeadAttention.

The default must be left at 1e9 for now (to avoid breaking/changing other peoples' code), but this will allow you to set float("inf") by using the new parameter.

The docstrings of the modules and components should also get the description of the new parameter. Where classes have no docstring, it would be great if you could add one, but that is not blocking.

@fkiraly
Copy link
Collaborator

fkiraly commented Sep 16, 2024

I would also appreciate feedback on which value is better as a default - even if we cannot change it right now, we could after a couple releases, giving the users a forewarning that the default will change.

@fkiraly fkiraly changed the title Allow mixed precision TFT training [ENH] TemporalFusionTransformer - allow mixed precision training Sep 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Training TFT with mixed precision Support for 16bit precision training
4 participants