From 505e672b848cc65080bc36b18ed96b94009b3198 Mon Sep 17 00:00:00 2001 From: cxm <781326019@qq.com> Date: Fri, 27 Sep 2024 17:31:02 +0800 Subject: [PATCH] Add pixel_unshuffle opencl support test=develop --- .../cl_kernel/image/pixel_unshuffle_kernel.cl | 108 ++++++++++ lite/kernels/opencl/CMakeLists.txt | 1 + .../opencl/pixel_unshuffle_image_compute.cc | 192 ++++++++++++++++++ .../pixel_unshuffle_image_compute_test.cc | 158 ++++++++++++++ 4 files changed, 459 insertions(+) create mode 100644 lite/backends/opencl/cl_kernel/image/pixel_unshuffle_kernel.cl create mode 100644 lite/kernels/opencl/pixel_unshuffle_image_compute.cc create mode 100644 lite/kernels/opencl/pixel_unshuffle_image_compute_test.cc diff --git a/lite/backends/opencl/cl_kernel/image/pixel_unshuffle_kernel.cl b/lite/backends/opencl/cl_kernel/image/pixel_unshuffle_kernel.cl new file mode 100644 index 00000000000..9c4b227bdca --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/pixel_unshuffle_kernel.cl @@ -0,0 +1,108 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +__kernel void pixel_unshuffle(__read_only image2d_t input_image, + __write_only image2d_t output_image, + __private const int in_N, + __private const int in_C, + __private const int in_H, + __private const int in_W, + __private const int out_N, + __private const int out_C, + __private const int out_H, + __private const int out_W, + __private const int downscale_factor) { + const int in_c4 = get_global_id(0); + const int in_w = get_global_id(1); + const int in_nh = get_global_id(2); + + int in_h = in_nh % in_H; + int in_n = in_nh / in_H; + + int out_h = in_h * downscale_factor; + int out_w = in_w * downscale_factor; + int out_nh = in_n * out_H + out_h; + + CL_DTYPE4 res; + int in_c; + int out_c; + CL_DTYPE4 in; + int2 out_pos; + + in_c = in_c4 * 4 + 0; + out_c = in_c / (downscale_factor * downscale_factor); + out_pos.x = (out_c / 4) * out_W + out_w + (in_c % downscale_factor); + out_pos.y = out_nh + (in_c / (out_C * 4)) * out_H; + in = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, SAMPLER, out_pos); + if (in_c % 4 == 0) { + res.x = in.x; + } else if (in_c % 4 == 1) { + res.x = in.y; + } else if (in_c % 4 == 2) { + res.x = in.z; + } else if (in_c % 4 == 3) { + res.x = in.w; + } + + in_c = in_c4 * 4 + 1; + out_c = in_c / (downscale_factor * downscale_factor); + out_pos.x = (out_c / 4) * out_W + out_w + (in_c % downscale_factor); + out_pos.y = out_nh + (in_c / (out_C * 4)) * out_H; + in = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, SAMPLER, out_pos); + if (in_c % 4 == 0) { + res.y = in.x; + } else if (in_c % 4 == 1) { + res.y = in.y; + } else if (in_c % 4 == 2) { + res.y = in.z; + } else if (in_c % 4 == 3) { + res.y = in.w; + } + + in_c = in_c4 * 4 + 2; + out_c = in_c / (downscale_factor * downscale_factor); + out_pos.x = (out_c / 4) * out_W + out_w + (in_c % downscale_factor); + out_pos.y = out_nh + (in_c / (out_C * 4)) * out_H; + in = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, SAMPLER, out_pos); + if (in_c % 4 == 0) { + res.z = in.x; + } else if (in_c % 4 == 1) { + res.z = in.y; + } else if (in_c % 4 == 2) { + res.z = in.z; + } else if (in_c % 4 == 3) { + res.z = in.w; + } + + in_c = in_c4 * 4 + 3; + out_c = in_c / (downscale_factor * downscale_factor); + out_pos.x = (out_c / 4) * out_W + out_w + (in_c % downscale_factor); + out_pos.y = out_nh + (in_c / (out_C * 4)) * out_H; + in = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, SAMPLER, out_pos); + if (in_c % 4 == 0) { + res.w = in.x; + } else if (in_c % 4 == 1) { + res.w = in.y; + } else if (in_c % 4 == 2) { + res.w = in.z; + } else if (in_c % 4 == 3) { + res.w = in.w; + } + + int2 in_pos; + in_pos.x = in_c4 * in_W + in_w; + in_pos.y = in_nh; + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, in_pos, res); +} \ No newline at end of file diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index 857e08ca6d7..c752b5ad060 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -42,6 +42,7 @@ add_kernel(dropout_opencl_image OPENCL basic SRCS dropout_image_compute.cc) add_kernel(pad2d_opencl_image OPENCL basic SRCS pad2d_image_compute.cc) add_kernel(box_coder_opencl_image OPENCL basic SRCS box_coder_image_compute.cc) add_kernel(pixel_shuffle_opencl_image OPENCL basic SRCS pixel_shuffle_image_compute.cc) +add_kernel(pixel_unshuffle_opencl_image OPENCL basic SRCS pixel_unshuffle_image_compute.cc) add_kernel(expand_opencl_image OPENCL basic SRCS expand_image_compute.cc) add_kernel(shuffle_channel_opencl_image OPENCL basic SRCS shuffle_channel_image_compute.cc) add_kernel(trigonometric_opencl_image OPENCL basic SRCS trigonometric_image_compute.cc) diff --git a/lite/kernels/opencl/pixel_unshuffle_image_compute.cc b/lite/kernels/opencl/pixel_unshuffle_image_compute.cc new file mode 100644 index 00000000000..60cc0226d74 --- /dev/null +++ b/lite/kernels/opencl/pixel_unshuffle_image_compute.cc @@ -0,0 +1,192 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/opencl/cl_half.h" +#include "lite/backends/opencl/cl_include.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" +#include "lite/operators/op_params.h" +#include "lite/utils/replace_stl/stream.h" +#include "lite/utils/string.h" +#ifdef LITE_WITH_PROFILE +#include "lite/core/profile/profiler.h" +#endif +#include "lite/backends/opencl/cl_utility.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +class PixelUnShuffleComputeImage2D + : public KernelLite { + public: + using param_t = operators::PixelUnShuffleParam; + + std::string doc() const override { + return "PixelUnShuffle using cl::Image2D, kFP16"; + } + + void PrepareForRun() override { + VLOG(1) << "kernel_func_name_:" << kernel_func_name_; + + auto& context = ctx_->As(); + context.cl_context()->AddKernel(kernel_func_name_, + "image/pixel_unshuffle_kernel.cl", + build_options_, + time_stamp_); + + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; + kernel_ = context.cl_context()->GetKernel(kernel_key.str()); + } + + void ReInitWhenNeeded() override { + VLOG(1) << "ReInitWhenNeeded: " << kernel_func_name_; + pixel_unshuffle_param_ = param_.get_mutable(); + auto x_dims = pixel_unshuffle_param_->x->dims(); + auto out_dims = pixel_unshuffle_param_->output->dims(); + VLOG(1) << "x_dims: " << x_dims; + VLOG(1) << "out_dims: " << out_dims; + VLOG(1) << "downscale_factor: " << pixel_unshuffle_param_->downscale_factor; + + if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) || + first_epoch_for_reinit_) { + last_x_dims_ = x_dims; + first_epoch_for_reinit_ = false; + // compute image shape + paddle::lite::CLImageConverterDefault default_convertor; + out_img_shape_ = default_convertor.InitImageDimInfoWith( + pixel_unshuffle_param_->output->dims()); + VLOG(1) << "out_img_shape_: " << out_img_shape_[0] << " " + << out_img_shape_[1]; + + // compute global work size + auto image_width = out_dims[3] * ((out_dims[1] + 3) / 4); + size_t work_size_0 = image_width / out_dims[3]; + size_t work_size_1 = out_dims[3]; + size_t work_size_2 = out_dims[0] * out_dims[2]; + global_work_size_ = cl::NDRange{work_size_0, work_size_1, work_size_2}; + VLOG(1) << "global_work_size_: " << global_work_size_[0] << " " + << global_work_size_[1] << " " << global_work_size_[2]; + } + } + + void Run() override { + auto* x_img = GET_DATA_GPU(pixel_unshuffle_param_->x); + auto* out_img = MUTABLE_DATA_GPU(pixel_unshuffle_param_->output, + out_img_shape_[0], + out_img_shape_[1], + nullptr); + + auto x_dims = pixel_unshuffle_param_->x->dims(); + + int in_n = x_dims[0]; + int in_c = x_dims[1]; + int in_h = x_dims[2]; + int in_w = x_dims[3]; + + auto out_dims = pixel_unshuffle_param_->output->dims(); + + int out_n = out_dims[0]; + int out_c = out_dims[1]; + int out_h = out_dims[2]; + int out_w = out_dims[3]; + + const int downscale_factor = pixel_unshuffle_param_->downscale_factor; + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + + auto kernel = kernel_; + cl_int status; + status = kernel.setArg(0, *x_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(1, *out_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(2, in_n); + CL_CHECK_FATAL(status); + status = kernel.setArg(3, in_c); + CL_CHECK_FATAL(status); + status = kernel.setArg(4, in_h); + CL_CHECK_FATAL(status); + status = kernel.setArg(5, in_w); + CL_CHECK_FATAL(status); + status = kernel.setArg(6, out_n); + CL_CHECK_FATAL(status); + status = kernel.setArg(7, out_c); + CL_CHECK_FATAL(status); + status = kernel.setArg(8, out_h); + CL_CHECK_FATAL(status); + status = kernel.setArg(9, out_w); + CL_CHECK_FATAL(status); + status = kernel.setArg(10, downscale_factor); + CL_CHECK_FATAL(status); + + status = EnqueueNDRangeKernel(context, + kernel, + cl::NullRange, + global_work_size_, + cl::NullRange, + nullptr, + event_); + CL_CHECK_FATAL(status); + } + +#ifdef LITE_WITH_PROFILE + void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) { + ch->kernel_func_name = kernel_func_name_; + ch->cl_event = + event_; // `event_` defined in `kernel.h`, valid after kernel::Run + } +#endif + private: + std::string kernel_func_name_{"pixel_unshuffle"}; + std::string build_options_{""}; + std::string time_stamp_{GetTimeStamp()}; + + param_t* pixel_unshuffle_param_{nullptr}; + cl::Kernel kernel_; + bool first_epoch_for_reinit_{true}; + DDim last_x_dims_; + DDim out_img_shape_ = DDim(std::vector( + {static_cast(1), static_cast(1)})); + cl::NDRange global_work_size_ = cl::NDRange{ + static_cast(1), static_cast(1), static_cast(1)}; +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(pixel_unshuffle, + kOpenCL, + kFP16, + kImageDefault, + paddle::lite::kernels::opencl::PixelUnShuffleComputeImage2D, + image2d) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); \ No newline at end of file diff --git a/lite/kernels/opencl/pixel_unshuffle_image_compute_test.cc b/lite/kernels/opencl/pixel_unshuffle_image_compute_test.cc new file mode 100644 index 00000000000..e340d94a722 --- /dev/null +++ b/lite/kernels/opencl/pixel_unshuffle_image_compute_test.cc @@ -0,0 +1,158 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/kernels/opencl/test_helper.h" + +#define FP16_MAX_DIFF (5e-1) + +namespace paddle { +namespace lite { + +TEST(pixel_unshuffle_image2d, compute) { + LOG(INFO) << "create kernel ..."; + auto kernels = KernelRegistry::Global().Create("pixel_unshuffle", + TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(kernels.empty()); + + const int INPUT_N = 1; + const int INPUT_C = 4; + const int INPUT_H = 2; + const int INPUT_W = 2; + const int DOWNSCALE_FACTOR = 2; + + auto kernel = std::move(kernels.front()); + + LOG(INFO) << "prepare to test kernel ====> " << kernel->doc(); + + lite::Tensor x, out; + operators::PixelUnShuffleParam param; + param.x = &x; + param.output = &out; + param.downscale_factor = DOWNSCALE_FACTOR; + + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + kernel->SetParam(param); + std::unique_ptr pixel_unshuffle_context(new KernelContext); + context->As().CopySharedTo( + &(pixel_unshuffle_context->As())); + + kernel->SetContext(std::move(pixel_unshuffle_context)); + + const DDim in_dim = + DDim(std::vector{INPUT_N, INPUT_C, INPUT_H, INPUT_W}); + const DDim out_dim = DDim( + std::vector{INPUT_N, + INPUT_C * DOWNSCALE_FACTOR * DOWNSCALE_FACTOR, + INPUT_H / DOWNSCALE_FACTOR, + INPUT_W / DOWNSCALE_FACTOR}); + LOG(INFO) << "in_dim: " << in_dim; + LOG(INFO) << "DOWNSCALE_FACTOR: " << DOWNSCALE_FACTOR; + LOG(INFO) << "out_dim: " << out_dim; + + x.Resize(in_dim); + out.Resize(out_dim); + + std::default_random_engine engine; + std::uniform_real_distribution dist(-2, 2); + std::vector input_v(INPUT_N * INPUT_C * INPUT_H * INPUT_W); + + int index = 0; + for (auto& i : input_v) { + i = index++; + } + VLOG(1) << "input_v ..... "; + for (size_t i = 0; i < input_v.size(); i++) { + VLOG(10) << input_v[i]; + } + + LOG(INFO) << "prepare input"; + CLImageConverterDefault* default_converter = new CLImageConverterDefault(); + DDim x_image_shape = default_converter->InitImageDimInfoWith(in_dim); + LOG(INFO) << "x_image_shape = " << x_image_shape[0] << " " + << x_image_shape[1]; + std::vector x_image_data(x_image_shape.production() * 4); // 4 : RGBA + default_converter->NCHWToImage(input_v.data(), x_image_data.data(), in_dim); + auto* x_image = x.mutable_data( + x_image_shape[0], x_image_shape[1], x_image_data.data()); + VLOG(1) << "x_image_data ..... "; + for (size_t i = 0; i < x_image_data.size(); i++) { + VLOG(10) << Half2Float(x_image_data[i]); + } + DDim out_image_shape = default_converter->InitImageDimInfoWith(out_dim); + LOG(INFO) << "out_image_shape = " << out_image_shape[0] << " " + << out_image_shape[1]; + auto* out_image = out.mutable_data(out_image_shape[0], + out_image_shape[1]); + kernel->Launch(); + CLRuntime::Global()->command_queue().finish(); + + std::vector out_data_v{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + half_t* out_image_data = new half_t[out_image_shape.production() * 4]; + TargetWrapperCL::ImgcpySync(out_image_data, + out_image, + out_image_shape[0], + out_image_shape[1], + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoH); + VLOG(1) << "out_image_data ..... "; + for (size_t i = 0; i < out_image_shape.production() * 4; i++) { + VLOG(10) << Half2Float(out_image_data[i]); + } + float* out_data = new float[out_image_shape.production() * 4]; + default_converter->ImageToNCHW( + out_image_data, out_data, out_image_shape, out_dim); + + VLOG(1) << "out_data ..... "; + for (int i = 0; i < out_dim.production(); i++) { + VLOG(10) << out_data[i]; + } + + for (int i = 0; i < out_dim.production(); i++) { + auto abs_diff = abs(out_data[i] - out_data_v[i]); + auto relative_diff = COMPUTE_RELATIVE_DIFF(out_data[i], out_data_v[i]); + EXPECT_EQ((relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF), + true); + if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) { + LOG(ERROR) << "error idx:" << i << " out_data[" << i + << "]:" << out_data[i] << " " + "out_ref[" + << i << "]:" << out_data_v[i] << " abs_diff:" << abs_diff + << " relative_diff:" << relative_diff + << " FP16_MAX_DIFF:" << FP16_MAX_DIFF; + } + } + + delete[] out_image_data; + delete[] out_data; + delete default_converter; +} + +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(pixel_unshuffle, kOpenCL, kFP16, kImageDefault, image2d); \ No newline at end of file