-
Notifications
You must be signed in to change notification settings - Fork 0
/
gemm_universal_adapter.h
546 lines (446 loc) · 20.8 KB
/
gemm_universal_adapter.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
batched array variants.
*/
#pragma once
// common
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/detail/layout.hpp"
#if !defined(__CUDACC_RTC__)
#include "cutlass/cluster_launch.hpp"
#include "cutlass/trace.h"
#endif // !defined(__CUDACC_RTC__)
// 2.x
#include "cutlass/gemm/device/gemm_universal_base.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
// 3.x
#include "cutlass/gemm/kernel/gemm_universal.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::device {
////////////////////////////////////////////////////////////////////////////////
/*!
GemmUniversalAdapter is a stateful, reusable GEMM handle built around a kernel
of type cutlass::gemm::kernel::Gemm or cutlass::gemm::kernel::GemmUniversal.
It manages the lifetime of the underlying `kernel::Params` struct, and exposes APIs
to create it from the host facing arguments. For power users, new static methods
are exposed in 3.x APIs that bypass the stateful methods or args->params lowering.
It supports kernel types that implement both the 2.x and 3.0 APIs,
however, this is done by specializing the implementation of GemmUniversalAdapter
on the two kernel API types, and thus, GemmUniversalAdapter's behaviour might
differ between the two specializations.
*/
template <class GemmKernel_, class Enable = void>
class GemmUniversalAdapter;
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template <class GemmKernel_>
class GemmUniversalAdapter<
GemmKernel_,
cute::enable_if_t<gemm::detail::IsCutlass3GemmKernel<GemmKernel_>::value>>
{
public:
using GemmKernel = GemmKernel_;
using TileShape = typename GemmKernel::TileShape;
using ElementA = typename GemmKernel::ElementA;
using ElementB = typename GemmKernel::ElementB;
using ElementC = typename GemmKernel::ElementC;
using ElementD = typename GemmKernel::ElementD;
using ElementAccumulator = typename GemmKernel::TiledMma::ValTypeC;
using DispatchPolicy = typename GemmKernel::DispatchPolicy;
using CollectiveMainloop = typename GemmKernel::CollectiveMainloop;
using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue;
// Map back to 2.x type as best as possible
using LayoutA = gemm::detail::StrideToLayoutTagA_t<typename GemmKernel::StrideA>;
using LayoutB = gemm::detail::StrideToLayoutTagB_t<typename GemmKernel::StrideB>;
using LayoutC = gemm::detail::StrideToLayoutTagC_t<typename GemmKernel::StrideC>;
using LayoutD = gemm::detail::StrideToLayoutTagC_t<typename GemmKernel::StrideD>;
// NOTE: 3.0 kernels do not support complex transforms for now ...
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
// Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0
using MathOperator = cutlass::arch::OpMultiplyAdd;
// All tensorop operations have atom shape's M >= 8
using OperatorClass = cute::conditional_t<
cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}) >= 8,
cutlass::arch::OpClassTensorOp, cutlass::arch::OpClassSimt>;
using ArchTag = typename GemmKernel::ArchTag;
// NOTE: Assume identity swizzle for now
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape
using ThreadblockShape = cutlass::gemm::GemmShape<
cute::size<0>(TileShape{}),
cute::size<1>(TileShape{}),
cute::size<2>(TileShape{})>;
using ClusterShape = cutlass::gemm::GemmShape<
cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})>;
// Instruction shape is easy too, since we get that directly from our TiledMma's atom shape
using InstructionShape = cutlass::gemm::GemmShape<
cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}),
cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}),
cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>;
// Legacy: provide a correct warp count, but no reliable warp shape
static int const kThreadCount = GemmKernel::MaxThreadsPerBlock;
// Warp shape is not a primary API type in 3.x
// But we can best approximate it by inspecting the TiledMma::TiledShape_MNK
// For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K
// We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads
static constexpr int WarpsInMma = cute::max(4, cute::size(typename GemmKernel::TiledMma{}) / 32);
static constexpr int WarpsInMmaM = 4;
static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM);
using WarpCount = cutlass::gemm::GemmShape<WarpsInMmaM, WarpsInMmaN, 1>;
using WarpShape = cutlass::gemm::GemmShape<
cute::size<0>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaM,
cute::size<1>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaN,
cute::size<2>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{})>;
static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages;
// Inspect TiledCopy for A and B to compute the alignment size
static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy<
typename CollectiveMainloop::GmemTiledCopyA, ElementA>();
static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy<
typename CollectiveMainloop::GmemTiledCopyB, ElementB>();
static int constexpr kAlignmentC = cutlass::detail::get_alignment_count_from_gmem_tiled_copy<
typename CollectiveEpilogue::GmemTiledCopyC, ElementC>();
static int constexpr kAlignmentD = cutlass::detail::get_alignment_count_from_gmem_tiled_copy<
typename CollectiveEpilogue::GmemTiledCopyD, ElementD>();
using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp;
// Split-K preserves splits that are 128b aligned
static int constexpr kSplitKAlignment = cute::max(
128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
/// Argument structure: User API
using Arguments = typename GemmKernel::Arguments;
/// Argument structure: Kernel API
using Params = typename GemmKernel::Params;
private:
/// Kernel API parameters object
Params params_;
public:
/// Access the Params structure
Params const& params() const {
return params_;
}
/// Determines whether the GEMM can execute the given problem.
static Status
can_implement(Arguments const& args) {
if (GemmKernel::can_implement(args)) {
return Status::kSuccess;
}
else {
return Status::kInvalid;
}
}
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
size_t workspace_bytes = 0;
if (args.mode == GemmUniversalMode::kGemmSplitKParallel) {
workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * size_t(cute::size<1>(TileShape{}));
}
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
workspace_bytes += GemmKernel::get_workspace_size(args);
return workspace_bytes;
}
/// Computes the grid shape
static dim3
get_grid_shape(Arguments const& args, void* workspace = nullptr) {
auto tmp_params = GemmKernel::to_underlying_arguments(args, workspace);
return GemmKernel::get_grid_shape(tmp_params);
}
/// Computes the grid shape
static dim3
get_grid_shape(Params const& params) {
return GemmKernel::get_grid_shape(params);
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
CUTLASS_TRACE_HOST("GemmUniversal::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = GemmKernel::SharedStorageSize;
// first, account for dynamic smem capacity if needed
cudaError_t result;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
result = cudaFuncSetAttribute(
device_kernel<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return -1;
}
}
// query occupancy after setting smem size
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
device_kernel<GemmKernel>,
GemmKernel::MaxThreadsPerBlock,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
<< cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
/// Initializes GEMM state from arguments.
Status
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Initialize the workspace
Status status = GemmKernel::initialize_workspace(args, workspace, stream);
if (status != Status::kSuccess) {
return status;
}
// Initialize the Params structure
params_ = GemmKernel::to_underlying_arguments(args, workspace);
// account for dynamic smem capacity if needed
int smem_size = GemmKernel::SharedStorageSize;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
cudaError_t result = cudaFuncSetAttribute(
device_kernel<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
Status
update(Arguments const& args, void* workspace = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversal()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes > 0 && nullptr == workspace) {
return Status::kErrorWorkspaceNull;
}
params_ = GemmKernel::to_underlying_arguments(args, workspace);
return Status::kSuccess;
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling GemmKernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversal::run()");
dim3 const block = GemmKernel::get_block_shape();
dim3 const grid = get_grid_shape(params);
// configure smem size and carveout
int smem_size = GemmKernel::SharedStorageSize;
Status launch_result;
// Use extended launch API only for mainloops that use it
if constexpr(GemmKernel::ArchTag::kMinComputeCapability >= 90) {
dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{}));
void const* kernel = (void const*) device_kernel<GemmKernel>;
void* kernel_params[] = {¶ms};
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
}
else {
launch_result = Status::kSuccess;
device_kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params);
}
cudaError_t result = cudaGetLastError();
if (cudaSuccess == result && Status::kSuccess == launch_result) {
return Status::kSuccess;
}
else {
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (Status::kSuccess == status) {
status = run(params_, stream);
}
return status;
}
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
return run(args, workspace, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
operator()(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
};
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 2.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template <typename GemmKernel_>
class GemmUniversalAdapter<
GemmKernel_,
cute::enable_if_t<not gemm::detail::IsCutlass3GemmKernel<GemmKernel_>::value>>
{
public:
using GemmKernel = GemmKernel_;
static bool const kInternalTranspose =
cute::is_same<typename GemmKernel::LayoutC, cutlass::layout::RowMajor>::value;
using ThreadblockShape = typename GemmKernel::Mma::Shape;
using WarpShape = typename GemmKernel::WarpShape;
using InstructionShape = typename GemmKernel::InstructionShape;
// warp-level, arch-level (instruction), math operator
using WarpMmaOperator = typename GemmKernel::Mma::Policy::Operator;
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
using MathOperator = typename WarpMmaOperator::MathOperator;
// Operator class and arch tag extract bottom-up
// set it for top-level gemm device-level template
using OperatorClass = typename WarpMmaOperator::OperatorClass;
using ArchTag = typename WarpMmaOperator::ArchTag;
// Type, layout, and complex transform deliberately exchanged with B
using MapArguments = kernel::detail::MapArguments<
typename GemmKernel::ElementA,
typename GemmKernel::LayoutA,
GemmKernel::kTransformA,
GemmKernel::kAlignmentA,
typename GemmKernel::ElementB,
typename GemmKernel::LayoutB,
GemmKernel::kTransformB,
GemmKernel::kAlignmentB,
typename GemmKernel::LayoutC,
kInternalTranspose
>;
using ElementA = typename MapArguments::ElementA;
using LayoutA = typename MapArguments::LayoutA;
static ComplexTransform const kTransformA = MapArguments::kTransformA;
static int const kAlignmentA = MapArguments::kAlignmentA;
using ElementB = typename MapArguments::ElementB;
using LayoutB = typename MapArguments::LayoutB;
static ComplexTransform const kTransformB = MapArguments::kTransformB;
static int const kAlignmentB = MapArguments::kAlignmentB;
using ElementC = typename GemmKernel::ElementC;
using LayoutC = typename MapArguments::LayoutC;
static int const kAlignmentC = GemmKernel::kAlignmentC;
// C and D same type for 2.x kernel
using ElementD = ElementC;
using LayoutD = LayoutC;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementD, LayoutD>;
static int const kStages = GemmKernel::Mma::kStages;
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator;
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
using UnderlyingOperator = GemmUniversalBase<GemmKernel>;
using Arguments = typename UnderlyingOperator::Arguments;
private:
UnderlyingOperator underlying_operator_;
public:
/// Constructs the GEMM.
GemmUniversalAdapter() { }
/// Helper to construct a transposed equivalent for the underying GEMM operator
static Arguments to_underlying_arguments(Arguments const &args) {
if (kInternalTranspose) {
return args.transposed_problem();
}
else {
return args;
}
}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const &args) {
return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args));
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1) {
return UnderlyingOperator::maximum_active_blocks(smem_capacity);
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
}
/// Lightweight update given a subset of arguments.
Status update(Arguments const &args) {
return underlying_operator_.update(to_underlying_arguments(args));
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
return underlying_operator_.run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::device
////////////////////////////////////////////////////////////////////////////////