Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement PReLU backward #3152

Merged
merged 79 commits into from
Aug 21, 2024
Merged

Implement PReLU backward #3152

merged 79 commits into from
Aug 21, 2024

Conversation

long10024070
Copy link
Collaborator

  • Added PReLU backward operation and kernels.
  • Added driver test and gtest for PReLU backward operation.
  • New API is guarded by MIOPEN_BETA_API macro.
  • Compared to ROCm pytorch:
float16
op_name dtype size num_param model direction ROCm pytorch MIOpen HIP Improvement
PReLU float16 [512 64 112 112] 1 arcface bwd 99312671 12277100 8.11
PReLU float16 [512 64 56 56] 1 arcface bwd 26715752 3073780 8.79
PReLU float16 [512 128 56 56] 1 arcface bwd 52103300 6140360 8.53
PReLU float16 [512 128 28 28] 1 arcface bwd 13279513 1545730 8.73
PReLU float16 [512 256 28 28] 1 arcface bwd 26771242 3072230 8.79
PReLU float16 [512 256 14 14] 1 arcface bwd 6605507 786651 8.68
PReLU float16 [512 512 14 14] 1 arcface bwd 13283229 1550010 8.71
PReLU float16 [512 512 7 7] 1 arcface bwd 3289272 403211 8.72
PReLU float16 [512 64 112 112] 64 arcface bwd 107275183 46587900 2.31
PReLU float16 [512 64 56 56] 64 arcface bwd 30255041 15273200 2.00
PReLU float16 [512 128 56 56] 128 arcface bwd 56547437 13889100 4.09
PReLU float16 [512 128 28 28] 128 arcface bwd 15626705 3868870 4.10
PReLU float16 [512 256 28 28] 256 arcface bwd 30232264 5028100 6.07
PReLU float16 [512 256 14 14] 256 arcface bwd 7826620 1531590 5.26
PReLU float16 [512 512 14 14] 512 arcface bwd 15628058 2114630 7.50
PReLU float16 [512 512 7 7] 512 arcface bwd 3877272 682737 6.01
float32
op_name dtype size num_param model direction ROCm pytorch MIOpen HIP Improvement
PReLU float32 [512 64 112 112] 1 arcface bwd 103155863 12389000 8.35
PReLU float32 [512 64 56 56] 1 arcface bwd 27476400 3102070 8.95
PReLU float32 [512 128 56 56] 1 arcface bwd 52891035 6194870 8.58
PReLU float32 [512 128 28 28] 1 arcface bwd 14441185 1560250 9.43
PReLU float32 [512 256 28 28] 1 arcface bwd 27334300 3102970 8.90
PReLU float32 [512 256 14 14] 1 arcface bwd 7261653 791718 9.51
PReLU float32 [512 512 14 14] 1 arcface bwd 14419330 1559770 9.42
PReLU float32 [512 512 7 7] 1 arcface bwd 3608816 406560 9.54
PReLU float32 [512 64 112 112] 64 arcface bwd 109970712 46718600 2.36
PReLU float32 [512 64 56 56] 64 arcface bwd 30409798 13333300 2.30
PReLU float32 [512 128 56 56] 128 arcface bwd 57683256 13940500 4.16
PReLU float32 [512 128 28 28] 128 arcface bwd 16256649 3878210 4.27
PReLU float32 [512 256 28 28] 256 arcface bwd 30886236 5124050 6.09
PReLU float32 [512 256 14 14] 256 arcface bwd 8320915 1522140 5.66
PReLU float32 [512 512 14 14] 512 arcface bwd 16164181 2137560 7.70
PReLU float32 [512 512 7 7] 512 arcface bwd 4095063 685281 6.41
bfloat16
op_name dtype size num_param model direction ROCm pytorch MIOpen HIP Improvement
PReLU bfloat16 [512 64 112 112] 1 arcface bwd 99413514 12540100 7.95
PReLU bfloat16 [512 64 56 56] 1 arcface bwd 26839289 3139650 8.63
PReLU bfloat16 [512 128 56 56] 1 arcface bwd 52355515 6275980 8.38
PReLU bfloat16 [512 128 28 28] 1 arcface bwd 13353346 1579090 8.60
PReLU bfloat16 [512 256 28 28] 1 arcface bwd 26856779 3140230 8.64
PReLU bfloat16 [512 256 14 14] 1 arcface bwd 6645672 801566 8.56
PReLU bfloat16 [512 512 14 14] 1 arcface bwd 13434983 1577450 8.66
PReLU bfloat16 [512 512 7 7] 1 arcface bwd 3310021 410969 8.57
PReLU bfloat16 [512 64 112 112] 64 arcface bwd 106599289 46867700 2.28
PReLU bfloat16 [512 64 56 56] 64 arcface bwd 30100583 11790800 2.57
PReLU bfloat16 [512 128 56 56] 128 arcface bwd 55829383 14050900 3.99
PReLU bfloat16 [512 128 28 28] 128 arcface bwd 15548941 3913070 4.03
PReLU bfloat16 [512 256 28 28] 256 arcface bwd 29955610 5126710 5.89
PReLU bfloat16 [512 256 14 14] 256 arcface bwd 7793750 1546550 5.17
PReLU bfloat16 [512 512 14 14] 512 arcface bwd 15556630 2171380 7.27
PReLU bfloat16 [512 512 7 7] 512 arcface bwd 3862986 692056 5.91
  • Average over all cases:
type average
float16 7.13
float32 7.83
bfloat16 7.32

@long10024070
Copy link
Collaborator Author

@junliume Github action has passed. @CAHEK7 Would you review this PR?

CAHEK7
CAHEK7 previously requested changes Aug 12, 2024
driver/prelu_driver.hpp Outdated Show resolved Hide resolved
driver/prelu_driver.hpp Outdated Show resolved Hide resolved
driver/prelu_driver.hpp Outdated Show resolved Hide resolved
driver/prelu_driver.hpp Outdated Show resolved Hide resolved
src/kernels/MIOpenPReLU.cpp Outdated Show resolved Hide resolved
src/include/miopen/prelu/problem_description.hpp Outdated Show resolved Hide resolved
src/solver/prelu/backward_prelu_single_weight.cpp Outdated Show resolved Hide resolved
src/solver/prelu/backward_prelu_multi_weights.cpp Outdated Show resolved Hide resolved
test/gtest/prelu.cpp Show resolved Hide resolved
src/kernels/MIOpenReduceSum.cpp Outdated Show resolved Hide resolved
src/kernels/MIOpenReduceSum.cpp Outdated Show resolved Hide resolved
@junliume junliume dismissed CAHEK7’s stale review August 16, 2024 17:31

Re-request review

@junliume junliume merged commit 6feaec9 into develop Aug 21, 2024
140 of 143 checks passed
@junliume junliume deleted the impl_PReLU branch August 21, 2024 06:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants