Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Add missing ops (im2col) #1757

Merged
merged 6 commits into from
Aug 5, 2024
Merged

Conversation

shubhambhokare1
Copy link
Contributor

No description provided.

@shubhambhokare1 shubhambhokare1 self-assigned this Jul 25, 2024
Copy link

codecov bot commented Jul 25, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 75.08%. Comparing base (14f88d3) to head (601f5d3).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1757      +/-   ##
==========================================
+ Coverage   75.04%   75.08%   +0.03%     
==========================================
  Files         245      245              
  Lines       26472    26516      +44     
  Branches     4829     4834       +5     
==========================================
+ Hits        19867    19910      +43     
- Misses       5674     5675       +1     
  Partials      931      931              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link

github-actions bot commented Jul 26, 2024

Test Results

     24 files  ±     0       24 suites  ±0   3h 38m 1s ⏱️ + 23m 7s
 12 110 tests  -  2 180   10 531 ✅  - 1 743    1 549 💤  -    430   30 ❌  - 6 
485 464 runs  +25 773  100 250 ✅ +2 754  384 946 💤 +23 026  268 ❌  - 6 

For more details on these failures, see this check.

Results for commit 601f5d3. ± Comparison against base commit 14f88d3.

This pull request removes 2278 and adds 98 tests. Note that renamed tests count towards both.
onnxscript._internal.analysis_test.TestAssignedVarAnalysis ‑ test_basic_defs
onnxscript._internal.analysis_test.TestAssignedVarAnalysis ‑ test_doc_string
onnxscript._internal.analysis_test.TestAssignedVarAnalysis ‑ test_if_defs
onnxscript._internal.analysis_test.TestAssignedVarAnalysis ‑ test_if_loop_defs
onnxscript._internal.analysis_test.TestAssignedVarAnalysis ‑ test_loop_defs
onnxscript._internal.analysis_test.TestExposedUses ‑ test_basic
onnxscript._internal.analysis_test.TestExposedUses ‑ test_called_function
onnxscript._internal.analysis_test.TestExposedUses ‑ test_doc_string
onnxscript._internal.analysis_test.TestExposedUses ‑ test_for_loop
onnxscript._internal.analysis_test.TestExposedUses ‑ test_if
…
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_317_aten_im2col
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_318_aten_linear
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_319_aten_linear_bias
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_320_aten_max_pool1d
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_321_aten_max_pool1d_with_indices
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_322_aten_max_pool2d
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_323_aten_max_pool2d_with_indices
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_324_aten_max_pool3d
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_325_aten_max_pool3d_with_indices
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_326_aten_scaled_dot_product_attention
…
This pull request removes 475 skipped tests and adds 45 skipped tests. Note that renamed tests count towards both.
onnxscript.backend.onnx_export_test.TestOnnxBackEnd ‑ test_export2python_produces_correct_onnx_script_model_0005_test_adagrad
onnxscript.backend.onnx_export_test.TestOnnxBackEnd ‑ test_export2python_produces_correct_onnx_script_model_0006_test_adagrad_multiple
onnxscript.backend.onnx_export_test.TestOnnxBackEnd ‑ test_export2python_produces_correct_onnx_script_model_0007_test_adam
onnxscript.backend.onnx_export_test.TestOnnxBackEnd ‑ test_export2python_produces_correct_onnx_script_model_0008_test_adam_multiple
onnxscript.backend.onnx_export_test.TestOnnxBackEnd ‑ test_export2python_produces_correct_onnx_script_model_0011_test_add_uint8
onnxscript.backend.onnx_export_test.TestOnnxBackEnd ‑ test_export2python_produces_correct_onnx_script_model_0020_test_ai_onnx_ml_array_feature_extractor
onnxscript.backend.onnx_export_test.TestOnnxBackEnd ‑ test_export2python_produces_correct_onnx_script_model_0021_test_ai_onnx_ml_binarizer
onnxscript.backend.onnx_export_test.TestOnnxBackEnd ‑ test_export2python_produces_correct_onnx_script_model_0022_test_ai_onnx_ml_label_encoder_string_int
onnxscript.backend.onnx_export_test.TestOnnxBackEnd ‑ test_export2python_produces_correct_onnx_script_model_0023_test_ai_onnx_ml_label_encoder_string_int_no_default
onnxscript.backend.onnx_export_test.TestOnnxBackEnd ‑ test_export2python_produces_correct_onnx_script_model_0024_test_ai_onnx_ml_label_encoder_tensor_mapping
…
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_script_function_passes_checker_317_aten_im2col
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_script_function_passes_checker_320_aten_max_pool1d
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_script_function_passes_checker_321_aten_max_pool1d_with_indices
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_script_function_passes_checker_322_aten_max_pool2d
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_script_function_passes_checker_323_aten_max_pool2d_with_indices
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_script_function_passes_checker_324_aten_max_pool3d
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_script_function_passes_checker_325_aten_max_pool3d_with_indices
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_script_function_passes_checker_326_aten_scaled_dot_product_attention
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_script_function_passes_checker_327_aten__scaled_dot_product_flash_attention
tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_script_function_passes_checker_328_aten__scaled_dot_product_efficient_attention
…

♻️ This comment has been updated with latest results.

output = op.Gather(padded_input, blocks_row_indices, axis=2)
output = op.Gather(output, blocks_col_indices, axis=4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to use slice, which is faster?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use Slice, however the indices would need to be transformed to starts, ends format adding extra Reshape and Split nodes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Then this lgtm. Thanks for explaining!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the extra-operations can be done at export-time, is that correct? That is, they depend only on export-time values (torch parameters == onnx attributes), and not on run-time values. If so, there is no need to encode them using onnx ops, as it can be done in Python? In other words, using Slice should be doable in the trace-mode without any extra cost?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is worth thinking through.

  • Slice is an operation that has a very regular access pattern that is easier to optimize and parallelize. But Gather is very irregular and random, harder to optimize and parallelize.
  • The cost of operations on a large input tensor dominate overall cost, not cost of constant-time operations like Reshape.
  • If we want to extract a million elements, creating the indices of these million elements seems potentially expensive, when it can be described using a slice-pattern with a few elements.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if the entire model consists of a single op-function call? I wasn't necessarily looking for something visual. Just knowing impl1 takes X time and impl2 takes Y time would be fine. The starting point would be a test-case for an op like im2col, we run its onnxscript impl exported to ORT as a model.

Copy link
Collaborator

@justinchuby justinchuby Aug 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder how correlated a tiny bench is with the e2e performance? Hopefully closely?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. We will need to avoid overheads (like copying tensors, eg. due to conversion, etc.). And not count session creation (which should be easy). May be even warm up. Should be doable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @shubhambhokare1 : I see this has been merged. I am concerned that the strategy used here might not be good, for reasons discussed above. Any thoughts about that? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @gramalingam,

Agreed with the point about the case with a large number of elements, creating these indices and using gather might be inefficient. I think I must have missed this comment thread pre-merge. Slice might be a better option.
Will add a PR on top of this to remedy this, replacing the gathers ops with slice, I guess models using im2col should be unblocked for now.

In regards to the second point, might be a good idea to create a single-op based evaluator for kernel performance. Will experiment and add that as part of the new PR.

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: blocking on the pad change, otherwise exporter will fail

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocking for dynamic input handling

@shubhambhokare1 shubhambhokare1 merged commit 47ecc6c into main Aug 5, 2024
30 of 43 checks passed
@shubhambhokare1 shubhambhokare1 deleted the sbhokare/torchlib-op-3 branch August 5, 2024 20:14
@justinchuby justinchuby added the topic: torch_lib Related to the torch/aten function lib in development label Aug 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: torch_lib Related to the torch/aten function lib in development
Projects
Development

Successfully merging this pull request may close these issues.

3 participants