-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upcast Celu for it to support more dtypes | feat(torchlib) #1158
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1158 +/- ##
==========================================
- Coverage 78.52% 78.23% -0.29%
==========================================
Files 118 118
Lines 15231 15237 +6
Branches 1635 1635
==========================================
- Hits 11960 11921 -39
- Misses 2891 2934 +43
- Partials 380 382 +2 ☔ View full report in Codecov by Sentry. |
Test Results 18 files ± 0 18 suites ±0 1h 5m 14s ⏱️ - 4m 7s For more details on these failures, see this check. Results for commit 41f4d3e. ± Comparison against base commit e75da82. This pull request removes 263 and adds 270 tests. Note that renamed tests count towards both.
♻️ This comment has been updated with latest results. |
@@ -347,6 +350,14 @@ def aten_celu(self: FLOAT, alpha: float = 1.0) -> FLOAT: | |||
return op.Celu(self, alpha=alpha) # op.Celu only support float32 | |||
|
|||
|
|||
@torch_op("aten::celu") | |||
def aten_celu_upcasted(self: TFloatUnlessFloat32, alpha: float = 1.0) -> TFloatUnlessFloat32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: hmm we are also downcasting from double. I know I started with the term "upcast" in my original issue, but maybe we should name it something else to callout the onnx dtype limitation? aten_celu_dtype_hack
? Just an idea and we might need to come up with something more formal if there are more cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updating to type_promoted. wdyt?
Celu only supports float in ONNX. This change adds the variant to support more floating types by manually adding casts in the function.