-
Notifications
You must be signed in to change notification settings - Fork 0
/
sm90_gemm_tma_warpspecialized.hpp
451 lines (396 loc) · 18.5 KB
/
sm90_gemm_tma_warpspecialized.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.hpp"
#include "cute/arch/cluster_sm90.hpp"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/arch/mma_sm90.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
#include "cute/tensor.hpp"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel {
///////////////////////////////////////////////////////////////////////////////
template <
class ProblemShape_,
class CollectiveMainloop_,
class CollectiveEpilogue_,
class TileScheduler_
>
class GemmUniversal<
ProblemShape_,
CollectiveMainloop_,
CollectiveEpilogue_,
TileScheduler_,
cute::enable_if_t<cute::is_base_of_v<KernelTmaWarpSpecialized, typename CollectiveMainloop_::DispatchPolicy::Schedule>>>
{
public:
//
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
using StrideA = typename CollectiveMainloop::StrideA;
using ElementB = typename CollectiveMainloop::ElementB;
using StrideB = typename CollectiveMainloop::StrideB;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using ClusterShape = typename DispatchPolicy::ClusterShape;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
static_assert(ArchTag::kMinComputeCapability >= 90);
// Epilogue derived types
using CollectiveEpilogue = CollectiveEpilogue_;
using ElementC = typename CollectiveEpilogue::ElementC;
using StrideC = typename CollectiveEpilogue::StrideC;
using ElementD = typename CollectiveEpilogue::ElementD;
using StrideD = typename CollectiveEpilogue::StrideD;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params;
static_assert(cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>,
"TMA warp-specialized kernel does not support specializing the tile scheduler.");
using TileSchedulerTag = TileScheduler_;
using TileScheduler = typename detail::TileSchedulerSelector<
TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
using TileSchedulerArguments = typename TileScheduler::Arguments;
// Kernel level shared memory storage
struct SharedStorage {
// Mainloop and epilogue don't use smem concurrently since kernel is non-persistent, so we can use a union
union TensorStorage {
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
MainloopTensorStorage mainloop;
EpilogueTensorStorage epilogue;
} tensors;
struct PipelineStorage : cute::aligned_struct<16> {
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
alignas(16) MainloopPipelineStorage mainloop;
alignas(16) EpiLoadPipelineStorage epi_load;
} pipelines;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
static constexpr uint32_t NumLoadWarpGroups = 1;
static constexpr uint32_t NumMmaWarpGroups = 1;
static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}) + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
// Device side arguments
struct Arguments {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopArguments mainloop{};
EpilogueArguments epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerArguments scheduler{};
};
// Kernel entry point API
struct Params {
GemmUniversalMode mode;
ProblemShape problem_shape;
MainloopParams mainloop;
EpilogueParams epilogue;
};
//
// Methods
//
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
static
Params
to_underlying_arguments(Arguments const& args, void* workspace) {
(void) workspace;
auto problem_shape = args.problem_shape;
if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
// swap M/N
get<0>(problem_shape) = get<1>(args.problem_shape);
get<1>(problem_shape) = get<0>(args.problem_shape);
}
return {
args.mode,
problem_shape,
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace)
};
}
CUTLASS_HOST_DEVICE static
bool
can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable;
}
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
return implementable;
}
static
int
get_workspace_size(Arguments const& args) {
return 0;
}
static
cutlass::Status
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
return Status::kSuccess;
}
// Computes the kernel launch grid shape based on runtime parameters
static dim3
get_grid_shape(Params const& params) {
auto cluster_shape = ClusterShape{};
auto tile_shape = TileShape{};
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
return TileScheduler::get_tiled_cta_shape_mnl(
problem_shape_MNKL, tile_shape, cluster_shape);
}
static dim3
get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
CUTLASS_DEVICE
void
operator()(Params const& params, char* smem_buf) {
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
return;
}
#endif
enum class WarpGroupRole {
Producer = 0,
Consumer = 1,
};
enum class ProducerWarpRole {
MainloopEpilogue = 0,
Warp1 = 1,
Warp2 = 2,
Warp3 = 3
};
// Kernel level shared memory storage
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
int thread_idx = int(threadIdx.x);
int lane_idx = canonical_lane_idx();
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
int lane_predicate = cute::elect_one_sync();
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
// Issue Tma Descriptor Prefetch from a single thread
if ((warp_idx == 0) && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
}
// Mainloop Load pipeline
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
typename MainloopPipeline::Params mainloop_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes;
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params);
// Epilogue Load pipeline
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
typename EpiLoadPipeline::Params epi_load_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) {
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer) {
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
}
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
// Epilogue Store pipeline
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
typename EpiStorePipeline::Params epi_store_pipeline_params;
epi_store_pipeline_params.always_wait = true;
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
// Initialize starting pipeline states for the collectives
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
// For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
auto cluster_wait_fn = [&] () {
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer thread blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
return [] () { cute::cluster_wait(); };
}
else {
__syncthreads();
return [] () {}; // do nothing
}
} ();
// Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
// Separate out problem shape for convenience
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
auto M = get<0>(problem_shape_MNKL);
auto N = get<1>(problem_shape_MNKL);
auto K = get<2>(problem_shape_MNKL);
auto L = get<3>(problem_shape_MNKL);
// TMA requires special handling of strides to deal with coord codomain mapping
// Represent the full tensors -- get these from TMA
Tensor mA_mkl = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
Tensor mB_nkl = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
// Get the appropriate blocks for this thread block -- potential for thread block locality
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
TiledMma tiled_mma;
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
// Compute m_coord, n_coord, and l_coord with their post-tiled shapes
auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl));
auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl));
auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
// Slice with m_coord and n_coord
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
// Get pipeline iterators and increments from tensor shapes
auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA));
auto k_tile_count = size<2>(gA);
// Wait for all thread blocks in the Cluster
cluster_wait_fn();
// In a warp specialized kernel, collectives expose data movement and compute operations separately
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
if (warp_group_role == WarpGroupRole::Producer) {
if (producer_warp_role == ProducerWarpRole::MainloopEpilogue) {
collective_mainloop.load(
mainloop_pipeline,
mainloop_pipe_producer_state,
gA, params.mainloop.tma_load_a,
gB, params.mainloop.tma_load_b,
k_tile_iter, k_tile_count,
lane_idx,
block_rank_in_cluster,
shared_storage.tensors.mainloop
);
// Update starting mainloop pipeline state for the pipeline drain
mainloop_pipe_producer_state.advance(k_tile_count);
// Make sure mainloop consumer has been waited upon before issuing epilogue load
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
if (collective_epilogue.is_producer_load_needed()) {
// Ensure warp is converged before issuing epilogue loads
__syncwarp();
epi_load_pipe_producer_state =
collective_epilogue.load(
epi_load_pipeline,
epi_load_pipe_producer_state,
problem_shape_MNKL,
blk_shape,
blk_coord,
tiled_mma,
lane_idx,
shared_storage.tensors.epilogue
);
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
}
}
}
else if (warp_group_role == WarpGroupRole::Consumer) {
Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
collective_mainloop.mma(
mainloop_pipeline,
mainloop_pipe_consumer_state,
accumulators,
k_tile_count,
thread_idx,
shared_storage.tensors.mainloop,
params.mainloop
);
// Make sure the math instructions are done and free buffers before entering the epilogue
collective_mainloop.mma_tail(
mainloop_pipeline,
mainloop_pipe_consumer_state,
k_tile_count
);
// Epilogue and write to gD
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
collective_epilogue.store(
epi_load_pipeline,
epi_load_pipe_consumer_state,
epi_store_pipeline,
epi_store_pipe_producer_state,
problem_shape_MNKL,
blk_shape,
blk_coord,
accumulators,
tiled_mma,
warp_group_thread_idx,
shared_storage.tensors.epilogue
);
collective_epilogue.store_tail(
epi_load_pipeline,
epi_load_pipe_consumer_state_next,
epi_store_pipeline,
epi_store_pipe_producer_state_next
);
}
}
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel