From ab1e53fb5e47dcf41df0a06a6b104337d1a7a2b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Tue, 3 Sep 2024 11:01:53 +0200 Subject: [PATCH 1/2] gpu: generic: deconvolution: move implementation from ocl folder --- src/gpu/generic/convolution_deconvolution.hpp | 374 ++++++++++++++++++ src/gpu/gpu_deconvolution_list.cpp | 6 +- .../intel/ocl/convolution_deconvolution.hpp | 291 -------------- 3 files changed, 378 insertions(+), 293 deletions(-) create mode 100644 src/gpu/generic/convolution_deconvolution.hpp diff --git a/src/gpu/generic/convolution_deconvolution.hpp b/src/gpu/generic/convolution_deconvolution.hpp new file mode 100644 index 00000000000..283f1d4fa2d --- /dev/null +++ b/src/gpu/generic/convolution_deconvolution.hpp @@ -0,0 +1,374 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GPU_GENERIC_CONVOLUTION_DECONVOLUTION_HPP +#define GPU_GENERIC_CONVOLUTION_DECONVOLUTION_HPP + +#include "common/c_types_map.hpp" +#include "common/primitive.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" +#include "gpu/gpu_deconvolution_pd.hpp" + +namespace dnnl { +namespace impl { +namespace gpu { +namespace generic { + +static status_t weights_axes_permutation( + memory_desc_t *o_md, const memory_desc_t *i_md, bool with_groups) { + int perm[DNNL_MAX_NDIMS] {}; // deconv to conv weight permutation + for (int d = 0; d < DNNL_MAX_NDIMS; ++d) + perm[d] = d; + nstl::swap(perm[0 + with_groups], perm[1 + with_groups]); + + return memory_desc_permute_axes(*o_md, *i_md, perm); +} + +static status_t conv_descr_create( + const deconvolution_desc_t *dd, convolution_desc_t *cd) { + using namespace prop_kind; + alg_kind_t alg_kind = alg_kind::convolution_direct; + + const memory_desc_t *src_md, *dst_md, *d_weights_d; + prop_kind_t prop_kind; + + switch (dd->prop_kind) { + case forward: + case forward_inference: + prop_kind = backward_data; + src_md = &dd->dst_desc; + dst_md = &dd->src_desc; + d_weights_d = &dd->weights_desc; + break; + case backward_data: + prop_kind = forward_training; + src_md = &dd->diff_dst_desc; + dst_md = &dd->diff_src_desc; + d_weights_d = &dd->weights_desc; + break; + case backward_weights: + prop_kind = dd->prop_kind; + src_md = &dd->diff_dst_desc; + dst_md = &dd->src_desc; + d_weights_d = &dd->diff_weights_desc; + break; + default: assert(!"unknown prop kind"); return status::invalid_arguments; + } + + // Create weights desc for convolution + memory_desc_t c_weights_d; + const bool with_groups = d_weights_d->ndims == src_md->ndims + 1; + CHECK(weights_axes_permutation(&c_weights_d, d_weights_d, with_groups)); + + return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d, + prop_kind != backward_weights ? &dd->bias_desc : nullptr, dst_md, + dd->strides, dd->dilates, dd->padding[0], dd->padding[1]); +} + +struct convolution_deconvolution_fwd_t : public gpu::primitive_t { + using gpu::primitive_t::primitive_t; + struct pd_t : public gpu_deconvolution_fwd_pd_t { + pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : gpu_deconvolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {} + + pd_t(const pd_t &other) = default; + + ~pd_t() = default; + + DECLARE_COMMON_PD_T(name_.c_str(), convolution_deconvolution_fwd_t); + status_t init_convolution(impl::engine_t *engine) { + convolution_desc_t cd; + CHECK(conv_descr_create(desc(), &cd)); + primitive_attr_t conv_attr(*attr()); + if (!conv_attr.is_initialized()) return status::out_of_memory; + primitive_desc_iterator_t it( + engine, (op_desc_t *)&cd, &conv_attr, nullptr); + if (!it.is_initialized()) return status::out_of_memory; + conv_pd_ = *(++it); + + return (conv_pd_) ? status::success : status::unimplemented; + } + + status_t init(impl::engine_t *engine) { + using namespace format_tag; + using sm = primitive_attr_t::skip_mask_t; + + const auto attr_skip_mask = sm::post_ops | sm::zero_points_runtime + | sm::scales_runtime; + + VDISPATCH_DECONVOLUTION(is_fwd(), VERBOSE_BAD_PROPKIND); + VDISPATCH_DECONVOLUTION( + desc()->alg_kind == alg_kind::deconvolution_direct, + VERBOSE_BAD_ALGORITHM); + VDISPATCH_DECONVOLUTION(attr()->has_default_values(attr_skip_mask), + VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_DECONVOLUTION( + (utils::everyone_is(data_type::f32, + desc()->src_desc.data_type, + desc()->weights_desc.data_type, + desc()->dst_desc.data_type) + || (utils::everyone_is(data_type::f64, + desc()->src_desc.data_type, + desc()->weights_desc.data_type, + desc()->dst_desc.data_type)) + || ((utils::everyone_is(data_type::f16, + desc()->src_desc.data_type, + desc()->weights_desc.data_type) + || utils::everyone_is(data_type::f32, + desc()->src_desc.data_type, + desc()->weights_desc.data_type) + || utils::everyone_is(data_type::bf16, + desc()->src_desc.data_type, + desc()->weights_desc.data_type)) + && utils::one_of(desc()->dst_desc.data_type, + data_type::f16, data_type::u8, + data_type::s8)) + || (utils::everyone_is(data_type::bf16, + desc()->src_desc.data_type, + desc()->weights_desc.data_type) + && utils::one_of(desc()->dst_desc.data_type, + data_type::f32, data_type::bf16)) + || (utils::everyone_is(data_type::f16, + desc()->src_desc.data_type, + desc()->weights_desc.data_type) + && utils::one_of(desc()->dst_desc.data_type, + data_type::f32, data_type::f16)) + || (desc()->weights_desc.data_type == data_type::s8 + && utils::one_of(desc()->src_desc.data_type, + data_type::u8, data_type::s8) + && desc()->dst_desc.data_type + != data_type::f64)), + VERBOSE_UNSUPPORTED_DT); + + VDISPATCH_DECONVOLUTION_SC( + init_convolution(engine), "init_convolution()"); + if (weights_md_.format_kind == format_kind::any) { + VDISPATCH_DECONVOLUTION_SC( + weights_axes_permutation(&weights_md_, + conv_pd_->weights_md(), with_groups()), + "weights_axes_permutation()"); + } + if (src_md_.format_kind == format_kind::any) + src_md_ = *conv_pd_->diff_dst_md(); + if (dst_md_.format_kind == format_kind::any) + dst_md_ = *conv_pd_->diff_src_md(); + if (bias_md_.format_kind == format_kind::any) { + VDISPATCH_DECONVOLUTION_SC(memory_desc_init_by_tag(bias_md_, x), + VERBOSE_UNSUPPORTED_TAG); + } + init_name(); + init_scratchpad(); + VDISPATCH_DECONVOLUTION_SC(attr_.set_default_formats(dst_md(0)), + VERBOSE_UNSUPPORTED_ATTR); + + return status::success; + } + + std::shared_ptr conv_pd_; + + private: + std::string name_ = "conv:any"; + + void init_name() { + name_.append("+"); + name_.append(conv_pd_->name()); + } + + void init_scratchpad() { + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(memory_tracking::names::key_nested, + conv_pd_->scratchpad_registry()); + } + }; + + status_t init(impl::engine_t *engine) override { + return create_nested_primitive(conv_p_, pd()->conv_pd_, engine); + } + + status_t execute(const exec_ctx_t &ctx) const override { + using namespace memory_tracking::names; + const auto &args = ctx.args(); + exec_args_t conv_args; + conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC); + conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS); + conv_args[DNNL_ARG_DIFF_SRC] = args.at(DNNL_ARG_DST); + if (pd()->with_bias()) + conv_args[DNNL_ARG_BIAS] = args.at(DNNL_ARG_BIAS); + + for (int idx = 0; idx < pd()->attr()->post_ops_.len(); ++idx) { + if (pd()->attr()->post_ops_.entry_[idx].is_binary()) { + conv_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1] + = args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) + | DNNL_ARG_SRC_1); + } else if (pd()->attr()->post_ops_.entry_[idx].is_prelu()) { + conv_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) + | DNNL_ARG_WEIGHTS] + = args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) + | DNNL_ARG_WEIGHTS); + } + } + const auto z_src = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC; + const auto z_dst = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST; + if (args.find(z_src) != args.end()) conv_args[z_src] = args.at(z_src); + if (args.find(z_dst) != args.end()) conv_args[z_dst] = args.at(z_dst); + + for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { + int key = DNNL_ARG_ATTR_SCALES | arg; + if (args.find(key) != args.end()) conv_args[key] = args.at(key); + } + + exec_ctx_t conv_ctx(ctx, std::move(conv_args)); + + nested_scratchpad_t ns(ctx, key_nested, conv_p_); + conv_ctx.set_scratchpad_grantor(ns.grantor()); + // Executing the convolution kernel + return conv_p_->execute(conv_ctx); + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + std::shared_ptr conv_p_; +}; + +struct convolution_deconvolution_bwd_data_t : public gpu::primitive_t { + using gpu::primitive_t::primitive_t; + struct pd_t : public gpu_deconvolution_bwd_data_pd_t { + pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : gpu_deconvolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) + , conv_pd_(nullptr) {} + + pd_t(const pd_t &other) = default; + + ~pd_t() = default; + + DECLARE_COMMON_PD_T( + name_.c_str(), convolution_deconvolution_bwd_data_t); + + status_t init_convolution(impl::engine_t *engine) { + convolution_desc_t cd; + CHECK(conv_descr_create(desc(), &cd)); + primitive_attr_t conv_attr(*attr()); + if (!conv_attr.is_initialized()) return status::out_of_memory; + primitive_desc_iterator_t it( + engine, (op_desc_t *)&cd, &conv_attr, nullptr); + if (!it.is_initialized()) return status::out_of_memory; + conv_pd_ = *(++it); + return (conv_pd_) ? status::success : status::unimplemented; + } + + status_t init(impl::engine_t *engine) { + VDISPATCH_DECONVOLUTION( + desc()->prop_kind == prop_kind::backward_data, + VERBOSE_BAD_PROPKIND); + + VDISPATCH_DECONVOLUTION( + (utils::everyone_is(data_type::f32, + desc()->diff_src_desc.data_type, + desc()->weights_desc.data_type, + desc()->diff_dst_desc.data_type) + || (utils::everyone_is(data_type::f64, + desc()->diff_src_desc.data_type, + desc()->weights_desc.data_type, + desc()->diff_dst_desc.data_type)) + || utils::everyone_is(data_type::f16, + desc()->weights_desc.data_type, + desc()->diff_dst_desc.data_type) + || utils::everyone_is(data_type::bf16, + desc()->weights_desc.data_type, + desc()->diff_dst_desc.data_type)), + VERBOSE_UNSUPPORTED_DT); + + VDISPATCH_DECONVOLUTION( + utils::one_of(desc()->diff_src_desc.data_type, + data_type::bf16, data_type::f16, data_type::f32, + data_type::f64), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_DECONVOLUTION( + desc()->alg_kind == alg_kind::deconvolution_direct, + VERBOSE_BAD_ALGORITHM); + VDISPATCH_DECONVOLUTION( + attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR); + + VDISPATCH_DECONVOLUTION_SC( + init_convolution(engine), "init_convolution()"); + if (weights_md_.format_kind == format_kind::any) + VDISPATCH_DECONVOLUTION_SC( + weights_axes_permutation(&weights_md_, + conv_pd_->weights_md(), with_groups()), + "weights_axes_permutation()"); + if (diff_src_md_.format_kind == format_kind::any) + diff_src_md_ = *conv_pd_->dst_md(); + if (diff_dst_md_.format_kind == format_kind::any) + diff_dst_md_ = *conv_pd_->src_md(); + + init_name(); + init_scratchpad(); + + return status::success; + } + + std::shared_ptr conv_pd_; + + private: + std::string name_ = "conv:any"; + + void init_name() { + name_.append("+"); + name_.append(conv_pd_->name()); + } + + void init_scratchpad() { + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(memory_tracking::names::key_nested, + conv_pd_->scratchpad_registry()); + } + }; + + status_t init(impl::engine_t *engine) override { + return create_nested_primitive(conv_p_, pd()->conv_pd_, engine); + } + + status_t execute(const exec_ctx_t &ctx) const override { + using namespace memory_tracking::names; + const auto &args = ctx.args(); + exec_args_t conv_args; + conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST); + conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS); + conv_args[DNNL_ARG_DST] = args.at(DNNL_ARG_DIFF_SRC); + if (!types::is_zero_md(pd()->scratchpad_md())) + conv_args[DNNL_ARG_SCRATCHPAD] = args.at(DNNL_ARG_SCRATCHPAD); + exec_ctx_t conv_ctx(ctx, std::move(conv_args)); + + nested_scratchpad_t ns(ctx, key_nested, conv_p_); + conv_ctx.set_scratchpad_grantor(ns.grantor()); + // Executing the convolution kernel + return conv_p_->execute(conv_ctx); + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + std::shared_ptr conv_p_; +}; + +} // namespace generic +} // namespace gpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/gpu/gpu_deconvolution_list.cpp b/src/gpu/gpu_deconvolution_list.cpp index 55aeaff4b23..20ebdd37c3c 100644 --- a/src/gpu/gpu_deconvolution_list.cpp +++ b/src/gpu/gpu_deconvolution_list.cpp @@ -34,6 +34,8 @@ #include "gpu/generic/sycl/ref_deconvolution.hpp" #endif +#include "gpu/generic/convolution_deconvolution.hpp" + namespace dnnl { namespace impl { namespace gpu { @@ -45,18 +47,18 @@ using namespace dnnl::impl::prop_kind; const std::map> impl_list_map REG_DECONV_P({ {{forward}, { - GPU_INSTANCE_INTEL(intel::ocl::convolution_deconvolution_fwd_t) GPU_INSTANCE_NVIDIA(nvidia::cudnn_deconvolution_fwd_t) GPU_INSTANCE_AMD(amd::miopen_deconvolution_fwd_t) + GPU_INSTANCE_GENERIC(generic::convolution_deconvolution_fwd_t) nullptr, }}, {{backward}, REG_BWD_PK({ - GPU_INSTANCE_INTEL(intel::ocl::convolution_deconvolution_bwd_data_t) GPU_INSTANCE_INTEL(intel::ocl::convolution_deconvolution_bwd_weights_t) GPU_INSTANCE_NVIDIA(nvidia::cudnn_deconvolution_bwd_data_t) GPU_INSTANCE_NVIDIA(nvidia::cudnn_deconvolution_bwd_weights_t) GPU_INSTANCE_AMD(amd::miopen_deconvolution_bwd_data_t) GPU_INSTANCE_AMD(amd::miopen_deconvolution_bwd_weights_t) + GPU_INSTANCE_GENERIC(generic::convolution_deconvolution_bwd_data_t) GPU_INSTANCE_GENERIC_SYCL(generic::sycl::ref_deconvolution_bwd_weights_t) nullptr, })}, diff --git a/src/gpu/intel/ocl/convolution_deconvolution.hpp b/src/gpu/intel/ocl/convolution_deconvolution.hpp index 4d157346f5f..de6f6a6edf9 100644 --- a/src/gpu/intel/ocl/convolution_deconvolution.hpp +++ b/src/gpu/intel/ocl/convolution_deconvolution.hpp @@ -83,297 +83,6 @@ static status_t conv_descr_create( dd->strides, dd->dilates, dd->padding[0], dd->padding[1]); } -struct convolution_deconvolution_fwd_t : public gpu_primitive_t { - using gpu_primitive_t::gpu_primitive_t; - struct pd_t : public gpu_deconvolution_fwd_pd_t { - pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, - const deconvolution_fwd_pd_t *hint_fwd_pd) - : gpu_deconvolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {} - - pd_t(const pd_t &other) = default; - - ~pd_t() = default; - - DECLARE_COMMON_PD_T(name_.c_str(), convolution_deconvolution_fwd_t); - status_t init_convolution(impl::engine_t *engine) { - convolution_desc_t cd; - CHECK(conv_descr_create(desc(), &cd)); - primitive_attr_t conv_attr(*attr()); - if (!conv_attr.is_initialized()) return status::out_of_memory; - primitive_desc_iterator_t it( - engine, (op_desc_t *)&cd, &conv_attr, nullptr); - if (!it.is_initialized()) return status::out_of_memory; - conv_pd_ = *(++it); - - return (conv_pd_) ? status::success : status::unimplemented; - } - - status_t init(impl::engine_t *engine) { - using namespace format_tag; - using sm = primitive_attr_t::skip_mask_t; - - const auto attr_skip_mask = sm::post_ops | sm::zero_points_runtime - | sm::scales_runtime; - - VDISPATCH_DECONVOLUTION(is_fwd(), VERBOSE_BAD_PROPKIND); - VDISPATCH_DECONVOLUTION( - desc()->alg_kind == alg_kind::deconvolution_direct, - VERBOSE_BAD_ALGORITHM); - VDISPATCH_DECONVOLUTION(attr()->has_default_values(attr_skip_mask), - VERBOSE_UNSUPPORTED_ATTR); - VDISPATCH_DECONVOLUTION( - post_ops_with_binary_ok( - attr(), desc()->dst_desc.data_type, ndims()), - VERBOSE_UNSUPPORTED_POSTOP); - VDISPATCH_DECONVOLUTION( - (utils::everyone_is(data_type::f32, - desc()->src_desc.data_type, - desc()->weights_desc.data_type, - desc()->dst_desc.data_type) - || (utils::everyone_is(data_type::f64, - desc()->src_desc.data_type, - desc()->weights_desc.data_type, - desc()->dst_desc.data_type)) - || ((utils::everyone_is(data_type::f16, - desc()->src_desc.data_type, - desc()->weights_desc.data_type) - || utils::everyone_is(data_type::f32, - desc()->src_desc.data_type, - desc()->weights_desc.data_type) - || utils::everyone_is(data_type::bf16, - desc()->src_desc.data_type, - desc()->weights_desc.data_type)) - && utils::one_of(desc()->dst_desc.data_type, - data_type::f16, data_type::u8, - data_type::s8)) - || (utils::everyone_is(data_type::bf16, - desc()->src_desc.data_type, - desc()->weights_desc.data_type) - && utils::one_of(desc()->dst_desc.data_type, - data_type::f32, data_type::bf16)) - || (utils::everyone_is(data_type::f16, - desc()->src_desc.data_type, - desc()->weights_desc.data_type) - && utils::one_of(desc()->dst_desc.data_type, - data_type::f32, data_type::f16)) - || (desc()->weights_desc.data_type == data_type::s8 - && utils::one_of(desc()->src_desc.data_type, - data_type::u8, data_type::s8) - && desc()->dst_desc.data_type - != data_type::f64)), - VERBOSE_UNSUPPORTED_DT); - - VDISPATCH_DECONVOLUTION_SC( - init_convolution(engine), "init_convolution()"); - if (weights_md_.format_kind == format_kind::any) { - VDISPATCH_DECONVOLUTION_SC( - weights_axes_permutation(&weights_md_, - conv_pd_->weights_md(), with_groups()), - "weights_axes_permutation()"); - } - if (src_md_.format_kind == format_kind::any) - src_md_ = *conv_pd_->diff_dst_md(); - if (dst_md_.format_kind == format_kind::any) - dst_md_ = *conv_pd_->diff_src_md(); - if (bias_md_.format_kind == format_kind::any) { - VDISPATCH_DECONVOLUTION_SC(memory_desc_init_by_tag(bias_md_, x), - VERBOSE_UNSUPPORTED_TAG); - } - init_name(); - init_scratchpad(); - VDISPATCH_DECONVOLUTION_SC(attr_.set_default_formats(dst_md(0)), - VERBOSE_UNSUPPORTED_ATTR); - - return status::success; - } - - std::shared_ptr conv_pd_; - - private: - std::string name_ = "conv:any"; - - void init_name() { - name_.append("+"); - name_.append(conv_pd_->name()); - } - - void init_scratchpad() { - auto scratchpad = scratchpad_registry().registrar(); - scratchpad.book(memory_tracking::names::key_nested, - conv_pd_->scratchpad_registry()); - } - }; - - status_t init(impl::engine_t *engine) override { - return create_nested_primitive(conv_p_, pd()->conv_pd_, engine); - } - - status_t execute(const exec_ctx_t &ctx) const override { - using namespace memory_tracking::names; - const auto &args = ctx.args(); - exec_args_t conv_args; - conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC); - conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS); - conv_args[DNNL_ARG_DIFF_SRC] = args.at(DNNL_ARG_DST); - if (pd()->with_bias()) - conv_args[DNNL_ARG_BIAS] = args.at(DNNL_ARG_BIAS); - - for (int idx = 0; idx < pd()->attr()->post_ops_.len(); ++idx) { - if (pd()->attr()->post_ops_.entry_[idx].is_binary()) { - conv_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1] - = args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) - | DNNL_ARG_SRC_1); - } else if (pd()->attr()->post_ops_.entry_[idx].is_prelu()) { - conv_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) - | DNNL_ARG_WEIGHTS] - = args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) - | DNNL_ARG_WEIGHTS); - } - } - const auto z_src = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC; - const auto z_dst = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST; - if (args.find(z_src) != args.end()) conv_args[z_src] = args.at(z_src); - if (args.find(z_dst) != args.end()) conv_args[z_dst] = args.at(z_dst); - - for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { - int key = DNNL_ARG_ATTR_SCALES | arg; - if (args.find(key) != args.end()) conv_args[key] = args.at(key); - } - - exec_ctx_t conv_ctx(ctx, std::move(conv_args)); - - nested_scratchpad_t ns(ctx, key_nested, conv_p_); - conv_ctx.set_scratchpad_grantor(ns.grantor()); - // Executing the convolution kernel - return conv_p_->execute(conv_ctx); - } - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - std::shared_ptr conv_p_; -}; - -struct convolution_deconvolution_bwd_data_t : public gpu_primitive_t { - using gpu_primitive_t::gpu_primitive_t; - struct pd_t : public gpu_deconvolution_bwd_data_pd_t { - pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, - const deconvolution_fwd_pd_t *hint_fwd_pd) - : gpu_deconvolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) - , conv_pd_(nullptr) {} - - pd_t(const pd_t &other) = default; - - ~pd_t() = default; - - DECLARE_COMMON_PD_T( - name_.c_str(), convolution_deconvolution_bwd_data_t); - - status_t init_convolution(impl::engine_t *engine) { - convolution_desc_t cd; - CHECK(conv_descr_create(desc(), &cd)); - primitive_attr_t conv_attr(*attr()); - if (!conv_attr.is_initialized()) return status::out_of_memory; - primitive_desc_iterator_t it( - engine, (op_desc_t *)&cd, &conv_attr, nullptr); - if (!it.is_initialized()) return status::out_of_memory; - conv_pd_ = *(++it); - return (conv_pd_) ? status::success : status::unimplemented; - } - - status_t init(impl::engine_t *engine) { - VDISPATCH_DECONVOLUTION( - desc()->prop_kind == prop_kind::backward_data, - VERBOSE_BAD_PROPKIND); - - VDISPATCH_DECONVOLUTION( - (utils::everyone_is(data_type::f32, - desc()->diff_src_desc.data_type, - desc()->weights_desc.data_type, - desc()->diff_dst_desc.data_type) - || (utils::everyone_is(data_type::f64, - desc()->diff_src_desc.data_type, - desc()->weights_desc.data_type, - desc()->diff_dst_desc.data_type)) - || utils::everyone_is(data_type::f16, - desc()->weights_desc.data_type, - desc()->diff_dst_desc.data_type) - || utils::everyone_is(data_type::bf16, - desc()->weights_desc.data_type, - desc()->diff_dst_desc.data_type)), - VERBOSE_UNSUPPORTED_DT); - - VDISPATCH_DECONVOLUTION( - utils::one_of(desc()->diff_src_desc.data_type, - data_type::bf16, data_type::f16, data_type::f32, - data_type::f64), - VERBOSE_UNSUPPORTED_DT); - VDISPATCH_DECONVOLUTION( - desc()->alg_kind == alg_kind::deconvolution_direct, - VERBOSE_BAD_ALGORITHM); - VDISPATCH_DECONVOLUTION( - attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR); - - VDISPATCH_DECONVOLUTION_SC( - init_convolution(engine), "init_convolution()"); - if (weights_md_.format_kind == format_kind::any) - VDISPATCH_DECONVOLUTION_SC( - weights_axes_permutation(&weights_md_, - conv_pd_->weights_md(), with_groups()), - "weights_axes_permutation()"); - if (diff_src_md_.format_kind == format_kind::any) - diff_src_md_ = *conv_pd_->dst_md(); - if (diff_dst_md_.format_kind == format_kind::any) - diff_dst_md_ = *conv_pd_->src_md(); - - init_name(); - init_scratchpad(); - - return status::success; - } - - std::shared_ptr conv_pd_; - - private: - std::string name_ = "conv:any"; - - void init_name() { - name_.append("+"); - name_.append(conv_pd_->name()); - } - - void init_scratchpad() { - auto scratchpad = scratchpad_registry().registrar(); - scratchpad.book(memory_tracking::names::key_nested, - conv_pd_->scratchpad_registry()); - } - }; - - status_t init(impl::engine_t *engine) override { - return create_nested_primitive(conv_p_, pd()->conv_pd_, engine); - } - - status_t execute(const exec_ctx_t &ctx) const override { - using namespace memory_tracking::names; - const auto &args = ctx.args(); - exec_args_t conv_args; - conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST); - conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS); - conv_args[DNNL_ARG_DST] = args.at(DNNL_ARG_DIFF_SRC); - if (!types::is_zero_md(pd()->scratchpad_md())) - conv_args[DNNL_ARG_SCRATCHPAD] = args.at(DNNL_ARG_SCRATCHPAD); - exec_ctx_t conv_ctx(ctx, std::move(conv_args)); - - nested_scratchpad_t ns(ctx, key_nested, conv_p_); - conv_ctx.set_scratchpad_grantor(ns.grantor()); - // Executing the convolution kernel - return conv_p_->execute(conv_ctx); - } - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - std::shared_ptr conv_p_; -}; - struct convolution_deconvolution_bwd_weights_t : public gpu_primitive_t { using gpu_primitive_t::gpu_primitive_t; struct pd_t : public gpu_deconvolution_bwd_weights_pd_t { From 9ba247d9d0497880c0270bbbf3f4bde9175a1f4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Wed, 2 Oct 2024 16:36:23 +0200 Subject: [PATCH 2/2] generic: deconvolution: bugfix includes (#2140) --- src/gpu/generic/convolution_deconvolution.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/gpu/generic/convolution_deconvolution.hpp b/src/gpu/generic/convolution_deconvolution.hpp index 283f1d4fa2d..ee3ad5e353b 100644 --- a/src/gpu/generic/convolution_deconvolution.hpp +++ b/src/gpu/generic/convolution_deconvolution.hpp @@ -19,9 +19,11 @@ #include "common/c_types_map.hpp" #include "common/primitive.hpp" +#include "common/primitive_desc_iterator.hpp" #include "common/type_helpers.hpp" #include "common/utils.hpp" #include "gpu/gpu_deconvolution_pd.hpp" +#include "gpu/gpu_primitive.hpp" namespace dnnl { namespace impl {