Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cpu: aarch64: injectors: Improve performance of tanh for block size 16 #2094

Merged

Conversation

vishwascm
Copy link
Contributor

Description

Performance Improvement: Eltwise Tanh performance improvement for block size 16

Major Code changes:
• Added a new function tanh_polynomial_approx_compute_vector_fwd(const TRegS &vmm_src) for
computing tanh.
• Added new tanh constants and polynomial constants table.

Checklist

General

All the tests are carried on A64FX machine which has block size 16:

  • [✓] Do all unit and benchdnn tests (make test and make test_benchdnn_*) pass locally for each commit?
  1. make test
95% tests passed, 11 tests failed out of 200
Total Test time (real) = 3790.05 sec
The following tests FAILED:
	 55 - test_convolution_backward_data_f32 (Subprocess aborted)
	123 - test_graph_c_api_compile_parametrized_usm_cpu (Failed)
	153 - test_graph_unit_dnnl_conv_usm_cpu (Failed)
	157 - test_graph_unit_dnnl_group_norm_usm_cpu (Failed)
	159 - test_graph_unit_dnnl_large_partition_usm_cpu (Failed)
	160 - test_graph_unit_dnnl_layer_norm_usm_cpu (Failed)
	161 - test_graph_unit_dnnl_matmul_usm_cpu (Failed)
	162 - test_graph_unit_dnnl_mqa_decomp_usm_cpu (Failed)
	163 - test_graph_unit_dnnl_pool_usm_cpu (Failed)
	168 - test_graph_unit_dnnl_sdp_decomp_usm_cpu (Failed)
	169 - test_graph_unit_dnnl_softmax_usm_cpu (Failed)
Errors while running CTest
Output from these tests are in: /home/vishwas/oneDNN/build/Testing/Temporary/LastTest.log
Use "--rerun-failed --output-on-failure" to re-run the failed cases verbosely.
make: *** [Makefile:71: test] Error 8
  1. ./benchdnn --eltwise --batch=inputs/eltwise/test_eltwise_all
tests:75713 passed:6771 skipped:68907 mistrusted:35 unimplemented:0 invalid_arguments:0 failed:0 listed:0
total: 125.07s; fill: 29.31s (23%); compute_ref: 15.50s (12%); compare: 27.88s (22%);
  1. make test_benchdnn_*
make: *** No rule to make target 'test_benchdnn_*'.  Stop.

gtests

  1. ./test_eltwise
----------] Global test environment tear-down
[==========] 44 tests from 5 test suites ran. (2752 ms total)
[  PASSED  ] 24 tests.
[  SKIPPED ] 20 tests, listed below:
[  SKIPPED ] Test_Eltwise_EF/eltwise_test_t.TestsEltwise/4
[  SKIPPED ] Test_Eltwise_EF/eltwise_test_t.TestsEltwise/5
[  SKIPPED ] EltwiseSimpleBF16/eltwise_test_t.TestsEltwise/0
[  SKIPPED ] EltwiseSimpleBF16/eltwise_test_t.TestsEltwise/1
[  SKIPPED ] EltwiseSimpleBF16/eltwise_test_t.TestsEltwise/2
[  SKIPPED ] EltwiseSimpleBF16/eltwise_test_t.TestsEltwise/3
[  SKIPPED ] EltwiseSimpleBF16/eltwise_test_t.TestsEltwise/4
[  SKIPPED ] EltwiseSimpleBF16/eltwise_test_t.TestsEltwise/5
[  SKIPPED ] EltwiseSimpleBF16/eltwise_test_t.TestsEltwise/6
[  SKIPPED ] EltwiseSimpleBF16/eltwise_test_t.TestsEltwise/7
[  SKIPPED ] EltwiseSimpleBF16/eltwise_test_t.TestsEltwise/8
[  SKIPPED ] EltwiseSimpleF16/eltwise_test_t.TestsEltwise/0
[  SKIPPED ] EltwiseSimpleF16/eltwise_test_t.TestsEltwise/1
[  SKIPPED ] EltwiseSimpleF16/eltwise_test_t.TestsEltwise/2
[  SKIPPED ] EltwiseSimpleF16/eltwise_test_t.TestsEltwise/3
[  SKIPPED ] EltwiseSimpleF16/eltwise_test_t.TestsEltwise/4
[  SKIPPED ] EltwiseSimpleF16/eltwise_test_t.TestsEltwise/5
[  SKIPPED ] EltwiseSimpleF16/eltwise_test_t.TestsEltwise/6
[  SKIPPED ] EltwiseSimpleF16/eltwise_test_t.TestsEltwise/7
[  SKIPPED ] EltwiseSimpleF16/eltwise_test_t.TestsEltwise/8

Note: All above results are same with and without the code changes.

  • [✓] Have you formatted the code using clang-format? - Yes

Performance improvements

  • [✓ ] Have you submitted performance data that demonstrates performance improvements?
    image
    image

@vishwascm vishwascm requested a review from a team as a code owner September 12, 2024 08:36
@github-actions github-actions bot added the platform:cpu-aarch64 Codeowner: @oneapi-src/onednn-cpu-aarch64 label Sep 12, 2024
@vpirogov vpirogov added this to the v3.7 milestone Sep 12, 2024
@theComputeKid
Copy link
Contributor

Thanks for this. Can you let me know what cmake command line options you use when building to get these perf results?

@vishwascm
Copy link
Contributor Author

vishwascm commented Sep 16, 2024

Thanks for this. Can you let me know what cmake command line options you use when building to get these perf results?

@theComputeKid following steps were used to get perf results (benchdnn was used):

