-
Notifications
You must be signed in to change notification settings - Fork 0
/
implicit_gemm_convolution_fusion.h
268 lines (205 loc) · 9.81 KB
/
implicit_gemm_convolution_fusion.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Template for device-level fused activation's scale+bias+relu and Implicit GEMM Convolution
*/
#pragma once
#include <limits>
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#include "cutlass/conv/convolution.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace conv {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
template<typename ImplicitGemmFusionKernel_>
class ImplicitGemmConvolutionFusion {
public:
using ImplicitGemmFusionKernel = ImplicitGemmFusionKernel_;
using ElementA = typename ImplicitGemmFusionKernel::ElementA;
using LayoutA = typename ImplicitGemmFusionKernel::LayoutA;
using ElementB = typename ImplicitGemmFusionKernel::ElementB;
using LayoutB = typename ImplicitGemmFusionKernel::LayoutB;
// using ElementScaleBias = typename ImplicitGemmFusionKernel::ElementScaleBias;
// using LayoutScaleBias = typename ImplicitGemmFusionKernel::LayoutScaleBias;
using ElementC = typename ImplicitGemmFusionKernel::ElementC;
using LayoutC = typename ImplicitGemmFusionKernel::LayoutC;
using ElementAccumulator = typename ImplicitGemmFusionKernel::ElementAccumulator;
using ElementCompute = typename ImplicitGemmFusionKernel::ElementCompute;
using OperatorClass = typename ImplicitGemmFusionKernel::OperatorClass;
using ArchTag = typename ImplicitGemmFusionKernel::ArchTag;
using ThreadblockShape = typename ImplicitGemmFusionKernel::ThreadblockShape;
using WarpShape = typename ImplicitGemmFusionKernel::WarpShape;
using InstructionShape = typename ImplicitGemmFusionKernel::InstructionShape;
using ThreadblockSwizzle = typename ImplicitGemmFusionKernel::ThreadblockSwizzle;
using EpilogueOutputOp = typename ImplicitGemmFusionKernel::EpilogueOutputOp;
static int const kStages = ImplicitGemmFusionKernel::kStages;
static int const kConvDim = ImplicitGemmFusionKernel::kConvDim;
using WarpMmaOperator = typename ImplicitGemmFusionKernel::WarpMmaOperator;
using ArchMmaOperator = typename ImplicitGemmFusionKernel::ArchMmaOperator;
using MathOperator = typename ImplicitGemmFusionKernel::MathOperator;
static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmFusionKernel::kConvolutionalOperator;
static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmFusionKernel::kIteratorAlgorithm;
static int const kWarpCount =
(ThreadblockShape::kM / WarpShape::kM) *
(ThreadblockShape::kN / WarpShape::kN) *
(ThreadblockShape::kK / WarpShape::kK);
/// Argument structure
using Arguments = typename ImplicitGemmFusionKernel::Arguments;
private:
/// Kernel parameters object
typename ImplicitGemmFusionKernel::Params params_;
public:
/// Constructs Implicit GEMM
ImplicitGemmConvolutionFusion() { }
/// Determines whether the Implicit GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
// dispatch to iterators
Status status = ImplicitGemmFusionKernel::Mma::IteratorA::can_implement(args.problem_size);
if (Status::kSuccess != status) {
return status;
}
status = ImplicitGemmFusionKernel::Mma::IteratorB::can_implement(args.problem_size);
if (Status::kSuccess != status) {
return status;
}
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(
threadblock_swizzle.get_tiled_shape(
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size),
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.problem_size.split_k_slices));
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
grid.z <= std::numeric_limits<uint16_t>::max())) {
return Status::kErrorInvalidProblem;
}
return Status::kSuccess;
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
size_t workspace_bytes = 0;
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size),
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.problem_size.split_k_slices);
if(args.split_k_mode == SplitKMode::kParallel) {
// Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace.
// The user needs to call a reduction operator to optain the final output tensor
workspace_bytes =
sizeof(ElementAccumulator) *
size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) *
size_t(grid_tiled_shape.k());
}
else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) {
// Split-K serial: The user workspace is used to store semaphore and serialize writing the
// final reduced output to user's output tensor
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
}
return workspace_bytes;
}
/// Initializes GEMM state from arguments.
Status initialize(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
if (args.problem_size.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream);
if (status != cudaSuccess) {
return Status::kErrorInternal;
}
}
// initialize the params structure from the arguments
params_ = typename ImplicitGemmFusionKernel::Params(
args,
static_cast<int *>(workspace)
);
int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<ImplicitGemmFusionKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Initializes Impicit GEMM state from arguments.
Status update(Arguments const &args, void *workspace = nullptr) {
// update the params structure from the arguments
params_.ptr_A = args.ref_A.data();
params_.ptr_B = args.ref_B.data();
params_.ptr_scale = args.ref_A_scale.data();
params_.ptr_bias = args.ref_A_bias.data();
params_.ptr_C = args.ref_C.data();
params_.ptr_D = args.ref_D.data();
params_.output_op = args.output_op;
params_.semaphore = static_cast<int *>(workspace);
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(32 * kWarpCount, 1, 1);
int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage));
cutlass::Kernel<ImplicitGemmFusionKernel><<<grid, block, smem_size, stream>>>(params_);
cudaError_t result = cudaGetLastError();
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////