Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gpu: generic: deconvolution: move implementation from ocl folder #2144

Open
wants to merge 2 commits into
base: rls-v3.6
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
376 changes: 376 additions & 0 deletions src/gpu/generic/convolution_deconvolution.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,376 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#ifndef GPU_GENERIC_CONVOLUTION_DECONVOLUTION_HPP
#define GPU_GENERIC_CONVOLUTION_DECONVOLUTION_HPP

#include "common/c_types_map.hpp"
#include "common/primitive.hpp"
#include "common/primitive_desc_iterator.hpp"
#include "common/type_helpers.hpp"
#include "common/utils.hpp"
#include "gpu/gpu_deconvolution_pd.hpp"
#include "gpu/gpu_primitive.hpp"

namespace dnnl {
namespace impl {
namespace gpu {
namespace generic {

static status_t weights_axes_permutation(
memory_desc_t *o_md, const memory_desc_t *i_md, bool with_groups) {
int perm[DNNL_MAX_NDIMS] {}; // deconv to conv weight permutation
for (int d = 0; d < DNNL_MAX_NDIMS; ++d)
perm[d] = d;
nstl::swap(perm[0 + with_groups], perm[1 + with_groups]);

return memory_desc_permute_axes(*o_md, *i_md, perm);
}

static status_t conv_descr_create(
const deconvolution_desc_t *dd, convolution_desc_t *cd) {
using namespace prop_kind;
alg_kind_t alg_kind = alg_kind::convolution_direct;

const memory_desc_t *src_md, *dst_md, *d_weights_d;
prop_kind_t prop_kind;

switch (dd->prop_kind) {
case forward:
case forward_inference:
prop_kind = backward_data;
src_md = &dd->dst_desc;
dst_md = &dd->src_desc;
d_weights_d = &dd->weights_desc;
break;
case backward_data:
prop_kind = forward_training;
src_md = &dd->diff_dst_desc;
dst_md = &dd->diff_src_desc;
d_weights_d = &dd->weights_desc;
break;
case backward_weights:
prop_kind = dd->prop_kind;
src_md = &dd->diff_dst_desc;
dst_md = &dd->src_desc;
d_weights_d = &dd->diff_weights_desc;
break;
default: assert(!"unknown prop kind"); return status::invalid_arguments;
}

// Create weights desc for convolution
memory_desc_t c_weights_d;
const bool with_groups = d_weights_d->ndims == src_md->ndims + 1;
CHECK(weights_axes_permutation(&c_weights_d, d_weights_d, with_groups));

return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d,
prop_kind != backward_weights ? &dd->bias_desc : nullptr, dst_md,
dd->strides, dd->dilates, dd->padding[0], dd->padding[1]);
}

struct convolution_deconvolution_fwd_t : public gpu::primitive_t {
using gpu::primitive_t::primitive_t;
struct pd_t : public gpu_deconvolution_fwd_pd_t {
pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
const deconvolution_fwd_pd_t *hint_fwd_pd)
: gpu_deconvolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {}

pd_t(const pd_t &other) = default;

~pd_t() = default;

DECLARE_COMMON_PD_T(name_.c_str(), convolution_deconvolution_fwd_t);
status_t init_convolution(impl::engine_t *engine) {
convolution_desc_t cd;
CHECK(conv_descr_create(desc(), &cd));
primitive_attr_t conv_attr(*attr());
if (!conv_attr.is_initialized()) return status::out_of_memory;
primitive_desc_iterator_t it(
engine, (op_desc_t *)&cd, &conv_attr, nullptr);
if (!it.is_initialized()) return status::out_of_memory;
conv_pd_ = *(++it);

return (conv_pd_) ? status::success : status::unimplemented;
}

status_t init(impl::engine_t *engine) {
using namespace format_tag;
using sm = primitive_attr_t::skip_mask_t;

const auto attr_skip_mask = sm::post_ops | sm::zero_points_runtime
| sm::scales_runtime;

VDISPATCH_DECONVOLUTION(is_fwd(), VERBOSE_BAD_PROPKIND);
VDISPATCH_DECONVOLUTION(
desc()->alg_kind == alg_kind::deconvolution_direct,
VERBOSE_BAD_ALGORITHM);
VDISPATCH_DECONVOLUTION(attr()->has_default_values(attr_skip_mask),
VERBOSE_UNSUPPORTED_ATTR);
VDISPATCH_DECONVOLUTION(
(utils::everyone_is(data_type::f32,
desc()->src_desc.data_type,
desc()->weights_desc.data_type,
desc()->dst_desc.data_type)
|| (utils::everyone_is(data_type::f64,
desc()->src_desc.data_type,
desc()->weights_desc.data_type,
desc()->dst_desc.data_type))
|| ((utils::everyone_is(data_type::f16,
desc()->src_desc.data_type,
desc()->weights_desc.data_type)
|| utils::everyone_is(data_type::f32,
desc()->src_desc.data_type,
desc()->weights_desc.data_type)
|| utils::everyone_is(data_type::bf16,
desc()->src_desc.data_type,
desc()->weights_desc.data_type))
&& utils::one_of(desc()->dst_desc.data_type,
data_type::f16, data_type::u8,
data_type::s8))
|| (utils::everyone_is(data_type::bf16,
desc()->src_desc.data_type,
desc()->weights_desc.data_type)
&& utils::one_of(desc()->dst_desc.data_type,
data_type::f32, data_type::bf16))
|| (utils::everyone_is(data_type::f16,
desc()->src_desc.data_type,
desc()->weights_desc.data_type)
&& utils::one_of(desc()->dst_desc.data_type,
data_type::f32, data_type::f16))
|| (desc()->weights_desc.data_type == data_type::s8
&& utils::one_of(desc()->src_desc.data_type,
data_type::u8, data_type::s8)
&& desc()->dst_desc.data_type
!= data_type::f64)),
VERBOSE_UNSUPPORTED_DT);

VDISPATCH_DECONVOLUTION_SC(
init_convolution(engine), "init_convolution()");
if (weights_md_.format_kind == format_kind::any) {
VDISPATCH_DECONVOLUTION_SC(
weights_axes_permutation(&weights_md_,
conv_pd_->weights_md(), with_groups()),
"weights_axes_permutation()");
}
if (src_md_.format_kind == format_kind::any)
src_md_ = *conv_pd_->diff_dst_md();
if (dst_md_.format_kind == format_kind::any)
dst_md_ = *conv_pd_->diff_src_md();
if (bias_md_.format_kind == format_kind::any) {
VDISPATCH_DECONVOLUTION_SC(memory_desc_init_by_tag(bias_md_, x),
VERBOSE_UNSUPPORTED_TAG);
}
init_name();
init_scratchpad();
VDISPATCH_DECONVOLUTION_SC(attr_.set_default_formats(dst_md(0)),
VERBOSE_UNSUPPORTED_ATTR);

return status::success;
}

std::shared_ptr<primitive_desc_t> conv_pd_;

private:
std::string name_ = "conv:any";

void init_name() {
name_.append("+");
name_.append(conv_pd_->name());
}

void init_scratchpad() {
auto scratchpad = scratchpad_registry().registrar();
scratchpad.book(memory_tracking::names::key_nested,
conv_pd_->scratchpad_registry());
}
};

status_t init(impl::engine_t *engine) override {
return create_nested_primitive(conv_p_, pd()->conv_pd_, engine);
}

status_t execute(const exec_ctx_t &ctx) const override {
using namespace memory_tracking::names;
const auto &args = ctx.args();
exec_args_t conv_args;
conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC);
conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS);
conv_args[DNNL_ARG_DIFF_SRC] = args.at(DNNL_ARG_DST);
if (pd()->with_bias())
conv_args[DNNL_ARG_BIAS] = args.at(DNNL_ARG_BIAS);

for (int idx = 0; idx < pd()->attr()->post_ops_.len(); ++idx) {
if (pd()->attr()->post_ops_.entry_[idx].is_binary()) {
conv_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1]
= args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
| DNNL_ARG_SRC_1);
} else if (pd()->attr()->post_ops_.entry_[idx].is_prelu()) {
conv_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
| DNNL_ARG_WEIGHTS]
= args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
| DNNL_ARG_WEIGHTS);
}
}
const auto z_src = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC;
const auto z_dst = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST;
if (args.find(z_src) != args.end()) conv_args[z_src] = args.at(z_src);
if (args.find(z_dst) != args.end()) conv_args[z_dst] = args.at(z_dst);

for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) {
int key = DNNL_ARG_ATTR_SCALES | arg;
if (args.find(key) != args.end()) conv_args[key] = args.at(key);
}

exec_ctx_t conv_ctx(ctx, std::move(conv_args));

nested_scratchpad_t ns(ctx, key_nested, conv_p_);
conv_ctx.set_scratchpad_grantor(ns.grantor());
// Executing the convolution kernel
return conv_p_->execute(conv_ctx);
}

private:
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
std::shared_ptr<impl::primitive_t> conv_p_;
};

struct convolution_deconvolution_bwd_data_t : public gpu::primitive_t {
using gpu::primitive_t::primitive_t;
struct pd_t : public gpu_deconvolution_bwd_data_pd_t {
pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
const deconvolution_fwd_pd_t *hint_fwd_pd)
: gpu_deconvolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd)
, conv_pd_(nullptr) {}

pd_t(const pd_t &other) = default;

~pd_t() = default;

DECLARE_COMMON_PD_T(
name_.c_str(), convolution_deconvolution_bwd_data_t);

status_t init_convolution(impl::engine_t *engine) {
convolution_desc_t cd;
CHECK(conv_descr_create(desc(), &cd));
primitive_attr_t conv_attr(*attr());
if (!conv_attr.is_initialized()) return status::out_of_memory;
primitive_desc_iterator_t it(
engine, (op_desc_t *)&cd, &conv_attr, nullptr);
if (!it.is_initialized()) return status::out_of_memory;
conv_pd_ = *(++it);
return (conv_pd_) ? status::success : status::unimplemented;
}

status_t init(impl::engine_t *engine) {
VDISPATCH_DECONVOLUTION(
desc()->prop_kind == prop_kind::backward_data,
VERBOSE_BAD_PROPKIND);

VDISPATCH_DECONVOLUTION(
(utils::everyone_is(data_type::f32,
desc()->diff_src_desc.data_type,
desc()->weights_desc.data_type,
desc()->diff_dst_desc.data_type)
|| (utils::everyone_is(data_type::f64,
desc()->diff_src_desc.data_type,
desc()->weights_desc.data_type,
desc()->diff_dst_desc.data_type))
|| utils::everyone_is(data_type::f16,
desc()->weights_desc.data_type,
desc()->diff_dst_desc.data_type)
|| utils::everyone_is(data_type::bf16,
desc()->weights_desc.data_type,
desc()->diff_dst_desc.data_type)),
VERBOSE_UNSUPPORTED_DT);

VDISPATCH_DECONVOLUTION(
utils::one_of(desc()->diff_src_desc.data_type,
data_type::bf16, data_type::f16, data_type::f32,
data_type::f64),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_DECONVOLUTION(
desc()->alg_kind == alg_kind::deconvolution_direct,
VERBOSE_BAD_ALGORITHM);
VDISPATCH_DECONVOLUTION(
attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR);

VDISPATCH_DECONVOLUTION_SC(
init_convolution(engine), "init_convolution()");
if (weights_md_.format_kind == format_kind::any)
VDISPATCH_DECONVOLUTION_SC(
weights_axes_permutation(&weights_md_,
conv_pd_->weights_md(), with_groups()),
"weights_axes_permutation()");
if (diff_src_md_.format_kind == format_kind::any)
diff_src_md_ = *conv_pd_->dst_md();
if (diff_dst_md_.format_kind == format_kind::any)
diff_dst_md_ = *conv_pd_->src_md();

init_name();
init_scratchpad();

return status::success;
}

std::shared_ptr<primitive_desc_t> conv_pd_;

private:
std::string name_ = "conv:any";

void init_name() {
name_.append("+");
name_.append(conv_pd_->name());
}

void init_scratchpad() {
auto scratchpad = scratchpad_registry().registrar();
scratchpad.book(memory_tracking::names::key_nested,
conv_pd_->scratchpad_registry());
}
};

status_t init(impl::engine_t *engine) override {
return create_nested_primitive(conv_p_, pd()->conv_pd_, engine);
}

status_t execute(const exec_ctx_t &ctx) const override {
using namespace memory_tracking::names;
const auto &args = ctx.args();
exec_args_t conv_args;
conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST);
conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS);
conv_args[DNNL_ARG_DST] = args.at(DNNL_ARG_DIFF_SRC);
if (!types::is_zero_md(pd()->scratchpad_md()))
conv_args[DNNL_ARG_SCRATCHPAD] = args.at(DNNL_ARG_SCRATCHPAD);
exec_ctx_t conv_ctx(ctx, std::move(conv_args));

nested_scratchpad_t ns(ctx, key_nested, conv_p_);
conv_ctx.set_scratchpad_grantor(ns.grantor());
// Executing the convolution kernel
return conv_p_->execute(conv_ctx);
}

private:
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
std::shared_ptr<impl::primitive_t> conv_p_;
};

} // namespace generic
} // namespace gpu
} // namespace impl
} // namespace dnnl

#endif
Loading