-
Notifications
You must be signed in to change notification settings - Fork 0
/
gather_scatter_fusion.cu
545 lines (451 loc) · 20.5 KB
/
gather_scatter_fusion.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// This example fuses gather before GEMM and scatter after GEMM into the same
// GEMM kernel. Gather and scatter operation is controled by an index vector
// to select rows or columns from A, B, C or D matrices.
//
// Suppose, all matrices are column major. The pseudo code of the fused kernel
// in this example is essentially
//
// for (int i = 0; i < problem_size.m(); ++i) {
// for (int j = 0; j < options.index_size; ++j) {
// int b_c_d_col = tensor_indices.at({j, 0});
//
// for (int k = 0; k < options.index_size; ++k) {
// tensor_d_ref.at({i, b_c_d_col}) +=
// alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col});
// }
// }
//
// Note that the index vector contains unique random integers with max to be N - 1
//
// The gather/scatter operation works best when we can still keep the biggest
// alignment. For example, when the matrix is row major, we select rows. When
// the matrix is column major, we select columns.
//
// Not all the combination of gather and scatter are legal. For example, if A is
// row major and C/D is column major, we cannot gather A and scatter C/D at the
// same time.
//
// Also, we don't check the index value is legal and index array point is valid
// for the sake of the performance.
#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include <math.h>
#include <assert.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <iostream>
#include <fstream>
#include <random>
#include <numeric>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result {
double runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
//
// Methods
//
Result(
double runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess
):
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
cutlass::gemm::GemmCoord problem_size;
int index_size;
bool reference_check;
int iterations;
Options():
help(false),
problem_size({248, 1024, 1024}),
index_size(240),
reference_check(true),
iterations(20) { }
bool valid() {
return true;
}
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("index_size", index_size);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "36_gather_scatter_fusion example\n\n"
<< " This example uses the CUTLASS Library to fuse gather/scatter into GEMM\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m=<int> GEMM M dimension\n"
<< " --n=<int> GEMM N dimension\n"
<< " --k=<int> GEMM K dimension\n"
<< " --index_size=<int> size of N dimension index\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/36_gather_scatter_fusion/36_gather_scatter_fusion --m=1024 --n=512 --k=1024 \\\n"
<< " --index_size=128\n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of real-valued multiply-adds
int64_t fmas = problem_size.product();
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A
using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D
// The code section below describes matrix layout of input and output matrices.
// Column Major for Matrix A, B and C.
//
using LayoutInputA = cutlass::layout::ColumnMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::ColumnMajor;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassTensorOp;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm80;
// This code section describes the tile size a thread block will compute
using ShapeMMAThreadBlock =
cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32
// This code section describes tile size a warp will compute
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32
// This code section describes the size of MMA op
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; // <- MMA Op tile M = 8, N = 8, K = 4
// 16, 8, 8 -> Turing
// 16, 8, 16 -> Ampere
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
// Define the epilogue operation as LinearCombination. This is approximately equal to
//
// d_ij = alpha * sum_k(a_ik * b_kj) + c_ij
//
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This becomes
// the vector width of math instructions in
// epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha in linear combination function
// Number of pipelines you want to use
constexpr int NumStages = 5;
// Ampere -> 4/5
// Turing -> 2
using Gemm = cutlass::gemm::device::GemmUniversal<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
8, /*alignmentA*/
8, /*alignmentB*/
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
false, /*GatherA*/
true, /*GatherB*/
true /*ScatterD*/
>;
int run(Options &options) {
// ================================================================================
// Initialization setup
// Create a tuple of problem size for matrix multiplication
cutlass::gemm::GemmCoord problem_size = options.problem_size;
// Create a tuple of problem size for matrix multiplication
cutlass::gemm::GemmCoord problem_size_real(problem_size.m(),
options.index_size,
problem_size.k());
// Initialize tensors using CUTLASS helper functions
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
problem_size.mk()); // <- Create matrix A with dimensions M x K
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
problem_size.kn()); // <- Create matrix B with dimensions K x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
problem_size.mn()); // <- Create matrix C with dimensions M x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d_scattered(
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
// CUTLASS kernel
// Fill input and output matrices on host using CUTLASS helper functions
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(7),
ElementInputA(-8),
0); // <- Fill matrix A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputA(7),
ElementInputA(-8),
0); // <- Fill matrix B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(7),
ElementOutput(-8),
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFill(
tensor_d_scattered.host_view()); // <- fill matrix D on host with zeros
cutlass::HostTensor<int, LayoutOutput> tensor_indices(
{options.index_size, 1}); // <- Create scatter indices with dimensions val_len x 1
// <- Fill tensor_b_indices on host with unique random integers
std::vector<int> to_fill(problem_size.n()) ; // vector with ints.
std::iota (std::begin(to_fill), std::end(to_fill), 0); // Fill with 0, 1, ...., problem_size.n()
{ // std::random_shuffle was deprecated in C++14 and removed in C++17
std::random_device make_seed;
std::mt19937 source_of_randomness(make_seed());
std::shuffle(to_fill.begin(), to_fill.end(), source_of_randomness);
}
memcpy(tensor_indices.host_data(), to_fill.data(), options.index_size * sizeof(int));
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_indices.sync_device();
tensor_c.sync_device();
tensor_d_scattered.sync_device();
// Initialize alpha/beta for dot product computation
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
ElementComputeEpilogue beta = ElementComputeEpilogue(1);
// Split K dimension into 1 partitions
int split_k_slices = 1;
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
// instantiated CUTLASS kernel
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size_real, // <- problem size of matrix multiplication
split_k_slices, // <- k-dimension split factor
{alpha, beta}, // <- alpha, beta
tensor_a.device_data(), // <- reference to matrix A on device
tensor_b.device_data(), // <- reference to matrix B on device
tensor_c.device_data(), // <- reference to matrix C on device
tensor_d_scattered.device_data(), // <- reference to matrix D on device
tensor_a.layout().capacity(problem_size.mk()),
tensor_b.layout().capacity(cutlass::make_Coord(options.index_size, problem_size.n())),
tensor_c.layout().capacity(problem_size.mn()),
tensor_d_scattered.layout().capacity(problem_size.mn()),
tensor_a.layout().stride(),
tensor_b.layout().stride(),
tensor_c.layout().stride(),
tensor_d_scattered.layout().stride(),
nullptr, // <- pointer to index vector to gather A on device
tensor_indices.device_data(), // <- pointer to index vector to gather B on device
tensor_indices.device_data()}; // <- pointer to index vector to scatter D on device
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm_op;
// Check the problem size is supported or not
cutlass::Status status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(status);
// CPU reference calculation
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d_ref(problem_size.mn());
cutlass::reference::host::TensorFill(
tensor_d_ref.host_view()); // <- Fill matrix D on host with zeros
status = gemm_op();
cudaDeviceSynchronize();
CUTLASS_CHECK(status);
if (options.reference_check) {
for (int i = 0; i < problem_size.m(); ++i) {
for (int j = 0; j < options.index_size; ++j) {
int b_c_d_col = tensor_indices.at({j, 0});
for (int k = 0; k < problem_size.k(); ++k) {
tensor_d_ref.at({i, b_c_d_col}) +=
alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col});
}
tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col}));
}
}
// Copy output data from CUTLASS and reference kernel to host for comparison
tensor_d_scattered.sync_host();
bool passed = cutlass::reference::host::TensorEquals(
tensor_d_scattered.host_view(),
tensor_d_ref.host_view());
if (!passed) {
std::cout << "Failed!\n";
std::stringstream fname;
fname << "error_gather_GEMM_scatter_fusion.txt";
std::cerr << "Dumping results in " << fname.str() << "\n";
std::ofstream file(fname.str());
file
<< "A =\n" << tensor_a.host_view()
<< "\nB =\n" << tensor_b.host_view()
<< "\nindices =\n" << tensor_indices.host_view()
<< "\nC =\n" << tensor_c.host_view()
<< "\n\nReference =\n" << tensor_d_ref.host_view()
<< "\nComputed =\n" << tensor_d_scattered.host_view();
return -1;
} else {
std::cout << "Passed!\n";
}
}
// Result structure
Result result;
//
// Construct events
//
cudaEvent_t events[2];
for (auto & event : events) {
result.error = cudaEventCreate(&event);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
return -1;
}
}
// Record an event at the start of a series of GEMMs
result.error = cudaEventRecord(events[0]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return -1;
}
//
// Run profiling loop
//
for (int iter = 0; iter < options.iterations; ++iter) {
// Launch initialized CUTLASS kernel
status = gemm_op();
CUTLASS_CHECK(status);
}
//
// Stop profiling loop
//
// Record an event when the GEMMs are complete
result.error = cudaEventRecord(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return -1;
}
// Wait for work on the device to complete.
result.error = cudaEventSynchronize(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
return -1;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
return -1;
}
// Compute average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)cudaEventDestroy(event);
}
std::cout << "Runtime: " << result.runtime_ms << " ms\n";
std::cout << " GFLOPs: " << result.gflops << "\n";
return 0;
}
int main(int argc, const char ** argv) {
bool notSupported = false;
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
//
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
notSupported = true;
}
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
if (!(props.major >= 8)) {
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
<< std::endl;
notSupported = true;
}
if (notSupported) {
return 0;
}
Options options;
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << "\n";
return 0;
}
if (!options.valid()) {
std::cerr << "Invalid problem." << "\n";
return -1;
}
return run(options);
}