Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement PReLU backward #3152

Merged
merged 79 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
f77ec05
added PReLU api skeleton code
long10024070 Jun 4, 2024
101c4ad
impl PReLU with GTest
long10024070 Jun 6, 2024
120bacc
added Driver for PReLU
long10024070 Jun 7, 2024
b3d7d2d
make PReLU benchmark more stable
long10024070 Jun 7, 2024
c37b08a
relocated prelu.hpp
long10024070 Jun 12, 2024
4d4c8f6
removed uneccesary #include
long10024070 Jun 12, 2024
c4137da
added more requirement to IsApplicable and fixed miopen doc comment
long10024070 Jun 12, 2024
5a10bdc
improved perf by balancing computation to threads
long10024070 Jun 12, 2024
86c198b
used double instead of FLOAT_ACCUM for better precision
long10024070 Jun 12, 2024
21148b8
potential bug fixed: using float instead of T datatype for workspace …
long10024070 Jun 12, 2024
19e13f8
changed ulong to uint64_t
long10024070 Jun 13, 2024
f466fdb
changed \n to std::endl
long10024070 Jun 13, 2024
a699069
used static_cast instead of Type casting
long10024070 Jun 13, 2024
70716d2
Seperated backward_prelu for each solvers
long10024070 Jun 13, 2024
ff851a7
added contiguous check for weight and dweight tensor
long10024070 Jun 13, 2024
582c1a7
added static_cast<>
long10024070 Jun 13, 2024
5a5bead
merge brach into impl_PReLU
long10024070 Jul 23, 2024
4baade9
up-to-date with branch
long10024070 Jul 23, 2024
7a96642
update tensor_view
long10024070 Jul 23, 2024
fb7cbea
added more data to NetworkConfig str
long10024070 Jul 24, 2024
317e9e8
Merge remote-tracking branch 'rocm/develop' into impl_PReLU
long10024070 Jul 24, 2024
eef821a
add MIOPEN_INTERNALS_EXPORT
long10024070 Jul 24, 2024
a5860f5
using int32_t and int64_t in kernels
long10024070 Jul 24, 2024
7da8549
use half_float::half instead of half in GTtest
long10024070 Jul 24, 2024
53cca97
Merge remote-tracking branch 'rocm/develop' into impl_PReLU
long10024070 Jul 25, 2024
f4c783a
update GTest runner
long10024070 Jul 25, 2024
67cee9e
try remove unused code in Gtest PReLU
long10024070 Jul 25, 2024
61cdc9a
rollback state
long10024070 Jul 25, 2024
639af57
try remove unused code in Gtest PReLU
long10024070 Jul 25, 2024
32b5ce8
try remove unused code in Gtest PReLU
long10024070 Jul 25, 2024
01b4a61
rollback state
long10024070 Jul 25, 2024
0e0fab0
shorten check float argument
long10024070 Jul 25, 2024
5a58dcb
try remove unused code in Gtest PReLU
long10024070 Jul 25, 2024
7343d1f
try remove unused code in Gtest PReLU
long10024070 Jul 25, 2024
e6cc24e
update prelu_driver
long10024070 Jul 26, 2024
8f60362
update prelu test check, using ASSERT_EQ and EXPECT_LT instead of EXP…
long10024070 Jul 26, 2024
0b50c3c
update prelu test to new naming scheme
long10024070 Jul 26, 2024
e0a685b
use MultiBufferWorkspaceTraits to handle workspace memory
long10024070 Jul 30, 2024
07de0bb
use miopenFloat instead of float
long10024070 Jul 30, 2024
427125e
Merge remote-tracking branch 'rocm/develop' into impl_PReLU
long10024070 Jul 31, 2024
59f19b7
fix error
long10024070 Jul 31, 2024
a8748db
remove MakeForwardNetworkConfig
long10024070 Jul 31, 2024
d5e2332
remove unused code
long10024070 Jul 31, 2024
24748ec
remove unused code
long10024070 Aug 1, 2024
c98530b
remove #define FLOAT_ACCUM
long10024070 Aug 1, 2024
4972e0d
reduce number and size of testcases
long10024070 Aug 1, 2024
6f6c4fd
add full test to FullTestConfig
long10024070 Aug 1, 2024
8595889
Merge remote-tracking branch 'rocm/develop' into impl_PReLU
long10024070 Aug 1, 2024
bf7332c
fix error
long10024070 Aug 2, 2024
cf2007b
fix 'Window Build' error, previous commit is 'fix Jenkins - Fp32 Hip …
long10024070 Aug 2, 2024
e9e60e4
update get warpSize automatically from context
long10024070 Aug 5, 2024
3778448
Merge branch 'develop' into impl_PReLU
junliume Aug 5, 2024
3f06347
update doxygen
long10024070 Aug 5, 2024
417cdf3
Merge remote-tracking branch 'rocm/develop' into impl_PReLU
long10024070 Aug 6, 2024
e588d2d
update doxygen
long10024070 Aug 6, 2024
1f2f385
remove unused code
long10024070 Aug 7, 2024
b7f4ed6
update doxygen
long10024070 Aug 7, 2024
d848252
change prelu/solvers header included location
long10024070 Aug 7, 2024
41469fe
Merge remote-tracking branch 'rocm/develop' into impl_PReLU
long10024070 Aug 7, 2024
05791df
Merge remote-tracking branch 'rocm/develop' into impl_PReLU
long10024070 Aug 9, 2024
86a1c94
Merge branch 'develop' into impl_PReLU
junliume Aug 9, 2024
f5f7cfd
fix clang format issue
junliume Aug 9, 2024
2dab3e6
Merge branch 'develop' into impl_PReLU
long10024070 Aug 11, 2024
d493d1e
Merge branch 'develop' into impl_PReLU
long10024070 Aug 12, 2024
6aa0432
return miopenStatusNotImplemented instead of miopenStatusSuccess in u…
long10024070 Aug 13, 2024
f24a302
correct Forward and Fwd to Backward and Bwd
long10024070 Aug 13, 2024
26b5d8f
fuse PReLUBackward kernels (Single weight and Multiple weight)
long10024070 Aug 13, 2024
fad7f3f
simplify kernels profiling
long10024070 Aug 13, 2024
6c8fbf7
remove excessive input tensor ndim check
long10024070 Aug 13, 2024
7789811
solve problem with non-packed weight and non-packed dweight
long10024070 Aug 13, 2024
8e611f5
Merge branch 'develop' into impl_PReLU
long10024070 Aug 13, 2024
1dfadfa
Merge branch 'develop' into impl_PReLU
long10024070 Aug 14, 2024
69d43db
diversify block_reduce and move block/warp reduce to seperated functi…
long10024070 Aug 14, 2024
fdd2933
make block/warp reduce can be used for more pupose, not just sum redu…
long10024070 Aug 15, 2024
ab4086b
Merge branch 'impl_PReLU' of https://github.com/ROCm/MIOpen into impl…
long10024070 Aug 15, 2024
dfe865b
bug fixxed: non-contiguous weight gradient tensor
long10024070 Aug 15, 2024
65e531b
add more smoke testcases
long10024070 Aug 16, 2024
d95d58a
Merge remote-tracking branch 'rocm/develop' into impl_PReLU
long10024070 Aug 16, 2024
5516f45
Merge branch 'develop' into impl_PReLU
long10024070 Aug 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ The MIOpen API library is structured as follows:
* :doc:`Getitem <../doxygen/html/group__getitem>` (experimental)
* :doc:`ReduceCalculation <../doxygen/html/group__ReduceCalculation>` (experimental)
* :doc:`RotaryPositionalEmbeddings <../doxygen/html/group__RotaryPositionalEmbeddings>` (experimental)
* :doc:`ReLU <../doxygen/html/group___re_l_u>` (experimental)
1 change: 1 addition & 0 deletions driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ add_executable(MIOpenDriver
dm_layernorm.cpp
dm_lrn.cpp
dm_pool.cpp
dm_prelu.cpp
dm_reduce.cpp
dm_reduceextreme.cpp
dm_reducecalculation.cpp
Expand Down
40 changes: 40 additions & 0 deletions driver/dm_prelu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2024 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "registry_driver_maker.hpp"
#include "prelu_driver.hpp"

static Driver* makeDriver(const std::string& base_arg)
{
if(base_arg == "prelu")
return new PReLUDriver<float, float>();
if(base_arg == "prelufp16")
return new PReLUDriver<float16, float>();
if(base_arg == "prelubfp16")
return new PReLUDriver<bfloat16, float>();
return nullptr;
}

REGISTER_DRIVER_MAKER(makeDriver);
6 changes: 4 additions & 2 deletions driver/driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ inline void PadBufferSize(size_t& sz, int datatype_sz)
"groupnorm[bfp16|fp16], cat[bfp16|fp16], addlayernorm[bfp16|fp16], "
"t5layernorm[bfp16|fp16], adam[fp16], ampadam, reduceextreme[bfp16|fp16], "
"adamw[fp16], ampadamw, transformersadamw[fp16], transformersampadamw, "
"getitem[bfp16|fp16], reducecalculation[bfp16|fp16], rope[bfp16|fp16]\n");
"getitem[bfp16|fp16], reducecalculation[bfp16|fp16], rope[bfp16|fp16], "
"prelu[bfp16|fp16]\n");
exit(0); // NOLINT (concurrency-mt-unsafe)
}

Expand Down Expand Up @@ -207,7 +208,8 @@ inline std::string ParseBaseArg(int argc, char* argv[])
arg != "transformersadamwfp16" && arg != "transformersampadamw" && arg != "getitem" &&
arg != "getitemfp16" && arg != "getitembfp16" && arg != "reducecalculation" &&
arg != "reducecalculationfp16" && arg != "reducecalculationbfp16" && arg != "rope" &&
arg != "ropefp16" && arg != "ropebfp16" && arg != "--version")
arg != "ropefp16" && arg != "ropebfp16" && arg != "prelu" && arg != "prelufp16" &&
arg != "prelubfp16" && arg != "--version")
{
printf("FAILED: Invalid Base Input Argument\n");
Usage();
Expand Down
104 changes: 104 additions & 0 deletions driver/mloPReLUHost.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2024 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/

#pragma once

#include <../test/ford.hpp>

#include <miopen/tensor.hpp>
#include <miopen/tensor_view_utils.hpp>
#include <miopen/prelu/utils.hpp>

template <typename Tgpu, typename Tcheck>
int32_t mloPReLUBackwardRunHost(const miopenTensorDescriptor_t inputDesc,
const miopenTensorDescriptor_t weightDesc,
const miopenTensorDescriptor_t doutputDesc,
const miopenTensorDescriptor_t dinputDesc,
const Tgpu* input,
const Tgpu* weight,
const Tgpu* doutput,
Tcheck* dinput_host,
Tcheck* dweight_host)
{
auto input_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(inputDesc));
auto doutput_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(doutputDesc));
auto dinput_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(dinputDesc));

auto input_sz = miopen::deref(inputDesc).GetElementSize();
auto weight_sz = miopen::deref(weightDesc).GetElementSize();
auto weight_grad_collector = std::vector<float>(input_sz);

par_ford(input_sz)([&](int gid) {
auto tensor_layout = tensor_layout_t<5>(input_tv, gid);
float input_v = static_cast<float>(input[input_tv.get_tensor_view_idx(tensor_layout)]);
float grad_v = static_cast<float>(doutput[doutput_tv.get_tensor_view_idx(tensor_layout)]);

if(dinput_host)
{
float weight_v;
if(weight_sz == 1)
weight_v = static_cast<float>(weight[0]);
else
weight_v = static_cast<float>(weight[tensor_layout.layout[1]]);
float input_grad_v = input_v > 0 ? grad_v : weight_v * grad_v;
dinput_host[dinput_tv.get_tensor_view_idx(tensor_layout)] =
static_cast<Tcheck>(input_grad_v);
}
if(dweight_host)
{
weight_grad_collector[gid] = input_v > 0 ? 0 : input_v * grad_v;
}
});

if(dweight_host)
{
if(weight_sz == 1)
{
double sum = 0;
for(int i = 0; i < input_sz; ++i)
sum += static_cast<double>(weight_grad_collector[i]);
dweight_host[0] = static_cast<Tcheck>(sum);
}
else
{
size_t inner_size = std::accumulate(
&input_tv.size[2], &input_tv.size[4], 1ul, std::multiplies<size_t>());
size_t outer_size = inner_size * input_tv.size[1];
par_ford(weight_sz)([&](int i) {
double sum = 0;
ford(input_tv.size[0])([&](int j) {
ford(inner_size)([&](int k) {
sum += static_cast<double>(
weight_grad_collector[j * outer_size + i * inner_size + k]);
});
});
dweight_host[i] = static_cast<Tcheck>(sum);
});
}
}

return miopenStatusSuccess;
}
Loading
Loading