-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement _log_softmax | feat(torchlib) #1079
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add test coverage?
On it |
Codecov Report
@@ Coverage Diff @@
## main #1079 +/- ##
==========================================
- Coverage 77.92% 77.91% -0.02%
==========================================
Files 115 115
Lines 14684 14706 +22
Branches 1558 1562 +4
==========================================
+ Hits 11443 11458 +15
- Misses 2872 2877 +5
- Partials 369 371 +2
|
Test Results 18 files ± 0 18 suites ±0 1h 13m 42s ⏱️ - 9m 42s For more details on these failures, see this check. Results for commit c33a9c0. ± Comparison against base commit b958002. This pull request removes 519 and adds 531 tests. Note that renamed tests count towards both.
♻️ This comment has been updated with latest results. |
Added tests |
|
||
@torch_op("aten::_log_softmax") | ||
def aten__log_softmax( | ||
self: TFloatHighPrecision, dim: int, half_to_float: bool # pylint: disable=unused-argument |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: very minor, should we do below to be consistent (saw in aten__softmax_half
below) ?
del half_to_float # Unused
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is not trace only so we couldn't use del
(not supported by onnxscript yet)
] | ||
|
||
for (shape, dim), half_to_float in itertools.product(cases, (False,)): | ||
# NOTE: softmax with half to float conversion is not supported on CPU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we test bfloat16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ORT doesn't run bfloat16 on cpu so
Fixes #1077