diff --git a/docs/reference/index.rst b/docs/reference/index.rst index c2b74eabee..831d41595a 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -33,6 +33,7 @@ The MIOpen API library is structured as follows: * :doc:`Cat <../doxygen/html/group__cat>` (experimental) * :doc:`SGD <../doxygen/html/group___s_g_d>` (experimental) * :doc:`ReduceExtreme <../doxygen/html/group__ReduceExtreme>` (experimental) + * :doc:`Fold <./group___f_o_l_d>` (experimental) * :doc:`Getitem <../doxygen/html/group__getitem>` (experimental) * :doc:`ReduceCalculation <../doxygen/html/group__ReduceCalculation>` (experimental) * :doc:`RotaryPositionalEmbeddings <../doxygen/html/group__RotaryPositionalEmbeddings>` (experimental) diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 60d6fe6ce6..af58ebc4f3 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -62,6 +62,8 @@ add_executable(MIOpenDriver dm_t5layernorm.cpp dm_tensorop.cpp dm_transformers_adam_w.cpp + dm_fold.cpp + dm_unfold.cpp main.cpp registry_driver_maker.cpp rocrand_wrapper.cpp) diff --git a/driver/dm_fold.cpp b/driver/dm_fold.cpp new file mode 100644 index 0000000000..d7a8e2cb9a --- /dev/null +++ b/driver/dm_fold.cpp @@ -0,0 +1,39 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "registry_driver_maker.hpp" +#include "fold_driver.hpp" +static Driver* makeDriver(const std::string& base_arg) +{ + if(base_arg == "fold") + return new FoldDriver(); + if(base_arg == "foldfp16") + return new FoldDriver(); + if(base_arg == "foldbfp16") + return new FoldDriver(); + return nullptr; +} + +REGISTER_DRIVER_MAKER(makeDriver); diff --git a/driver/dm_unfold.cpp b/driver/dm_unfold.cpp new file mode 100644 index 0000000000..3d7ed56a91 --- /dev/null +++ b/driver/dm_unfold.cpp @@ -0,0 +1,39 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "registry_driver_maker.hpp" +#include "unfold_driver.hpp" +static Driver* makeDriver(const std::string& base_arg) +{ + if(base_arg == "unfold") + return new UnfoldDriver(); + if(base_arg == "unfoldfp16") + return new UnfoldDriver(); + if(base_arg == "unfoldbfp16") + return new UnfoldDriver(); + return nullptr; +} + +REGISTER_DRIVER_MAKER(makeDriver); diff --git a/driver/driver.hpp b/driver/driver.hpp index d77d5d02d2..d5170b2c83 100644 --- a/driver/driver.hpp +++ b/driver/driver.hpp @@ -314,7 +314,8 @@ inline void PadBufferSize(size_t& sz, int datatype_sz) "adamw[fp16], ampadamw, transformersadamw[fp16], transformersampadamw, " "getitem[bfp16|fp16], reducecalculation[bfp16|fp16], rope[bfp16|fp16], " "prelu[bfp16|fp16], kthvalue[bfp16|fp16], glu[bfp16|fp16], softmarginloss[bfp16|fp16], " - "multimarginloss[bfp16|fp16]\n"); + "multimarginloss[bfp16|fp16], unfold[bfp16|fp16], " + "fold[bfp16|fp16]\n"); exit(0); // NOLINT (concurrency-mt-unsafe) } @@ -352,7 +353,8 @@ inline std::string ParseBaseArg(int argc, char* argv[]) arg != "kthvaluebfp16" && arg != "glu" && arg != "glufp16" && arg != "glubfp16" && arg != "softmarginloss" && arg != "softmarginlossfp16" && arg != "softmarginlossbfp16" && arg != "multimarginloss" && arg != "multimarginlossfp16" && arg != "multimarginlossbfp16" && - arg != "--version") + arg != "unfold" && arg != "unfoldfp16" && arg != "unfoldbfp16" && arg != "fold" && + arg != "foldfp16" && arg != "foldbfp16" && arg != "--version") { printf("FAILED: Invalid Base Input Argument\n"); Usage(); diff --git a/driver/fold_driver.hpp b/driver/fold_driver.hpp new file mode 100644 index 0000000000..7648e26447 --- /dev/null +++ b/driver/fold_driver.hpp @@ -0,0 +1,412 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACTORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include "InputFlags.hpp" +#include "driver.hpp" +#include "mloUnfoldHost.hpp" +#include "random.hpp" +#include "tensor_driver.hpp" +#include "timer.hpp" + +#include <../test/tensor_holder.hpp> +#include <../test/verify.hpp> + +#include +#include +#include +#include +#include + +template +class FoldDriver : public Driver +{ +public: + FoldDriver() : Driver() + { + miopenCreateTensorDescriptor(&inputDesc); + miopenCreateTensorDescriptor(&outputDesc); + miopenCreateTensorDescriptor(&dinputDesc); + miopenCreateTensorDescriptor(&doutputDesc); + + data_type = miopen_type{}; + } + + int AddCmdLineArgs() override; + int ParseCmdLineArgs(int argc, char* argv[]) override; + InputFlags& GetInputFlags() override { return inflags; } + + int GetandSetData() override; + + int AllocateBuffersAndCopy() override; + + int RunForwardGPU() override; + int RunForwardCPU(); + + int RunBackwardGPU() override; + int RunBackwardCPU(); + + Tref GetTolerance(); + int VerifyBackward() override; + int VerifyForward() override; + ~FoldDriver() override + { + miopenDestroyTensorDescriptor(inputDesc); + miopenDestroyTensorDescriptor(outputDesc); + miopenDestroyTensorDescriptor(dinputDesc); + miopenDestroyTensorDescriptor(doutputDesc); + } + +private: + InputFlags inflags; + + miopenTensorDescriptor_t inputDesc; + miopenTensorDescriptor_t outputDesc; + + miopenTensorDescriptor_t doutputDesc; + miopenTensorDescriptor_t dinputDesc; + + std::unique_ptr input_dev; + std::unique_ptr output_dev; + + std::unique_ptr doutput_dev; + std::unique_ptr dinput_dev; + + std::vector input; + std::vector output; + + std::vector doutput; + std::vector dinput; + + std::vector output_host; + std::vector dinput_host; + + std::vector output_size; + std::vector kernel_size; + std::vector stride; + std::vector padding; + std::vector dilation; +}; + +template +int FoldDriver::ParseCmdLineArgs(int argc, char* argv[]) +{ + inflags.Parse(argc, argv); + + if(inflags.GetValueInt("time") == 1) + { + miopenEnableProfiling(GetHandle(), true); + } + return miopenStatusSuccess; +} + +template +int FoldDriver::GetandSetData() +{ + std::vector input_length = inflags.GetValueTensor("DimLengths").lengths; + std::vector output_size_int = inflags.GetValueTensor("outputSize").lengths; + output_size = {output_size_int.begin(), output_size_int.end()}; + std::vector kernel_size_int = inflags.GetValueTensor("kernelSize").lengths; + kernel_size = {kernel_size_int.begin(), kernel_size_int.end()}; + std::vector stride_int = inflags.GetValueTensor("stride").lengths; + stride = {stride_int.begin(), stride_int.end()}; + std::vector padding_int = inflags.GetValueTensor("padding").lengths; + padding = {padding_int.begin(), padding_int.end()}; + std::vector dilation_int = inflags.GetValueTensor("dilation").lengths; + dilation = {dilation_int.begin(), dilation_int.end()}; + + uint64_t N = input_length[0]; + uint64_t C = input_length[1]; + for(uint64_t i : kernel_size) + { + C = C / i; + } + + std::vector output_length = {N, C, output_size[0], output_size[1]}; + if(SetTensorNd(inputDesc, input_length, data_type) != miopenStatusSuccess) + MIOPEN_THROW("Error parsing input tensor: " + inflags.GetValueStr("input_dims") + "."); + if(SetTensorNd(outputDesc, output_length, data_type) != miopenStatusSuccess) + MIOPEN_THROW("Error parsing output tensor: " + inflags.GetValueStr("output_dims") + "."); + if(SetTensorNd(doutputDesc, output_length, data_type) != miopenStatusSuccess) + MIOPEN_THROW("Error parsing output grad tensor: " + inflags.GetValueStr("output_dims") + + "."); + if(SetTensorNd(dinputDesc, input_length, data_type) != miopenStatusSuccess) + MIOPEN_THROW("Error parsing input grad tensor: " + inflags.GetValueStr("input_dims") + "."); + + return miopenStatusSuccess; +} + +template +int FoldDriver::AddCmdLineArgs() +{ + inflags.AddInputFlag( + "forw", 'F', "1", "Run Fold Forward (Default=1) or both Forward and Backward (0)", "int"); + inflags.AddTensorFlag("DimLengths", + 'D', + "3x12x12", + "The dimensional lengths of the input tensor (Default=3x12x12)"); + inflags.AddTensorFlag("outputSize", 'o', "4x5", "Output Size (Default=2x3)"); + inflags.AddTensorFlag("kernelSize", 'k', "2x2", "Kernel Size (Default=2x3)"); + inflags.AddTensorFlag("stride", 's', "1x1", "Stride (Default=1x1)"); + inflags.AddTensorFlag("padding", 'p', "0x0", "Padding (Default=0x0)"); + inflags.AddTensorFlag("dilation", 'd', "1x1", "Dilation (Default=1x1)"); + inflags.AddInputFlag("iter", 'i', "10", "Number of Iterations (Default=10)", "int"); + inflags.AddInputFlag("verify", 'V', "0", "Verify Each Layer (Default=0)", "int"); + inflags.AddInputFlag("time", 't', "0", "Time Each Layer (Default=0)", "int"); + inflags.AddInputFlag( + "wall", 'w', "0", "Wall-clock Time Each Layer, Requires time == 1 (Default=0)", "int"); + + return miopenStatusSuccess; +} + +template +int FoldDriver::AllocateBuffersAndCopy() +{ + size_t input_sz = GetTensorSize(inputDesc); + size_t output_sz = GetTensorSize(outputDesc); + + size_t doutput_sz = GetTensorSize(doutputDesc); + size_t dinput_sz = GetTensorSize(dinputDesc); + + uint32_t ctx = 0; + + input_dev = std::unique_ptr(new GPUMem(ctx, input_sz, sizeof(Tgpu))); + output_dev = std::unique_ptr(new GPUMem(ctx, output_sz, sizeof(Tgpu))); + + doutput_dev = std::unique_ptr(new GPUMem(ctx, doutput_sz, sizeof(Tgpu))); + dinput_dev = std::unique_ptr(new GPUMem(ctx, dinput_sz, sizeof(Tgpu))); + + input = std::vector(input_sz, static_cast(0.0f)); + output = std::vector(output_sz, static_cast(0.0f)); + + doutput = std::vector(doutput_sz, static_cast(1.0f)); + dinput = std::vector(dinput_sz, static_cast(0.0f)); + + output_host = std::vector(output_sz, static_cast(0.0f)); + + dinput_host = std::vector(dinput_sz, static_cast(0.0f)); + + int status; + + for(int i = 0; i < input_sz; i++) + input[i] = prng::gen_A_to_B(static_cast(0.0), static_cast(1.0)); + status = input_dev->ToGPU(GetStream(), input.data()); + + for(int i = 0; i < doutput_sz; i++) + { + doutput[i] = prng::gen_A_to_B(static_cast(0.0), static_cast(1.0)); + } + status |= doutput_dev->ToGPU(GetStream(), doutput.data()); + status |= dinput_dev->ToGPU(GetStream(), dinput.data()); + + if(status != 0) + { + std::cout << "Error copying data to GPU\n" << std::endl; + return miopenStatusAllocFailed; + } + return miopenStatusSuccess; +} + +template +int FoldDriver::RunForwardGPU() +{ + float kernel_total_time = 0; + float kernel_first_time = 0; + + Timer t; + START_TIME + + for(int i = 0; i < inflags.GetValueInt("iter"); i++) + { + miopenFoldForward(GetHandle(), + inputDesc, + input_dev->GetMem(), + outputDesc, + output_dev->GetMem(), + kernel_size.data(), + kernel_size.size(), + stride.data(), + stride.size(), + padding.data(), + padding.size(), + dilation.data(), + dilation.size()); + + float time = 0.0; + miopenGetKernelTime(GetHandle(), &time); + kernel_total_time += time; + if(i == 0) + kernel_first_time = time; + } + + if(inflags.GetValueInt("time") == 1) + { + STOP_TIME + int iter = inflags.GetValueInt("iter"); + if(WALL_CLOCK) + std::cout << "Wall-clock Time Fold Forward Elapsed: " << t.gettime_ms() / iter << " ms" + << std::endl; + + float kernel_average_time = + iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + std::cout << "GPU Kernel Time Fold Forward Elapsed: " << kernel_average_time << " ms" + << std::endl; + } + + if(output_dev->FromGPU(GetStream(), output.data()) != 0) + std::cerr << "Error copying (out_dev) from GPU, size: " << output_dev->GetSize() + << std::endl; + + return miopenStatusSuccess; +} + +template +int FoldDriver::RunForwardCPU() +{ + mloUnFoldBwd4DRunHost(output_host.data(), + outputDesc, + input.data(), + inputDesc, + kernel_size, + stride, + padding, + dilation); + return miopenStatusSuccess; +} + +template +int FoldDriver::RunBackwardGPU() +{ + float kernel_total_time = 0; + float kernel_first_time = 0; + + Timer t; + START_TIME + + for(int i = 0; i < inflags.GetValueInt("iter"); i++) + { + miopenFoldBackward(GetHandle(), + dinputDesc, + dinput_dev->GetMem(), + doutputDesc, + doutput_dev->GetMem(), + kernel_size.data(), + kernel_size.size(), + stride.data(), + stride.size(), + padding.data(), + padding.size(), + dilation.data(), + dilation.size()); + + float time = 0.0; + miopenGetKernelTime(GetHandle(), &time); + kernel_total_time += time; + if(i == 0) + kernel_first_time = time; + } + + if(inflags.GetValueInt("time") == 1) + { + STOP_TIME + int iter = inflags.GetValueInt("iter"); + if(WALL_CLOCK) + std::cout << "Wall-clock Time Fold Backward Elapsed: " << t.gettime_ms() / iter << " ms" + << std::endl; + + float kernel_average_time = + iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + std::cout << "GPU Kernel Time Fold Backward Elapsed: " << kernel_average_time << " ms" + << std::endl; + } + + if(dinput_dev->FromGPU(GetStream(), dinput.data()) != 0) + std::cerr << "Error copying (dinput_dev) from GPU, size: " << dinput_dev->GetSize() + << std::endl; + + return miopenStatusSuccess; +} + +template +int FoldDriver::RunBackwardCPU() +{ + mloUnFoldFwd4DRunHost(doutput.data(), + doutputDesc, + dinput_host.data(), + dinputDesc, + kernel_size, + stride, + padding, + dilation); + return miopenStatusSuccess; +} + +template +Tref FoldDriver::GetTolerance() +{ + Tref tolerance = std::numeric_limits::epsilon() * 10; + return tolerance; +} + +template +int FoldDriver::VerifyForward() +{ + RunForwardCPU(); + const Tref tolerance = GetTolerance(); + auto error_output = miopen::rms_range(output_host, output); + + if(!std::isfinite(error_output) || error_output > tolerance) + { + std::cout << "Forward Fold FAILED: {" << error_output << "} > " << tolerance << std::endl; + return EC_VerifyFwd; + } + else + { + std::cout << "Forward Fold Verifies OK on CPU reference ({" << error_output << "} < " + << tolerance << ')' << std::endl; + } + return miopenStatusSuccess; +} + +template +int FoldDriver::VerifyBackward() +{ + RunBackwardCPU(); + const Tref tolerance = GetTolerance(); + auto error_dinput = miopen::rms_range(dinput_host, dinput); + + if(!std::isfinite(error_dinput) || error_dinput > tolerance) + { + std::cout << "Backward Fold FAILED: {" << error_dinput << "} > " << tolerance << std::endl; + return EC_VerifyFwd; + } + else + { + std::cout << "Backward Fold Verifies OK on CPU reference ({" << error_dinput << "} < " + << tolerance << ')' << std::endl; + } + return miopenStatusSuccess; +} diff --git a/driver/mloUnfoldHost.hpp b/driver/mloUnfoldHost.hpp new file mode 100644 index 0000000000..6a2ba7e715 --- /dev/null +++ b/driver/mloUnfoldHost.hpp @@ -0,0 +1,195 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include <../test/ford.hpp> +#include "tensor_view.hpp" +#include +#include +#include + +template +int32_t mloUnFoldFwd4DRunHost(Tgpu* input, + const miopenTensorDescriptor_t inputDesc, + Tcheck* ref_output, + const miopenTensorDescriptor_t ref_outputDesc, + const std::vector kernel_size, + const std::vector stride, + const std::vector padding, + const std::vector dilation) +{ + auto input_tv = miopen::get_inner_expanded_tv<4>(miopen::deref(inputDesc)); + auto output_tv = miopen::get_inner_expanded_tv<3>(miopen::deref(ref_outputDesc)); + auto input_dims = miopen::deref(inputDesc).GetLengths(); + auto input_size = miopen::deref(inputDesc).GetNumDims(); + + const int LOCAL_SIZE = 256; + int spatial_dim_size = input_size - 2; + const uint64_t N = static_cast(input_dims[0]); + const uint64_t C = static_cast(input_dims[1]); + uint64_t P = 1, L = 1; + std::vector ls; + for(int i = 0; i < spatial_dim_size; ++i) + { + P *= kernel_size[i]; + uint64_t l = (static_cast(input_dims[i + 2]) + 2 * padding[i] - + dilation[i] * (kernel_size[i] - 1) - 1) / + stride[i] + + 1; + L *= l; + ls.push_back(l); + } + uint64_t kernel_size_w = kernel_size[1]; + uint64_t stride_h = stride[0]; + uint64_t stride_w = stride[1]; + uint64_t padding_h = padding[0]; + uint64_t padding_w = padding[1]; + uint64_t dilation_h = dilation[0]; + uint64_t dilation_w = dilation[1]; + uint64_t LW = ls[1]; + uint64_t H = static_cast(input_dims[2]); + uint64_t W = static_cast(input_dims[3]); + uint64_t work_size = (((N * C * P * L) + LOCAL_SIZE - 1) / LOCAL_SIZE) * LOCAL_SIZE; + par_ford(work_size)([&](uint64_t gid) { + uint64_t ncp = gid / L, l = gid % L; + uint64_t nc = ncp / P, p = ncp % P; + uint64_t n = nc / C, c = nc % C; + if(n >= N) + return; + + uint64_t lh = l / LW, lw = l % LW; // sliding window position + uint64_t ph = p / kernel_size_w, pw = p % kernel_size_w; // position inside kernel + + Tgpu x = static_cast(0.0f); + if(lh * stride_h >= padding_h + ph * dilation_h && + lw * stride_w >= padding_w + pw * dilation_w) + { + uint64_t h = lh * stride_h - padding_h + ph * dilation_h; + uint64_t w = lw * stride_w - padding_w + pw * dilation_w; + if(h < H && w < W) + { + long input_idx = input_tv.stride[3] * w + input_tv.stride[2] * h + + input_tv.stride[1] * c + input_tv.stride[0] * n; + x = input[input_idx]; + } + } + + long output_idx = + output_tv.stride[2] * l + output_tv.stride[1] * (c * P + p) + output_tv.stride[0] * n; + ref_output[output_idx] = static_cast(x); + }); + + return miopenStatusSuccess; +} + +template +int32_t mloUnFoldBwd4DRunHost(Tcheck* ref_dinput, + const miopenTensorDescriptor_t dinputDesc, + Tgpu* doutput, + const miopenTensorDescriptor_t doutputDesc, + const std::vector kernel_size, + const std::vector stride, + const std::vector padding, + const std::vector dilation) +{ + auto input_grad_tv = miopen::get_inner_expanded_tv<4>(miopen::deref(dinputDesc)); + auto output_grad_tv = miopen::get_inner_expanded_tv<3>(miopen::deref(doutputDesc)); + auto input_grad_dims = miopen::deref(dinputDesc).GetLengths(); + auto input_size = miopen::deref(dinputDesc).GetNumDims(); + + const int LOCAL_SIZE = 256; + int spatial_dim_size = input_size - 2; + const uint64_t N = static_cast(input_grad_dims[0]); + const uint64_t C = static_cast(input_grad_dims[1]); + uint64_t P = 1; + std::vector ls; + for(int i = 0; i < spatial_dim_size; ++i) + { + P *= kernel_size[i]; + uint64_t l = (static_cast(input_grad_dims[i + 2]) + 2 * padding[i] - + dilation[i] * (kernel_size[i] - 1) - 1) / + stride[i] + + 1; + ls.push_back(l); + } + uint64_t kernel_size_h = kernel_size[0]; + uint64_t kernel_size_w = kernel_size[1]; + uint64_t stride_h = stride[0]; + uint64_t stride_w = stride[1]; + uint64_t padding_h = padding[0]; + uint64_t padding_w = padding[1]; + uint64_t dilation_h = dilation[0]; + uint64_t dilation_w = dilation[1]; + uint64_t LH = ls[0]; + uint64_t LW = ls[1]; + uint64_t H = static_cast(input_grad_dims[2]); + uint64_t W = static_cast(input_grad_dims[3]); + uint64_t work_size = (((N * C * H * W) + LOCAL_SIZE - 1) / LOCAL_SIZE) * LOCAL_SIZE; + par_ford(work_size)([&](uint64_t gid) { + uint64_t nch = gid / W, w = gid % W; + uint64_t nc = nch / H, h = nch % H; + uint64_t n = nc / C, c = nc % C; + if(n >= N) + return; + + float sum = 0.0f; + + for(uint64_t ph = 0; ph < kernel_size_h; ++ph) + { + for(uint64_t pw = 0; pw < kernel_size_w; ++pw) + { + if(h < ph * dilation_h + padding_h) + continue; + if(w < pw * dilation_w + padding_w) + continue; + uint64_t lhsh = h - ph * dilation_h + padding_h; + uint64_t lwsw = w - pw * dilation_w + padding_w; + if(lhsh % stride_h != 0) + continue; + if(lwsw % stride_w != 0) + continue; + uint64_t lh = lhsh / stride_h; + uint64_t lw = lwsw / stride_w; + if(LH <= lh) + continue; + if(LW <= lw) + continue; + long output_grad_idx = + output_grad_tv.stride[2] * (lh * LW + lw) + + output_grad_tv.stride[1] * (c * P + (ph * kernel_size_w + pw)) + + output_grad_tv.stride[0] * n; + sum += static_cast(doutput[output_grad_idx]); + } + } + + long input_grad_idx = input_grad_tv.stride[3] * w + input_grad_tv.stride[2] * h + + input_grad_tv.stride[1] * c + input_grad_tv.stride[0] * n; + ref_dinput[input_grad_idx] = static_cast(sum); + }); + + return miopenStatusSuccess; +} diff --git a/driver/unfold_driver.hpp b/driver/unfold_driver.hpp new file mode 100644 index 0000000000..188a01a968 --- /dev/null +++ b/driver/unfold_driver.hpp @@ -0,0 +1,419 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACTORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include "InputFlags.hpp" +#include "driver.hpp" +#include "mloUnfoldHost.hpp" +#include "random.hpp" +#include "tensor_driver.hpp" +#include "timer.hpp" + +#include <../test/tensor_holder.hpp> +#include <../test/verify.hpp> + +#include +#include +#include +#include +#include + +template +class UnfoldDriver : public Driver +{ +public: + UnfoldDriver() : Driver() + { + miopenCreateTensorDescriptor(&inputDesc); + miopenCreateTensorDescriptor(&outputDesc); + miopenCreateTensorDescriptor(&dinputDesc); + miopenCreateTensorDescriptor(&doutputDesc); + + data_type = miopen_type{}; + } + + int AddCmdLineArgs() override; + int ParseCmdLineArgs(int argc, char* argv[]) override; + InputFlags& GetInputFlags() override { return inflags; } + + int GetandSetData() override; + + int AllocateBuffersAndCopy() override; + + int RunForwardGPU() override; + int RunForwardCPU(); + + int RunBackwardGPU() override; + int RunBackwardCPU(); + + Tref GetTolerance(); + int VerifyBackward() override; + int VerifyForward() override; + ~UnfoldDriver() override + { + miopenDestroyTensorDescriptor(inputDesc); + miopenDestroyTensorDescriptor(outputDesc); + miopenDestroyTensorDescriptor(dinputDesc); + miopenDestroyTensorDescriptor(doutputDesc); + } + +private: + InputFlags inflags; + + miopenTensorDescriptor_t inputDesc; + miopenTensorDescriptor_t outputDesc; + + miopenTensorDescriptor_t doutputDesc; + miopenTensorDescriptor_t dinputDesc; + + std::unique_ptr input_dev; + std::unique_ptr output_dev; + + std::unique_ptr doutput_dev; + std::unique_ptr dinput_dev; + + std::vector input; + std::vector output; + + std::vector doutput; + std::vector dinput; + + std::vector output_host; + std::vector dinput_host; + + std::vector input_length; + std::vector kernel_size; + std::vector stride; + std::vector padding; + std::vector dilation; +}; + +template +int UnfoldDriver::ParseCmdLineArgs(int argc, char* argv[]) +{ + inflags.Parse(argc, argv); + + if(inflags.GetValueInt("time") == 1) + { + miopenEnableProfiling(GetHandle(), true); + } + return miopenStatusSuccess; +} + +template +int UnfoldDriver::GetandSetData() +{ + std::vector input_dims = inflags.GetValueTensor("DimLengths").lengths; + input_length = {input_dims.begin(), input_dims.end()}; + std::vector kernel_size_int = inflags.GetValueTensor("kernelSize").lengths; + kernel_size = {kernel_size_int.begin(), kernel_size_int.end()}; + std::vector stride_int = inflags.GetValueTensor("stride").lengths; + stride = {stride_int.begin(), stride_int.end()}; + std::vector padding_int = inflags.GetValueTensor("padding").lengths; + padding = {padding_int.begin(), padding_int.end()}; + std::vector dilation_int = inflags.GetValueTensor("dilation").lengths; + dilation = {dilation_int.begin(), dilation_int.end()}; + uint64_t spatial_dim_size = input_length.size() - 2; + uint64_t N = input_length[0]; + uint64_t C = input_length[1]; + + uint64_t P = 1, L = 1; + std::vector ls; + for(uint64_t i = 0; i < spatial_dim_size; ++i) + { + P *= kernel_size[i]; + uint64_t l = + (input_length[i + 2] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) / + stride[i] + + 1; + L *= l; + ls.push_back(l); + } + + std::vector output_length = {N, (C * P), L}; + if(SetTensorNd(inputDesc, input_length, data_type) != miopenStatusSuccess) + MIOPEN_THROW("Error parsing input tensor: " + inflags.GetValueStr("input_dims") + "."); + if(SetTensorNd(outputDesc, output_length, data_type) != miopenStatusSuccess) + MIOPEN_THROW("Error parsing output tensor: " + inflags.GetValueStr("output_dims") + "."); + if(SetTensorNd(doutputDesc, output_length, data_type) != miopenStatusSuccess) + MIOPEN_THROW("Error parsing output grad tensor: " + inflags.GetValueStr("output_dims") + + "."); + if(SetTensorNd(dinputDesc, input_length, data_type) != miopenStatusSuccess) + MIOPEN_THROW("Error parsing input grad tensor: " + inflags.GetValueStr("input_dims") + "."); + + return miopenStatusSuccess; +} + +template +int UnfoldDriver::AddCmdLineArgs() +{ + inflags.AddInputFlag( + "forw", 'F', "1", "Run Unfold Forward (Default=1) or both Forward and Backward (0)", "int"); + inflags.AddTensorFlag( + "DimLengths", 'D', "2x5x3x4", "The dimensional lengths of the input tensor"); + inflags.AddTensorFlag("kernelSize", 'k', "2x2", "Kernel Size (Default=2x3)"); + inflags.AddTensorFlag("stride", 's', "1x1", "Stride (Default=1x1)"); + inflags.AddTensorFlag("padding", 'p', "0x0", "Padding (Default=0x0)"); + inflags.AddTensorFlag("dilation", 'd', "1x1", "Dilation (Default=1x1)"); + inflags.AddInputFlag("iter", 'i', "10", "Number of Iterations (Default=10)", "int"); + inflags.AddInputFlag("verify", 'V', "0", "Verify Each Layer (Default=0)", "int"); + inflags.AddInputFlag("time", 't', "0", "Time Each Layer (Default=0)", "int"); + inflags.AddInputFlag( + "wall", 'w', "0", "Wall-clock Time Each Layer, Requires time == 1 (Default=0)", "int"); + + return miopenStatusSuccess; +} + +template +int UnfoldDriver::AllocateBuffersAndCopy() +{ + size_t input_sz = GetTensorSize(inputDesc); + size_t output_sz = GetTensorSize(outputDesc); + + size_t doutput_sz = GetTensorSize(doutputDesc); + size_t dinput_sz = GetTensorSize(dinputDesc); + + uint32_t ctx = 0; + + input_dev = std::unique_ptr(new GPUMem(ctx, input_sz, sizeof(Tgpu))); + output_dev = std::unique_ptr(new GPUMem(ctx, output_sz, sizeof(Tgpu))); + + doutput_dev = std::unique_ptr(new GPUMem(ctx, doutput_sz, sizeof(Tgpu))); + dinput_dev = std::unique_ptr(new GPUMem(ctx, dinput_sz, sizeof(Tgpu))); + + input = std::vector(input_sz, static_cast(0.0f)); + output = std::vector(output_sz, static_cast(0.0f)); + + doutput = std::vector(doutput_sz, static_cast(1.0f)); + dinput = std::vector(dinput_sz, static_cast(0.0f)); + + output_host = std::vector(output_sz, static_cast(0.0f)); + + dinput_host = std::vector(dinput_sz, static_cast(0.0f)); + + int status; + + for(int i = 0; i < input_sz; i++) + input[i] = prng::gen_A_to_B(static_cast(0.0), static_cast(1.0)); + status = input_dev->ToGPU(GetStream(), input.data()); + + for(int i = 0; i < doutput_sz; i++) + { + doutput[i] = prng::gen_A_to_B(static_cast(0.0), static_cast(1.0)); + } + status |= doutput_dev->ToGPU(GetStream(), doutput.data()); + status |= dinput_dev->ToGPU(GetStream(), dinput.data()); + + if(status != 0) + { + std::cout << "Error copying data to GPU\n" << std::endl; + return miopenStatusAllocFailed; + } + + return miopenStatusSuccess; +} + +template +int UnfoldDriver::RunForwardGPU() +{ + float kernel_total_time = 0; + float kernel_first_time = 0; + + Timer t; + START_TIME + + for(int i = 0; i < inflags.GetValueInt("iter"); i++) + { + miopenUnfoldForward(GetHandle(), + inputDesc, + input_dev->GetMem(), + outputDesc, + output_dev->GetMem(), + kernel_size.data(), + kernel_size.size(), + stride.data(), + stride.size(), + padding.data(), + padding.size(), + dilation.data(), + dilation.size()); + + float time = 0.0; + miopenGetKernelTime(GetHandle(), &time); + kernel_total_time += time; + if(i == 0) + kernel_first_time = time; + } + + if(inflags.GetValueInt("time") == 1) + { + STOP_TIME + int iter = inflags.GetValueInt("iter"); + if(WALL_CLOCK) + std::cout << "Wall-clock Time Unfold Forward Elapsed: " << t.gettime_ms() / iter + << " ms" << std::endl; + + float kernel_average_time = + iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + std::cout << "GPU Kernel Time Unfold Forward Elapsed: " << kernel_average_time << " ms" + << std::endl; + } + + if(output_dev->FromGPU(GetStream(), output.data()) != 0) + std::cerr << "Error copying (out_dev) from GPU, size: " << output_dev->GetSize() + << std::endl; + + return miopenStatusSuccess; +} + +template +int UnfoldDriver::RunForwardCPU() +{ + mloUnFoldFwd4DRunHost(input.data(), + inputDesc, + output_host.data(), + outputDesc, + kernel_size, + stride, + padding, + dilation); + return miopenStatusSuccess; +} + +template +int UnfoldDriver::RunBackwardGPU() +{ + float kernel_total_time = 0; + float kernel_first_time = 0; + + Timer t; + START_TIME + + for(int i = 0; i < inflags.GetValueInt("iter"); i++) + { + miopenUnfoldBackward(GetHandle(), + dinputDesc, + dinput_dev->GetMem(), + doutputDesc, + doutput_dev->GetMem(), + kernel_size.data(), + kernel_size.size(), + stride.data(), + stride.size(), + padding.data(), + padding.size(), + dilation.data(), + dilation.size()); + + float time = 0.0; + miopenGetKernelTime(GetHandle(), &time); + kernel_total_time += time; + if(i == 0) + kernel_first_time = time; + } + + if(inflags.GetValueInt("time") == 1) + { + STOP_TIME + int iter = inflags.GetValueInt("iter"); + if(WALL_CLOCK) + std::cout << "Wall-clock Time Unfold Backward Elapsed: " << t.gettime_ms() / iter + << " ms" << std::endl; + + float kernel_average_time = + iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + std::cout << "GPU Kernel Time Unfold Backward Elapsed: " << kernel_average_time << " ms" + << std::endl; + } + + if(dinput_dev->FromGPU(GetStream(), dinput.data()) != 0) + std::cerr << "Error copying (dinput_dev) from GPU, size: " << dinput_dev->GetSize() + << std::endl; + + return miopenStatusSuccess; +} + +template +int UnfoldDriver::RunBackwardCPU() +{ + mloUnFoldBwd4DRunHost(dinput_host.data(), + inputDesc, + doutput.data(), + doutputDesc, + kernel_size, + stride, + padding, + dilation); + return miopenStatusSuccess; +} + +template +Tref UnfoldDriver::GetTolerance() +{ + Tref tolerance = std::numeric_limits::epsilon() * 10; + return tolerance; +} + +template +int UnfoldDriver::VerifyForward() +{ + RunForwardCPU(); + const Tref tolerance = GetTolerance(); + auto error_output = miopen::rms_range(output_host, output); + + if(!std::isfinite(error_output) || error_output > tolerance) + { + std::cout << "Forward Unfold FAILED: {" << error_output << "} > " << tolerance << std::endl; + return EC_VerifyFwd; + } + else + { + std::cout << "Forward Unfold Verifies OK on CPU reference ({" << error_output << "} < " + << tolerance << ')' << std::endl; + } + return miopenStatusSuccess; +} + +template +int UnfoldDriver::VerifyBackward() +{ + RunBackwardCPU(); + const Tref tolerance = GetTolerance(); + auto error_dinput = miopen::rms_range(dinput_host, dinput); + + if(!std::isfinite(error_dinput) || error_dinput > tolerance) + { + std::cout << "Backward Unfold FAILED: {" << error_dinput << "} > " << tolerance + << std::endl; + return EC_VerifyFwd; + } + else + { + std::cout << "Backward Unfold Verifies OK on CPU reference ({" << error_dinput << "} < " + << tolerance << ')' << std::endl; + } + return miopenStatusSuccess; +} diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 67652ab832..7e5438ceb7 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -8176,6 +8176,144 @@ MIOPEN_EXPORT miopenStatus_t miopenMultiMarginLossForward(miopenHandle_t handle, // CLOSEOUT LossFunction DOXYGEN GROUP #endif // MIOPEN_BETA_API +#ifdef MIOPEN_BETA_API +// Fold APIs +/** @addtogroup FOLD + * + * @{ + */ +/*! @brief Execute an fold forward layer + * + * @param handle MIOpen handle (input) + * @param inputDesc Tensor descriptor for data input tensor input (input) + * @param input Data tensor input (input) + * @param outputDesc Tensor descriptor for data output tensor output (input) + * @param output Data tensor output (output) + * @param kernel_size Size of the sliding box array (input) + * @param kernel_size_size Size of the kernel_size array (input) + * @param stride Stride array of the sliding box (input) + * @param stride_size Size of the stride array (input) + * @param padding Padding array to be added on input (input) + * @param padding_size Size of the padding array (input) + * @param dilation Dilation array control the stride of the elements within the + * neighborhood (input) + * @param dilation_size Size of the dilation array (input) + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenFoldForward(miopenHandle_t handle, + const miopenTensorDescriptor_t inputDesc, + const void* input, + const miopenTensorDescriptor_t outputDesc, + void* output, + const uint64_t* kernel_size, + const uint64_t kernel_size_size, + const uint64_t* stride, + const uint64_t stride_size, + const uint64_t* padding, + const uint64_t padding_size, + const uint64_t* dilation, + const uint64_t dilation_size); + +/*! @brief Execute an fold backward layer +* +* @param handle MIOpen handle (input) +* @param dinputDesc Tensor descriptor for data input grad tensor (input) +* @param dinput Data tensor input grad (output) +* @param doutputDesc Tensor descriptor for data output grad tensor (input) +* @param doutput Data tensor output grad (input) +* @param kernel_size Size of the sliding box array (input) +* @param kernel_size_size Size of the kernel_size array (input) +* @param stride Stride array of the sliding box (input) +* @param stride_size Size of the stride array (input) +* @param padding Padding array to be added on input (input) +* @param padding_size Size of the padding array (input) +* @param dilation Dilation array control the stride of the elements within the +neighborhood (input) +* @param dilation_size Size of the dilation array (input) +* @return miopenStatus_t +*/ +MIOPEN_EXPORT miopenStatus_t miopenFoldBackward(miopenHandle_t handle, + const miopenTensorDescriptor_t dinputDesc, + void* dinput, + const miopenTensorDescriptor_t doutputDesc, + const void* doutput, + const uint64_t* kernel_size, + const uint64_t kernel_size_size, + const uint64_t* stride, + const uint64_t stride_size, + const uint64_t* padding, + const uint64_t padding_size, + const uint64_t* dilation, + const uint64_t dilation_size); + +/*! @brief Execute an unfold forward layer + * + * @param handle MIOpen handle (input) + * @param inputDesc Tensor descriptor for data input tensor input (input) + * @param input Data tensor input (input) + * @param outputDesc Tensor descriptor for data output tensor output (input) + * @param output Data tensor output (output) + * @param kernel_size Size of the sliding box array (input) + * @param kernel_size_size Size of the kernel_size array (input) + * @param stride Stride array of the sliding box (input) + * @param stride_size Size of the stride array (input) + * @param padding Padding array to be added on input (input) + * @param padding_size Size of the padding array (input) + * @param dilation Dilation array control the stride of the elements within the + * neighborhood (input) + * @param dilation_size Size of the dilation array (input) + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenUnfoldForward(miopenHandle_t handle, + const miopenTensorDescriptor_t inputDesc, + const void* input, + const miopenTensorDescriptor_t outputDesc, + void* output, + const uint64_t* kernel_size, + const uint64_t kernel_size_size, + const uint64_t* stride, + const uint64_t stride_size, + const uint64_t* padding, + const uint64_t padding_size, + const uint64_t* dilation, + const uint64_t dilation_size); + +/*! @brief Execute an unfold backward layer + * + * @param handle MIOpen handle (input) + * @param dinputDesc Tensor descriptor for data input grad tensor (input) + * @param dinput Data tensor input grad (output) + * @param doutputDesc Tensor descriptor for data output grad tensor (input) + * @param doutput Data tensor output grad (input) + * @param kernel_size Size of the sliding box array (input) + * @param kernel_size_size Size of the kernel_size array (input) + * @param stride Stride array of the sliding box (input) + * @param stride_size Size of the stride array (input) + * @param padding Padding array to be added on input (input) + * @param padding_size Size of the padding array (input) + * @param dilation Dilation array control the stride of the elements within the + neighborhood (input) + * @param dilation_size Size of the dilation array (input) + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenUnfoldBackward(miopenHandle_t handle, + const miopenTensorDescriptor_t dinputDesc, + void* dinput, + const miopenTensorDescriptor_t doutputDesc, + const void* doutput, + const uint64_t* kernel_size, + const uint64_t kernel_size_size, + const uint64_t* stride, + const uint64_t stride_size, + const uint64_t* padding, + const uint64_t padding_size, + const uint64_t* dilation, + const uint64_t dilation_size); + +/** @} */ +// CLOSEOUT FOLD DOXYGEN GROUP +#endif + #ifdef __cplusplus } #endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 92e4f4264a..799a14517b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -122,6 +122,8 @@ set( MIOpen_Source expanduser.cpp find_controls.cpp find_db.cpp + fold_api.cpp + fold/problem_description.cpp fused_api.cpp fusion.cpp fusion/problem_description.cpp @@ -304,6 +306,10 @@ set( MIOpen_Source solver/conv_winoRxS_fused.cpp solver/glu/backward_glu.cpp solver/glu/forward_glu.cpp + solver/fold/fold_forward.cpp + solver/fold/fold_backward.cpp + solver/fold/unfold_forward.cpp + solver/fold/unfold_backward.cpp solver/groupnorm/forward_groupnorm.cpp solver/getitem/backward_getitem.cpp solver/kthvalue/forward_kthvalue.cpp @@ -587,6 +593,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/conv7x7c3h224w224k64u2v2p3q3f1.s kernels/xform_out.s kernels/gcnAsmBNBwdTrainSpatial.s + kernels/MIOpenUnfold.cpp kernels/MIOpenTensorKernels.cl kernels/MIOpenTensorKernelsHip.cpp kernels/MIOpenSubTensorOpWithScalarKernel.cl @@ -670,6 +677,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN addlayernorm.cpp cat.cpp exec_utils.cpp + fold.cpp groupnorm.cpp getitem.cpp glu.cpp diff --git a/src/fold.cpp b/src/fold.cpp new file mode 100644 index 0000000000..e9ec800d5f --- /dev/null +++ b/src/fold.cpp @@ -0,0 +1,244 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace miopen { + +namespace fold { + +miopenStatus_t UnfoldForward(Handle& handle, + const TensorDescriptor& inputDesc, + ConstData_t input, + const TensorDescriptor& outputDesc, + Data_t output, + const uint64_t* kernel_size, + const uint64_t kernel_size_size, + const uint64_t* stride, + const uint64_t stride_size, + const uint64_t* padding, + const uint64_t padding_size, + const uint64_t* dilation, + const uint64_t dilation_size) +{ + const auto problem = fold::UnfoldFwdProblemDescription{inputDesc, + outputDesc, + kernel_size, + kernel_size_size, + stride, + stride_size, + padding, + padding_size, + dilation, + dilation_size}; + + const auto invoke_params = [&]() { + auto tmp = fold::InvokeParams{}; + tmp.type = InvokeType::Run; + tmp.inputDesc = &inputDesc; + tmp.outputDesc = &outputDesc; + tmp.input = input; + tmp.output = output; + tmp.kernel_size = kernel_size; + tmp.stride = stride; + tmp.padding = padding; + tmp.dilation = dilation; + tmp.kernel_size_size = kernel_size_size; + tmp.stride_size = stride_size; + tmp.padding_size = padding_size; + tmp.dilation_size = dilation_size; + return tmp; + }(); + + const auto algo = AlgorithmName{"UnfoldFwd"}; + const auto solvers = solver::SolverContainer{}; + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + + return miopenStatusSuccess; +} + +miopenStatus_t UnfoldBackward(Handle& handle, + const TensorDescriptor& dinputDesc, + Data_t dinput, + const TensorDescriptor& doutputDesc, + ConstData_t doutput, + const uint64_t* kernel_size, + const uint64_t kernel_size_size, + const uint64_t* stride, + const uint64_t stride_size, + const uint64_t* padding, + const uint64_t padding_size, + const uint64_t* dilation, + const uint64_t dilation_size) +{ + const auto problem = fold::UnfoldBwdProblemDescription{dinputDesc, + doutputDesc, + kernel_size, + kernel_size_size, + stride, + stride_size, + padding, + padding_size, + dilation, + dilation_size}; + + const auto invoke_params = [&]() { + auto tmp = fold::InvokeParams{}; + tmp.type = InvokeType::Run; + tmp.dinputDesc = &dinputDesc; + tmp.doutputDesc = &doutputDesc; + tmp.dinput = dinput; + tmp.doutput = doutput; + tmp.kernel_size = kernel_size; + tmp.stride = stride; + tmp.padding = padding; + tmp.dilation = dilation; + tmp.kernel_size_size = kernel_size_size; + tmp.stride_size = stride_size; + tmp.padding_size = padding_size; + tmp.dilation_size = dilation_size; + return tmp; + }(); + + const auto algo = AlgorithmName{"UnfoldBwd"}; + const auto solvers = solver::SolverContainer{}; + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + + return miopenStatusSuccess; +} + +miopenStatus_t FoldForward(Handle& handle, + const TensorDescriptor& inputDesc, + ConstData_t input, + const TensorDescriptor& outputDesc, + Data_t output, + const uint64_t* kernel_size, + const uint64_t kernel_size_size, + const uint64_t* stride, + const uint64_t stride_size, + const uint64_t* padding, + const uint64_t padding_size, + const uint64_t* dilation, + const uint64_t dilation_size) +{ + const auto problem = fold::FoldFwdProblemDescription{inputDesc, + outputDesc, + kernel_size, + kernel_size_size, + stride, + stride_size, + padding, + padding_size, + dilation, + dilation_size}; + + const auto invoke_params = [&]() { + auto tmp = fold::InvokeParams{}; + tmp.type = InvokeType::Run; + tmp.inputDesc = &inputDesc; + tmp.outputDesc = &outputDesc; + tmp.input = input; + tmp.output = output; + tmp.kernel_size = kernel_size; + tmp.stride = stride; + tmp.padding = padding; + tmp.dilation = dilation; + tmp.kernel_size_size = kernel_size_size; + tmp.stride_size = stride_size; + tmp.padding_size = padding_size; + tmp.dilation_size = dilation_size; + return tmp; + }(); + + const auto algo = AlgorithmName{"FoldFwd"}; + const auto solvers = solver::SolverContainer{}; + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + + return miopenStatusSuccess; +} + +miopenStatus_t FoldBackward(Handle& handle, + const TensorDescriptor& dinputDesc, + Data_t dinput, + const TensorDescriptor& doutputDesc, + ConstData_t doutput, + const uint64_t* kernel_size, + const uint64_t kernel_size_size, + const uint64_t* stride, + const uint64_t stride_size, + const uint64_t* padding, + const uint64_t padding_size, + const uint64_t* dilation, + const uint64_t dilation_size) +{ + const auto problem = fold::FoldBwdProblemDescription{dinputDesc, + doutputDesc, + kernel_size, + kernel_size_size, + stride, + stride_size, + padding, + padding_size, + dilation, + dilation_size}; + + const auto invoke_params = [&]() { + auto tmp = fold::InvokeParams{}; + tmp.type = InvokeType::Run; + tmp.dinputDesc = &dinputDesc; + tmp.doutputDesc = &doutputDesc; + tmp.dinput = dinput; + tmp.doutput = doutput; + tmp.kernel_size = kernel_size; + tmp.stride = stride; + tmp.padding = padding; + tmp.dilation = dilation; + tmp.kernel_size_size = kernel_size_size; + tmp.stride_size = stride_size; + tmp.padding_size = padding_size; + tmp.dilation_size = dilation_size; + return tmp; + }(); + + const auto algo = AlgorithmName{"FoldBwd"}; + const auto solvers = solver::SolverContainer{}; + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + + return miopenStatusSuccess; +} + +} // namespace fold + +} // namespace miopen diff --git a/src/fold/problem_description.cpp b/src/fold/problem_description.cpp new file mode 100644 index 0000000000..6f7270f5ff --- /dev/null +++ b/src/fold/problem_description.cpp @@ -0,0 +1,145 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include + +namespace miopen { + +namespace fold { + +NetworkConfig UnfoldFwdProblemDescription::MakeNetworkConfig() const +{ + auto input_dtype = inputDesc.GetType(); + auto size = inputDesc.GetElementSize(); + auto in_dims = inputDesc.GetLengths(); + + std::ostringstream ss; + + ss << "Unfold_fwd"; + ss << "i_dtype" << input_dtype; + ss << "size" << size; + ss << "in_dims"; + for(auto val : in_dims) + { + ss << "_" << val; + } + ss << "kernel_size_" << kernel_size[0] << "_" << kernel_size[1]; + ss << "stride_" << stride[0] << "_" << stride[1]; + ss << "padding_" << padding[0] << "_" << padding[1]; + ss << "dilation_" << dilation[0] << "_" << dilation[1]; + + return NetworkConfig{ss.str()}; +} + +NetworkConfig UnfoldBwdProblemDescription::MakeNetworkConfig() const +{ + auto input_dtype = dinputDesc.GetType(); + auto size = dinputDesc.GetElementSize(); + auto in_dims = dinputDesc.GetLengths(); + + std::ostringstream ss; + + ss << "Unfold_bwd"; + ss << "i_dtype" << input_dtype; + ss << "size" << size; + ss << "in_grad_dims"; + for(auto val : in_dims) + { + ss << "_" << val; + } + ss << "kernel_size_" << kernel_size[0] << "_" << kernel_size[1]; + ss << "stride_" << stride[0] << "_" << stride[1]; + ss << "padding_" << padding[0] << "_" << padding[1]; + ss << "dilation_" << dilation[0] << "_" << dilation[1]; + + return NetworkConfig{ss.str()}; +} + +NetworkConfig FoldFwdProblemDescription::MakeNetworkConfig() const +{ + auto input_dtype = inputDesc.GetType(); + auto size = inputDesc.GetElementSize(); + auto in_dims = inputDesc.GetLengths(); + auto out_dims = outputDesc.GetLengths(); + + std::ostringstream ss; + + ss << "Fold_fwd"; + ss << "i_dtype" << input_dtype; + ss << "size" << size; + ss << "in_dims"; + for(auto val : in_dims) + { + ss << "_" << val; + } + ss << "out_dims"; + for(auto val : out_dims) + { + ss << "_" << val; + } + ss << "kernel_size_" << kernel_size[0] << "_" << kernel_size[1]; + ss << "stride_" << stride[0] << "_" << stride[1]; + ss << "padding_" << padding[0] << "_" << padding[1]; + ss << "dilation_" << dilation[0] << "_" << dilation[1]; + + return NetworkConfig{ss.str()}; +} + +NetworkConfig FoldBwdProblemDescription::MakeNetworkConfig() const +{ + auto input_dtype = dinputDesc.GetType(); + auto size = dinputDesc.GetElementSize(); + auto in_dims = dinputDesc.GetLengths(); + auto out_dims = doutputDesc.GetLengths(); + + std::ostringstream ss; + + ss << "Fold_bwd"; + ss << "i_dtype" << input_dtype; + ss << "size" << size; + ss << "in_grad_dims"; + for(auto val : in_dims) + { + ss << "_" << val; + } + ss << "out_grad_dims"; + for(auto val : out_dims) + { + ss << "_" << val; + } + ss << "kernel_size_" << kernel_size[0] << "_" << kernel_size[1]; + ss << "stride_" << stride[0] << "_" << stride[1]; + ss << "padding_" << padding[0] << "_" << padding[1]; + ss << "dilation_" << dilation[0] << "_" << dilation[1]; + + return NetworkConfig{ss.str()}; +} + +} // namespace fold + +} // namespace miopen diff --git a/src/fold_api.cpp b/src/fold_api.cpp new file mode 100644 index 0000000000..f1cc61813c --- /dev/null +++ b/src/fold_api.cpp @@ -0,0 +1,156 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include + +extern "C" miopenStatus_t miopenUnfoldForward(miopenHandle_t handle, + const miopenTensorDescriptor_t inputDesc, + const void* input, + const miopenTensorDescriptor_t outputDesc, + void* output, + const uint64_t* kernel_size, + const uint64_t kernel_size_size, + const uint64_t* stride, + const uint64_t stride_size, + const uint64_t* padding, + const uint64_t padding_size, + const uint64_t* dilation, + const uint64_t dilation_size) +{ + return miopen::try_([&] { + miopen::fold::UnfoldForward(miopen::deref(handle), + miopen::deref(inputDesc), + DataCast(input), + miopen::deref(outputDesc), + DataCast(output), + kernel_size, + kernel_size_size, + stride, + stride_size, + padding, + padding_size, + dilation, + dilation_size); + }); +} + +extern "C" miopenStatus_t miopenUnfoldBackward(miopenHandle_t handle, + const miopenTensorDescriptor_t dinputDesc, + void* dinput, + const miopenTensorDescriptor_t doutputDesc, + const void* doutput, + const uint64_t* kernel_size, + const uint64_t kernel_size_size, + const uint64_t* stride, + const uint64_t stride_size, + const uint64_t* padding, + const uint64_t padding_size, + const uint64_t* dilation, + const uint64_t dilation_size) +{ + return miopen::try_([&] { + miopen::fold::UnfoldBackward(miopen::deref(handle), + miopen::deref(dinputDesc), + DataCast(dinput), + miopen::deref(doutputDesc), + DataCast(doutput), + kernel_size, + kernel_size_size, + stride, + stride_size, + padding, + padding_size, + dilation, + dilation_size); + }); +} + +extern "C" miopenStatus_t miopenFoldForward(miopenHandle_t handle, + const miopenTensorDescriptor_t inputDesc, + const void* input, + const miopenTensorDescriptor_t outputDesc, + void* output, + const uint64_t* kernel_size, + const uint64_t kernel_size_size, + const uint64_t* stride, + const uint64_t stride_size, + const uint64_t* padding, + const uint64_t padding_size, + const uint64_t* dilation, + const uint64_t dilation_size) +{ + return miopen::try_([&] { + miopen::fold::FoldForward(miopen::deref(handle), + miopen::deref(inputDesc), + DataCast(input), + miopen::deref(outputDesc), + DataCast(output), + kernel_size, + kernel_size_size, + stride, + stride_size, + padding, + padding_size, + dilation, + dilation_size); + }); +} + +extern "C" miopenStatus_t miopenFoldBackward(miopenHandle_t handle, + const miopenTensorDescriptor_t dinputDesc, + void* dinput, + const miopenTensorDescriptor_t doutputDesc, + const void* doutput, + const uint64_t* kernel_size, + const uint64_t kernel_size_size, + const uint64_t* stride, + const uint64_t stride_size, + const uint64_t* padding, + const uint64_t padding_size, + const uint64_t* dilation, + const uint64_t dilation_size) +{ + return miopen::try_([&] { + miopen::fold::FoldBackward(miopen::deref(handle), + miopen::deref(dinputDesc), + DataCast(dinput), + miopen::deref(doutputDesc), + DataCast(doutput), + kernel_size, + kernel_size_size, + stride, + stride_size, + padding, + padding_size, + dilation, + dilation_size); + }); +} diff --git a/src/include/miopen/fold.hpp b/src/include/miopen/fold.hpp new file mode 100644 index 0000000000..8c7da4b2f6 --- /dev/null +++ b/src/include/miopen/fold.hpp @@ -0,0 +1,94 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once +#include + +namespace miopen { + +struct Handle; +struct TensorDescriptor; + +namespace fold { + +MIOPEN_INTERNALS_EXPORT miopenStatus_t UnfoldForward(Handle& handle, + const TensorDescriptor& inputDesc, + ConstData_t input, + const TensorDescriptor& outputDesc, + Data_t output, + const uint64_t* kernel_size, + uint64_t kernel_size_size, + const uint64_t* stride, + uint64_t stride_size, + const uint64_t* padding, + uint64_t padding_size, + const uint64_t* dilation, + uint64_t dilation_size); + +MIOPEN_INTERNALS_EXPORT miopenStatus_t UnfoldBackward(Handle& handle, + const TensorDescriptor& dinputDesc, + Data_t dinput, + const TensorDescriptor& doutputDesc, + ConstData_t doutput, + const uint64_t* kernel_size, + uint64_t kernel_size_size, + const uint64_t* stride, + uint64_t stride_size, + const uint64_t* padding, + uint64_t padding_size, + const uint64_t* dilation, + uint64_t dilation_size); + +MIOPEN_INTERNALS_EXPORT miopenStatus_t FoldForward(Handle& handle, + const TensorDescriptor& inputDesc, + ConstData_t input, + const TensorDescriptor& outputDesc, + Data_t output, + const uint64_t* kernel_size, + uint64_t kernel_size_size, + const uint64_t* stride, + uint64_t stride_size, + const uint64_t* padding, + uint64_t padding_size, + const uint64_t* dilation, + uint64_t dilation_size); + +MIOPEN_INTERNALS_EXPORT miopenStatus_t FoldBackward(Handle& handle, + const TensorDescriptor& dinputDesc, + Data_t dinput, + const TensorDescriptor& doutputDesc, + ConstData_t doutput, + const uint64_t* kernel_size, + uint64_t kernel_size_size, + const uint64_t* stride, + uint64_t stride_size, + const uint64_t* padding, + uint64_t padding_size, + const uint64_t* dilation, + uint64_t dilation_size); + +} // namespace fold + +} // namespace miopen diff --git a/src/include/miopen/fold/invoke_params.hpp b/src/include/miopen/fold/invoke_params.hpp new file mode 100644 index 0000000000..60d84e96f4 --- /dev/null +++ b/src/include/miopen/fold/invoke_params.hpp @@ -0,0 +1,67 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include +#include +#include + +#include + +namespace miopen { + +namespace fold { + +struct InvokeParams : public miopen::InvokeParams +{ + InvokeParams() = default; + + const TensorDescriptor* inputDesc = nullptr; + const TensorDescriptor* outputDesc = nullptr; + ConstData_t input = nullptr; + Data_t output = nullptr; + + const TensorDescriptor* dinputDesc = nullptr; + const TensorDescriptor* doutputDesc = nullptr; + Data_t dinput = nullptr; + ConstData_t doutput = nullptr; + + const uint64_t* kernel_size = nullptr; + const uint64_t* stride = nullptr; + const uint64_t* padding = nullptr; + const uint64_t* dilation = nullptr; + uint64_t kernel_size_size = 0; + uint64_t stride_size = 0; + uint64_t padding_size = 0; + uint64_t dilation_size = 0; + + std::size_t GetWorkspaceSize() const { return 0; } + Data_t GetWorkspace() const { return nullptr; } +}; + +} // namespace fold + +} // namespace miopen diff --git a/src/include/miopen/fold/problem_description.hpp b/src/include/miopen/fold/problem_description.hpp new file mode 100644 index 0000000000..fb8baeae89 --- /dev/null +++ b/src/include/miopen/fold/problem_description.hpp @@ -0,0 +1,418 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include +#include +#include +#include +#include + +namespace miopen { + +struct NetworkConfig; + +namespace fold { + +bool checkSameLength(const TensorDescriptor& x, const TensorDescriptor& y); + +struct UnfoldFwdProblemDescription : ProblemDescriptionBase +{ + UnfoldFwdProblemDescription(const TensorDescriptor& inputDesc_, + const TensorDescriptor& outputDesc_, + const uint64_t* kernel_size_, + const uint64_t kernel_size_size_, + const uint64_t* stride_, + const uint64_t stride_size_, + const uint64_t* padding_, + const uint64_t padding_size_, + const uint64_t* dilation_, + const uint64_t dilation_size_) + : inputDesc(inputDesc_), + outputDesc(outputDesc_), + kernel_size(kernel_size_), + kernel_size_size(kernel_size_size_), + stride(stride_), + stride_size(stride_size_), + padding(padding_), + padding_size(padding_size_), + dilation(dilation_), + dilation_size(dilation_size_) + { + IsValidSize(); + IsValidType(); + } + + bool IsValidSize() const + { + if(inputDesc.GetNumDims() != 4) + { +#if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG + MIOPEN_THROW(miopenStatusBadParm, "Unfold: The input tensor should be 4D."); +#else + return false; +#endif + } + uint64_t spatial_dim_size = inputDesc.GetNumDims() - 2; + if(kernel_size_size != spatial_dim_size || stride_size != spatial_dim_size || + padding_size != spatial_dim_size || dilation_size != spatial_dim_size) + { + MIOPEN_THROW(miopenStatusBadParm, "Unfold: Argument length should be 2D"); + } + auto input_dims = inputDesc.GetLengths(); + const uint64_t N = static_cast(input_dims[0]); + const uint64_t C = static_cast(input_dims[1]); + uint64_t P = 1, L = 1; + std::vector ls; + for(uint64_t i = 0; i < spatial_dim_size; ++i) + { + P *= kernel_size[i]; + uint64_t l = (static_cast(input_dims[i + 2]) + 2 * padding[i] - + dilation[i] * (kernel_size[i] - 1) - 1) / + stride[i] + + 1; + L *= l; + ls.push_back(l); + } + std::vector output_dims_desired{ + static_cast(N), static_cast(C * P), static_cast(L)}; + auto output_dims = outputDesc.GetLengths(); + if(output_dims != output_dims_desired) + { + MIOPEN_THROW(miopenStatusBadParm, "Unfold: Invalid output dimension"); + } + return true; + } + + bool IsValidType() const + { + if(inputDesc.GetType() != outputDesc.GetType()) + { + MIOPEN_THROW(miopenStatusBadParm, + "Unfold: The input tensor and output tensor has mismatch type."); + } + return true; + } + + const TensorDescriptor& GetInputDesc() const { return inputDesc; } + const TensorDescriptor& GetOutputDesc() const { return outputDesc; } + + NetworkConfig MakeNetworkConfig() const override; + +public: + TensorDescriptor inputDesc; + TensorDescriptor outputDesc; + const uint64_t* kernel_size; + const uint64_t kernel_size_size; + const uint64_t* stride; + const uint64_t stride_size; + const uint64_t* padding; + const uint64_t padding_size; + const uint64_t* dilation; + const uint64_t dilation_size; +}; + +struct UnfoldBwdProblemDescription : ProblemDescriptionBase +{ + UnfoldBwdProblemDescription(const TensorDescriptor& dinputDesc_, + const TensorDescriptor& doutputDesc_, + const uint64_t* kernel_size_, + const uint64_t kernel_size_size_, + const uint64_t* stride_, + const uint64_t stride_size_, + const uint64_t* padding_, + const uint64_t padding_size_, + const uint64_t* dilation_, + const uint64_t dilation_size_) + : dinputDesc(dinputDesc_), + doutputDesc(doutputDesc_), + kernel_size(kernel_size_), + kernel_size_size(kernel_size_size_), + stride(stride_), + stride_size(stride_size_), + padding(padding_), + padding_size(padding_size_), + dilation(dilation_), + dilation_size(dilation_size_) + { + IsValidSize(); + IsValidType(); + } + + bool IsValidSize() const + { + if(dinputDesc.GetNumDims() != 4) + { + MIOPEN_THROW(miopenStatusBadParm, "Unfold: The input gradient tensor should be 4D."); + } + uint64_t spatial_dim_size = dinputDesc.GetNumDims() - 2; + if(kernel_size_size != spatial_dim_size || stride_size != spatial_dim_size || + padding_size != spatial_dim_size || dilation_size != spatial_dim_size) + { + MIOPEN_THROW(miopenStatusBadParm, "Unfold: Argument length should be 2D"); + } + auto input_dims = dinputDesc.GetLengths(); + const uint64_t N = static_cast(input_dims[0]); + const uint64_t C = static_cast(input_dims[1]); + uint64_t P = 1, L = 1; + std::vector ls; + for(uint64_t i = 0; i < spatial_dim_size; ++i) + { + P *= kernel_size[i]; + uint64_t l = (static_cast(input_dims[i + 2]) + 2 * padding[i] - + dilation[i] * (kernel_size[i] - 1) - 1) / + stride[i] + + 1; + L *= l; + ls.push_back(l); + } + std::vector output_dims_desired{ + static_cast(N), static_cast(C * P), static_cast(L)}; + auto output_dims = doutputDesc.GetLengths(); + if(output_dims != output_dims_desired) + { + MIOPEN_THROW(miopenStatusBadParm, "Unfold: Invalid output gradient dimension"); + } + return true; + } + + bool IsValidType() const + { + if(dinputDesc.GetType() != doutputDesc.GetType()) + { + MIOPEN_THROW( + miopenStatusBadParm, + "Unfold: The input gradient tensor and output gradient tensor has mismatch type."); + } + return true; + } + + const TensorDescriptor& GetDinputDesc() const { return dinputDesc; } + const TensorDescriptor& GetDoutputDesc() const { return doutputDesc; } + + NetworkConfig MakeNetworkConfig() const override; + +public: + TensorDescriptor dinputDesc; + TensorDescriptor doutputDesc; + const uint64_t* kernel_size; + const uint64_t kernel_size_size; + const uint64_t* stride; + const uint64_t stride_size; + const uint64_t* padding; + const uint64_t padding_size; + const uint64_t* dilation; + const uint64_t dilation_size; +}; + +struct FoldFwdProblemDescription : ProblemDescriptionBase +{ + FoldFwdProblemDescription(const TensorDescriptor& inputDesc_, + const TensorDescriptor& outputDesc_, + const uint64_t* kernel_size_, + const uint64_t kernel_size_size_, + const uint64_t* stride_, + const uint64_t stride_size_, + const uint64_t* padding_, + const uint64_t padding_size_, + const uint64_t* dilation_, + const uint64_t dilation_size_) + : inputDesc(inputDesc_), + outputDesc(outputDesc_), + kernel_size(kernel_size_), + kernel_size_size(kernel_size_size_), + stride(stride_), + stride_size(stride_size_), + padding(padding_), + padding_size(padding_size_), + dilation(dilation_), + dilation_size(dilation_size_) + { + IsValidSize(); + IsValidType(); + } + + bool IsValidSize() const + { + if(outputDesc.GetNumDims() != 4) + { + MIOPEN_THROW(miopenStatusBadParm, "Fold: The output tensor should be 4D."); + } + uint64_t spatial_dim_size = outputDesc.GetNumDims() - 2; + if(kernel_size_size != spatial_dim_size || stride_size != spatial_dim_size || + padding_size != spatial_dim_size || dilation_size != spatial_dim_size) + { + MIOPEN_THROW(miopenStatusBadParm, "Fold: Argument length should be 2D"); + } + auto input_dims = inputDesc.GetLengths(); + auto output_dims = outputDesc.GetLengths(); + const uint64_t N = static_cast(output_dims[0]); + const uint64_t C = static_cast(output_dims[1]); + uint64_t P = 1, L = 1; + std::vector ls; + for(uint64_t i = 0; i < spatial_dim_size; ++i) + { + P *= kernel_size[i]; + uint64_t l = (static_cast(output_dims[i + 2]) + 2 * padding[i] - + dilation[i] * (kernel_size[i] - 1) - 1) / + stride[i] + + 1; + L *= l; + ls.push_back(l); + } + std::vector input_dims_desired{ + static_cast(N), static_cast(C * P), static_cast(L)}; + if(input_dims != input_dims_desired) + { + MIOPEN_THROW(miopenStatusBadParm, "Fold: Invalid input dimension"); + } + return true; + } + + bool IsValidType() const + { + if(inputDesc.GetType() != outputDesc.GetType()) + { + MIOPEN_THROW(miopenStatusBadParm, + "Fold: The input tensor and output tensor has mismatch type."); + } + return true; + } + + const TensorDescriptor& GetInputDesc() const { return inputDesc; } + const TensorDescriptor& GetOutputDesc() const { return outputDesc; } + + NetworkConfig MakeNetworkConfig() const override; + +public: + TensorDescriptor inputDesc; + TensorDescriptor outputDesc; + const uint64_t* kernel_size; + const uint64_t kernel_size_size; + const uint64_t* stride; + const uint64_t stride_size; + const uint64_t* padding; + const uint64_t padding_size; + const uint64_t* dilation; + const uint64_t dilation_size; +}; + +struct FoldBwdProblemDescription : ProblemDescriptionBase +{ + FoldBwdProblemDescription(const TensorDescriptor& dinputDesc_, + const TensorDescriptor& doutputDesc_, + const uint64_t* kernel_size_, + const uint64_t kernel_size_size_, + const uint64_t* stride_, + const uint64_t stride_size_, + const uint64_t* padding_, + const uint64_t padding_size_, + const uint64_t* dilation_, + const uint64_t dilation_size_) + : dinputDesc(dinputDesc_), + doutputDesc(doutputDesc_), + kernel_size(kernel_size_), + kernel_size_size(kernel_size_size_), + stride(stride_), + stride_size(stride_size_), + padding(padding_), + padding_size(padding_size_), + dilation(dilation_), + dilation_size(dilation_size_) + { + IsValidSize(); + IsValidType(); + } + + bool IsValidSize() const + { + if(doutputDesc.GetNumDims() != 4) + { + MIOPEN_THROW(miopenStatusBadParm, "Fold: The output gradient tensor should be 4D."); + } + uint64_t spatial_dim_size = doutputDesc.GetNumDims() - 2; + if(kernel_size_size != spatial_dim_size || stride_size != spatial_dim_size || + padding_size != spatial_dim_size || dilation_size != spatial_dim_size) + { + MIOPEN_THROW(miopenStatusBadParm, "Fold: Argument length should be 2D"); + } + auto input_dims = dinputDesc.GetLengths(); + auto output_dims = doutputDesc.GetLengths(); + const uint64_t N = static_cast(output_dims[0]); + const uint64_t C = static_cast(output_dims[1]); + uint64_t P = 1, L = 1; + std::vector ls; + for(uint64_t i = 0; i < spatial_dim_size; ++i) + { + P *= kernel_size[i]; + uint64_t l = (static_cast(output_dims[i + 2]) + 2 * padding[i] - + dilation[i] * (kernel_size[i] - 1) - 1) / + stride[i] + + 1; + L *= l; + ls.push_back(l); + } + std::vector input_dims_desired{ + static_cast(N), static_cast(C * P), static_cast(L)}; + if(input_dims != input_dims_desired) + { + MIOPEN_THROW(miopenStatusBadParm, "Fold: Invalid input gradient dimension"); + } + return true; + } + + bool IsValidType() const + { + if(dinputDesc.GetType() != doutputDesc.GetType()) + { + MIOPEN_THROW( + miopenStatusBadParm, + "Fold: The input gradient tensor and output gradient tensor has mismatch type."); + } + return true; + } + + const TensorDescriptor& GetDinputDesc() const { return dinputDesc; } + const TensorDescriptor& GetDoutputDesc() const { return doutputDesc; } + + NetworkConfig MakeNetworkConfig() const override; + +public: + TensorDescriptor dinputDesc; + TensorDescriptor doutputDesc; + const uint64_t* kernel_size; + const uint64_t kernel_size_size; + const uint64_t* stride; + const uint64_t stride_size; + const uint64_t* padding; + const uint64_t padding_size; + const uint64_t* dilation; + const uint64_t dilation_size; +}; + +} // namespace fold + +} // namespace miopen diff --git a/src/include/miopen/fold/solvers.hpp b/src/include/miopen/fold/solvers.hpp new file mode 100644 index 0000000000..1ff3ef7566 --- /dev/null +++ b/src/include/miopen/fold/solvers.hpp @@ -0,0 +1,101 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace miopen { + +namespace solver { + +namespace fold { + +using UnfoldFwdSolverBase = + NonTunableSolverBase; + +struct UnfoldFwd final : UnfoldFwdSolverBase +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::fold::UnfoldFwdProblemDescription& problem) const override; + + ConvSolution + GetSolution(const ExecutionContext& context, + const miopen::fold::UnfoldFwdProblemDescription& problem) const override; +}; + +using UnfoldBwdSolverBase = + NonTunableSolverBase; + +struct UnfoldBwd final : UnfoldBwdSolverBase +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::fold::UnfoldBwdProblemDescription& problem) const override; + + ConvSolution + GetSolution(const ExecutionContext& context, + const miopen::fold::UnfoldBwdProblemDescription& problem) const override; +}; + +using FoldFwdSolverBase = + NonTunableSolverBase; + +struct FoldFwd final : FoldFwdSolverBase +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::fold::FoldFwdProblemDescription& problem) const override; + + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::fold::FoldFwdProblemDescription& problem) const override; +}; + +using FoldBwdSolverBase = + NonTunableSolverBase; + +struct FoldBwd final : FoldBwdSolverBase +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::fold::FoldBwdProblemDescription& problem) const override; + + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::fold::FoldBwdProblemDescription& problem) const override; +}; + +} // namespace fold + +} // namespace solver + +} // namespace miopen diff --git a/src/include/miopen/solver_id.hpp b/src/include/miopen/solver_id.hpp index f79a5f5a54..fb8f8fa25c 100644 --- a/src/include/miopen/solver_id.hpp +++ b/src/include/miopen/solver_id.hpp @@ -63,7 +63,9 @@ enum class Primitive ReLU, Kthvalue, SoftMarginLoss, - MultiMarginLoss + MultiMarginLoss, + Fold, + Unfold, }; struct MIOPEN_INTERNALS_EXPORT Id diff --git a/src/kernels/MIOpenUnfold.cpp b/src/kernels/MIOpenUnfold.cpp new file mode 100644 index 0000000000..bb962e10a1 --- /dev/null +++ b/src/kernels/MIOpenUnfold.cpp @@ -0,0 +1,231 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include +#include +#endif + +#include "float_types.h" +#include "tensor_view.hpp" + +template +__device__ void unfoldForward4D(const DTYPE* __restrict__ input, + DTYPE* __restrict__ output, + uint64_t N, + uint64_t C, + uint64_t H, + uint64_t W, + uint64_t P, + uint64_t L, + uint64_t LW, + uint64_t kernel_size_w, + uint64_t stride_h, + uint64_t stride_w, + uint64_t padding_h, + uint64_t padding_w, + uint64_t dilation_h, + uint64_t dilation_w, + tensor_view_t<4> input_tv, + tensor_view_t<3> output_tv) +{ + /* + * input = {N, C, H, W}, output = {N, C * P, L} + * where P = kernel_size_h * kernel_size_w, L = # of blocks (see host code for + * formula) + * => gws = {ceil(N * C * P * L, LOCAL_SIZE)}, lws = {LOCAL_SIZE} + */ + + const uint64_t gid = threadIdx.x + blockIdx.x * blockDim.x; + uint64_t ncp = gid / L, l = gid % L; + uint64_t nc = ncp / P, p = ncp % P; + uint64_t n = nc / C, c = nc % C; + if(n >= N) + return; + + uint64_t lh = l / LW, lw = l % LW; // sliding window position + uint64_t ph = p / kernel_size_w, pw = p % kernel_size_w; // position inside kernel + + DTYPE x = 0; + if(lh * stride_h >= padding_h + ph * dilation_h && lw * stride_w >= padding_w + pw * dilation_w) + { + uint64_t h = lh * stride_h - padding_h + ph * dilation_h; + uint64_t w = lw * stride_w - padding_w + pw * dilation_w; + if(h < H && w < W) + { + tensor_layout_t<4> input_layout({n, c, h, w}); + x = input[input_tv.get_tensor_view_idx(input_layout)]; + } + } + tensor_layout_t<3> output_layout({n, c * P + p, l}); + output[output_tv.get_tensor_view_idx(output_layout)] = x; +} + +extern "C" __global__ void UnfoldForward4D(const FLOAT* __restrict__ input, + FLOAT* __restrict__ output, + uint64_t N, + uint64_t C, + uint64_t H, + uint64_t W, + uint64_t P, + uint64_t L, + uint64_t LW, + uint64_t kernel_size_w, + uint64_t stride_h, + uint64_t stride_w, + uint64_t padding_h, + uint64_t padding_w, + uint64_t dilation_h, + uint64_t dilation_w, + tensor_view_t<4> input_tv, + tensor_view_t<3> output_tv) +{ + unfoldForward4D(input, + output, + N, + C, + H, + W, + P, + L, + LW, + kernel_size_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + input_tv, + output_tv); +} + +template +__device__ void unfoldBackward4D(const DTYPE* __restrict__ output_grad, + DTYPE* __restrict__ input_grad, + uint64_t N, + uint64_t C, + uint64_t H, + uint64_t W, + uint64_t P, + uint64_t LH, + uint64_t LW, + uint64_t kernel_size_h, + uint64_t kernel_size_w, + uint64_t stride_h, + uint64_t stride_w, + uint64_t padding_h, + uint64_t padding_w, + uint64_t dilation_h, + uint64_t dilation_w, + tensor_view_t<3> output_grad_tv, + tensor_view_t<4> input_grad_tv) +{ + /* + * output_grad = {N, C * P, L}, input_grad = {N, C, H, W} + * where P = kernel_size_h * kernel_size_w, L = # of blocks (see host code for + * formula) + * => gws = {ceil(N * C * H * W, LOCAL_SIZE)}, lws = {LOCAL_SIZE} + */ + + const uint64_t gid = threadIdx.x + blockIdx.x * blockDim.x; + uint64_t nch = gid / W, w = gid % W; + uint64_t nc = nch / H, h = nch % H; + uint64_t n = nc / C, c = nc % C; + if(n >= N) + return; + + FLOAT_ACCUM sum = 0.0f; + for(uint64_t ph = 0; ph < kernel_size_h; ++ph) + { + for(uint64_t pw = 0; pw < kernel_size_w; ++pw) + { + if(h < ph * dilation_h + padding_h) + continue; + if(w < pw * dilation_w + padding_w) + continue; + uint64_t lhsh = h - ph * dilation_h + padding_h; + uint64_t lwsw = w - pw * dilation_w + padding_w; + if(lhsh % stride_h != 0) + continue; + if(lwsw % stride_w != 0) + continue; + uint64_t lh = lhsh / stride_h; + uint64_t lw = lwsw / stride_w; + if(LH <= lh) + continue; + if(LW <= lw) + continue; + tensor_layout_t<3> output_grad_layout( + {n, c * P + (ph * kernel_size_w + pw), lh * LW + lw}); + sum += CVT_FLOAT2ACCUM( + output_grad[output_grad_tv.get_tensor_view_idx(output_grad_layout)]); + } + } + tensor_layout_t<4> input_grad_layout({n, c, h, w}); + input_grad[input_grad_tv.get_tensor_view_idx(input_grad_layout)] = CVT_ACCUM2FLOAT(sum); +} + +extern "C" __global__ void UnfoldBackward4D(const FLOAT* __restrict__ output_grad, + FLOAT* __restrict__ input_grad, + uint64_t N, + uint64_t C, + uint64_t H, + uint64_t W, + uint64_t P, + uint64_t LH, + uint64_t LW, + uint64_t kernel_size_h, + uint64_t kernel_size_w, + uint64_t stride_h, + uint64_t stride_w, + uint64_t padding_h, + uint64_t padding_w, + uint64_t dilation_h, + uint64_t dilation_w, + tensor_view_t<3> output_grad_tv, + tensor_view_t<4> input_grad_tv) +{ + unfoldBackward4D(output_grad, + input_grad, + N, + C, + H, + W, + P, + LH, + LW, + kernel_size_h, + kernel_size_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + output_grad_tv, + input_grad_tv); +} diff --git a/src/solver.cpp b/src/solver.cpp index 1f6873d5f7..17bcbe83be 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -676,7 +677,6 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) Register(registry, ++id, Primitive::Cat, cat::CatForward{}.SolverDbId()); Register(registry, ++id, Primitive::Adam, adam::Adam{}.SolverDbId()); - Register(registry, ++id, Primitive::Item, getitem::GetitemBackward{}.SolverDbId()); Register(registry, ++id, Primitive::Adam, adam::TransformersAdamW{}.SolverDbId()); @@ -709,6 +709,10 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) multimarginloss::MultiMarginLossForward{}.SolverDbId()); Register(registry, ++id, Primitive::Mha, mha::MhaCKFlashAttentionV2Forward{}.SolverDbId()); + Register(registry, ++id, Primitive::Unfold, fold::UnfoldFwd{}.SolverDbId()); + Register(registry, ++id, Primitive::Unfold, fold::UnfoldBwd{}.SolverDbId()); + Register(registry, ++id, Primitive::Fold, fold::FoldFwd{}.SolverDbId()); + Register(registry, ++id, Primitive::Fold, fold::FoldBwd{}.SolverDbId()); // IMPORTANT: New solvers should be added to the end of the function, and don't leave a white // space between this comment and the newly registered solver(s)! } diff --git a/src/solver/fold/fold_backward.cpp b/src/solver/fold/fold_backward.cpp new file mode 100644 index 0000000000..ef7a1e184c --- /dev/null +++ b/src/solver/fold/fold_backward.cpp @@ -0,0 +1,157 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define LOCAL_SIZE 256 + +namespace miopen { + +namespace solver { + +namespace fold { + +bool FoldBwd::IsApplicable( + [[maybe_unused]] const ExecutionContext&, + [[maybe_unused]] const miopen::fold::FoldBwdProblemDescription& problem) const +{ + return true; +} + +ConvSolution FoldBwd::GetSolution([[maybe_unused]] const ExecutionContext& context, + const miopen::fold::FoldBwdProblemDescription& problem) const +{ + std::ignore = context; + auto result = ConvSolution{miopenStatusSuccess}; + + auto in_dtype = miopen::GetDataType(problem.GetDinputDesc().GetType()); + auto dtype = problem.GetDoutputDesc().GetType(); + auto input_grad_dims = problem.GetDinputDesc().GetLengths(); + auto output_grad_dims = problem.GetDoutputDesc().GetLengths(); + + const uint64_t N = static_cast(output_grad_dims[0]); + const uint64_t C = static_cast(output_grad_dims[1]); + uint64_t spatial_dim_size = output_grad_dims.size() - 2; + uint64_t P = 1, L = 1; + std::vector ls; + for(int i = 0; i < spatial_dim_size; ++i) + { + P *= problem.kernel_size[i]; + uint64_t l = (static_cast(output_grad_dims[i + 2]) + 2 * problem.padding[i] - + problem.dilation[i] * (problem.kernel_size[i] - 1) - 1) / + problem.stride[i] + + 1; + L *= l; + ls.push_back(l); + } + + auto kernel = KernelInfo{}; + kernel.kernel_file = "MIOpenUnfold.cpp"; + kernel.kernel_name = "UnfoldForward4D"; + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + }; + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + size_t xlocalsize = LOCAL_SIZE; + size_t xgridsize = AlignUp(static_cast(N * C * P * L), LOCAL_SIZE); + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + + result.invoker_factory = [N, C, P, L, ls](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + + auto input_grad_tv = get_inner_expanded_tv<3>(deref(params.dinputDesc)); + auto output_grad_tv = get_inner_expanded_tv<4>(deref(params.doutputDesc)); + auto input_grad_dims = deref(params.dinputDesc).GetLengths(); + auto output_grad_dims = deref(params.doutputDesc).GetLengths(); + + uint64_t kernel_size_w = params.kernel_size[1]; + uint64_t stride_h = params.stride[0]; + uint64_t stride_w = params.stride[1]; + uint64_t padding_h = params.padding[0]; + uint64_t padding_w = params.padding[1]; + uint64_t dilation_h = params.dilation[0]; + uint64_t dilation_w = params.dilation[1]; + uint64_t LW = ls[1]; + uint64_t H = static_cast(output_grad_dims[2]); + uint64_t W = static_cast(output_grad_dims[3]); + + kernel(params.doutput, + params.dinput, + N, + C, + H, + W, + P, + L, + LW, + kernel_size_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + output_grad_tv, + input_grad_tv); + }; + }; + + return result; +} + +} // namespace fold + +} // namespace solver + +} // namespace miopen diff --git a/src/solver/fold/fold_forward.cpp b/src/solver/fold/fold_forward.cpp new file mode 100644 index 0000000000..3f96eff0f0 --- /dev/null +++ b/src/solver/fold/fold_forward.cpp @@ -0,0 +1,165 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define LOCAL_SIZE 256 + +namespace miopen { + +namespace solver { + +namespace fold { + +bool FoldFwd::IsApplicable( + [[maybe_unused]] const ExecutionContext&, + [[maybe_unused]] const miopen::fold::FoldFwdProblemDescription& problem) const +{ + return true; +} + +ConvSolution FoldFwd::GetSolution([[maybe_unused]] const ExecutionContext& context, + const miopen::fold::FoldFwdProblemDescription& problem) const +{ + std::ignore = context; + auto result = ConvSolution{miopenStatusSuccess}; + + auto in_dtype = miopen::GetDataType(problem.GetInputDesc().GetType()); + auto dtype = problem.GetOutputDesc().GetType(); + auto input_dims = problem.GetInputDesc().GetLengths(); + + auto output_dims = problem.GetOutputDesc().GetLengths(); + const uint64_t N = static_cast(output_dims[0]); + const uint64_t C = static_cast(output_dims[1]); + uint64_t H = static_cast(output_dims[2]); + uint64_t W = static_cast(output_dims[3]); + + { + auto kernel = KernelInfo{}; + kernel.kernel_file = "MIOpenUnfold.cpp"; + kernel.kernel_name = "UnfoldBackward4D"; + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + }; + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + size_t xlocalsize = LOCAL_SIZE; + size_t xgridsize = AlignUp(static_cast(N * C * H * W), LOCAL_SIZE); + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + } + + result.invoker_factory = [N, C](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + + auto input_tv = get_inner_expanded_tv<3>(deref(params.inputDesc)); + auto output_tv = get_inner_expanded_tv<4>(deref(params.outputDesc)); + auto input_dims = deref(params.inputDesc).GetLengths(); + auto output_dims = deref(params.outputDesc).GetLengths(); + + uint64_t spatial_dim_size = output_dims.size() - 2; + uint64_t P = 1, L = 1; + std::vector ls; + for(int i = 0; i < spatial_dim_size; ++i) + { + P *= params.kernel_size[i]; + uint64_t l = (static_cast(output_dims[i + 2]) + 2 * params.padding[i] - + params.dilation[i] * (params.kernel_size[i] - 1) - 1) / + params.stride[i] + + 1; + L *= l; + ls.push_back(l); + } + + uint64_t kernel_size_h = params.kernel_size[0]; + uint64_t kernel_size_w = params.kernel_size[1]; + uint64_t stride_h = params.stride[0]; + uint64_t stride_w = params.stride[1]; + uint64_t padding_h = params.padding[0]; + uint64_t padding_w = params.padding[1]; + uint64_t dilation_h = params.dilation[0]; + uint64_t dilation_w = params.dilation[1]; + uint64_t LH = ls[0]; + uint64_t LW = ls[1]; + uint64_t H = static_cast(output_dims[2]); + uint64_t W = static_cast(output_dims[3]); + + kernel(params.input, + params.output, + N, + C, + H, + W, + P, + LH, + LW, + kernel_size_h, + kernel_size_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + input_tv, + output_tv); + }; + }; + + return result; +} + +} // namespace fold + +} // namespace solver + +} // namespace miopen diff --git a/src/solver/fold/unfold_backward.cpp b/src/solver/fold/unfold_backward.cpp new file mode 100644 index 0000000000..8051066dda --- /dev/null +++ b/src/solver/fold/unfold_backward.cpp @@ -0,0 +1,164 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define LOCAL_SIZE 256 + +namespace miopen { + +namespace solver { + +namespace fold { + +bool UnfoldBwd::IsApplicable( + [[maybe_unused]] const ExecutionContext&, + [[maybe_unused]] const miopen::fold::UnfoldBwdProblemDescription& problem) const +{ + return true; +} + +ConvSolution UnfoldBwd::GetSolution([[maybe_unused]] const ExecutionContext& context, + const miopen::fold::UnfoldBwdProblemDescription& problem) const +{ + std::ignore = context; + auto result = ConvSolution{miopenStatusSuccess}; + + auto in_dtype = miopen::GetDataType(problem.GetDinputDesc().GetType()); + auto dtype = problem.GetDoutputDesc().GetType(); + auto input_grad_dims = problem.GetDinputDesc().GetLengths(); + auto output_grad_dims = problem.GetDoutputDesc().GetLengths(); + + const uint64_t N = static_cast(input_grad_dims[0]); + const uint64_t C = static_cast(input_grad_dims[1]); + uint64_t H = static_cast(input_grad_dims[2]); + uint64_t W = static_cast(input_grad_dims[3]); + + { + auto kernel = KernelInfo{}; + kernel.kernel_file = "MIOpenUnfold.cpp"; + kernel.kernel_name = "UnfoldBackward4D"; + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + }; + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + size_t xlocalsize = LOCAL_SIZE; + size_t xgridsize = AlignUp(static_cast(N * C * H * W), LOCAL_SIZE); + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + } + + result.invoker_factory = [N, C, H, W](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + + auto input_grad_tv = get_inner_expanded_tv<4>(deref(params.dinputDesc)); + auto output_grad_tv = get_inner_expanded_tv<3>(deref(params.doutputDesc)); + auto input_grad_dims = deref(params.dinputDesc).GetLengths(); + auto output_grad_dims = deref(params.doutputDesc).GetLengths(); + + int spatial_dim_size = input_grad_dims.size() - 2; + uint64_t P = 1, L = 1; + std::vector ls; + for(int i = 0; i < spatial_dim_size; ++i) + { + P *= params.kernel_size[i]; + uint64_t l = + (static_cast(input_grad_dims[i + 2]) + 2 * params.padding[i] - + params.dilation[i] * (params.kernel_size[i] - 1) - 1) / + params.stride[i] + + 1; + L *= l; + ls.push_back(l); + } + + uint64_t kernel_size_h = params.kernel_size[0]; + uint64_t kernel_size_w = params.kernel_size[1]; + uint64_t stride_h = params.stride[0]; + uint64_t stride_w = params.stride[1]; + uint64_t padding_h = params.padding[0]; + uint64_t padding_w = params.padding[1]; + uint64_t dilation_h = params.dilation[0]; + uint64_t dilation_w = params.dilation[1]; + uint64_t LH = ls[0]; + uint64_t LW = ls[1]; + + kernel(params.doutput, + params.dinput, + N, + C, + H, + W, + P, + LH, + LW, + kernel_size_h, + kernel_size_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + output_grad_tv, + input_grad_tv); + }; + }; + + return result; +} + +} // namespace fold + +} // namespace solver + +} // namespace miopen diff --git a/src/solver/fold/unfold_forward.cpp b/src/solver/fold/unfold_forward.cpp new file mode 100644 index 0000000000..27cda27136 --- /dev/null +++ b/src/solver/fold/unfold_forward.cpp @@ -0,0 +1,157 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define LOCAL_SIZE 256 + +namespace miopen { + +namespace solver { + +namespace fold { + +bool UnfoldFwd::IsApplicable( + [[maybe_unused]] const ExecutionContext&, + [[maybe_unused]] const miopen::fold::UnfoldFwdProblemDescription& problem) const +{ + return true; +} + +ConvSolution UnfoldFwd::GetSolution([[maybe_unused]] const ExecutionContext& context, + const miopen::fold::UnfoldFwdProblemDescription& problem) const +{ + std::ignore = context; + auto result = ConvSolution{miopenStatusSuccess}; + + auto in_dtype = miopen::GetDataType(problem.GetInputDesc().GetType()); + auto dtype = problem.GetOutputDesc().GetType(); + auto input_dims = problem.GetInputDesc().GetLengths(); + auto output_dims = problem.GetOutputDesc().GetLengths(); + + const uint64_t N = static_cast(input_dims[0]); + const uint64_t C = static_cast(input_dims[1]); + int spatial_dim_size = input_dims.size() - 2; + uint64_t P = 1, L = 1; + std::vector ls; + for(int i = 0; i < spatial_dim_size; ++i) + { + P *= problem.kernel_size[i]; + uint64_t l = (static_cast(input_dims[i + 2]) + 2 * problem.padding[i] - + problem.dilation[i] * (problem.kernel_size[i] - 1) - 1) / + problem.stride[i] + + 1; + L *= l; + ls.push_back(l); + } + + auto kernel = KernelInfo{}; + kernel.kernel_file = "MIOpenUnfold.cpp"; + kernel.kernel_name = "UnfoldForward4D"; + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + }; + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + size_t xlocalsize = LOCAL_SIZE; + size_t xgridsize = AlignUp(static_cast(N * C * P * L), LOCAL_SIZE); + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + + result.invoker_factory = [N, C, P, L, ls](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + + auto input_tv = get_inner_expanded_tv<4>(deref(params.inputDesc)); + auto output_tv = get_inner_expanded_tv<3>(deref(params.outputDesc)); + auto input_dims = deref(params.inputDesc).GetLengths(); + auto output_dims = deref(params.outputDesc).GetLengths(); + + uint64_t kernel_size_w = params.kernel_size[1]; + uint64_t stride_h = params.stride[0]; + uint64_t stride_w = params.stride[1]; + uint64_t padding_h = params.padding[0]; + uint64_t padding_w = params.padding[1]; + uint64_t dilation_h = params.dilation[0]; + uint64_t dilation_w = params.dilation[1]; + uint64_t LW = ls[1]; + uint64_t H = static_cast(input_dims[2]); + uint64_t W = static_cast(input_dims[3]); + + kernel(params.input, + params.output, + N, + C, + H, + W, + P, + L, + LW, + kernel_size_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + input_tv, + output_tv); + }; + }; + + return result; +} + +} // namespace fold + +} // namespace solver + +} // namespace miopen diff --git a/test/cpu_unfold.hpp b/test/cpu_unfold.hpp new file mode 100644 index 0000000000..0348a3be4e --- /dev/null +++ b/test/cpu_unfold.hpp @@ -0,0 +1,196 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include +#include "tensor_holder.hpp" +#include + +template +void cpu_unfold_fwd_4d(tensor input_tensor, + tensor& ref_output_tensor, + const std::vector kernel_size, + const std::vector stride, + const std::vector padding, + const std::vector dilation) +{ + auto input_tv = miopen::get_inner_expanded_tv<4>(input_tensor.desc); + auto output_tv = miopen::get_inner_expanded_tv<3>(ref_output_tensor.desc); + auto input_size = input_tensor.desc.GetNumDims(); + auto input_dims = input_tensor.desc.GetLengths(); + + auto input = input_tensor.data.data(); + auto output = ref_output_tensor.data.data(); + + const uint64_t LOCAL_SIZE = 256; + uint64_t spatial_dim_size = input_size - 2; + + const uint64_t N = static_cast(input_dims[0]); + const uint64_t C = static_cast(input_dims[1]); + + uint64_t P = 1, L = 1; + std::vector ls; + for(uint64_t i = 0; i < spatial_dim_size; ++i) + { + P *= kernel_size[i]; + uint64_t l = (static_cast(input_dims[i + 2]) + 2 * padding[i] - + dilation[i] * (kernel_size[i] - 1) - 1) / + stride[i] + + 1; + L *= l; + ls.push_back(l); + } + + uint64_t kernel_size_w = kernel_size[1]; + uint64_t stride_h = stride[0]; + uint64_t stride_w = stride[1]; + uint64_t padding_h = padding[0]; + uint64_t padding_w = padding[1]; + uint64_t dilation_h = dilation[0]; + uint64_t dilation_w = dilation[1]; + uint64_t LW = ls[1]; + uint64_t H = static_cast(input_dims[2]); + uint64_t W = static_cast(input_dims[3]); + uint64_t work_size = (((N * C * P * L) + LOCAL_SIZE - 1) / LOCAL_SIZE) * LOCAL_SIZE; + par_ford(work_size)([&](uint64_t gid) { + uint64_t ncp = gid / L, l = gid % L; + uint64_t nc = ncp / P, p = ncp % P; + uint64_t n = nc / C, c = nc % C; + if(n >= N) + return; + + uint64_t lh = l / LW, lw = l % LW; // sliding window position + uint64_t ph = p / kernel_size_w, pw = p % kernel_size_w; // position inside kernel + + T x = static_cast(0.0f); + if(lh * stride_h >= padding_h + ph * dilation_h && + lw * stride_w >= padding_w + pw * dilation_w) + { + uint64_t h = lh * stride_h - padding_h + ph * dilation_h; + uint64_t w = lw * stride_w - padding_w + pw * dilation_w; + if(h < H && w < W) + { + long input_idx = input_tv.stride[3] * w + input_tv.stride[2] * h + + input_tv.stride[1] * c + input_tv.stride[0] * n; + x = input[input_idx]; + } + } + + long output_idx = + output_tv.stride[2] * l + output_tv.stride[1] * (c * P + p) + output_tv.stride[0] * n; + output[output_idx] = x; + }); +} + +template +void cpu_unfold_bwd_4d(tensor& ref_dinput_tensor, + tensor doutput_tensor, + const std::vector kernel_size, + const std::vector stride, + const std::vector padding, + const std::vector dilation) +{ + auto input_grad_tv = miopen::get_inner_expanded_tv<4>(ref_dinput_tensor.desc); + auto output_grad_tv = miopen::get_inner_expanded_tv<3>(doutput_tensor.desc); + auto input_size = ref_dinput_tensor.desc.GetNumDims(); + auto input_grad_dims = ref_dinput_tensor.desc.GetLengths(); + + auto input_grad = ref_dinput_tensor.data.data(); + auto output_grad = doutput_tensor.data.data(); + + const uint64_t LOCAL_SIZE = 256; + uint64_t spatial_dim_size = input_size - 2; + + const uint64_t N = static_cast(input_grad_dims[0]); + const uint64_t C = static_cast(input_grad_dims[1]); + + uint64_t P = 1; + std::vector ls; + for(uint64_t i = 0; i < spatial_dim_size; ++i) + { + P *= kernel_size[i]; + uint64_t l = (static_cast(input_grad_dims[i + 2]) + 2 * padding[i] - + dilation[i] * (kernel_size[i] - 1) - 1) / + stride[i] + + 1; + ls.push_back(l); + } + + uint64_t kernel_size_h = kernel_size[0]; + uint64_t kernel_size_w = kernel_size[1]; + uint64_t stride_h = stride[0]; + uint64_t stride_w = stride[1]; + uint64_t padding_h = padding[0]; + uint64_t padding_w = padding[1]; + uint64_t dilation_h = dilation[0]; + uint64_t dilation_w = dilation[1]; + uint64_t LH = ls[0]; + uint64_t LW = ls[1]; + uint64_t H = static_cast(input_grad_dims[2]); + uint64_t W = static_cast(input_grad_dims[3]); + uint64_t work_size = (((N * C * H * W) + LOCAL_SIZE - 1) / LOCAL_SIZE) * LOCAL_SIZE; + par_ford(work_size)([&](uint64_t gid) { + uint64_t nch = gid / W, w = gid % W; + uint64_t nc = nch / H, h = nch % H; + uint64_t n = nc / C, c = nc % C; + if(n >= N) + return; + + float sum = 0.0f; + + for(uint64_t ph = 0; ph < kernel_size_h; ++ph) + { + for(uint64_t pw = 0; pw < kernel_size_w; ++pw) + { + if(h < ph * dilation_h + padding_h) + continue; + if(w < pw * dilation_w + padding_w) + continue; + uint64_t lhsh = h - ph * dilation_h + padding_h; + uint64_t lwsw = w - pw * dilation_w + padding_w; + if(lhsh % stride_h != 0) + continue; + if(lwsw % stride_w != 0) + continue; + uint64_t lh = lhsh / stride_h; + uint64_t lw = lwsw / stride_w; + if(LH <= lh) + continue; + if(LW <= lw) + continue; + long output_grad_idx = + output_grad_tv.stride[2] * (lh * LW + lw) + + output_grad_tv.stride[1] * (c * P + (ph * kernel_size_w + pw)) + + output_grad_tv.stride[0] * n; + sum += static_cast(output_grad[output_grad_idx]); + } + } + + long input_grad_idx = input_grad_tv.stride[3] * w + input_grad_tv.stride[2] * h + + input_grad_tv.stride[1] * c + input_grad_tv.stride[0] * n; + input_grad[input_grad_idx] = static_cast(sum); + }); +} diff --git a/test/gtest/fold.cpp b/test/gtest/fold.cpp new file mode 100644 index 0000000000..8478ec84ae --- /dev/null +++ b/test/gtest/fold.cpp @@ -0,0 +1,104 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "fold.hpp" + +namespace fold { + +struct GPU_Fold_fwd_FP32 : FoldFwdTest +{ +}; + +struct GPU_Fold_fwd_FP16 : FoldFwdTest +{ +}; + +struct GPU_Fold_fwd_BFP16 : FoldFwdTest +{ +}; + +struct GPU_Fold_bwd_FP32 : FoldBwdTest +{ +}; + +struct GPU_Fold_bwd_FP16 : FoldBwdTest +{ +}; + +struct GPU_Fold_bwd_BFP16 : FoldBwdTest +{ +}; +}; // namespace fold + +using namespace fold; + +TEST_P(GPU_Fold_fwd_FP32, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, GPU_Fold_fwd_FP32, testing::ValuesIn(FoldTestConfigs())); + +TEST_P(GPU_Fold_fwd_FP16, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, GPU_Fold_fwd_FP16, testing::ValuesIn(FoldTestConfigs())); + +TEST_P(GPU_Fold_fwd_BFP16, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, GPU_Fold_fwd_BFP16, testing::ValuesIn(FoldTestConfigs())); + +TEST_P(GPU_Fold_bwd_FP32, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, GPU_Fold_bwd_FP32, testing::ValuesIn(FoldTestConfigs())); + +TEST_P(GPU_Fold_bwd_FP16, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, GPU_Fold_bwd_FP16, testing::ValuesIn(FoldTestConfigs())); + +TEST_P(GPU_Fold_bwd_BFP16, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, GPU_Fold_bwd_BFP16, testing::ValuesIn(FoldTestConfigs())); diff --git a/test/gtest/fold.hpp b/test/gtest/fold.hpp new file mode 100644 index 0000000000..b6d9e71156 --- /dev/null +++ b/test/gtest/fold.hpp @@ -0,0 +1,288 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTN OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTN WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "cpu_unfold.hpp" +#include "get_handle.hpp" +#include +#include "random.hpp" +#include "tensor_holder.hpp" +#include "verify.hpp" +#include +#include +#include + +struct FoldTestCase +{ + uint64_t N; + uint64_t C; + uint64_t D; + uint64_t H; + uint64_t W; + std::vector outputSize; + std::vector kernelSize; + std::vector stride; + std::vector padding; + std::vector dilation; + bool isContiguous = true; + friend std::ostream& operator<<(std::ostream& os, const FoldTestCase& tc) + { + os << "N:" << tc.N << " C:" << tc.C << " D:" << tc.D << " H:" << tc.H << " W:" << tc.W; + os << " output_size:"; + for(const auto& outs : tc.outputSize) + os << outs << " "; + os << " kernel_size:"; + for(const auto& ks : tc.kernelSize) + os << ks << " "; + os << "stride:"; + for(const auto& s : tc.stride) + os << s << " "; + os << "padding:"; + for(const auto& p : tc.padding) + os << p << " "; + os << "dilation:"; + for(const auto& d : tc.dilation) + os << d << " "; + os << "isContiguous:" << std::boolalpha << tc.isContiguous; + return os; + } + + std::vector GetInput() + { + if((N != 0) && (C != 0) && (D != 0) && (H != 0) && (W != 0)) + { + return std::vector({N, C, D, H, W}); + } + else if((N != 0) && (C != 0) && (H != 0) && (W != 0)) + { + return std::vector({N, C, H, W}); + } + else if((N != 0) && (C != 0) && (W != 0)) + { + return std::vector({N, C, W}); + } + else if((N != 0) && (W != 0)) + { + return std::vector({N, W}); + } + else if((N != 0)) + { + return std::vector({N}); + } + else + { + std::cout << "Error Input Tensor Lengths\n" << std::endl; + return std::vector({0}); + } + } + + std::vector ComputeStrides(std::vector inputDim) const + { + if(!isContiguous) + std::swap(inputDim.front(), inputDim.back()); + std::vector strides(inputDim.size()); + strides.back() = 1; + for(int i = inputDim.size() - 2; i >= 0; --i) + strides[i] = strides[i + 1] * inputDim[i + 1]; + if(!isContiguous) + std::swap(strides.front(), strides.back()); + return strides; + } +}; + +inline std::vector FoldTestConfigs() +{ + // clang-format: off + return { + {3, 3 * 2 * 2, 0, 0, 3 * 4, {4, 5}, {2, 2}, {1, 1}, {0, 0}, {1, 1}, true}, + {3, 3 * 2 * 2, 0, 0, 3 * 4, {6, 11}, {2, 2}, {2, 3}, {0, 0}, {1, 1}, true}, + {3, 3 * 2 * 2, 0, 0, 3 * 4, {7, 12}, {2, 2}, {2, 3}, {0, 0}, {1, 1}, true}, + {3, 3 * 2 * 2, 0, 0, 3 * 4, {7, 13}, {2, 2}, {2, 3}, {0, 0}, {1, 1}, true}, + {3, 3 * 3 * 4, 0, 0, 3 * 4, {5, 7}, {3, 4}, {1, 1}, {0, 0}, {1, 1}, true}, + {3, 3 * 2 * 2, 0, 0, 3 * 4, {2, 3}, {2, 2}, {1, 1}, {1, 1}, {1, 1}, true}, + {3, 3 * 2 * 2, 0, 0, 3 * 4, {5, 7}, {2, 2}, {1, 1}, {0, 0}, {2, 3}, true}, + {3, 3 * 2 * 2, 0, 0, 3 * 4, {4, 5}, {2, 2}, {1, 1}, {0, 0}, {1, 1}, false}, + {3, 3 * 2 * 2, 0, 0, 3 * 4, {6, 11}, {2, 2}, {2, 3}, {0, 0}, {1, 1}, false}, + {3, 3 * 2 * 2, 0, 0, 3 * 4, {7, 12}, {2, 2}, {2, 3}, {0, 0}, {1, 1}, false}, + {3, 3 * 2 * 2, 0, 0, 3 * 4, {7, 13}, {2, 2}, {2, 3}, {0, 0}, {1, 1}, false}, + {3, 3 * 3 * 4, 0, 0, 3 * 4, {5, 7}, {3, 4}, {1, 1}, {0, 0}, {1, 1}, false}, + {3, 3 * 2 * 2, 0, 0, 3 * 4, {2, 3}, {2, 2}, {1, 1}, {1, 1}, {1, 1}, false}, + {3, 3 * 2 * 2, 0, 0, 3 * 4, {5, 7}, {2, 2}, {1, 1}, {0, 0}, {2, 3}, false}, + }; + // clang-format: on +} + +template +struct FoldFwdTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + config = GetParam(); + + std::vector in_dims = config.GetInput(); + std::vector in_strides = config.ComputeStrides(in_dims); + + auto gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(1e-2, 100); }; + auto gen_zero = [&](auto...) { return 0; }; + input = tensor{in_dims, in_strides}.generate(gen_value); + const uint64_t N = static_cast(in_dims[0]); + uint64_t C = static_cast(in_dims[1]); + for(uint64_t i : config.kernelSize) + { + C = C / i; + } + + std::vector out_dims{N, C, config.outputSize[0], config.outputSize[1]}; + + output = tensor{out_dims}.generate(gen_zero); + outputHost = tensor{out_dims}.generate(gen_zero); + + input_dev = handle.Write(input.data); + output_dev = handle.Write(output.data); + } + + void RunTest() + { + auto&& handle = get_handle(); + miopenStatus_t status; + + status = miopen::fold::FoldForward(handle, + input.desc, + input_dev.get(), + output.desc, + output_dev.get(), + config.kernelSize.data(), + static_cast(config.kernelSize.size()), + config.stride.data(), + static_cast(config.stride.size()), + config.padding.data(), + static_cast(config.padding.size()), + config.dilation.data(), + static_cast(config.dilation.size())); + + cpu_unfold_bwd_4d( + outputHost, input, config.kernelSize, config.stride, config.padding, config.dilation); + + EXPECT_EQ(status, miopenStatusSuccess); + output.data = handle.Read(output_dev, output.data.size()); + } + + void Verify() + { + double threshold = std::numeric_limits::epsilon(); + + auto error = miopen::rms_range(outputHost, output); + + ASSERT_EQ(miopen::range_distance(outputHost), miopen::range_distance(output)); + EXPECT_LT(error, threshold * 10) << "Error forward output beyond tolerance Error: {" + << error << "}, Tolerance: " << threshold * 10; + } + FoldTestCase config; + + tensor input; + tensor output; + tensor outputHost; + + miopen::Allocator::ManageDataPtr input_dev; + miopen::Allocator::ManageDataPtr output_dev; +}; + +template +struct FoldBwdTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + config = GetParam(); + + std::vector in_dims = config.GetInput(); + std::vector in_strides = config.ComputeStrides(in_dims); + + auto gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(1e-2, 100); }; + auto gen_zero = [&](auto...) { return 0; }; + dinput = tensor{in_dims, in_strides}.generate(gen_zero); + dinputHost = tensor{in_dims, in_strides}.generate(gen_zero); + + const uint64_t N = static_cast(in_dims[0]); + uint64_t C = static_cast(in_dims[1]); + for(uint64_t i : config.kernelSize) + { + C = C / i; + } + + std::vector out_dims{N, C, config.outputSize[0], config.outputSize[1]}; + + doutput = tensor{out_dims}.generate(gen_value); + + dinput_dev = handle.Write(dinput.data); + doutput_dev = handle.Write(doutput.data); + } + + void RunTest() + { + auto&& handle = get_handle(); + miopenStatus_t status; + + status = miopen::fold::FoldBackward(handle, + dinput.desc, + dinput_dev.get(), + doutput.desc, + doutput_dev.get(), + config.kernelSize.data(), + static_cast(config.kernelSize.size()), + config.stride.data(), + static_cast(config.stride.size()), + config.padding.data(), + static_cast(config.padding.size()), + config.dilation.data(), + static_cast(config.dilation.size())); + + cpu_unfold_fwd_4d( + doutput, dinputHost, config.kernelSize, config.stride, config.padding, config.dilation); + + EXPECT_EQ(status, miopenStatusSuccess); + dinput.data = handle.Read(dinput_dev, dinput.data.size()); + } + + void Verify() + { + double threshold = std::numeric_limits::epsilon(); + auto error = miopen::rms_range(dinputHost, dinput); + ASSERT_EQ(miopen::range_distance(dinputHost), miopen::range_distance(dinput)); + EXPECT_LT(error, threshold * 10) << "Error backward input_grad beyond tolerance Error: {" + << error << "}, Tolerance: " << threshold * 10; + } + + FoldTestCase config; + + tensor dinput; + tensor doutput; + tensor dinputHost; + + miopen::Allocator::ManageDataPtr dinput_dev; + miopen::Allocator::ManageDataPtr doutput_dev; +}; diff --git a/test/gtest/unfold.cpp b/test/gtest/unfold.cpp new file mode 100644 index 0000000000..ba638c8851 --- /dev/null +++ b/test/gtest/unfold.cpp @@ -0,0 +1,104 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "unfold.hpp" + +namespace unfold { + +struct GPU_Unfold_fwd_FP32 : UnfoldFwdTest +{ +}; + +struct GPU_Unfold_fwd_FP16 : UnfoldFwdTest +{ +}; + +struct GPU_Unfold_fwd_BFP16 : UnfoldFwdTest +{ +}; + +struct GPU_Unfold_bwd_FP32 : UnfoldBwdTest +{ +}; + +struct GPU_Unfold_bwd_FP16 : UnfoldBwdTest +{ +}; + +struct GPU_Unfold_bwd_BFP16 : UnfoldBwdTest +{ +}; +}; // namespace unfold + +using namespace unfold; + +TEST_P(GPU_Unfold_fwd_FP32, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, GPU_Unfold_fwd_FP32, testing::ValuesIn(UnfoldTestConfigs())); + +TEST_P(GPU_Unfold_fwd_FP16, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, GPU_Unfold_fwd_FP16, testing::ValuesIn(UnfoldTestConfigs())); + +TEST_P(GPU_Unfold_fwd_BFP16, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, GPU_Unfold_fwd_BFP16, testing::ValuesIn(UnfoldTestConfigs())); + +TEST_P(GPU_Unfold_bwd_FP32, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, GPU_Unfold_bwd_FP32, testing::ValuesIn(UnfoldTestConfigs())); + +TEST_P(GPU_Unfold_bwd_FP16, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, GPU_Unfold_bwd_FP16, testing::ValuesIn(UnfoldTestConfigs())); + +TEST_P(GPU_Unfold_bwd_BFP16, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, GPU_Unfold_bwd_BFP16, testing::ValuesIn(UnfoldTestConfigs())); diff --git a/test/gtest/unfold.hpp b/test/gtest/unfold.hpp new file mode 100644 index 0000000000..21d4d50294 --- /dev/null +++ b/test/gtest/unfold.hpp @@ -0,0 +1,301 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTN OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTN WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "cpu_unfold.hpp" +#include "get_handle.hpp" +#include +#include "random.hpp" +#include "tensor_holder.hpp" +#include "verify.hpp" +#include +#include +#include + +struct UnfoldTestCase +{ + uint64_t N; + uint64_t C; + uint64_t D; + uint64_t H; + uint64_t W; + std::vector kernelSize; + std::vector stride; + std::vector padding; + std::vector dilation; + bool isContiguous = true; + friend std::ostream& operator<<(std::ostream& os, const UnfoldTestCase& tc) + { + os << "N:" << tc.N << " C:" << tc.C << " D:" << tc.D << " H:" << tc.H << " W:" << tc.W + << " kernel_size:"; + for(const auto& ks : tc.kernelSize) + os << ks << " "; + os << "stride:"; + for(const auto& s : tc.stride) + os << s << " "; + os << "padding:"; + for(const auto& p : tc.padding) + os << p << " "; + os << "dilation:"; + for(const auto& d : tc.dilation) + os << d << " "; + os << "isContiguous:" << std::boolalpha << tc.isContiguous; + return os; + } + + std::vector GetInput() + { + if((N != 0) && (C != 0) && (D != 0) && (H != 0) && (W != 0)) + { + return std::vector({N, C, D, H, W}); + } + else if((N != 0) && (C != 0) && (H != 0) && (W != 0)) + { + return std::vector({N, C, H, W}); + } + else if((N != 0) && (C != 0) && (W != 0)) + { + return std::vector({N, C, W}); + } + else if((N != 0) && (W != 0)) + { + return std::vector({N, W}); + } + else if((N != 0)) + { + return std::vector({N}); + } + else + { + std::cout << "Error Input Tensor Lengths\n" << std::endl; + return std::vector({0}); + } + } + + std::vector ComputeStrides(std::vector inputDim) const + { + if(!isContiguous) + std::swap(inputDim.front(), inputDim.back()); + std::vector strides(inputDim.size()); + strides.back() = 1; + for(int i = inputDim.size() - 2; i >= 0; --i) + strides[i] = strides[i + 1] * inputDim[i + 1]; + if(!isContiguous) + std::swap(strides.front(), strides.back()); + return strides; + } +}; + +inline std::vector UnfoldTestConfigs() +{ + // clang-format: off + return { + {2, 5, 0, 3, 4, {2, 3}, {1, 1}, {0, 0}, {1, 1}, true}, + {1, 3, 0, 10, 12, {4, 5}, {1, 1}, {0, 0}, {1, 1}, true}, + {11, 13, 0, 17, 19, {3, 3}, {3, 2}, {0, 0}, {1, 1}, true}, + {11, 13, 0, 17, 19, {3, 3}, {1, 1}, {3, 2}, {1, 1}, true}, + {11, 13, 0, 17, 19, {3, 3}, {1, 1}, {0, 0}, {3, 2}, true}, + {11, 13, 0, 33, 37, {4, 3}, {2, 3}, {5, 2}, {3, 5}, true}, + {2, 5, 0, 3, 4, {2, 3}, {1, 1}, {0, 0}, {1, 1}, false}, + {1, 3, 0, 10, 12, {4, 5}, {1, 1}, {0, 0}, {1, 1}, false}, + {11, 13, 0, 17, 19, {3, 3}, {3, 2}, {0, 0}, {1, 1}, false}, + {11, 13, 0, 17, 19, {3, 3}, {1, 1}, {3, 2}, {1, 1}, false}, + {11, 13, 0, 17, 19, {3, 3}, {1, 1}, {0, 0}, {3, 2}, false}, + {11, 13, 0, 33, 37, {4, 3}, {2, 3}, {5, 2}, {3, 5}, false}, + }; + // clang-format: on +} + +template +struct UnfoldFwdTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + config = GetParam(); + + std::vector in_dims = config.GetInput(); + std::vector in_strides = config.ComputeStrides(in_dims); + + auto gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(1e-2, 100); }; + auto gen_zero = [&](auto...) { return 0; }; + input = tensor{in_dims, in_strides}.generate(gen_value); + + int spatial_dim_size = in_dims.size() - 2; + const uint64_t N = static_cast(in_dims[0]); + const uint64_t C = static_cast(in_dims[1]); + uint64_t P = 1, L = 1; + std::vector ls; + for(int i = 0; i < spatial_dim_size; ++i) + { + P *= config.kernelSize[i]; + uint64_t l = (in_dims[i + 2] + 2 * config.padding[i] - + config.dilation[i] * (config.kernelSize[i] - 1) - 1) / + config.stride[i] + + 1; + L *= l; + ls.push_back(l); + } + + std::vector out_dims{N, C * P, L}; + + output = tensor{out_dims}.generate(gen_zero); + outputHost = tensor{out_dims}.generate(gen_zero); + + input_dev = handle.Write(input.data); + output_dev = handle.Write(output.data); + } + + void RunTest() + { + auto&& handle = get_handle(); + miopenStatus_t status; + + status = miopen::fold::UnfoldForward(handle, + input.desc, + input_dev.get(), + output.desc, + output_dev.get(), + config.kernelSize.data(), + static_cast(config.kernelSize.size()), + config.stride.data(), + static_cast(config.stride.size()), + config.padding.data(), + static_cast(config.padding.size()), + config.dilation.data(), + static_cast(config.dilation.size())); + + cpu_unfold_fwd_4d( + input, outputHost, config.kernelSize, config.stride, config.padding, config.dilation); + + EXPECT_EQ(status, miopenStatusSuccess); + output.data = handle.Read(output_dev, output.data.size()); + } + + void Verify() + { + double threshold = std::numeric_limits::epsilon(); + + auto error = miopen::rms_range(outputHost, output); + + ASSERT_EQ(miopen::range_distance(outputHost), miopen::range_distance(output)); + EXPECT_LT(error, threshold * 10) << "Error forward output beyond tolerance Error: {" + << error << "}, Tolerance: " << threshold * 10; + } + + UnfoldTestCase config; + + tensor input; + tensor output; + tensor outputHost; + + miopen::Allocator::ManageDataPtr input_dev; + miopen::Allocator::ManageDataPtr output_dev; +}; + +template +struct UnfoldBwdTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + config = GetParam(); + + std::vector in_dims = config.GetInput(); + std::vector in_strides = config.ComputeStrides(in_dims); + + auto gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(1e-2, 100); }; + auto gen_zero = [&](auto...) { return 0; }; + dinput = tensor{in_dims, in_strides}.generate(gen_zero); + dinputHost = tensor{in_dims, in_strides}.generate(gen_zero); + + int spatial_dim_size = in_dims.size() - 2; + const uint64_t N = static_cast(in_dims[0]); + const uint64_t C = static_cast(in_dims[1]); + uint64_t P = 1, L = 1; + std::vector ls; + for(int i = 0; i < spatial_dim_size; ++i) + { + P *= config.kernelSize[i]; + uint64_t l = (static_cast(in_dims[i + 2]) + 2 * config.padding[i] - + config.dilation[i] * (config.kernelSize[i] - 1) - 1) / + config.stride[i] + + 1; + L *= l; + ls.push_back(l); + } + + std::vector out_dims{N, C * P, L}; + + doutput = tensor{out_dims}.generate(gen_value); + + dinput_dev = handle.Write(dinput.data); + doutput_dev = handle.Write(doutput.data); + } + + void RunTest() + { + auto&& handle = get_handle(); + miopenStatus_t status; + + status = miopen::fold::UnfoldBackward(handle, + dinput.desc, + dinput_dev.get(), + doutput.desc, + doutput_dev.get(), + config.kernelSize.data(), + static_cast(config.kernelSize.size()), + config.stride.data(), + static_cast(config.stride.size()), + config.padding.data(), + static_cast(config.padding.size()), + config.dilation.data(), + static_cast(config.dilation.size())); + + cpu_unfold_bwd_4d( + dinputHost, doutput, config.kernelSize, config.stride, config.padding, config.dilation); + + EXPECT_EQ(status, miopenStatusSuccess); + dinput.data = handle.Read(dinput_dev, dinput.data.size()); + } + + void Verify() + { + double threshold = std::numeric_limits::epsilon(); + auto error = miopen::rms_range(dinputHost, dinput); + ASSERT_EQ(miopen::range_distance(dinputHost), miopen::range_distance(dinput)); + EXPECT_LT(error, threshold * 10) << "Error backward input_grad beyond tolerance Error: {" + << error << "}, Tolerance: " << threshold * 10; + } + UnfoldTestCase config; + + tensor dinput; + tensor doutput; + tensor dinputHost; + + miopen::Allocator::ManageDataPtr dinput_dev; + miopen::Allocator::ManageDataPtr doutput_dev; +};