Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upcast Celu for it to support more dtypes | feat(torchlib) #1158

Merged
merged 6 commits into from
Nov 21, 2023

Conversation

justinchuby
Copy link
Collaborator

Celu only supports float in ONNX. This change adds the variant to support more floating types by manually adding casts in the function.

Copy link

codecov bot commented Nov 16, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (eb73ec2) 78.52% compared to head (39e7ef1) 78.23%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1158      +/-   ##
==========================================
- Coverage   78.52%   78.23%   -0.29%     
==========================================
  Files         118      118              
  Lines       15231    15237       +6     
  Branches     1635     1635              
==========================================
- Hits        11960    11921      -39     
- Misses       2891     2934      +43     
- Partials      380      382       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@justinchuby justinchuby added topic: torch_lib Related to the torch/aten function lib in development merge at lgtm Reviewers can merge when they approve labels Nov 16, 2023
Copy link

github-actions bot commented Nov 16, 2023

Test Results

         18 files  ±  0         18 suites  ±0   1h 5m 14s ⏱️ - 4m 7s
  11 194 tests +  7    8 343 ✔️ +  5      2 797 💤 +  2       54 ±0 
160 176 runs  +63  36 743 ✔️ +45  121 384 💤 +18  2 049 ±0 

For more details on these failures, see this check.

Results for commit 41f4d3e. ± Comparison against base commit e75da82.

This pull request removes 263 and adds 270 tests. Note that renamed tests count towards both.
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_163_aten_cross_entropy_loss
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_164_aten_dropout
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_165_aten_elu
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_166_aten_embedding_bag
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_167_aten_embedding_bag_padding_idx
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_168_aten_embedding_renorm
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_169_aten_embedding
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_170_aten_hardtanh
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_171_aten_leaky_relu
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_172_aten_log_sigmoid
…
onnxscript.function_libs.tools.torch_lib.deduce_type_constraints_test.TestDeduceTypeConstraints ‑ test_deduce_type_constraints_does_not_crash_for_onnx_function_aten_celu_type_promoted
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_163_aten_celu_type_promoted
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_164_aten_cross_entropy_loss
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_165_aten_dropout
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_166_aten_elu
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_167_aten_embedding_bag
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_168_aten_embedding_bag_padding_idx
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_169_aten_embedding_renorm
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_170_aten_embedding
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_171_aten_hardtanh
…

♻️ This comment has been updated with latest results.

@@ -347,6 +350,14 @@ def aten_celu(self: FLOAT, alpha: float = 1.0) -> FLOAT:
return op.Celu(self, alpha=alpha) # op.Celu only support float32


@torch_op("aten::celu")
def aten_celu_upcasted(self: TFloatUnlessFloat32, alpha: float = 1.0) -> TFloatUnlessFloat32:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: hmm we are also downcasting from double. I know I started with the term "upcast" in my original issue, but maybe we should name it something else to callout the onnx dtype limitation? aten_celu_dtype_hack? Just an idea and we might need to come up with something more formal if there are more cases.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating to type_promoted. wdyt?

@justinchuby justinchuby added hold on merging Don't merge yet and removed merge at lgtm Reviewers can merge when they approve labels Nov 19, 2023
@justinchuby justinchuby removed the hold on merging Don't merge yet label Nov 20, 2023
@justinchuby justinchuby added the merge at lgtm Reviewers can merge when they approve label Nov 21, 2023
@justinchuby justinchuby merged commit 660b9f4 into main Nov 21, 2023
26 of 29 checks passed
@justinchuby justinchuby deleted the justinchu/upcast-celu branch November 21, 2023 00:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
merge at lgtm Reviewers can merge when they approve topic: torch_lib Related to the torch/aten function lib in development
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants