diff --git a/src/gpu/nvidia/cudnn_matmul.cpp b/src/gpu/nvidia/cudnn_matmul.cpp index 35923ef62fe..6665a83db95 100644 --- a/src/gpu/nvidia/cudnn_matmul.cpp +++ b/src/gpu/nvidia/cudnn_matmul.cpp @@ -41,31 +41,17 @@ status_t cudnn_matmul_t::execute(const exec_ctx_t &ctx) const { const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md()); const auto bias_d = ctx.memory_mdw(DNNL_ARG_BIAS, pd()->weights_md(1)); - status_t status; - size_t bias_scratchpad_size - = 0; // To avoid extra allocation in an executor. - - bool has_runtime_args = matmul_impl_->has_runtime_params(); - if (has_runtime_args) { - // Initialise all runtime parameters - status = matmul_impl_->init_parameters(src_d, weights_d, dst_d, bias_d); - if (status != status::success) return status; - - bias_scratchpad_size = matmul_impl_->bias_scratch_size(); - } - nvidia::stream_t *cuda_stream = utils::downcast(ctx.stream()); - status = executor_->execute( - ctx, ctx.stream()->engine(), matmul_impl_, bias_scratchpad_size); + status_t status = executor_->execute(ctx, ctx.stream()->engine(), + matmul_impl_, pd()->params_, src_d, weights_d, dst_d, bias_d); - if (has_runtime_args) { + if (pd()->params_->has_runtime_params_) { auto &evts = cuda_stream->sycl_ctx().get_sycl_deps().events; for (auto e : evts) { e.wait(); } - matmul_impl_->cleanup(); } return status; } @@ -76,32 +62,6 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const { const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md()); const auto weights_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md()); const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md()); - const auto bias_d = ctx.memory_mdw(DNNL_ARG_BIAS, pd()->weights_md(1)); - - // To avoid extra allocation in an executor. - size_t algo_scratchpad_size = 0; - size_t bias_scratchpad_size = 0; - size_t block_a_scratchpad_size = 0; - size_t block_b_scratchpad_size = 0; - size_t block_c_scratchpad_size = 0; - size_t src_scale_scratchpad_size = 0; - size_t wei_scale_scratchpad_size = 0; - - bool has_runtime_args = matmul_impl_->has_runtime_params(); - if (has_runtime_args) { - // Initialise all runtime parameters - auto engine = ctx.stream()->engine(); - CHECK(matmul_impl_->init_parameters( - src_d, weights_d, dst_d, bias_d, engine)); - - algo_scratchpad_size = matmul_impl_->algo_scratch_size(); - bias_scratchpad_size = matmul_impl_->bias_scratch_size(); - block_a_scratchpad_size = matmul_impl_->block_a_scratch_size(); - block_b_scratchpad_size = matmul_impl_->block_b_scratch_size(); - block_c_scratchpad_size = matmul_impl_->block_c_scratch_size(); - src_scale_scratchpad_size = matmul_impl_->src_scale_size(); - wei_scale_scratchpad_size = matmul_impl_->wei_scale_size(); - } nvidia::stream_t *cuda_stream = utils::downcast(ctx.stream()); @@ -117,8 +77,8 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const { != ctx.args().end(); if (has_src_scales - && (matmul_impl_->multi_src_scale() - || matmul_impl_->scale_type() == CUDA_R_32I)) { + && (pd()->params_->multi_src_scale_ + || pd()->params_->acc_type_ == CUDA_R_32I)) { // src scale sycl binary exec_args_t src_scale_binary_args; src_scale_binary_args[DNNL_ARG_SRC_0] @@ -141,8 +101,8 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const { CHECK(src_scale_binary_->execute(binary_ctx)); } if (has_wei_scales - && (matmul_impl_->multi_wei_scale() - || matmul_impl_->scale_type() == CUDA_R_32I)) { + && (pd()->params_->multi_wei_scale_ + || pd()->params_->acc_type_ == CUDA_R_32I)) { // wei scale sycl binary exec_args_t wei_scale_binary_args; wei_scale_binary_args[DNNL_ARG_SRC_0] @@ -167,11 +127,9 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const { } CHECK(executor_->execute(ctx, ctx.stream()->engine(), matmul_impl_, - algo_scratchpad_size, bias_scratchpad_size, block_a_scratchpad_size, - block_b_scratchpad_size, block_c_scratchpad_size, - src_scale_scratchpad_size, wei_scale_scratchpad_size)); + pd()->params_, src_d, weights_d, dst_d)); - if (matmul_impl_->with_bias()) { + if (pd()->params_->with_bias_) { // bias sycl binary exec_args_t binary_args; std::unique_ptr scratch_mem; @@ -198,8 +156,8 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const { } if (has_dst_scales - && (matmul_impl_->multi_dst_scale() - || matmul_impl_->scale_type() == CUDA_R_32I)) { + && (pd()->params_->multi_dst_scale_ + || pd()->params_->acc_type_ == CUDA_R_32I)) { // dst scale sycl binary exec_args_t dst_scale_binary_args; dst_scale_binary_args[DNNL_ARG_SRC_0] @@ -213,13 +171,11 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const { CHECK(dst_scale_binary_->execute(binary_ctx)); } - if (has_runtime_args) { + if (pd()->params_->has_runtime_params_) { auto &evts = cuda_stream->sycl_ctx().get_sycl_deps().events; for (auto e : evts) { e.wait(); } - - matmul_impl_->rt_cleanup(); } return status::success; diff --git a/src/gpu/nvidia/cudnn_matmul.hpp b/src/gpu/nvidia/cudnn_matmul.hpp index cfd919f9031..f1674c860fe 100644 --- a/src/gpu/nvidia/cudnn_matmul.hpp +++ b/src/gpu/nvidia/cudnn_matmul.hpp @@ -20,9 +20,12 @@ #include "gpu/gpu_matmul_pd.hpp" -#include "gpu/nvidia/cudnn_matmul_base.hpp" +#include "common/primitive.hpp" +#include "common/primitive_desc_iterator.hpp" +#include "gpu/gpu_primitive.hpp" #include "gpu/nvidia/cudnn_matmul_executor.hpp" #include "gpu/nvidia/cudnn_matmul_impl.hpp" +#include "gpu/nvidia/cudnn_matmul_lt_impl.hpp" #include "gpu/nvidia/sycl_cuda_utils.hpp" namespace dnnl { @@ -30,8 +33,8 @@ namespace impl { namespace gpu { namespace nvidia { -struct cudnn_matmul_t : cudnn_matmul_base_t { - using cudnn_matmul_base_t::cudnn_matmul_base_t; +struct cudnn_matmul_t : public gpu::primitive_t { + using primitive_t::primitive_t; struct pd_t : public gpu_matmul_pd_t { using gpu_matmul_pd_t::gpu_matmul_pd_t; @@ -79,12 +82,15 @@ struct cudnn_matmul_t : cudnn_matmul_base_t { if (src_md()->ndims > 3) return status::unimplemented; - return status::success; - } + params_ = std::make_shared(); + CHECK(params_->init(src_md(), weights_md(), dst_md(), weights_md(1), + attr(), batched(), with_bias())); - size_t scratchpad_size(const memory_desc_t *dst_md) const { - const auto dst_nelems = memory_desc_wrapper(dst_md).nelems(true); - return dst_nelems * sizeof(float); + if (!params_->has_runtime_params_) { + auto scratchpad = scratchpad_registry().registrar(); + params_->init_scratchpad(dst_md(), scratchpad); + } + return status::success; } bool scales_ok() const { @@ -116,21 +122,22 @@ struct cudnn_matmul_t : cudnn_matmul_base_t { } return true; } + + std::shared_ptr params_; }; status_t init(impl::engine_t *engine) override { matmul_impl_.reset(new cudnn_matmul_impl_t()); - auto status = matmul_impl_->init((matmul_pd_t *)pd()); - if (status != status::success) return status; - bool has_runtime_args = matmul_impl_->has_runtime_params(); + bool has_runtime_args = pd()->params_->has_runtime_params_; if (has_runtime_args) { executor_.reset(new cudnn_matmul_runtime_args_exec_t); } else { executor_.reset(new cudnn_matmul_exec_t); + matmul_impl_->set_non_runtime_params(pd()->params_); } - return status; + return status::success; } status_t execute(const exec_ctx_t &ctx) const override; diff --git a/src/gpu/nvidia/cudnn_matmul_base.hpp b/src/gpu/nvidia/cudnn_matmul_base.hpp deleted file mode 100644 index fdf1d2caaf6..00000000000 --- a/src/gpu/nvidia/cudnn_matmul_base.hpp +++ /dev/null @@ -1,50 +0,0 @@ -/******************************************************************************* -* Copyright 2024 Intel Corporation -* Copyright 2024 Codeplay Software Limited -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef GPU_NVIDIA_CUDNN_MATMUL_BASE_HPP -#define GPU_NVIDIA_CUDNN_MATMUL_BASE_HPP - -#include "gpu/gpu_matmul_pd.hpp" - -#include "common/primitive.hpp" -#include "common/primitive_desc_iterator.hpp" -#include "gpu/gpu_primitive.hpp" -#include "gpu/nvidia/cudnn_matmul_executor.hpp" -#include "gpu/nvidia/cudnn_matmul_impl.hpp" -#include "gpu/nvidia/cudnn_matmul_lt_impl.hpp" -#include "gpu/nvidia/sycl_cuda_utils.hpp" - -namespace dnnl { -namespace impl { -namespace gpu { -namespace nvidia { - -struct cudnn_matmul_base_t : public gpu::primitive_t { - using primitive_t::primitive_t; - - struct pd_t : public gpu_matmul_pd_t { - using gpu_matmul_pd_t::gpu_matmul_pd_t; - virtual status_t init(impl::engine_t *engine) = 0; - }; -}; - -} // namespace nvidia -} // namespace gpu -} // namespace impl -} // namespace dnnl - -#endif diff --git a/src/gpu/nvidia/cudnn_matmul_base_impl.hpp b/src/gpu/nvidia/cudnn_matmul_base_impl.hpp index bb878192e2d..ab5b7a24c6f 100644 --- a/src/gpu/nvidia/cudnn_matmul_base_impl.hpp +++ b/src/gpu/nvidia/cudnn_matmul_base_impl.hpp @@ -30,124 +30,38 @@ namespace impl { namespace gpu { namespace nvidia { -struct cudnn_matmul_base_impl_t { - virtual status_t init_gemm_parameters(const memory_desc_wrapper src_d, - const memory_desc_wrapper weights_d, - const memory_desc_wrapper dst_d) - = 0; - virtual void init_scratchpad(matmul_pd_t *pd) = 0; - virtual void cleanup() = 0; +struct cublas_base_params { - bool isbatched() { return isbatched_; } - bool with_bias() { return with_separate_bias_; } - bool bias_dt_mismatch() { return bias_dt_mismatch_; } - bool has_runtime_params() { return has_runtime_params_; } - - bool with_eltwise(int position, const matmul_pd_t *pd) { - return pd->attr()->post_ops_.contain(primitive_kind::eltwise, position); + bool with_eltwise(int position, const primitive_attr_t *attr) { + return attr->post_ops_.contain(primitive_kind::eltwise, position); } - bool with_sum(const matmul_pd_t *pd) { - return pd->attr()->post_ops_.contain(primitive_kind::sum, 0) - || pd->attr()->post_ops_.contain(primitive_kind::sum, 1); + bool with_sum(const primitive_attr_t *attr) { + return attr->post_ops_.contain(primitive_kind::sum, 0) + || attr->post_ops_.contain(primitive_kind::sum, 1); } // Returns scaling factor for post-ops=sum operation - float sum_scale(const matmul_pd_t *pd) { - int sum_idx_ = pd->attr()->post_ops_.find(primitive_kind::sum); - return pd->attr()->post_ops_.entry_[sum_idx_].sum.scale; - } - - float eltwise_alpha(const matmul_pd_t *pd) { - int eltwise_idx_ = pd->attr()->post_ops_.find(primitive_kind::eltwise); - return with_eltwise(0, pd) || with_eltwise(1, pd) - ? pd->attr()->post_ops_.entry_[eltwise_idx_].eltwise.alpha - : 1.0f; - } - - float eltwise_beta(const matmul_pd_t *pd) { - int eltwise_idx_ = pd->attr()->post_ops_.find(primitive_kind::eltwise); - return with_eltwise(0, pd) || with_eltwise(1, pd) - ? pd->attr()->post_ops_.entry_[eltwise_idx_].eltwise.beta - : 0.0f; + float sum_scale(const primitive_attr_t *attr) { + int sum_idx_ = attr->post_ops_.find(primitive_kind::sum); + return attr->post_ops_.entry_[sum_idx_].sum.scale; } - alg_kind_t eltwise_algo(const matmul_pd_t *pd) { - int eltwise_idx_ = pd->attr()->post_ops_.find(primitive_kind::eltwise); - return with_eltwise(0, pd) || with_eltwise(1, pd) - ? pd->attr()->post_ops_.entry_[eltwise_idx_].eltwise.alg + alg_kind_t eltwise_algo(const primitive_attr_t *attr) { + int eltwise_idx_ = attr->post_ops_.find(primitive_kind::eltwise); + return with_eltwise(0, attr) || with_eltwise(1, attr) + ? attr->post_ops_.entry_[eltwise_idx_].eltwise.alg : dnnl_alg_kind_undef; } - int get_ld(const memory_desc_wrapper desc, cublasOperation_t trans) { - const int ndims = desc.ndims(); - const auto *strides = &desc.blocking_desc().strides[ndims - 2]; - const int ld = strides[trans == cublasOperation_t::CUBLAS_OP_N ? 0 : 1]; - return ld; - } - // creates operation descriptor based on the elemen-wise operation specified - status_t create_and_set_op_descriptor( - const matmul_pd_t *pd, cudnnActivationDescriptor_t &act_desc) { - CHECK(CUDNN_EXECUTE_FUNC_S(cudnnCreateActivationDescriptor, &act_desc)); - - cudnnActivationMode_t mode; - - switch (eltwise_algo(pd)) { - case alg_kind::eltwise_relu: - mode = cudnnActivationMode_t::CUDNN_ACTIVATION_RELU; - break; - case alg_kind::eltwise_tanh: - mode = cudnnActivationMode_t::CUDNN_ACTIVATION_TANH; - break; - case alg_kind::eltwise_elu: - mode = cudnnActivationMode_t::CUDNN_ACTIVATION_ELU; - break; - case alg_kind::eltwise_logistic: - mode = cudnnActivationMode_t::CUDNN_ACTIVATION_SIGMOID; - break; - default: return status::unimplemented; - } - - // NaNs by default are propagated in oneDNN, although the forward - // convolution routine does not support this. - auto propagate_nan = cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN; - - // For ReLU, a ceiling of 0 means no limit. - double ceiling = eltwise_alpha(pd); - - CHECK(CUDNN_EXECUTE_FUNC_S(cudnnSetActivationDescriptor, act_desc, mode, - propagate_nan, ceiling)); - - return status::success; - } - - void convert_dims_matmul( - const dnnl_dim_t *dims, int *new_dims, int n_dims) { - // Moving the dimensions because cudnnAddTensor doesn't work when - // bia_mask=1 - if (n_dims == 3) { return convert_dims(dims, new_dims, n_dims); } - new_dims[0] = 1; - for (size_t i = 0; i < n_dims; i++) { - new_dims[i + 1] = static_cast(dims[i]); - } - for (size_t i = n_dims; i < 4; i++) { - new_dims[i + 1] = 1; - } - } - int get_batch_stride(const memory_desc_wrapper desc) { auto dims = desc.dims(); auto strides = desc.blocking_desc().strides; return dims[0] == 1 ? 0 : strides[0]; } - virtual ~cudnn_matmul_base_impl_t() = default; - - size_t bias_scratch_size() { return reorder_scratch_size_; } + uint64_t M_, N_, K_; -protected: - int lda_, ldb_, ldc_; - long long int stride_a_, stride_b_, stride_c_; bool isbatched_ = false, with_separate_bias_ = false, bias_dt_mismatch_ = false, with_dst_scale_ = false; bool reorder_required_ = false, with_separate_eltwise_ = false; @@ -160,37 +74,13 @@ struct cudnn_matmul_base_impl_t { temp_mem_desc_ = nullptr; cudnnActivationDescriptor_t act_desc_ = nullptr; float post_op_sum_ = 0; - size_t algo_scratch_size_ = 0; size_t reorder_scratch_size_ = 0; - status_t handle_post_ops(cudnnHandle_t cudnn_handle, void *dst, void *bias, - void *reorder_scratch, float host_dst_scale) { - if (with_separate_bias_) { - // When bias is specified call cudnnAddTensor() - float bias_beta = 1; - auto scale = (with_separate_eltwise_ ? 1 : 1.0f / host_dst_scale); - CUDNN_EXECUTE_FUNC(cudnnAddTensor, cudnn_handle, &scale, - tensor_descs_[io::bias], bias, &bias_beta, temp_mem_desc_, - reorder_scratch); - } - if (with_separate_eltwise_) { - // Perform elementwise operation if specified - float alpha = 1.0f / host_dst_scale; - float beta = 0; - CUDNN_EXECUTE_FUNC(cudnnActivationForward, cudnn_handle, act_desc_, - &alpha, temp_mem_desc_, reorder_scratch, &beta, - temp_mem_desc_, reorder_scratch); - } - if (reorder_required_) { - // Reorder from scratchpad to destination if required - float reorder_alpha = 1; - CUDNN_EXECUTE_FUNC(cudnnTransformTensor, cudnn_handle, - &reorder_alpha, temp_mem_desc_, reorder_scratch, - &post_op_sum_, tensor_descs_[io::dst], dst); - } - - return status::success; - } + cublasOperation_t transA_; + cublasOperation_t transB_; + cublasOperation_t transC_; + cublasGemmAlgo_t gemm_algo_ + = cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP; }; } // namespace nvidia diff --git a/src/gpu/nvidia/cudnn_matmul_executor.hpp b/src/gpu/nvidia/cudnn_matmul_executor.hpp index a8e790d7193..f78cd853d0a 100644 --- a/src/gpu/nvidia/cudnn_matmul_executor.hpp +++ b/src/gpu/nvidia/cudnn_matmul_executor.hpp @@ -38,14 +38,17 @@ struct cudnn_matmul_base_exec_t { virtual status_t execute(const exec_ctx_t &ctx, impl::engine_t *engine, const std::shared_ptr matmul_impl_, - std::size_t bias_scratch_size) + const std::shared_ptr params, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d) = 0; protected: template <::sycl::access::mode bias_m, ::sycl::access::mode scratch_m> void interop_task(std::shared_ptr matmul_impl_, - impl::engine_t *engine, ::sycl::handler &cgh, - nvidia::stream_t *cuda_stream, + const std::shared_ptr params, impl::engine_t *engine, + ::sycl::handler &cgh, nvidia::stream_t *cuda_stream, xpu::sycl::interop_memory_arg_t<::sycl::access::mode::read> arg_weights, xpu::sycl::interop_memory_arg_t<::sycl::access::mode::read> arg_src, @@ -87,22 +90,22 @@ struct cudnn_matmul_base_exec_t { void *wei_scale = arg_wei_scale.get_native_pointer(ih); void *dst_scale = arg_dst_scale.get_native_pointer(ih); - matmul_impl_->execute(cublas_handle, cudnn_handle, weights, - src, dst, bias, reorder_scratch, src_scale, + matmul_impl_->execute(cublas_handle, cudnn_handle, params, + weights, src, dst, bias, reorder_scratch, src_scale, wei_scale, dst_scale); - free_runtime_scratch(matmul_impl_->has_runtime_params(), - cublas_handle, cuda_stream, bias_scratch_ptr); + if (params->has_runtime_params_) { + sync_device(); + free_runtime_scratch( + cublas_handle, cuda_stream, bias_scratch_ptr); + params->cleanup(); + } }); } - void free_runtime_scratch(bool has_runtime_params, - cublasHandle_t cublas_handle, nvidia::stream_t *cuda_stream, - uint8_t *bias_scratch_ptr) { - if (has_runtime_params && bias_scratch_ptr) { - cudaStream_t streamId; - cublasGetStream(cublas_handle, &streamId); - cudaStreamSynchronize(streamId); + void free_runtime_scratch(cublasHandle_t cublas_handle, + nvidia::stream_t *cuda_stream, uint8_t *bias_scratch_ptr) { + if (bias_scratch_ptr) { ::sycl::free(bias_scratch_ptr, cuda_stream->queue()); } } @@ -123,19 +126,27 @@ struct cudnn_matmul_exec_t final : cudnn_matmul_base_exec_t { status_t execute(const exec_ctx_t &ctx, impl::engine_t *engine, const std::shared_ptr matmul_impl_, - std::size_t bias_scratch_size) override { + const std::shared_ptr params, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const memory_desc_wrapper &bias_d) override { nvidia::stream_t *cuda_stream = utils::downcast(ctx.stream()); - return cuda_stream->interop_task([=, this](::sycl::handler &cgh) { + return cuda_stream->interop_task([= WA_THIS_COPY_CAPTURE]( + ::sycl::handler &cgh) { auto arg_src = CTX_IN_SYCL_MEMORY(DNNL_ARG_SRC); auto arg_wt = CTX_IN_SYCL_MEMORY(DNNL_ARG_WEIGHTS); auto arg_dst = CTX_OUT_SYCL_MEMORY(DNNL_ARG_DST); auto arg_bias = CTX_IN_SYCL_MEMORY(DNNL_ARG_BIAS); - auto arg_bias_scratch = CTX_SCRATCH_SYCL_MEMORY( - memory_tracking::names::key_matmul_dst_in_acc_dt); + auto arg_bias_scratch = params->reorder_scratch_size_ != 0 + ? CTX_SCRATCH_SYCL_MEMORY( + memory_tracking::names::key_matmul_dst_in_acc_dt) + : xpu::sycl::interop_memory_arg_t< + ::sycl::access::mode::read_write>(); auto arg_src_scale = CTX_IN_SYCL_MEMORY(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC); @@ -144,7 +155,7 @@ struct cudnn_matmul_exec_t final : cudnn_matmul_base_exec_t { auto arg_dst_scale = CTX_IN_SYCL_MEMORY(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); - interop_task(matmul_impl_, engine, cgh, cuda_stream, arg_wt, + interop_task(matmul_impl_, params, engine, cgh, cuda_stream, arg_wt, arg_src, arg_dst, arg_bias, arg_bias_scratch, arg_src_scale, arg_wei_scale, arg_dst_scale, nullptr); }); @@ -157,27 +168,37 @@ struct cudnn_matmul_runtime_args_exec_t final : public cudnn_matmul_base_exec_t { status_t execute(const exec_ctx_t &ctx, impl::engine_t *engine, const std::shared_ptr matmul_impl_, - std::size_t bias_scratch_size) override { + const std::shared_ptr params, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const memory_desc_wrapper &bias_d) override { nvidia::stream_t *cuda_stream = utils::downcast(ctx.stream()); + std::shared_ptr matmul_params + = std::make_shared(); + matmul_params->init_from_params(params); + if (matmul_params->has_runtime_params_) { + matmul_params->cleanup(); + matmul_params->set_params(src_d, weights_d, dst_d, bias_d); + } + uint8_t *bias_scratch_ptr = nullptr; - if (bias_scratch_size > 0) { + if (matmul_params->reorder_scratch_size_ > 0) { bias_scratch_ptr = ::sycl::malloc_device( - bias_scratch_size, cuda_stream->queue()); + matmul_params->reorder_scratch_size_, cuda_stream->queue()); } - auto status = cuda_stream->interop_task([=, this]( - ::sycl::handler &cgh) { + return cuda_stream->interop_task([= WA_THIS_COPY_CAPTURE]( + ::sycl::handler &cgh) { auto arg_src = CTX_IN_SYCL_MEMORY(DNNL_ARG_SRC); auto arg_wt = CTX_IN_SYCL_MEMORY(DNNL_ARG_WEIGHTS); auto arg_dst = CTX_OUT_SYCL_MEMORY(DNNL_ARG_DST); auto arg_bias = CTX_IN_SYCL_MEMORY(DNNL_ARG_BIAS); - //auto arg_bias_scratch = init_scratch_from_buffer( - // bias_scratch_size, bias_scratch_buff_, cgh); auto arg_bias_scratch = init_scratch_from_ptr( - bias_scratch_size, bias_scratch_ptr); + matmul_params->reorder_scratch_size_, bias_scratch_ptr); auto arg_src_scale = CTX_IN_SYCL_MEMORY(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC); @@ -186,12 +207,11 @@ struct cudnn_matmul_runtime_args_exec_t final auto arg_dst_scale = CTX_IN_SYCL_MEMORY(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); - interop_task(matmul_impl_, engine, cgh, cuda_stream, arg_wt, - arg_src, arg_dst, arg_bias, arg_bias_scratch, arg_src_scale, - arg_wei_scale, arg_dst_scale, bias_scratch_ptr); + interop_task(matmul_impl_, matmul_params, engine, cgh, cuda_stream, + arg_wt, arg_src, arg_dst, arg_bias, arg_bias_scratch, + arg_src_scale, arg_wei_scale, arg_dst_scale, + bias_scratch_ptr); }); - - return status; } ~cudnn_matmul_runtime_args_exec_t() = default; @@ -202,16 +222,16 @@ struct cudnn_matmul_lt_base_exec_t { virtual status_t execute(const exec_ctx_t &ctx, impl::engine_t *engine, const std::shared_ptr matmul_impl_, - std::size_t algo_scratch_size, std::size_t bias_scratch_size, - std::size_t block_a_scratch_size, std::size_t block_b_scratch_size, - std::size_t block_c_scratch_size, - std::size_t src_scale_scratchpad_size, - std::size_t wei_scale_scratchpad_size) + const std::shared_ptr params, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d) = 0; protected: template <::sycl::access::mode bias_m, ::sycl::access::mode scratch_m> void interop_task(std::shared_ptr matmul_impl_, + const std::shared_ptr params, impl::engine_t *engine, ::sycl::handler &cgh, nvidia::stream_t *cuda_stream, xpu::sycl::interop_memory_arg_t<::sycl::access::mode::read> @@ -250,8 +270,6 @@ struct cudnn_matmul_lt_base_exec_t { auto native_stream = cuda_stream->get_underlying_stream(); auto cublas_handle = cuda_stream->get_cublas_handle(native_stream); - auto cudnn_handle - = cuda_stream->get_cudnn_handle(native_stream); void *reorder_scratch = arg_bias_scratch.get_native_pointer(ih); @@ -276,17 +294,18 @@ struct cudnn_matmul_lt_base_exec_t { void *wei_scale = arg_wei_scale.get_native_pointer(ih); void *dst_scale = arg_dst_scale.get_native_pointer(ih); - matmul_impl_->execute(cublas_handle, cudnn_handle, weights, - src, dst, bias, algo_scratch, reorder_scratch, + matmul_impl_->execute(cublas_handle, params, weights, src, + dst, bias, algo_scratch, reorder_scratch, block_a_scratch, block_b_scratch, block_c_scratch, scaled_src, scaled_wt, src_scale, wei_scale, dst_scale); - free_runtime_scratch(matmul_impl_->has_runtime_params(), + free_runtime_scratch(params->has_runtime_params_, cublas_handle, cuda_stream, algo_scratch_ptr, bias_scratch_ptr, block_a_scratch_ptr, block_b_scratch_ptr, block_c_scratch_ptr, src_scale_scratch_ptr, wei_scale_scratch_ptr); + if (params->has_runtime_params_) { params->rt_cleanup(); } }); } @@ -341,16 +360,15 @@ struct cudnn_matmul_lt_exec_t final : public cudnn_matmul_lt_base_exec_t { status_t execute(const exec_ctx_t &ctx, impl::engine_t *engine, const std::shared_ptr matmul_impl_, - std::size_t algo_scratch_size, std::size_t bias_scratch_size, - std::size_t block_a_scratch_size, std::size_t block_b_scratch_size, - std::size_t block_c_scratch_size, - std::size_t src_scale_scratchpad_size, - std::size_t wei_scale_scratchpad_size) override { + const std::shared_ptr params, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d) override { nvidia::stream_t *cuda_stream = utils::downcast(ctx.stream()); - return cuda_stream->interop_task([= WA_THIS_COPY_CAPTURE]( + return cuda_stream->interop_task([= WA_THIS_COPY_CAPTURE, ¶ms]( ::sycl::handler &cgh) { auto arg_src = CTX_IN_SYCL_MEMORY(DNNL_ARG_SRC); auto arg_wt = CTX_IN_SYCL_MEMORY(DNNL_ARG_WEIGHTS); @@ -364,22 +382,43 @@ struct cudnn_matmul_lt_exec_t final : public cudnn_matmul_lt_base_exec_t { DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS); auto arg_dst_scale = CTX_IN_SYCL_MEMORY(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); - auto arg_algo_scratch = CTX_SCRATCH_SYCL_MEMORY( - memory_tracking::names::key_matmul_lt_algo_scratch); - auto arg_bias_scratch = CTX_SCRATCH_SYCL_MEMORY( - memory_tracking::names::key_matmul_dst_in_acc_dt); - auto arg_block_a_scratch = CTX_SCRATCH_SYCL_MEMORY( - memory_tracking::names::key_gemm_blocked_a); - auto arg_block_b_scratch = CTX_SCRATCH_SYCL_MEMORY( - memory_tracking::names::key_gemm_blocked_b); - auto arg_block_c_scratch = CTX_SCRATCH_SYCL_MEMORY( - memory_tracking::names::key_matmul_lt_block_c); - auto scaled_arg_src = CTX_SCRATCH_SYCL_MEMORY( - memory_tracking::names::key_matmul_lt_src_scale); - auto scaled_arg_wt = CTX_SCRATCH_SYCL_MEMORY( - memory_tracking::names::key_matmul_lt_wei_scale); - - interop_task(matmul_impl_, engine, cgh, cuda_stream, arg_wt, + auto arg_algo_scratch = params->algo_scratch_size_ != 0 + ? CTX_SCRATCH_SYCL_MEMORY( + memory_tracking::names::key_matmul_lt_algo_scratch) + : xpu::sycl::interop_memory_arg_t< + ::sycl::access::mode::read_write>(); + auto arg_bias_scratch = params->reorder_scratch_size_ != 0 + ? CTX_SCRATCH_SYCL_MEMORY( + memory_tracking::names::key_matmul_dst_in_acc_dt) + : xpu::sycl::interop_memory_arg_t< + ::sycl::access::mode::read_write>(); + auto arg_block_a_scratch = params->source_size_ != 0 + ? CTX_SCRATCH_SYCL_MEMORY( + memory_tracking::names::key_gemm_blocked_a) + : xpu::sycl::interop_memory_arg_t< + ::sycl::access::mode::read_write>(); + auto arg_block_b_scratch = params->weight_size_ != 0 + ? CTX_SCRATCH_SYCL_MEMORY( + memory_tracking::names::key_gemm_blocked_b) + : xpu::sycl::interop_memory_arg_t< + ::sycl::access::mode::read_write>(); + auto arg_block_c_scratch = params->dest_size_ != 0 + ? CTX_SCRATCH_SYCL_MEMORY( + memory_tracking::names::key_matmul_lt_block_c) + : xpu::sycl::interop_memory_arg_t< + ::sycl::access::mode::read_write>(); + auto scaled_arg_src = params->src_scale_size_ != 0 + ? CTX_SCRATCH_SYCL_MEMORY( + memory_tracking::names::key_matmul_lt_src_scale) + : xpu::sycl::interop_memory_arg_t< + ::sycl::access::mode::read_write>(); + auto scaled_arg_wt = params->wei_scale_size_ != 0 + ? CTX_SCRATCH_SYCL_MEMORY( + memory_tracking::names::key_matmul_lt_wei_scale) + : xpu::sycl::interop_memory_arg_t< + ::sycl::access::mode::read_write>(); + + interop_task(matmul_impl_, params, engine, cgh, cuda_stream, arg_wt, arg_src, arg_dst, arg_bias, arg_algo_scratch, arg_bias_scratch, arg_block_a_scratch, arg_block_b_scratch, arg_block_c_scratch, scaled_arg_src, scaled_arg_wt, @@ -395,35 +434,42 @@ struct cudnn_matmul_lt_runtime_args_exec_t final : public cudnn_matmul_lt_base_exec_t { status_t execute(const exec_ctx_t &ctx, impl::engine_t *engine, const std::shared_ptr matmul_impl_, - std::size_t algo_scratch_size, std::size_t bias_scratch_size, - std::size_t block_a_scratch_size, std::size_t block_b_scratch_size, - std::size_t block_c_scratch_size, - std::size_t src_scale_scratchpad_size, - std::size_t wei_scale_scratchpad_size) { + const std::shared_ptr params, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d) { + + std::shared_ptr matmul_params + = std::make_shared(); + matmul_params->init_from_params(params); + if (matmul_params->has_runtime_params_) { + matmul_params->rt_cleanup(); + matmul_params->set_params(src_d, weights_d, dst_d, engine); + } nvidia::stream_t *cuda_stream = utils::downcast(ctx.stream()); - uint8_t *bias_scratch_ptr - = alloc_ptr(algo_scratch_size, cuda_stream->queue()); + uint8_t *bias_scratch_ptr = alloc_ptr( + matmul_params->algo_scratch_size_, cuda_stream->queue()); - uint8_t *algo_scratch_ptr - = alloc_ptr(bias_scratch_size, cuda_stream->queue()); + uint8_t *algo_scratch_ptr = alloc_ptr( + matmul_params->reorder_scratch_size_, cuda_stream->queue()); uint8_t *block_a_scratch_ptr - = alloc_ptr(block_a_scratch_size, cuda_stream->queue()); + = alloc_ptr(matmul_params->source_size_, cuda_stream->queue()); uint8_t *block_b_scratch_ptr - = alloc_ptr(block_b_scratch_size, cuda_stream->queue()); + = alloc_ptr(matmul_params->weight_size_, cuda_stream->queue()); uint8_t *block_c_scratch_ptr - = alloc_ptr(block_c_scratch_size, cuda_stream->queue()); + = alloc_ptr(matmul_params->dest_size_, cuda_stream->queue()); - uint8_t *src_scale_scratch_ptr - = alloc_ptr(src_scale_scratchpad_size, cuda_stream->queue()); + uint8_t *src_scale_scratch_ptr = alloc_ptr( + matmul_params->src_scale_size_, cuda_stream->queue()); - uint8_t *wei_scale_scratch_ptr - = alloc_ptr(wei_scale_scratchpad_size, cuda_stream->queue()); + uint8_t *wei_scale_scratch_ptr = alloc_ptr( + matmul_params->wei_scale_size_, cuda_stream->queue()); return cuda_stream->interop_task([= WA_THIS_COPY_CAPTURE]( ::sycl::handler &cgh) { @@ -433,19 +479,19 @@ struct cudnn_matmul_lt_runtime_args_exec_t final auto arg_bias = CTX_IN_SYCL_MEMORY(DNNL_ARG_BIAS); auto arg_algo_scratch = init_scratch_from_ptr( - algo_scratch_size, algo_scratch_ptr); + matmul_params->algo_scratch_size_, algo_scratch_ptr); auto arg_bias_scratch = init_scratch_from_ptr( - bias_scratch_size, bias_scratch_ptr); + matmul_params->reorder_scratch_size_, bias_scratch_ptr); auto arg_block_a_scratch = init_scratch_from_ptr( - block_a_scratch_size, block_a_scratch_ptr); + matmul_params->source_size_, block_a_scratch_ptr); auto arg_block_b_scratch = init_scratch_from_ptr( - block_b_scratch_size, block_b_scratch_ptr); + matmul_params->weight_size_, block_b_scratch_ptr); auto arg_block_c_scratch = init_scratch_from_ptr( - block_c_scratch_size, block_c_scratch_ptr); + matmul_params->dest_size_, block_c_scratch_ptr); auto scaled_arg_src = init_scratch_from_ptr( - src_scale_scratchpad_size, src_scale_scratch_ptr); + matmul_params->src_scale_size_, src_scale_scratch_ptr); auto scaled_arg_wt = init_scratch_from_ptr( - wei_scale_scratchpad_size, wei_scale_scratch_ptr); + matmul_params->wei_scale_size_, wei_scale_scratch_ptr); auto arg_src_scale = CTX_IN_SYCL_MEMORY(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC); @@ -454,8 +500,8 @@ struct cudnn_matmul_lt_runtime_args_exec_t final auto arg_dst_scale = CTX_IN_SYCL_MEMORY(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); - interop_task(matmul_impl_, engine, cgh, cuda_stream, arg_wt, - arg_src, arg_dst, arg_bias, arg_algo_scratch, + interop_task(matmul_impl_, matmul_params, engine, cgh, cuda_stream, + arg_wt, arg_src, arg_dst, arg_bias, arg_algo_scratch, arg_bias_scratch, arg_block_a_scratch, arg_block_b_scratch, arg_block_c_scratch, scaled_arg_src, scaled_arg_wt, arg_src_scale, arg_wei_scale, arg_dst_scale, diff --git a/src/gpu/nvidia/cudnn_matmul_impl.hpp b/src/gpu/nvidia/cudnn_matmul_impl.hpp index 7cae3f96fd0..cf3a0705cb6 100644 --- a/src/gpu/nvidia/cudnn_matmul_impl.hpp +++ b/src/gpu/nvidia/cudnn_matmul_impl.hpp @@ -30,37 +30,37 @@ namespace impl { namespace gpu { namespace nvidia { -struct cudnn_matmul_impl_t : cudnn_matmul_base_impl_t { +struct cublas_params : cublas_base_params { - status_t init(matmul_pd_t *pd) { + status_t init(const memory_desc_t *src_md, const memory_desc_t *weights_md, + const memory_desc_t *dst_md, const memory_desc_t *bias_md, + const primitive_attr_t *attr, bool batched, bool with_bias) { - CHECK(get_cublas_data_type(pd->src_md()->data_type, src_type_)); + CHECK(get_cublas_data_type(src_md->data_type, src_type_)); - CHECK(get_cublas_data_type(pd->weights_md()->data_type, weights_type_)); + CHECK(get_cublas_data_type(weights_md->data_type, weights_type_)); - isbatched_ = pd->batched(); + isbatched_ = batched; - memory_desc_wrapper src_d = memory_desc_wrapper(pd->src_md()); - memory_desc_wrapper weights_d = memory_desc_wrapper(pd->weights_md()); - memory_desc_wrapper dst_d = memory_desc_wrapper(pd->dst_md()); + memory_desc_wrapper src_d = memory_desc_wrapper(src_md); + memory_desc_wrapper weights_d = memory_desc_wrapper(weights_md); + memory_desc_wrapper dst_d = memory_desc_wrapper(dst_md); if (!(src_d.is_plain() && weights_d.is_plain() && dst_d.is_plain())) { return status::unimplemented; } - with_dst_scale_ - = !pd->attr()->scales_.get(DNNL_ARG_DST).has_default_values(); - with_separate_bias_ = pd->with_bias(); + with_dst_scale_ = !attr->scales_.get(DNNL_ARG_DST).has_default_values(); + with_separate_bias_ = with_bias; if ((with_separate_bias_) - && (pd->weights_md(1)->data_type != pd->dst_md()->data_type)) { + && (bias_md->data_type != dst_md->data_type)) { // When datatype of bias is different from the dst, // we need to reorder the output. bias_dt_mismatch_ = true; reorder_required_ = true; - CHECK(get_cublas_data_type( - pd->weights_md(1)->data_type, dst_type_)); + CHECK(get_cublas_data_type(bias_md->data_type, dst_type_)); } else { - CHECK(get_cublas_data_type(pd->dst_md()->data_type, dst_type_)); + CHECK(get_cublas_data_type(dst_md->data_type, dst_type_)); } // cuBLAS only supports s8s8f32 configuration. @@ -70,13 +70,13 @@ struct cudnn_matmul_impl_t : cudnn_matmul_base_impl_t { dst_type_ = cudaDataType_t::CUDA_R_32F; } - if (with_eltwise(0, pd) || with_eltwise(1, pd)) { + if (with_eltwise(0, attr) || with_eltwise(1, attr)) { with_separate_eltwise_ = true; - CHECK(create_and_set_op_descriptor(pd, act_desc_)); + CHECK(create_and_set_op_descriptor(attr, act_desc_)); } // Set parameter when post-op sum is specified - if (with_sum(pd)) { post_op_sum_ = sum_scale(pd); } + if (with_sum(attr)) { post_op_sum_ = sum_scale(attr); } has_runtime_params_ = src_d.has_runtime_dims_or_strides() || dst_d.has_runtime_dims_or_strides() @@ -84,17 +84,30 @@ struct cudnn_matmul_impl_t : cudnn_matmul_base_impl_t { if (!has_runtime_params_) { // Initialise all gemm parameters if there are no runtime parameters - init_parameters(src_d, weights_d, dst_d, - memory_desc_wrapper(pd->weights_md(1))); - init_scratchpad(pd); + set_params(src_d, weights_d, dst_d, memory_desc_wrapper(bias_md)); } return status::success; } - status_t init_gemm_parameters(const memory_desc_wrapper src_d, - const memory_desc_wrapper weights_d, - const memory_desc_wrapper dst_d) override { + status_t init_from_params(const std::shared_ptr &other) { + if (!other) { return status::invalid_arguments; } + src_type_ = other->src_type_; + weights_type_ = other->weights_type_; + isbatched_ = other->isbatched_; + with_dst_scale_ = other->with_dst_scale_; + with_separate_bias_ = other->with_separate_bias_; + bias_dt_mismatch_ = other->bias_dt_mismatch_; + reorder_required_ = other->reorder_required_; + dst_type_ = other->dst_type_; + with_separate_eltwise_ = other->with_separate_eltwise_; + has_runtime_params_ = other->has_runtime_params_; + return status::success; + } + + status_t set_gemm_params(const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d) { if (isbatched_) batch_count_ = dst_d.dims()[0]; const dim_t M = dst_d.dims()[isbatched_ + 1]; @@ -144,12 +157,13 @@ struct cudnn_matmul_impl_t : cudnn_matmul_base_impl_t { return status::success; } - status_t init_parameters(const memory_desc_wrapper src_d, - const memory_desc_wrapper weights_d, - const memory_desc_wrapper dst_d, const memory_desc_wrapper bias_d) { + status_t set_params(const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const memory_desc_wrapper &bias_d) { // Matmul supports runtime paramters for dimensions and scales. // We need to initialize them in the execute function. - CHECK(init_gemm_parameters(src_d, weights_d, dst_d)); + CHECK(set_gemm_params(src_d, weights_d, dst_d)); if (with_separate_bias_ || reorder_required_ || with_separate_eltwise_ || with_dst_scale_) { @@ -195,23 +209,170 @@ struct cudnn_matmul_impl_t : cudnn_matmul_base_impl_t { return status::success; } - void init_scratchpad(matmul_pd_t *pd) override { - if (reorder_scratch_size_ > 0) { - auto scratchpad = pd->scratchpad_registry().registrar(); + size_t scratchpad_size(const memory_desc_t *dst_md) const { + const auto dst_nelems = memory_desc_wrapper(dst_md).nelems(true); + return dst_nelems * sizeof(float); + } + + void init_scratchpad(const memory_desc_t *dst_md, + memory_tracking::registrar_t scratchpad) { + auto reorder_scratch_size = scratchpad_size(dst_md); + if (reorder_scratch_size > 0) { scratchpad.book(memory_tracking::names::key_matmul_dst_in_acc_dt, - reorder_scratch_size_, 1); + reorder_scratch_size, 1); + } + } + + void convert_dims_matmul( + const dnnl_dim_t *dims, int *new_dims, int n_dims) { + // Moving the dimensions because cudnnAddTensor doesn't work when + // bia_mask=1 + if (n_dims == 3) { return convert_dims(dims, new_dims, n_dims); } + new_dims[0] = 1; + for (size_t i = 0; i < n_dims; i++) { + new_dims[i + 1] = static_cast(dims[i]); + } + for (size_t i = n_dims; i < 4; i++) { + new_dims[i + 1] = 1; + } + } + + int get_ld(const memory_desc_wrapper desc, cublasOperation_t trans) { + const int ndims = desc.ndims(); + const auto *strides = &desc.blocking_desc().strides[ndims - 2]; + const int ld = strides[trans == cublasOperation_t::CUBLAS_OP_N ? 0 : 1]; + return ld; + } + + // creates operation descriptor based on the elemen-wise operation specified + status_t create_and_set_op_descriptor(const primitive_attr_t *attr, + cudnnActivationDescriptor_t &act_desc) { + CHECK(CUDNN_EXECUTE_FUNC_S(cudnnCreateActivationDescriptor, &act_desc)); + + cudnnActivationMode_t mode; + + switch (eltwise_algo(attr)) { + case alg_kind::eltwise_relu: + mode = cudnnActivationMode_t::CUDNN_ACTIVATION_RELU; + break; + case alg_kind::eltwise_tanh: + mode = cudnnActivationMode_t::CUDNN_ACTIVATION_TANH; + break; + case alg_kind::eltwise_elu: + mode = cudnnActivationMode_t::CUDNN_ACTIVATION_ELU; + break; + case alg_kind::eltwise_logistic: + mode = cudnnActivationMode_t::CUDNN_ACTIVATION_SIGMOID; + break; + default: return status::unimplemented; + } + + // NaNs by default are propagated in oneDNN, although the forward + // convolution routine does not support this. + auto propagate_nan = cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN; + + // For ReLU, a ceiling of 0 means no limit. + double ceiling = eltwise_alpha(attr); + + CHECK(CUDNN_EXECUTE_FUNC_S(cudnnSetActivationDescriptor, act_desc, mode, + propagate_nan, ceiling)); + + return status::success; + } + + float eltwise_alpha(const primitive_attr_t *attr) { + int eltwise_idx_ = attr->post_ops_.find(primitive_kind::eltwise); + return with_eltwise(0, attr) || with_eltwise(1, attr) + ? attr->post_ops_.entry_[eltwise_idx_].eltwise.alpha + : 1.0f; + } + + status_t handle_post_ops(cudnnHandle_t cudnn_handle, void *dst, void *bias, + void *reorder_scratch, float host_dst_scale) { + if (with_separate_bias_) { + // When bias is specified call cudnnAddTensor() + float bias_beta = 1; + auto scale = (with_separate_eltwise_ ? 1 : 1.0f / host_dst_scale); + CUDNN_EXECUTE_FUNC(cudnnAddTensor, cudnn_handle, &scale, + tensor_descs_[io::bias], bias, &bias_beta, temp_mem_desc_, + reorder_scratch); + } + if (with_separate_eltwise_) { + // Perform elementwise operation if specified + float alpha = 1.0f / host_dst_scale; + float beta = 0; + CUDNN_EXECUTE_FUNC(cudnnActivationForward, cudnn_handle, act_desc_, + &alpha, temp_mem_desc_, reorder_scratch, &beta, + temp_mem_desc_, reorder_scratch); + } + if (reorder_required_) { + // Reorder from scratchpad to destination if required + float reorder_alpha = 1; + CUDNN_EXECUTE_FUNC(cudnnTransformTensor, cudnn_handle, + &reorder_alpha, temp_mem_desc_, reorder_scratch, + &post_op_sum_, tensor_descs_[io::dst], dst); + } + + return status::success; + } + + void cleanup() const { + if (act_desc_) { + CUDNN_EXECUTE_FUNC_V(cudnnDestroyActivationDescriptor, act_desc_); + } + if ((reorder_required_ && !bias_dt_mismatch_) + || ((with_separate_bias_ && bias_dt_mismatch_) + && temp_mem_desc_)) { + CUDNN_EXECUTE_FUNC_V(cudnnDestroyTensorDescriptor, temp_mem_desc_); + } + for (size_t i = 0; i < NUM_IO; i++) { + if (tensor_descs_[i]) { + CUDNN_EXECUTE_FUNC_V( + cudnnDestroyTensorDescriptor, tensor_descs_[i]); + } } } + int lda_, ldb_, ldc_; + + int64_t stride_a_, stride_b_, stride_c_; + + enum io { bias = 0, dst, NUM_IO }; + cudnnTensorDescriptor_t tensor_descs_[NUM_IO] = {}, + temp_mem_desc_ = nullptr; + cudnnActivationDescriptor_t act_desc_ = nullptr; + + cublasOperation_t transA_; + cublasOperation_t transB_; + cublasOperation_t transC_; + cublasGemmAlgo_t gemm_algo_ + = cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP; +}; + +struct cudnn_matmul_impl_t { + + void set_non_runtime_params( + const std::shared_ptr &matmul_params) { + matmul_params_ = matmul_params; + } + void execute(cublasHandle_t cublas_handle, cudnnHandle_t cudnn_handle, - void *a, void *b, void *c, void *bias, void *reorder_scratch, + const std::shared_ptr &matmul_params, void *a, + void *b, void *c, void *bias, void *reorder_scratch, void *src_scale, void *wei_scale, void *dst_scale) { + + // use cached params unless using runtime dimensions + std::shared_ptr params + = matmul_params->has_runtime_params_ ? matmul_params + : matmul_params_; + float gemm_beta = 0; - if (!bias_dt_mismatch_ && !reorder_required_) { + if (!params->bias_dt_mismatch_ && !params->reorder_required_) { // Case where no reorder is required, scratchpad points to dst (c) reorder_scratch = c; - temp_mem_desc_ = tensor_descs_[io::dst]; - gemm_beta = post_op_sum_; + params->temp_mem_desc_ + = params->tensor_descs_[cublas_params::io::dst]; + gemm_beta = params->post_op_sum_; } auto flip_op = [](cublasOperation_t op) { return (op == cublasOperation_t::CUBLAS_OP_T) @@ -237,73 +398,73 @@ struct cudnn_matmul_impl_t : cudnn_matmul_base_impl_t { CUDA_EXECUTE_FUNC(cuMemcpy, (CUdeviceptr)&host_dst_scale, (CUdeviceptr)dst_scale, sizeof(float)); // For eltwise post-ops, apply the dst scale afterward - if (!with_separate_eltwise_) scale /= host_dst_scale; + if (!params->with_separate_eltwise_) scale /= host_dst_scale; } - if (isbatched_) { + auto M = params->M_; + auto N = params->N_; + auto K = params->K_; + + auto lda = params->lda_; + auto ldb = params->ldb_; + auto ldc = params->ldc_; + + auto src_type = params->src_type_; + auto weights_type = params->weights_type_; + auto dst_type = params->dst_type_; + + auto stride_a = params->stride_a_; + auto stride_b = params->stride_b_; + auto stride_c = params->stride_c_; + + auto batch_count = params->batch_count_; + auto acc_type = params->acc_type_; + auto gemm_algo = params->gemm_algo_; + + auto transA = params->transA_; + auto transB = params->transB_; + auto transC = params->transC_; + + if (params->isbatched_) { // Calls cublasGemmStridedBatchedEx() - if (transC_ == cublasOperation_t::CUBLAS_OP_T) { + if (transC == cublasOperation_t::CUBLAS_OP_T) { CUBLAS_EXECUTE_FUNC(cublasGemmStridedBatchedEx, cublas_handle, - flip_op(transB_), flip_op(transA_), N_, M_, K_, &scale, - b, src_type_, ldb_, stride_b_, a, weights_type_, lda_, - stride_a_, &gemm_beta, reorder_scratch, dst_type_, ldc_, - stride_c_, batch_count_, acc_type_, gemm_algo_); + flip_op(transB), flip_op(transA), N, M, K, &scale, b, + src_type, ldb, stride_b, a, weights_type, lda, stride_a, + &gemm_beta, reorder_scratch, dst_type, ldc, stride_c, + batch_count, acc_type, gemm_algo); } else { CUBLAS_EXECUTE_FUNC(cublasGemmStridedBatchedEx, cublas_handle, - transA_, transB_, M_, N_, K_, &scale, a, weights_type_, - lda_, stride_a_, b, src_type_, ldb_, stride_b_, - &gemm_beta, reorder_scratch, dst_type_, ldc_, stride_c_, - batch_count_, acc_type_, gemm_algo_); + transA, transB, M, N, K, &scale, a, weights_type, lda, + stride_a, b, src_type, ldb, stride_b, &gemm_beta, + reorder_scratch, dst_type, ldc, stride_c, batch_count, + acc_type, gemm_algo); } } else { // Calls cublasGemmEx() - if (transC_ == cublasOperation_t::CUBLAS_OP_T) { + if (transC == cublasOperation_t::CUBLAS_OP_T) { CUBLAS_EXECUTE_FUNC(cublasGemmEx, cublas_handle, - flip_op(transB_), flip_op(transA_), N_, M_, K_, &scale, - b, src_type_, ldb_, a, weights_type_, lda_, &gemm_beta, - reorder_scratch, dst_type_, ldc_, acc_type_, - gemm_algo_); + flip_op(transB), flip_op(transA), N, M, K, &scale, b, + src_type, ldb, a, weights_type, lda, &gemm_beta, + reorder_scratch, dst_type, ldc, acc_type, gemm_algo); } else { - CUBLAS_EXECUTE_FUNC(cublasGemmEx, cublas_handle, transA_, - transB_, M_, N_, K_, &scale, a, weights_type_, lda_, b, - src_type_, ldb_, &gemm_beta, reorder_scratch, dst_type_, - ldc_, acc_type_, gemm_algo_); + CUBLAS_EXECUTE_FUNC(cublasGemmEx, cublas_handle, transA, transB, + M, N, K, &scale, a, weights_type, lda, b, src_type, ldb, + &gemm_beta, reorder_scratch, dst_type, ldc, acc_type, + gemm_algo); } } - handle_post_ops(cudnn_handle, c, bias, reorder_scratch, host_dst_scale); + params->handle_post_ops( + cudnn_handle, c, bias, reorder_scratch, host_dst_scale); } - ~cudnn_matmul_impl_t() { cleanup(); } - - void cleanup() override { - if (act_desc_) { - CUDNN_EXECUTE_FUNC_V(cudnnDestroyActivationDescriptor, act_desc_); - act_desc_ = nullptr; - } - if ((reorder_required_ && !bias_dt_mismatch_) - || ((with_separate_bias_ && bias_dt_mismatch_) - && temp_mem_desc_)) { - CUDNN_EXECUTE_FUNC_V(cudnnDestroyTensorDescriptor, temp_mem_desc_); - temp_mem_desc_ = nullptr; - } - for (size_t i = 0; i < NUM_IO; i++) { - if (tensor_descs_[i]) { - CUDNN_EXECUTE_FUNC_V( - cudnnDestroyTensorDescriptor, tensor_descs_[i]); - tensor_descs_[i] = nullptr; - } - } + ~cudnn_matmul_impl_t() { + if (matmul_params_) { matmul_params_->cleanup(); } } private: - cublasOperation_t transA_; - cublasOperation_t transB_; - cublasOperation_t transC_; - int M_, N_, K_; - cudaDataType_t src_type_, weights_type_, dst_type_; - cublasGemmAlgo_t gemm_algo_ - = cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP; + std::shared_ptr matmul_params_; }; } // namespace nvidia diff --git a/src/gpu/nvidia/cudnn_matmul_lt.hpp b/src/gpu/nvidia/cudnn_matmul_lt.hpp index 20883eb56a4..5f07bfd708e 100644 --- a/src/gpu/nvidia/cudnn_matmul_lt.hpp +++ b/src/gpu/nvidia/cudnn_matmul_lt.hpp @@ -18,15 +18,24 @@ #ifndef GPU_NVIDIA_CUDNN_MATMUL_LT_HPP #define GPU_NVIDIA_CUDNN_MATMUL_LT_HPP -#include "gpu/nvidia/cudnn_matmul_base.hpp" +#include + +#include "gpu/gpu_matmul_pd.hpp" + +#include "common/primitive.hpp" +#include "common/primitive_desc_iterator.hpp" +#include "gpu/gpu_primitive.hpp" +#include "gpu/nvidia/cudnn_matmul_executor.hpp" +#include "gpu/nvidia/cudnn_matmul_lt_impl.hpp" +#include "gpu/nvidia/sycl_cuda_utils.hpp" namespace dnnl { namespace impl { namespace gpu { namespace nvidia { -struct cudnn_matmul_lt_t : cudnn_matmul_base_t { - using cudnn_matmul_base_t::cudnn_matmul_base_t; +struct cudnn_matmul_lt_t : public gpu::primitive_t { + using primitive_t::primitive_t; struct pd_t : public gpu_matmul_pd_t { using gpu_matmul_pd_t::gpu_matmul_pd_t; @@ -160,6 +169,16 @@ struct cudnn_matmul_lt_t : cudnn_matmul_base_t { if (is_scale_ok(DNNL_ARG_DST)) { CHECK(create_scale_binary_pd(engine, DNNL_ARG_DST)); } + + params_ = std::make_shared(); + CHECK(params_->init(engine, src_md(), weights_md(), dst_md(), + weights_md(1), attr(), batched(), with_bias())); + + if (!params_->has_runtime_params()) { + auto scratchpad = scratchpad_registry().registrar(); + params_->init_scratchpad(scratchpad); + } + return status::success; } @@ -167,6 +186,7 @@ struct cudnn_matmul_lt_t : cudnn_matmul_base_t { std::shared_ptr wei_scale_binary_pd_; std::shared_ptr dst_scale_binary_pd_; std::shared_ptr binary_pd_; + std::shared_ptr params_; memory_desc_t s32_dst_md_; @@ -458,42 +478,42 @@ struct cudnn_matmul_lt_t : cudnn_matmul_base_t { status_t init(impl::engine_t *engine) override { // LT matmul matmul_impl_.reset(new cudnn_matmul_lt_impl_t()); - auto status = matmul_impl_->init((matmul_pd_t *)pd(), engine); - bool has_runtime_args = matmul_impl_->has_runtime_params(); + bool has_runtime_args = pd()->params_->has_runtime_params(); if (has_runtime_args) { executor_.reset(new cudnn_matmul_lt_runtime_args_exec_t); } else if (!has_runtime_args) { executor_.reset(new cudnn_matmul_lt_exec_t); + matmul_impl_->set_non_runtime_params(pd()->params_); } - if (matmul_impl_->with_bias()) { + if (pd()->params_->with_bias_) { CHECK(create_nested_primitive(binary_, pd()->binary_pd_, engine)); } if (!memory_desc_wrapper(pd()->src_md()).is_cublaslt_blocked_desc() && !pd()->default_scale(DNNL_ARG_SRC) - && (matmul_impl_->multi_src_scale() - || matmul_impl_->scale_type() == CUDA_R_32I)) { + && (pd()->params_->multi_src_scale_ + || pd()->params_->acc_type_ == CUDA_R_32I)) { CHECK(create_nested_primitive( src_scale_binary_, pd()->src_scale_binary_pd_, engine)); } if (!pd()->default_scale(DNNL_ARG_WEIGHTS) - && (matmul_impl_->multi_wei_scale() - || matmul_impl_->scale_type() == CUDA_R_32I)) { + && (pd()->params_->multi_wei_scale_ + || pd()->params_->acc_type_ == CUDA_R_32I)) { CHECK(create_nested_primitive( wei_scale_binary_, pd()->wei_scale_binary_pd_, engine)); } if (!pd()->default_scale(DNNL_ARG_DST) - && (matmul_impl_->multi_dst_scale() - || matmul_impl_->scale_type() == CUDA_R_32I)) { + && (pd()->params_->multi_dst_scale_ + || pd()->params_->acc_type_ == CUDA_R_32I)) { CHECK(create_nested_primitive( dst_scale_binary_, pd()->dst_scale_binary_pd_, engine)); } - return status; + return status::success; } status_t execute(const exec_ctx_t &ctx) const override; diff --git a/src/gpu/nvidia/cudnn_matmul_lt_impl.hpp b/src/gpu/nvidia/cudnn_matmul_lt_impl.hpp index 6670d4c6fd3..69ef8190234 100644 --- a/src/gpu/nvidia/cudnn_matmul_lt_impl.hpp +++ b/src/gpu/nvidia/cudnn_matmul_lt_impl.hpp @@ -35,52 +35,46 @@ namespace impl { namespace gpu { namespace nvidia { -template >::type> -T ceildiv(T n, T d) { - return (n + d - 1) / d; -} +struct cublas_lt_params : cublas_base_params { -struct cudnn_matmul_lt_impl_t : cudnn_matmul_base_impl_t { + status_t init(impl::engine_t *engine, const memory_desc_t *src_md, + const memory_desc_t *weights_md, const memory_desc_t *dst_md, + const memory_desc_t *bias_md, const primitive_attr_t *attr, + bool batched, bool with_bias) { + CHECK(get_cublas_data_type(src_md->data_type, src_type_)); + CHECK(get_cublas_data_type(weights_md->data_type, weights_type_)); + CHECK(get_cublas_data_type(dst_md->data_type, dst_type_)); - status_t init(matmul_pd_t *pd, impl::engine_t *engine) { + auto src_d = memory_desc_wrapper(*src_md); + auto weights_d = memory_desc_wrapper(*weights_md); + auto dst_d = memory_desc_wrapper(*dst_md); - CHECK(get_cublas_data_type(pd->src_md()->data_type, src_type_)); - CHECK(get_cublas_data_type(pd->weights_md()->data_type, weights_type_)); - CHECK(get_cublas_data_type(pd->dst_md()->data_type, dst_type_)); - - auto src_d = memory_desc_wrapper(pd->src_md()); - auto weights_d = memory_desc_wrapper(pd->weights_md()); - auto dst_d = memory_desc_wrapper(pd->dst_md()); - auto bias_d = memory_desc_wrapper(pd->weights_md(1)); - - isbatched_ = pd->batched() && dst_d.dims()[0]; + isbatched_ = batched && dst_d.dims()[0]; has_runtime_params_ = src_d.has_runtime_dims_or_strides() || dst_d.has_runtime_dims_or_strides() || weights_d.has_runtime_dims_or_strides(); - if (!pd->attr()->scales_.get(DNNL_ARG_SRC).has_default_values()) { - auto src_scale = pd->attr()->scales_.get(DNNL_ARG_SRC); + if (attr->scales_.get(DNNL_ARG_SRC).has_default_values()) { + auto src_scale = attr->scales_.get(DNNL_ARG_SRC); if (src_scale.mask_ != 0) { multi_src_scale_ = true; } } - if (!pd->attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values()) { - auto wei_scale = pd->attr()->scales_.get(DNNL_ARG_WEIGHTS); + if (attr->scales_.get(DNNL_ARG_WEIGHTS).has_default_values()) { + auto wei_scale = attr->scales_.get(DNNL_ARG_WEIGHTS); if (wei_scale.mask_ != 0) { multi_wei_scale_ = true; } } - with_dst_scale_ - = !pd->attr()->scales_.get(DNNL_ARG_DST).has_default_values(); + with_dst_scale_ = !attr->scales_.get(DNNL_ARG_DST).has_default_values(); if (with_dst_scale_) { - auto dst_scale = pd->attr()->scales_.get(DNNL_ARG_DST); + auto dst_scale = attr->scales_.get(DNNL_ARG_DST); if (dst_scale.mask_ != 0) { multi_dst_scale_ = true; } } // Initialise flags and variables for the imma case (E.g. imma_case_ flag). check_imma_case(src_d, weights_d, dst_d); - with_bias_ = pd->with_bias(); + with_bias_ = with_bias; bool dst_row_major = !is_md_col_major(dst_d); @@ -89,8 +83,7 @@ struct cudnn_matmul_lt_impl_t : cudnn_matmul_base_impl_t { if (has_runtime_params_) { return status::unimplemented; } else { - bias_dt_mismatch_ = (pd->weights_md(1)->data_type - != pd->dst_md()->data_type); + bias_dt_mismatch_ = (bias_md->data_type != dst_md->data_type); if (imma_case_) { with_separate_bias_ = true; if (dst_d.data_type() == dnnl_s8) { @@ -105,6 +98,8 @@ struct cudnn_matmul_lt_impl_t : cudnn_matmul_base_impl_t { } if (!with_separate_bias_ && !dst_row_major) { // bias epilogue not supported for dst dim = 1 + memory_desc_wrapper bias_d + = memory_desc_wrapper(*bias_md); if ((bias_d.dims()[1 + isbatched_] != M_ || bias_d.dims()[0 + isbatched_] != 1) || M_ == 1 || N_ == 1 || has_runtime_params_) { @@ -118,11 +113,11 @@ struct cudnn_matmul_lt_impl_t : cudnn_matmul_base_impl_t { with_bias_epilogue_ = with_bias_ && !with_separate_bias_; // Check if activation can be used in epilogue - if (with_eltwise(0, pd) || with_eltwise(1, pd)) { + if (with_eltwise(0, attr) || with_eltwise(1, attr)) { if (dst_d.has_runtime_dims_or_strides()) { return status::unimplemented; } else { - with_relu_ = eltwise_algo(pd) == alg_kind::eltwise_relu; + with_relu_ = eltwise_algo(attr) == alg_kind::eltwise_relu; if (!with_relu_ || dst_row_major || with_separate_bias_) { with_separate_eltwise_ = true; } @@ -163,7 +158,7 @@ struct cudnn_matmul_lt_impl_t : cudnn_matmul_base_impl_t { imma_plain_case_ = imma_case_ && !imma_ampere_case_; // Set parameter when post-op sum is specified - if (with_sum(pd)) { post_op_sum_ = sum_scale(pd); } + if (with_sum(attr)) { post_op_sum_ = sum_scale(attr); } // Initialise scaling parameters alpha_beta_size_bytes_ = dst_d.data_type_size(); @@ -175,48 +170,46 @@ struct cudnn_matmul_lt_impl_t : cudnn_matmul_base_impl_t { // Initialise all gemm parameters if (!has_runtime_params_) { - CHECK(init_parameters(src_d, weights_d, dst_d, bias_d, engine)); - init_scratchpad(pd); + CHECK(set_params(src_d, weights_d, dst_d, engine)); } return status::success; } - void check_imma_case(const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d) { - if (src_d.data_type() == dnnl_s8 && weights_d.data_type() == dnnl_s8 - && (dst_d.data_type() == dnnl_s32 - || dst_d.data_type() == dnnl_s8)) { - // weights blocked in Ab32a - w_blocked_ = is_md_col32(weights_d); - bool weights_supported = weights_d.has_runtime_dims_or_strides() - || w_blocked_ || weights_d.is_plain(); - - // src not blocked - src_blocked_ = src_d.is_cublaslt_blocked_desc(); - bool src_supported = src_d.has_runtime_dims_or_strides() - || src_blocked_ || src_d.is_plain(); - - // dst blocked in Ab32a - dst_blocked_ = is_md_col32(dst_d); - bool dst_supported = dst_d.has_runtime_dims_or_strides() - || dst_blocked_ || dst_d.is_plain(); - - imma_case_ = weights_supported && src_supported && dst_supported; - } + status_t init_from_params(const std::shared_ptr &other) { + if (!other) { return status::invalid_arguments; } + src_type_ = other->src_type_; + weights_type_ = other->weights_type_; + dst_type_ = other->dst_type_; + isbatched_ = other->isbatched_; + has_runtime_params_ = other->has_runtime_params_; + multi_src_scale_ = other->multi_src_scale_; + multi_wei_scale_ = other->multi_wei_scale_; + with_dst_scale_ = other->with_dst_scale_; + multi_dst_scale_ = other->multi_dst_scale_; + with_bias_ = other->with_bias_; + bias_dt_mismatch_ = other->bias_dt_mismatch_; + with_separate_bias_ = other->with_separate_bias_; + reorder_required_ = other->reorder_required_; + with_bias_epilogue_ = other->with_bias_epilogue_; + with_relu_epilogue_ = other->with_relu_epilogue_; + imma_ampere_case_ = other->imma_ampere_case_; + imma_plain_case_ = other->imma_plain_case_; + alpha_beta_size_bytes_ = other->alpha_beta_size_bytes_; + alpha_ = std::malloc(alpha_beta_size_bytes_); + beta_ = std::malloc(alpha_beta_size_bytes_); + return status::success; } - status_t init_parameters(const memory_desc_wrapper src_d, - const memory_desc_wrapper weights_d, - const memory_desc_wrapper dst_d, const memory_desc_wrapper bias_d, - impl::engine_t *engine) { + status_t set_params(const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, impl::engine_t *engine) { batch_count_ = isbatched_ ? dst_d.dims()[0] : 1; M_ = static_cast(dst_d.dims()[isbatched_ + 1]); N_ = static_cast(dst_d.dims()[isbatched_ + 0]); K_ = static_cast(src_d.dims()[isbatched_ + 1]); - if (is_imma_case()) { + if (imma_case_) { w_blocked_ = is_md_col32(weights_d); dst_blocked_ = is_md_col32(dst_d); src_blocked_ = src_d.is_cublaslt_blocked_desc(); @@ -244,7 +237,7 @@ struct cudnn_matmul_lt_impl_t : cudnn_matmul_base_impl_t { // Matmul supports runtime paramters for dimensions and scales. // We need to initialize them in the execute function. - CHECK(init_gemm_parameters(src_d, weights_d, dst_d)); + CHECK(set_gemm_params(src_d, weights_d, dst_d)); auto &sycl_engine = *utils::downcast(engine); @@ -256,37 +249,37 @@ struct cudnn_matmul_lt_impl_t : cudnn_matmul_base_impl_t { auto lt_handle = (cublasLtHandle_t)cublas_handle; CHECK(init_scratchpad_size(lt_handle, src_d, weights_d, dst_d)); - if (with_separate_bias_) { - // Initialise cuDNN descriptors - cudnnDataType_t data_types[NUM_IO]; - int ndims = dst_d.ndims() < 4 ? 4 : dst_d.ndims(); - int dims[NUM_IO][DNNL_MAX_NDIMS]; - int strides[NUM_IO][DNNL_MAX_NDIMS]; - - convert_dims_matmul(dst_d.dims(), dims[dst], dst_d.ndims()); - CHECK(convert_data_type(dst_d.md_, &data_types[dst], false)); - convert_dims_matmul( - dst_d.blocking_desc().strides, strides[dst], dst_d.ndims()); - CHECK(create_and_set_tensor_descriptor(&tensor_descs_[dst], - data_types[dst], ndims, dims[dst], strides[dst])); - - // Create bias and destination tensor descriptors - convert_dims_matmul(bias_d.dims(), dims[bias], bias_d.ndims()); - convert_dims_matmul(bias_d.blocking_desc().strides, strides[bias], - bias_d.ndims()); - CHECK(convert_data_type(bias_d.md_, &data_types[bias], false)); - CHECK(create_and_set_tensor_descriptor(&tensor_descs_[bias], - data_types[bias], ndims, dims[bias], strides[bias])); - - CHECK(create_and_set_tensor_descriptor(&tensor_descs_[dst], - data_types[dst], ndims, dims[dst], strides[dst])); - } return status::success; } - status_t init_gemm_parameters(const memory_desc_wrapper src_d, - const memory_desc_wrapper weights_d, - const memory_desc_wrapper dst_d) override { + void check_imma_case(const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d) { + if (src_d.data_type() == dnnl_s8 && weights_d.data_type() == dnnl_s8 + && (dst_d.data_type() == dnnl_s32 + || dst_d.data_type() == dnnl_s8)) { + // weights blocked in Ab32a + w_blocked_ = is_md_col32(weights_d); + bool weights_supported = weights_d.has_runtime_dims_or_strides() + || w_blocked_ || weights_d.is_plain(); + + // src not blocked + src_blocked_ = src_d.is_cublaslt_blocked_desc(); + bool src_supported = src_d.has_runtime_dims_or_strides() + || src_blocked_ || src_d.is_plain(); + + // dst blocked in Ab32a + dst_blocked_ = is_md_col32(dst_d); + bool dst_supported = dst_d.has_runtime_dims_or_strides() + || dst_blocked_ || dst_d.is_plain(); + + imma_case_ = weights_supported && src_supported && dst_supported; + } + } + + status_t set_gemm_params(const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d) { // C matrix is the dst trans_c_ = !is_md_col_major(dst_d); // A matrix is the weights @@ -428,8 +421,7 @@ struct cudnn_matmul_lt_impl_t : cudnn_matmul_base_impl_t { return status_t::dnnl_success; } - void init_scratchpad(matmul_pd_t *pd) override { - auto scratchpad = pd->scratchpad_registry().registrar(); + void init_scratchpad(memory_tracking::registrar_t scratchpad) { if (reorder_scratch_size_ > 0) { scratchpad.book(memory_tracking::names::key_matmul_dst_in_acc_dt, reorder_scratch_size_, 1, 256); @@ -460,197 +452,169 @@ struct cudnn_matmul_lt_impl_t : cudnn_matmul_base_impl_t { } } - void execute(cublasHandle_t cublas_handle, cudnnHandle_t cudnn_handle, - void *a, void *b, void *c, void *bias, void *algo_scratch, - void *reorder_scratch, void *block_a_scratch, void *block_b_scratch, - void *block_c_scratch, void *scaled_src, void *scaled_wt, - void *src_scale, void *wei_scale, void *dst_scale) { - // read from binary output instead of input if multi scale or s32 scaling - if ((src_scale && acc_type_ == CUDA_R_32I) || multi_wei_scale_) { - b = scaled_src; - } - if ((wei_scale && acc_type_ == CUDA_R_32I) || multi_wei_scale_) { - a = scaled_wt; + bool is_md_col_major(const memory_desc_wrapper &md) { + if (md.is_blocking_desc()) { + const auto &md_strides = &md.blocking_desc().strides[isbatched_]; + return (md_strides[1] == 1 && md.dims()[isbatched_ + 0] > 1); } + return false; + } - cudaStream_t streamId; - auto lt_handle = (cublasLtHandle_t)(cublas_handle); - CUBLAS_EXECUTE_FUNC(cublasGetStream, cublas_handle, &streamId); - - if (imma_ampere_case_) { - if (!src_blocked_) { - transform_matrix(lt_handle, b_layout_, b, blocked_b_layout_, - block_b_scratch, !trans_b_, streamId); - b = block_b_scratch; - } - if (!w_blocked_) { - transform_matrix(lt_handle, a_layout_, a, blocked_a_layout_, - block_a_scratch, trans_a_, streamId); - a = block_a_scratch; - } + void maybe_swap(uint64_t &row, uint64_t &col, cublasOperation_t &op, + cublasLtOrder_t order, bool transpose) { + if (transpose) { + std::swap(row, col); + op = cublasOperation_t::CUBLAS_OP_T; + order = CUBLASLT_ORDER_ROW; } + } - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; - if (with_bias_epilogue_) { - if (with_relu_epilogue_) { - epilogue = CUBLASLT_EPILOGUE_RELU_BIAS; - } else { - epilogue = CUBLASLT_EPILOGUE_BIAS; - } - CUBLAS_EXECUTE_FUNC(cublasLtMatmulDescSetAttribute, operation_desc_, - CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); - } else if (with_relu_epilogue_ && !with_bias_epilogue_) { - epilogue = CUBLASLT_EPILOGUE_RELU; - } + status_t create_matrix_layout(cublasLtMatrixLayout_t &layout, + cublasLtOrder_t order, cublasOperation_t trans, uint64_t row, + uint64_t col, uint64_t ld, const cudaDataType_t data_type, + cublasLtMatmulDescAttributes_t trans_attr, uint64_t stride) { CUBLAS_EXECUTE_FUNC(cublasLtMatmulDescSetAttribute, operation_desc_, - CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); + trans_attr, &trans, sizeof(trans)); - float scale = 1.0f; - float host_dst_scale = 1.0f; - if (src_scale && !multi_src_scale_ && acc_type_ != CUDA_R_32I) { - float host_src_scale = 1.0f; - CUDA_EXECUTE_FUNC(cuMemcpy, (CUdeviceptr)&host_src_scale, - (CUdeviceptr)src_scale, sizeof(float)); - scale *= host_src_scale; - } - if (wei_scale && !multi_wei_scale_ && acc_type_ != CUDA_R_32I) { - float host_wei_scale = 1.0f; - CUDA_EXECUTE_FUNC(cuMemcpy, (CUdeviceptr)&host_wei_scale, - (CUdeviceptr)wei_scale, sizeof(float)); - scale *= host_wei_scale; - } - if (dst_scale && !multi_dst_scale_ && acc_type_ != CUDA_R_32I) { - CUDA_EXECUTE_FUNC(cuMemcpy, (CUdeviceptr)&host_dst_scale, - (CUdeviceptr)dst_scale, sizeof(float)); - // only applied here if no post ops used - scale /= host_dst_scale; + CUBLAS_EXECUTE_FUNC( + cublasLtMatrixLayoutCreate, &layout, data_type, row, col, ld); + + CUBLAS_EXECUTE_FUNC(cublasLtMatrixLayoutSetAttribute, layout, + CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)); + + if (batch_count_ != 1) { + CUBLAS_EXECUTE_FUNC(cublasLtMatrixLayoutSetAttribute, layout, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count_, + sizeof(batch_count_)); + CUBLAS_EXECUTE_FUNC(cublasLtMatrixLayoutSetAttribute, layout, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride, + sizeof(stride)); } + return status_t::dnnl_success; + } - if (acc_type_ == CUDA_R_16F) { - dnnl::impl::float16_t half_scale = scale; - dnnl::impl::float16_t half_gemm_beta = post_op_sum_; - *static_cast(alpha_) = half_scale; - *static_cast(beta_) = half_gemm_beta; - } else { - *static_cast(alpha_) = scale; - *static_cast(beta_) = post_op_sum_; + size_t bias_scratch_size() { return reorder_scratch_size_; } + bool has_runtime_params() { return has_runtime_params_; } + + void create_non_blocked_layouts() { + auto trans_op = cublasOperation_t::CUBLAS_OP_N; + auto order = CUBLASLT_ORDER_COL; + + auto row = M_; + auto col = K_; + maybe_swap(row, col, trans_op, order, trans_a_); + create_matrix_layout(a_layout_, CUBLASLT_ORDER_COL, trans_op, row, col, + row, weights_type_, CUBLASLT_MATMUL_DESC_TRANSA, stride_a_); + + row = K_; + col = N_; + trans_op = cublasOperation_t::CUBLAS_OP_N; + maybe_swap(row, col, trans_op, order, trans_b_); + create_matrix_layout(b_layout_, CUBLASLT_ORDER_COL, trans_op, row, col, + row, src_type_, CUBLASLT_MATMUL_DESC_TRANSB, stride_b_); + + row = M_; + col = N_; + order = CUBLASLT_ORDER_COL; + maybe_swap(row, col, trans_op, order, trans_c_); + create_matrix_layout(c_layout_, order, cublasOperation_t::CUBLAS_OP_N, + row, col, row, dst_type_, CUBLASLT_MATMUL_DESC_TRANSC, + stride_c_); + } + + void create_blocked_layouts() { + create_matrix_layout(blocked_a_layout_, CUBLASLT_ORDER_COL32, + cublasOperation_t::CUBLAS_OP_N, M_, K_, a_blocked_ld_, + weights_type_, CUBLASLT_MATMUL_DESC_TRANSA, stride_a_blocked_); + + create_matrix_layout(blocked_b_layout_, CUBLASLT_ORDER_COL32_2R_4R4, + cublasOperation_t::CUBLAS_OP_N, N_, K_, b_blocked_ld_, + src_type_, CUBLASLT_MATMUL_DESC_TRANSB, stride_b_blocked_); + + create_matrix_layout(blocked_c_layout_, CUBLASLT_ORDER_COL32, + cublasOperation_t::CUBLAS_OP_N, M_, N_, c_blocked_ld_, + dst_type_, CUBLASLT_MATMUL_DESC_TRANSC, stride_c_blocked_); + + uint64_t row, col; + if (!w_blocked_) { + row = M_; + col = K_; + if (trans_a_) { std::swap(row, col); } + create_matrix_layout(a_layout_, CUBLASLT_ORDER_COL, + cublasOperation_t::CUBLAS_OP_N, row, col, row, + weights_type_, CUBLASLT_MATMUL_DESC_TRANSA, stride_a_); } - if (imma_ampere_case_) { - if (!dst_blocked_) { - std::memset(beta_, 0, alpha_beta_size_bytes_); - } - c = reorder_required_ ? reorder_scratch : c; - void *tmp_c = dst_blocked_ ? c : block_c_scratch; - CUBLAS_EXECUTE_FUNC(cublasLtMatmul, lt_handle, operation_desc_, - alpha_, a, blocked_a_layout_, b, blocked_b_layout_, beta_, - tmp_c, blocked_c_layout_, tmp_c, blocked_c_layout_, - &gemm_algo_, algo_scratch, heuristic_results_.workspaceSize, - streamId); + if (!src_blocked_) { + row = K_; + col = N_; + if (trans_b_) { std::swap(row, col); } + create_matrix_layout(b_layout_, CUBLASLT_ORDER_COL, + cublasOperation_t::CUBLAS_OP_N, row, col, row, src_type_, + CUBLASLT_MATMUL_DESC_TRANSB, stride_b_); + } - if (!dst_blocked_) { - transform_matrix(lt_handle, blocked_c_layout_, block_c_scratch, - c_layout_, c, trans_c_, streamId, post_op_sum_); - } - } else { - CUBLAS_EXECUTE_FUNC(cublasLtMatmul, lt_handle, operation_desc_, - alpha_, a, a_layout_, b, b_layout_, beta_, c, c_layout_, c, - c_layout_, &gemm_algo_, algo_scratch, - heuristic_results_.workspaceSize, streamId); + if (!dst_blocked_) { + row = M_; + col = N_; + if (trans_c_) { std::swap(row, col); } + create_matrix_layout(c_layout_, CUBLASLT_ORDER_COL, + cublasOperation_t::CUBLAS_OP_N, row, col, row, dst_type_, + CUBLASLT_MATMUL_DESC_TRANSC, stride_c_); } - } - ~cudnn_matmul_lt_impl_t() { cleanup(); } + // Constraint for Turing/Ampere kernels matmul config needs to be + // A^N B^T + cublasOperation_t b_trans_t = cublasOperation_t::CUBLAS_OP_T; + CUBLAS_EXECUTE_FUNC(cublasLtMatmulDescSetAttribute, operation_desc_, + CUBLASLT_MATMUL_DESC_TRANSB, &b_trans_t, sizeof(b_trans_t)); + } - void rt_cleanup() { + void rt_cleanup() const { if (a_layout_) { CUBLAS_EXECUTE_FUNC(cublasLtMatrixLayoutDestroy, a_layout_); - a_layout_ = nullptr; } if (b_layout_) { CUBLAS_EXECUTE_FUNC(cublasLtMatrixLayoutDestroy, b_layout_); - b_layout_ = nullptr; } if (c_layout_) { CUBLAS_EXECUTE_FUNC(cublasLtMatrixLayoutDestroy, c_layout_); - c_layout_ = nullptr; } if (operation_desc_) { CUBLAS_EXECUTE_FUNC(cublasLtMatmulDescDestroy, operation_desc_); - operation_desc_ = nullptr; } if (imma_ampere_case_) { - if (trans_desc_) { CUBLAS_EXECUTE_FUNC( cublasLtMatrixTransformDescDestroy, trans_desc_); - trans_desc_ = nullptr; } if (blocked_a_layout_) { CUBLAS_EXECUTE_FUNC( cublasLtMatrixLayoutDestroy, blocked_a_layout_); - blocked_a_layout_ = nullptr; } if (blocked_b_layout_) { CUBLAS_EXECUTE_FUNC( cublasLtMatrixLayoutDestroy, blocked_b_layout_); - blocked_b_layout_ = nullptr; } if (blocked_c_layout_) { CUBLAS_EXECUTE_FUNC( cublasLtMatrixLayoutDestroy, blocked_c_layout_); - blocked_c_layout_ = nullptr; - } - } - - for (size_t i = 0; i < NUM_IO; i++) { - if (tensor_descs_[i]) { - CUDNN_EXECUTE_FUNC_V( - cudnnDestroyTensorDescriptor, tensor_descs_[i]); - tensor_descs_[i] = nullptr; } } } - void cleanup() override { + void cleanup() const { std::free(alpha_); - alpha_ = nullptr; std::free(beta_); - beta_ = nullptr; if (preference_) { CUBLAS_EXECUTE_FUNC(cublasLtMatmulPreferenceDestroy, preference_); - preference_ = nullptr; } - rt_cleanup(); - - for (size_t i = 0; i < NUM_IO; i++) { - if (tensor_descs_[i]) { - CUDNN_EXECUTE_FUNC_V( - cudnnDestroyTensorDescriptor, tensor_descs_[i]); - tensor_descs_[i] = nullptr; - } - } } - bool is_imma_case() { return imma_case_; } - - size_t algo_scratch_size() const { return algo_scratch_size_; } - size_t block_a_scratch_size() const { return source_size_; } - size_t block_b_scratch_size() const { return weight_size_; } - size_t block_c_scratch_size() const { return dest_size_; } - size_t src_scale_size() const { return src_scale_size_; } - size_t wei_scale_size() const { return wei_scale_size_; } - - bool multi_src_scale() const { return multi_src_scale_; } - bool multi_wei_scale() const { return multi_wei_scale_; } - bool multi_dst_scale() const { return multi_dst_scale_; } - cudaDataType_t scale_type() const { return acc_type_; } - -private: cublasLtMatmulDesc_t operation_desc_; cublasLtMatrixLayout_t a_layout_; @@ -682,8 +646,6 @@ struct cudnn_matmul_lt_impl_t : cudnn_matmul_base_impl_t { uint64_t weight_size_ = 0; uint64_t dest_size_ = 0; - uint64_t M_, N_, K_; - int64_t stride_a_, stride_b_, stride_c_, stride_a_blocked_, stride_b_blocked_, stride_c_blocked_, a_blocked_ld_, b_blocked_ld_, c_blocked_ld_; @@ -692,147 +654,182 @@ struct cudnn_matmul_lt_impl_t : cudnn_matmul_base_impl_t { cublasComputeType_t compute_type_ = CUBLAS_COMPUTE_32F; - cudaDataType_t src_type_, weights_type_, dst_type_; size_t alpha_beta_size_bytes_ = 0; void *alpha_ = nullptr; void *beta_ = nullptr; + size_t algo_scratch_size_ = 0; + cublasLtMatmulAlgo_t gemm_algo_; cublasLtMatmulPreference_t preference_; cublasLtMatmulHeuristicResult_t heuristic_results_; +}; - status_t create_matrix_layout(cublasLtMatrixLayout_t &layout, - cublasLtOrder_t order, cublasOperation_t trans, uint64_t row, - uint64_t col, uint64_t ld, const cudaDataType_t data_type, - cublasLtMatmulDescAttributes_t trans_attr, uint64_t stride) { - CUBLAS_EXECUTE_FUNC(cublasLtMatmulDescSetAttribute, operation_desc_, - trans_attr, &trans, sizeof(trans)); +struct cudnn_matmul_lt_impl_t { - CUBLAS_EXECUTE_FUNC( - cublasLtMatrixLayoutCreate, &layout, data_type, row, col, ld); + void set_non_runtime_params( + const std::shared_ptr &matmul_params) { + matmul_params_ = matmul_params; + } - CUBLAS_EXECUTE_FUNC(cublasLtMatrixLayoutSetAttribute, layout, - CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)); + void execute(cublasHandle_t cublas_handle, + const std::shared_ptr matmul_params, void *a, + void *b, void *c, void *bias, void *algo_scratch, + void *reorder_scratch, void *block_a_scratch, void *block_b_scratch, + void *block_c_scratch, void *scaled_src, void *scaled_wt, + void *src_scale, void *wei_scale, void *dst_scale) { - if (batch_count_ != 1) { - CUBLAS_EXECUTE_FUNC(cublasLtMatrixLayoutSetAttribute, layout, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count_, - sizeof(batch_count_)); - CUBLAS_EXECUTE_FUNC(cublasLtMatrixLayoutSetAttribute, layout, - CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride, - sizeof(stride)); + // use cached params unless using runtime dimensions + std::shared_ptr params + = matmul_params->has_runtime_params_ ? matmul_params + : matmul_params_; + + auto acc_type = params->acc_type_; + auto multi_wei_scale = params->multi_wei_scale_; + // read from binary output instead of input if multi scale or s32 scaling + if ((src_scale && acc_type == CUDA_R_32I) || multi_wei_scale) { + b = scaled_src; + } + if ((wei_scale && acc_type == CUDA_R_32I) || multi_wei_scale) { + a = scaled_wt; } - return status_t::dnnl_success; - } - void transform_matrix(cublasLtHandle_t handle, - cublasLtMatrixLayout_t in_layout, void *in, - cublasLtMatrixLayout_t out_layout, void *out, bool transpose, - cudaStream_t stream, int beta = 0) { - int alpha = 1; + cudaStream_t streamId; + auto lt_handle = (cublasLtHandle_t)(cublas_handle); + CUBLAS_EXECUTE_FUNC(cublasGetStream, cublas_handle, &streamId); - cublasOperation_t transform_trans - = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; - CUBLAS_EXECUTE_FUNC(cublasLtMatrixTransformDescSetAttribute, - trans_desc_, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, - &transform_trans, sizeof(transform_trans)); - CUBLAS_EXECUTE_FUNC(cublasLtMatrixTransform, handle, trans_desc_, - &alpha, in, in_layout, &beta, out, out_layout, out, out_layout, - stream); - } + auto b_layout = params->b_layout_; + auto blocked_b_layout = params->blocked_b_layout_; + auto a_layout = params->a_layout_; + auto blocked_a_layout = params->blocked_a_layout_; - void create_non_blocked_layouts() { - auto trans_op = cublasOperation_t::CUBLAS_OP_N; - auto order = CUBLASLT_ORDER_COL; + auto imma_ampere_case = params->imma_ampere_case_; - auto row = M_; - auto col = K_; - maybe_swap(row, col, trans_op, order, trans_a_); - create_matrix_layout(a_layout_, CUBLASLT_ORDER_COL, trans_op, row, col, - row, weights_type_, CUBLASLT_MATMUL_DESC_TRANSA, stride_a_); + if (imma_ampere_case) { + if (!params->src_blocked_) { + transform_matrix(lt_handle, params, b_layout, b, + blocked_b_layout, block_b_scratch, !params->trans_b_, + streamId); + b = block_b_scratch; + } + if (!params->w_blocked_) { + transform_matrix(lt_handle, params, a_layout, a, + blocked_a_layout, block_a_scratch, !params->trans_a_, + streamId); + a = block_a_scratch; + } + } - row = K_; - col = N_; - trans_op = cublasOperation_t::CUBLAS_OP_N; - maybe_swap(row, col, trans_op, order, trans_b_); - create_matrix_layout(b_layout_, CUBLASLT_ORDER_COL, trans_op, row, col, - row, src_type_, CUBLASLT_MATMUL_DESC_TRANSB, stride_b_); + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; + auto with_bias_epilogue = params->with_bias_epilogue_; + auto with_relu_epilogue = params->with_relu_epilogue_; - row = M_; - col = N_; - order = CUBLASLT_ORDER_COL; - maybe_swap(row, col, trans_op, order, trans_c_); - create_matrix_layout(c_layout_, order, cublasOperation_t::CUBLAS_OP_N, - row, col, row, dst_type_, CUBLASLT_MATMUL_DESC_TRANSC, - stride_c_); - } + auto operation_desc = params->operation_desc_; - void create_blocked_layouts() { - create_matrix_layout(blocked_a_layout_, CUBLASLT_ORDER_COL32, - cublasOperation_t::CUBLAS_OP_N, M_, K_, a_blocked_ld_, - weights_type_, CUBLASLT_MATMUL_DESC_TRANSA, stride_a_blocked_); + if (with_bias_epilogue) { + if (with_relu_epilogue) { + epilogue = CUBLASLT_EPILOGUE_RELU_BIAS; + } else { + epilogue = CUBLASLT_EPILOGUE_BIAS; + } + CUBLAS_EXECUTE_FUNC(cublasLtMatmulDescSetAttribute, operation_desc, + CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); + } else if (with_relu_epilogue && !with_bias_epilogue) { + epilogue = CUBLASLT_EPILOGUE_RELU; + } + CUBLAS_EXECUTE_FUNC(cublasLtMatmulDescSetAttribute, operation_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - create_matrix_layout(blocked_b_layout_, CUBLASLT_ORDER_COL32_2R_4R4, - cublasOperation_t::CUBLAS_OP_N, N_, K_, b_blocked_ld_, - src_type_, CUBLASLT_MATMUL_DESC_TRANSB, stride_b_blocked_); + float scale = 1.0f; + float host_dst_scale = 1.0f; + if (src_scale && !params->multi_src_scale_ && acc_type != CUDA_R_32I) { + float host_src_scale = 1.0f; + CUDA_EXECUTE_FUNC(cuMemcpy, (CUdeviceptr)&host_src_scale, + (CUdeviceptr)src_scale, sizeof(float)); + scale *= host_src_scale; + } + if (wei_scale && !params->multi_wei_scale_ && acc_type != CUDA_R_32I) { + float host_wei_scale = 1.0f; + CUDA_EXECUTE_FUNC(cuMemcpy, (CUdeviceptr)&host_wei_scale, + (CUdeviceptr)wei_scale, sizeof(float)); + scale *= host_wei_scale; + } + if (dst_scale && !params->multi_dst_scale_ && acc_type != CUDA_R_32I) { + CUDA_EXECUTE_FUNC(cuMemcpy, (CUdeviceptr)&host_dst_scale, + (CUdeviceptr)dst_scale, sizeof(float)); + // only applied here if no post ops used + scale /= host_dst_scale; + } - create_matrix_layout(blocked_c_layout_, CUBLASLT_ORDER_COL32, - cublasOperation_t::CUBLAS_OP_N, M_, N_, c_blocked_ld_, - dst_type_, CUBLASLT_MATMUL_DESC_TRANSC, stride_c_blocked_); + auto alpha = params->alpha_; + auto beta = params->beta_; + auto post_op_sum = params->post_op_sum_; - uint64_t row, col; - if (!w_blocked_) { - row = M_; - col = K_; - if (trans_a_) { std::swap(row, col); } - create_matrix_layout(a_layout_, CUBLASLT_ORDER_COL, - cublasOperation_t::CUBLAS_OP_N, row, col, row, - weights_type_, CUBLASLT_MATMUL_DESC_TRANSA, stride_a_); + if (acc_type == CUDA_R_16F) { + dnnl::impl::float16_t half_scale = scale; + dnnl::impl::float16_t half_gemm_beta = post_op_sum; + *static_cast(alpha) = half_scale; + *static_cast(beta) = half_gemm_beta; + } else { + *static_cast(alpha) = scale; + *static_cast(beta) = post_op_sum; } - if (!src_blocked_) { - row = K_; - col = N_; - if (trans_b_) { std::swap(row, col); } - create_matrix_layout(b_layout_, CUBLASLT_ORDER_COL, - cublasOperation_t::CUBLAS_OP_N, row, col, row, src_type_, - CUBLASLT_MATMUL_DESC_TRANSB, stride_b_); - } + auto dst_blocked = params->dst_blocked_; + auto c_layout = params->c_layout_; + auto gemm_algo = params->gemm_algo_; + auto heuristic_results = params->heuristic_results_; - if (!dst_blocked_) { - row = M_; - col = N_; - if (trans_c_) { std::swap(row, col); } - create_matrix_layout(c_layout_, CUBLASLT_ORDER_COL, - cublasOperation_t::CUBLAS_OP_N, row, col, row, dst_type_, - CUBLASLT_MATMUL_DESC_TRANSC, stride_c_); + if (imma_ampere_case) { + if (!dst_blocked) { + std::memset(beta, 0, params->alpha_beta_size_bytes_); + } + auto blocked_c_layout = params->blocked_c_layout_; + c = params->reorder_required_ ? reorder_scratch : c; + void *tmp_c = dst_blocked ? c : block_c_scratch; + CUBLAS_EXECUTE_FUNC(cublasLtMatmul, lt_handle, operation_desc, + alpha, a, blocked_a_layout, b, blocked_b_layout, beta, + tmp_c, blocked_c_layout, tmp_c, blocked_c_layout, + &gemm_algo, algo_scratch, heuristic_results.workspaceSize, + streamId); + + if (!dst_blocked) { + transform_matrix(lt_handle, params, blocked_c_layout, + block_c_scratch, c_layout, c, params->trans_c_, + streamId, post_op_sum); + } + } else { + CUBLAS_EXECUTE_FUNC(cublasLtMatmul, lt_handle, operation_desc, + alpha, a, a_layout, b, b_layout, beta, c, c_layout, c, + c_layout, &gemm_algo, algo_scratch, + heuristic_results.workspaceSize, streamId); } + } - // Constraint for Turing/Ampere kernels matmul config needs to be - // A^N B^T - cublasOperation_t b_trans_t = cublasOperation_t::CUBLAS_OP_T; - CUBLAS_EXECUTE_FUNC(cublasLtMatmulDescSetAttribute, operation_desc_, - CUBLASLT_MATMUL_DESC_TRANSB, &b_trans_t, sizeof(b_trans_t)); + ~cudnn_matmul_lt_impl_t() { + if (matmul_params_) { matmul_params_->cleanup(); } } - bool is_md_col_major(const memory_desc_wrapper &md) { - if (md.is_blocking_desc()) { - const auto &md_strides = &md.blocking_desc().strides[isbatched_]; - return (md_strides[1] == 1 && md.dims()[isbatched_ + 0] > 1); - } - return false; +private: + void transform_matrix(cublasLtHandle_t handle, + const std::shared_ptr ¶ms, + cublasLtMatrixLayout_t in_layout, void *in, + cublasLtMatrixLayout_t out_layout, void *out, bool transpose, + cudaStream_t stream, int beta = 0) { + int alpha = 1; + cublasLtMatrixTransformDesc_t trans_desc = params->trans_desc_; + cublasOperation_t transform_trans + = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; + CUBLAS_EXECUTE_FUNC(cublasLtMatrixTransformDescSetAttribute, trans_desc, + CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &transform_trans, + sizeof(transform_trans)); + CUBLAS_EXECUTE_FUNC(cublasLtMatrixTransform, handle, trans_desc, &alpha, + in, in_layout, &beta, out, out_layout, out, out_layout, stream); } - void maybe_swap(uint64_t &row, uint64_t &col, cublasOperation_t &op, - cublasLtOrder_t order, bool transpose) { - if (transpose) { - std::swap(row, col); - op = cublasOperation_t::CUBLAS_OP_T; - order = CUBLASLT_ORDER_ROW; - } - }; + std::shared_ptr matmul_params_; }; } // namespace nvidia diff --git a/src/gpu/nvidia/cudnn_reorder_lt_impl.hpp b/src/gpu/nvidia/cudnn_reorder_lt_impl.hpp index b9aba18a6ad..5158244487e 100644 --- a/src/gpu/nvidia/cudnn_reorder_lt_impl.hpp +++ b/src/gpu/nvidia/cudnn_reorder_lt_impl.hpp @@ -27,12 +27,6 @@ namespace impl { namespace gpu { namespace nvidia { -template >::type> -T ceildiv(T n, T d) { - return (n + d - 1) / d; -} - struct cublaslt_reorder_t { public: bool trans; diff --git a/src/gpu/nvidia/sycl_cuda_utils.hpp b/src/gpu/nvidia/sycl_cuda_utils.hpp index 09bb3a9bdf2..059a69bfff5 100644 --- a/src/gpu/nvidia/sycl_cuda_utils.hpp +++ b/src/gpu/nvidia/sycl_cuda_utils.hpp @@ -235,6 +235,12 @@ inline bool is_md_col32(const memory_desc_wrapper &md) { return false; } +template >::type> +T ceildiv(T n, T d) { + return (n + d - 1) / d; +} + class cublas_error : virtual public std::runtime_error { protected: