-
Notifications
You must be signed in to change notification settings - Fork 0
/
fmha_grouped.h
1042 lines (906 loc) · 36.6 KB
/
fmha_grouped.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Grouped FMHA kernel
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "fmha_grouped_problem_visitor.h"
#include "gemm_kernel_utils.h"
#include "gemm/mma_accum_lambda_iterator.h"
#include "epilogue/epilogue_rescale_output.h"
namespace {
static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
// source: https://stackoverflow.com/a/51549250
return (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename MM0_, ///! Structure for computing P = Q @ K
typename MM1_, ///! Structure for computing O = P @ V
typename scalar_t_,
typename accum_t_,
typename output_t_,
typename output_accum_t_,
bool kKeepOutputInRF, ///! Whether the intermediate output from MM0_ should be kept in the register file
GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to perform
>
struct FMHAGrouped {
public:
using MM0 = MM0_;
using MM1 = MM1_;
using scalar_t = scalar_t_;
using accum_t = accum_t_;
using output_t = output_t_;
using output_accum_t = output_accum_t_;
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
!cutlass::platform::is_same<output_accum_t, output_t>::value;
// Parameters to satisfy BaseGrouped
using ElementA = scalar_t;
using ElementB = scalar_t;
using ElementC = accum_t;
using LayoutA = typename MM0::LayoutA;
using LayoutB = typename MM0::ElementB;
using LayoutC = typename MM1::ElementC;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
static int const kAlignmentA = MM0::kAlignmentA;
static int const kAlignmentB = MM0::kAlignmentB;
static int const kAlignmentC = 1;
using Mma = typename MM1::Mma;
using EpilogueOutputOp = typename MM1::EpilogueOutputOp;
using ThreadblockSwizzle = void;
using Operator = typename MM1::Operator;
using WarpShape = typename MM1::WarpShape;
using InstructionShape = typename MM1::InstructionShape;
using ElementQ = scalar_t;
using ElementK = scalar_t;
using ElementP = accum_t;
using ElementV = scalar_t;
using ElementO = output_t;
using ElementOAccum = output_accum_t;
using ElementAccumulator = accum_t;
using LayoutQ = typename MM0::LayoutA;
using LayoutK = typename MM0::LayoutB;
using LayoutP = typename MM0::LayoutC;
using LayoutV = typename MM1::LayoutB;
using LayoutO = typename MM1::LayoutC;
static bool const kPreloadV = (MM1::Mma::ArchTag::kMinComputeCapability >= 80 &&
cutlass::sizeof_bits<ElementV>::value == 16);
static int const kAlignmentQ = MM0::kAlignmentA;
static int const kAlignmentK = MM0::kAlignmentB;
static int const kAlignmentV = 1;
using ThreadblockShape = typename MM0::ThreadblockShape;
static int const kQueriesPerBlock = ThreadblockShape::kM;
static int const kKeysPerBlock = ThreadblockShape::kN;
static constexpr bool kSupportsDropout = false;
static constexpr bool kSupportsBias = false;
/// Warp count (concept: GemmShape)
using WarpCount = typename MM1::WarpCount;
static int const kThreadsPerWarp = 32;
static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount;
static constexpr int kNumWarpsPerBlock =
kQueriesPerBlock * kKeysPerBlock / (kThreadsPerWarp * kThreadsPerWarp);
using ProblemVisitor = FMHAGroupedProblemVisitor<
ThreadblockShape,
kGroupScheduleMode,
kThreadCount,
kThreadCount>;
//
// Structures
//
/// Argument structure
struct Arguments {
//
// Data members
//
GemmCoord *problem_sizes0;
GemmCoord *problem_sizes1;
int problem_count;
int threadblock_count;
ElementQ ** ptr_Q;
ElementK ** ptr_K;
ElementP ** ptr_P;
ElementV ** ptr_V;
ElementO ** ptr_O;
ElementOAccum ** ptr_O_accum;
typename LayoutQ::Stride::LongIndex *ldq;
typename LayoutK::Stride::LongIndex *ldk;
typename LayoutP::Stride::LongIndex *ldv;
typename LayoutO::Stride::LongIndex *ldo;
// Scale
ElementAccumulator scale;
// Whether causal masking is to be performed
bool causal;
// Only used by device-level operator
GemmCoord *host_problem_sizes;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments():
problem_count(0),
threadblock_count(0),
ptr_Q(nullptr),
ptr_K(nullptr),
ptr_P(nullptr),
ptr_V(nullptr),
ptr_O(nullptr),
ptr_O_accum(nullptr),
ldq(nullptr),
ldk(nullptr),
ldv(nullptr),
ldo(nullptr),
scale(0),
causal(false),
host_problem_sizes(nullptr)
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Arguments(
GemmCoord *problem_sizes0,
GemmCoord *problem_sizes1,
int problem_count,
int threadblock_count,
ElementQ ** ptr_Q,
ElementK ** ptr_K,
ElementP ** ptr_P,
ElementV ** ptr_V,
ElementO ** ptr_O,
ElementOAccum ** ptr_O_accum,
typename LayoutQ::Stride::LongIndex *ldq,
typename LayoutK::Stride::LongIndex *ldk,
typename LayoutP::Stride::LongIndex *ldp,
typename LayoutV::Stride::LongIndex *ldv,
typename LayoutO::Stride::LongIndex *ldo,
bool causal,
ElementAccumulator scale,
GemmCoord *host_problem_sizes=nullptr
):
problem_sizes0(problem_sizes0),
problem_sizes1(problem_sizes1),
problem_count(problem_count),
threadblock_count(threadblock_count),
ptr_Q(ptr_Q),
ptr_K(ptr_K),
ptr_P(ptr_P),
ptr_V(ptr_V),
ptr_O(ptr_O),
ptr_O_accum(kNeedsOutputAccumulatorBuffer ? ptr_O_accum : (accum_t**)ptr_O),
ldq(ldq),
ldk(ldk),
ldv(ldv),
ldo(ldo),
causal(causal),
scale(scale),
host_problem_sizes(host_problem_sizes)
{
}
bool __host__ check_supported() {
CHECK_ALIGNED_PTR(ptr_Q, kAlignmentQ);
CHECK_ALIGNED_PTR(ptr_K, kAlignmentK);
CHECK_ALIGNED_PTR(ptr_V, kAlignmentV);
XFORMERS_CHECK(ldq % kAlignmentQ == 0, "query is not correctly aligned");
XFORMERS_CHECK(ldk % kAlignmentK == 0, "key is not correctly aligned");
XFORMERS_CHECK(ldv % kAlignmentV == 0, "value is not correctly aligned");
return true;
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
typename ProblemVisitor::Params problem_visitor;
int threadblock_count;
ElementQ ** ptr_Q;
ElementK ** ptr_K;
ElementP ** ptr_P;
ElementV ** ptr_V;
ElementO ** ptr_O;
ElementOAccum ** ptr_O_accum;
typename LayoutQ::Stride::LongIndex *ldq;
typename LayoutK::Stride::LongIndex *ldk;
typename LayoutP::Stride::LongIndex *ldv;
typename LayoutO::Stride::LongIndex *ldo;
ElementAccumulator scale;
bool causal;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
ptr_Q(nullptr),
ptr_K(nullptr),
ptr_P(nullptr),
ptr_V(nullptr),
ptr_O(nullptr),
ptr_O_accum(nullptr),
ldq(nullptr),
ldk(nullptr),
ldv(nullptr),
ldo(nullptr),
causal(false),
scale(0)
{ }
CUTLASS_HOST_DEVICE
Params(Arguments const &args,
void *workspace = nullptr,
int tile_count = 0):
problem_visitor(args.problem_sizes0, args.problem_sizes1, args.problem_count, workspace, tile_count),
threadblock_count(args.threadblock_count),
ptr_Q(args.ptr_Q),
ptr_K(args.ptr_K),
ptr_P(args.ptr_P),
ptr_V(args.ptr_V),
ptr_O(args.ptr_O),
ptr_O_accum(kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O),
ldq(args.ldq),
ldk(args.ldk),
ldv(args.ldv),
ldo(args.ldo),
causal(args.causal),
scale(args.scale)
{
}
CUTLASS_HOST_DEVICE
void update(
Arguments const &args,
void *workspace = nullptr,
int tile_count = 0) {
problem_visitor = typename ProblemVisitor::Params(args.problem_sizes0,
args.problem_sizes1,
args.problem_count,
workspace, tile_count);
threadblock_count = args.threadblock_count;
ptr_Q = args.ptr_Q;
ptr_K = args.ptr_K;
ptr_P = args.ptr_P;
ptr_V = args.ptr_V;
ptr_O = args.ptr_O;
ptr_O_accum = kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O;
ldq = args.ldq;
ldk = args.ldk;
ldv = args.ldv;
ldo = args.ldo;
causal = args.causal;
scale = args.scale;
}
};
// Shared storage - depends on kernel params
struct ScalingCoefs {
cutlass::Array<ElementAccumulator, kQueriesPerBlock> m_prime;
cutlass::Array<ElementAccumulator, kQueriesPerBlock> s_prime;
cutlass::Array<ElementAccumulator, kQueriesPerBlock> mi;
cutlass::Array<ElementAccumulator, kQueriesPerBlock> out_rescale;
cutlass::Array<ElementAccumulator, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
addition_storage;
};
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si;
typename MM1::Mma::SharedStorage mm1;
};
union {
typename MM0::Mma::SharedStorage mm0;
SharedStorageAfterMM0 after_mm0;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
epilogue_shared_storage() {
return epilogue;
}
// ProblemVisitor shared storage can't be overlapped with others
typename ProblemVisitor::SharedStorage problem_visitor;
};
struct SharedStorageEpilogueInLoop : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si;
typename MM1::Mma::SharedStorage mm1;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
union {
typename MM0::Mma::SharedStorage mm0;
SharedStorageAfterMM0 after_mm0;
};
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
epilogue_shared_storage() {
return after_mm0.epilogue;
}
// ProblemVisitor shared storage can't be overlapped with others
typename ProblemVisitor::SharedStorage problem_visitor;
};
using SharedStorage = typename cutlass::platform::conditional<
kKeepOutputInRF,
SharedStorageEpilogueAtEnd,
SharedStorageEpilogueInLoop>::type;
private:
// Parameters to be used by an individual tile
struct TileParams {
CUTLASS_HOST_DEVICE
static int query_start(int threadblock_idx) {
return threadblock_idx * kQueriesPerBlock;
}
// Returns whether this threadblock computes within the number of queries,
// which is determined by the M dimension of problem 0
CUTLASS_HOST_DEVICE
static bool can_compute(int threadblock_idx, const GemmCoord& problem_size0) {
return query_start(threadblock_idx) < problem_size0.m();
}
CUTLASS_HOST_DEVICE
static int num_queries(int threadblock_idx, const GemmCoord& problem_size0) {
return problem_size0.m() - query_start(threadblock_idx);
}
CUTLASS_HOST_DEVICE
static int num_keys(int threadblock_idx, const GemmCoord& problem_size0, bool causal) {
int nk = problem_size0.n();
if (causal) {
nk = cutlass::fast_min(int32_t(query_start(threadblock_idx) + kQueriesPerBlock), nk);
}
return nk;
}
};
public:
//
// Methods
//
CUTLASS_DEVICE
FMHAGrouped() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return Status::kSuccess;
}
static CUTLASS_DEVICE int16_t thread_id() {
return threadIdx.x;
}
static CUTLASS_DEVICE int8_t warp_id() {
return threadIdx.x / kThreadsPerWarp;
}
static CUTLASS_DEVICE int8_t lane_id() {
return threadIdx.x % kThreadsPerWarp;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
auto& m_prime = shared_storage.m_prime;
auto& s_prime = shared_storage.s_prime;
[[maybe_unused]] auto& si = shared_storage.after_mm0.si;
auto& mi = shared_storage.mi;
auto& out_rescale = shared_storage.out_rescale;
ProblemVisitor problem_visitor(
params.problem_visitor,
shared_storage.problem_visitor,
blockIdx.x);
// Outer 'persistent' loop to iterate over tiles
while (problem_visitor.next_tile()) {
GemmCoord problem_size0 = problem_visitor.problem_size0();
GemmCoord problem_size1 = problem_visitor.problem_size1();
const int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
if (!TileParams::can_compute(threadblock_idx, problem_size0)) {
problem_visitor.advance(gridDim.x);
continue;
}
const int32_t problem_idx = problem_visitor.problem_index();
if (thread_id() < kQueriesPerBlock) {
s_prime[thread_id()] = ElementAccumulator(0);
out_rescale[thread_id()] = accum_t(1.0);
m_prime[thread_id()] =
-cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
mi[thread_id()] = -cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
}
ElementO *ptr_O = params.ptr_O[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx];
ElementOAccum *ptr_O_accum = params.ptr_O_accum[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx];
const int num_queries = TileParams::num_queries(threadblock_idx, problem_size0);
auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
using OutputTileIterator = typename MM1::OutputTileIterator;
return OutputTileIterator(
typename OutputTileIterator::Params{(int32_t)params.ldo[problem_idx]},
ptr_O,
typename OutputTileIterator::TensorCoord{
num_queries, problem_size1.n()},
thread_id(),
{0, col});
};
auto createOutputAccumIter = [&](int col) ->
typename MM1::OutputTileIteratorAccum {
using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
return OutputTileIteratorAccum(
typename OutputTileIteratorAccum::Params{(int32_t)params.ldo[problem_idx]},
ptr_O_accum,
typename OutputTileIteratorAccum::TensorCoord{
num_queries, problem_size1.n()},
thread_id(),
{0, col});
};
typename MM1::Mma::FragmentC accum_o;
accum_o.clear();
const int num_keys = TileParams::num_keys(threadblock_idx, problem_size0, params.causal);
for (int32_t iter_key_start = 0; iter_key_start < num_keys;
iter_key_start += kKeysPerBlock) {
int32_t problem_size_0_m =
cutlass::fast_min((int32_t)kQueriesPerBlock, num_queries);
int32_t problem_size_0_n = cutlass::fast_min(
(int32_t)kKeysPerBlock, num_keys - iter_key_start);
int32_t const& problem_size_0_k = problem_size0.k();
int32_t const& problem_size_1_n = problem_size1.n();
int32_t const& problem_size_1_k = problem_size_0_n;
auto prologueV = [&](int blockN) {
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])},
params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx],
{problem_size_1_k, problem_size_1_n},
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
MM1::Mma::prologue(
shared_storage.after_mm0.mm1,
iterator_V,
thread_id(),
problem_size_1_k);
};
__syncthreads(); // Need to have shared memory initialized, and `m_prime`
// updated from end of prev iter
//
// MATMUL: Q.K_t
//
// Computes the block-matrix product of:
// (a) query[query_start:query_end, :]
// with
// (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
// and stores that into `shared_storage.si`
//
ElementQ *ptr_Q = params.ptr_Q[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldq[problem_idx];
// Construct iterators to A and B operands
typename MM0::IteratorA iterator_A(
typename MM0::IteratorA::Params(
typename MM0::MmaCore::LayoutA(params.ldq[problem_idx])),
ptr_Q,
{problem_size_0_m, problem_size_0_k},
thread_id(),
{0, 0});
typename MM0::IteratorB iterator_B(
typename MM0::IteratorB::Params(
typename MM0::MmaCore::LayoutB(params.ldk[problem_idx])),
params.ptr_K[problem_idx] + iter_key_start * params.ldk[problem_idx],
{problem_size_0_k, problem_size_0_n},
thread_id(),
{0, 0});
// Construct thread-scoped matrix multiply
typename MM0::Mma mma(
shared_storage.mm0, thread_id(), warp_id(), lane_id());
typename MM0::Mma::FragmentC accum;
accum.clear();
auto gemm_k_iterations =
(problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
__syncthreads();
if (kPreloadV) {
prologueV(0);
} else {
MM1::Mma::drain_cp_asyncs();
}
typename MM0::Mma::Operator::IteratorC::TensorCoord
iteratorC_tile_offset = {
(warp_id() % MM0::Mma::WarpCount::kM),
(warp_id() / MM0::Mma::WarpCount::kM)
};
// Mask out last if causal
if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) {
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset);
int32_t last_col;
MM0::AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
last_col = TileParams::query_start(threadblock_idx) + accum_m - iter_key_start;
},
[&](int accum_m, int accum_n, int idx) {
if (accum_n > last_col) {
accum[idx] =
-cutlass::platform::numeric_limits<accum_t>::infinity();
}
},
[&](int accum_m) {});
}
// DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
// DISPATCH_BOOL(
// num_keys - iter_key_start >= kKeysPerBlock,
// kFullColumns,
// ([&] {
// // Update `mi` from accum stored in registers
// // Also does accum[i] <- exp(accum[i] - mi)
// iterative_softmax<
// typename MM0::Mma::Operator::IteratorC,
// kFullColumns,
// kIsFirst>(
// accum_o,
// accum,
// mi,
// m_prime,
// s_prime,
// lane_id(),
// thread_id(),
// warp_id(),
// num_keys - iter_key_start,
// iteratorC_tile_offset,
// kSupportsBias ? 1.0f : params.scale);
// }));
// }));
// Update `mi` from accum stored in registers
// Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
accum_o,
accum,
mi,
m_prime,
s_prime,
out_rescale,
shared_storage.addition_storage,
lane_id(),
thread_id(),
warp_id(),
num_keys - iter_key_start,
iter_key_start == 0,
iteratorC_tile_offset,
kSupportsBias ? 1.0f : params.scale);
// Output results to shared-memory
int warp_idx_mn_0 = warp_id() %
(MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
auto output_tile_coords = cutlass::MatrixCoord{
warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
MM0::B2bGemm::accumToSmem(
shared_storage.after_mm0.si, accum, lane_id(), output_tile_coords);
__syncthreads();
//
// MATMUL: Attn . V
// Run the matmul `attn @ V` for a block of attn and V.
// `attn` is read from shared memory (in `shared_storage_si`)
// `V` is read from global memory (with iterator_B)
//
const int64_t nBlockN = kKeepOutputInRF ? 1
: ceil_div(
(int64_t)problem_size_1_n,
int64_t(MM1::ThreadblockShape::kN));
// Iterate over the N dimension of GEMM1
for (int blockN = 0; blockN < nBlockN; ++blockN) {
int gemm_k_iterations =
(problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add and store it in accum
// (in registers)
if (!kPreloadV) {
__syncthreads(); // we share shmem between mma and epilogue
}
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])},
params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx],
{problem_size_1_k, problem_size_1_n},
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
typename MM1::Mma mma_pv(
// operand A: Pij_dropped in shared memory
shared_storage.after_mm0.si.accum_ref(),
// operand B: shared memory staging area for Vj, which is loaded
// from global memory
shared_storage.after_mm0.mm1.operand_B_ref(),
(int)thread_id(),
(int)warp_id(),
(int)lane_id());
mma_pv.set_prologue_done(kPreloadV);
if (!kKeepOutputInRF) {
accum_o.clear();
}
mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
__syncthreads();
if (kPreloadV && !kKeepOutputInRF && blockN + 1 < nBlockN) {
prologueV(blockN + 1);
}
if (!kKeepOutputInRF) {
MM1::Mma::drain_cp_asyncs();
DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
(iter_key_start + kKeysPerBlock) >= num_keys,
kIsLast,
([&] {
using DefaultEpilogue = typename MM1::DefaultEpilogue;
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
using ElementCompute = typename DefaultOp::ElementCompute;
using EpilogueOutputOp = typename cutlass::epilogue::
thread::MemoryEfficientAttentionNormalize<
typename cutlass::platform::conditional<
kIsLast,
output_t,
output_accum_t>::type,
output_accum_t,
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator,
output_accum_t,
kIsFirst,
kIsLast,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue = typename cutlass::epilogue::threadblock::
EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename cutlass::platform::conditional<
kIsLast,
typename MM1::OutputTileIterator,
typename MM1::OutputTileIteratorAccum>::type,
typename DefaultEpilogue::
AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true, // IterationsUnroll
typename MM1::OutputTileIteratorAccum // Read
// iterator
>;
int col = blockN * MM1::Mma::Shape::kN;
auto source_iter = createOutputAccumIter(col);
auto dest_iter = gemm_kernel_utils::call_conditional<
kIsLast,
decltype(createOutputIter),
decltype(createOutputAccumIter)>::
apply(createOutputIter, createOutputAccumIter, col);
EpilogueOutputOp rescale(s_prime, out_rescale);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
epilogue(rescale, dest_iter, accum_o, source_iter);
}));
}));
if (!kKeepOutputInRF) {
__syncthreads();
}
}
}
__syncthreads(); // we modify `m_prime` after
}
if (kKeepOutputInRF) {
const bool kIsFirst = true;
const bool kIsLast = true;
using DefaultEpilogue = typename MM1::DefaultEpilogue;
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
using ElementCompute = typename DefaultOp::ElementCompute;
using EpilogueOutputOp =
typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
output_t, // output
output_accum_t, // source
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator, // accum
output_accum_t, // compute
kIsFirst,
kIsLast,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue =
typename cutlass::epilogue::threadblock::EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename MM1::OutputTileIterator, // destination
typename DefaultEpilogue::AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true, // IterationsUnroll
typename MM1::OutputTileIteratorAccum // source tile
>;
auto dest_iter = createOutputIter(0);
EpilogueOutputOp rescale(s_prime, out_rescale);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
MM1::Mma::drain_cp_asyncs();
epilogue(rescale, dest_iter, accum_o);
}
// Next tile
problem_visitor.advance(gridDim.x);
__syncthreads(); // Don't start the next iteration until all threads are done using shared memory.
}
}
template <typename WarpIteratorC>
CUTLASS_DEVICE static void iterative_softmax(
typename WarpIteratorC::Fragment& frag_o, // output so far
typename WarpIteratorC::Fragment& frag,
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
addition_storage,
int8_t lane_id,
int8_t thread_id,
int8_t warp_id,
int max_col,
bool is_first,
typename WarpIteratorC::TensorCoord const& tile_offset,
float scaling) {
/* Iterates on the accumulator and corresponding position on result matrix
(1) Update `mi[r]` to the max value of the row `r`
(2) In a second iteration do the following:
(a) accum <- exp(accum - mi)
(b) m_prime <- exp(m_prime - mi)
(c) s_prime <- s_prime * m_prime + sum(accum)
All of this is done on registers, before we store all of this
on shared memory for the next matmul with Value.
*/
using Fragment = typename WarpIteratorC::Fragment;
using LambdaIterator = typename DefaultMmaAccumLambdaIterator<
WarpIteratorC,
accum_t,
kThreadsPerWarp>::Iterator;
// Convert to `accum_t` (rather than double)
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
auto lane_offset =
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
// First update `mi` to the max per-row
{
accum_t max;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
},
[&](int accum_m, int accum_n, int idx) {
if (accum_n < max_col) {
max = cutlass::fast_max(max, frag[idx]);
}
},
[&](int accum_m) {
// Having 4x atomicMax seems faster than reduce within warp
// first...
atomicMaxFloat(&mi[accum_m], max);
});
}
// Make sure we all share the update values for `mi`
__syncthreads();
// Doing this `exp` is quite expensive. Let's
// split it across the warps
bool restore_mi_to_minus_inf = false;
if (lane_id < kLinesPerWarp) {
int id = warp_id * kLinesPerWarp + lane_id;
auto m_prime_id = m_prime[id];
auto mi_id = mi[id];
bool changed = m_prime_id < mi_id; // `false` if both are -inf
if (changed) {
auto m_prime_exp = exp2f(m_prime_id - mi_id);
out_rescale[id] = m_prime_exp;
s_prime[id] *= m_prime_exp;
} else {
// Only when bias is enabled, it's possible that all the first values
// of attention are masked to `-inf`. In that case we want to avoid
// `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
if (kSupportsBias &&
mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
restore_mi_to_minus_inf = true;
mi[id] = 0.0f;
}
out_rescale[id] = 1.0f;
}
}
__syncthreads(); // Update output fragments
if (kKeepOutputInRF && !is_first) {
accum_t line_rescale;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { line_rescale = out_rescale[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag_o[idx] = frag_o[idx] * line_rescale;
},
[&](int accum_m) {});
}
// Update accum_m, accum_n, ...
{
accum_t mi_row, total_row;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mi_row = mi[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag[idx] =
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
},
[&](int accum_m) {});
LambdaIterator::iterateRows(
lane_offset,