cd build
cmake ..
make -j
cd tests/benchdnn
taskset -c 1,2,3,4 ./benchdnn --eltwise --mode=p  --alg=gelu_tanh 3x17x2x5x3
taskset -c 1,2,3,4 ./benchdnn --conv --mode=p --attr-post-ops=eltwise_tanh ic3oc64_ih224oh112kh7sh2dh0ph3_iw224ow112kw7sw2dw0pw3

Above test is for 4 Core.

@theComputeKid
Copy link
Contributor

Sorry, I should have clarified, I was interested in the CMake options during the configure phase. Particularly, I wanted to know whether you compile with -DONEDNN_WERROR=ON, as I have found that the JIT codebase for aarch64 produces a lot of warnings that prevent us from turning on the flag. Could you just confirm that no new warnings are added by your changes? Thanks.

@vishwascm
Copy link
Contributor Author

Sorry, I should have clarified, I was interested in the CMake options during the configure phase. Particularly, I wanted to know whether you compile with -DONEDNN_WERROR=ON, as I have found that the JIT codebase for aarch64 produces a lot of warnings that prevent us from turning on the flag. Could you just confirm that no new warnings are added by your changes? Thanks.

@theComputeKid I did not compile with -DONEDNN_WERROR=ON.

@theComputeKid
Copy link
Contributor

Can you still please confirm that no new warnings are added due to your changes? I.e. the number of warnings emitted before and after your changes (if any) are the same or less.

@jondea
Copy link
Contributor

jondea commented Sep 16, 2024

When you say block size 16, do you mean this should only affect machines with an SVE vector length of 512? Block size is a somewhat overloaded term, especially with different data types.

@vishwascm
Copy link
Contributor Author

Can you still please confirm that no new warnings are added due to your changes? I.e. the number of warnings emitted before and after your changes (if any) are the same or less.

@theComputeKid Yeah, there are no new warnings due to the changes.

@vishwascm
Copy link
Contributor Author

vishwascm commented Sep 17, 2024

When you say block size 16, do you mean this should only affect machines with an SVE vector length of 512? Block size is a somewhat overloaded term, especially with different data types.

@jondea Yes, it is for SVE_512, for fp32 datatype.

@abhijain1204fujitsu
Copy link

@jondea Kindly support to check the changes and let us know if some information or changes are required.

@theComputeKid
Copy link
Contributor

Can you please rebase and push to force pipelines to run again. Thanks.

@vishwascm vishwascm force-pushed the aarch64-sve-jit-eltwise_injector-tanh branch from e1e11fa to 9e31edf Compare October 4, 2024 06:07
@vishwascm
Copy link
Contributor Author

Can you please rebase and push to force pipelines to run again. Thanks.

@theComputeKid I have rebased and force pushed the code again. Please check.

Copy link
Contributor

@theComputeKid theComputeKid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AArch64 approved. But might need to check with repo admins about the other CI failures.

@mgouicem
Copy link
Contributor

mgouicem commented Oct 4, 2024

THe only remaining failures I see are:

  • commit message too long
  • clang-format check of src/cpu/aarch64/injectors/jit_uni_eltwise_injector.cpp fails

@vishwascm vishwascm force-pushed the aarch64-sve-jit-eltwise_injector-tanh branch from 9e31edf to fda21a0 Compare October 4, 2024 09:27
@theComputeKid
Copy link
Contributor

@vishwascm Might also need to shorten the PR title:

Run python3 ./.github/automation/pr-title-check.py "src: cpu: aarch64: injectors: eltwise: Improve performance of tanh for block size 16"
msg: src: cpu: aarch64: injectors: eltwise: Improve performance of tanh for block size 16
Traceback (most recent call last):
  File "/home/runner/work/oneDNN/oneDNN/./.github/automation/pr-title-check.py", line 73, in <module>
    main()
  File "/home/runner/work/oneDNN/oneDNN/./.github/automation/pr-title-check.py", line 68, in main
    __numCharacterCheck(msg)
  File "/home/runner/work/oneDNN/oneDNN/./.github/automation/pr-title-check.py", line 57, in __numCharacterCheck
    raise ValueError(
ValueError: Please see contribution guidelines. Message summary must be less than 72. Got: 84

https://github.com/oneapi-src/oneDNN/actions/runs/11177487422/job/31073064297?pr=2094

@vishwascm vishwascm changed the title src: cpu: aarch64: injectors: eltwise: Improve performance of tanh for block size 16 cpu: aarch64: injectors: Improve performance of tanh for block size 16 Oct 4, 2024
@vishwascm
Copy link
Contributor Author

@vishwascm Might also need to shorten the PR title:

Run python3 ./.github/automation/pr-title-check.py "src: cpu: aarch64: injectors: eltwise: Improve performance of tanh for block size 16"
msg: src: cpu: aarch64: injectors: eltwise: Improve performance of tanh for block size 16
Traceback (most recent call last):
  File "/home/runner/work/oneDNN/oneDNN/./.github/automation/pr-title-check.py", line 73, in <module>
    main()
  File "/home/runner/work/oneDNN/oneDNN/./.github/automation/pr-title-check.py", line 68, in main
    __numCharacterCheck(msg)
  File "/home/runner/work/oneDNN/oneDNN/./.github/automation/pr-title-check.py", line 57, in __numCharacterCheck
    raise ValueError(
ValueError: Please see contribution guidelines. Message summary must be less than 72. Got: 84

https://github.com/oneapi-src/oneDNN/actions/runs/11177487422/job/31073064297?pr=2094

Done

@spalicki spalicki merged commit 3d1e89a into oneapi-src:main Oct 4, 2024
23 of 25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
platform:cpu-aarch64 Codeowner: @oneapi-src/onednn-cpu-aarch64
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants