-
Notifications
You must be signed in to change notification settings - Fork 0
/
functional.h
627 lines (538 loc) · 15.1 KB
/
functional.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Define basic numeric operators
This is inspired by the Standard Library's <functional> header.
*/
/*
Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain
existing integrations of CUTLASS require C++11 host compilers.
Until this requirement can be lifted, certain headers with this annotation are required
to be remain consistent with C++11 syntax.
C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/half.h"
#include "cutlass/tfloat32.h"
#include "cutlass/bfloat16.h"
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
#include <mma.h>
#endif // defined(CUTLASS_ARCH_WMMA_ENABLED)
#ifdef _MSC_VER
// Provides support for alternate operators such as 'and', 'or', ...
#include <iso646.h>
#endif // _MSC_VER
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct absolute_value_op {
CUTLASS_HOST_DEVICE
T operator()(T lhs) const {
return abs(lhs);
}
};
template <>
struct absolute_value_op<float> {
CUTLASS_HOST_DEVICE
float operator()(float lhs) const { return fabs(lhs); }
};
template <typename T>
struct plus {
CUTLASS_HOST_DEVICE
T operator()(T lhs, T const &rhs) const {
lhs += rhs;
return lhs;
}
};
template <typename T>
struct minus {
CUTLASS_HOST_DEVICE
T operator()(T lhs, T const &rhs) const {
lhs -= rhs;
return lhs;
}
};
template <typename T>
struct multiplies {
CUTLASS_HOST_DEVICE
T operator()(T lhs, T const &rhs) const {
lhs *= rhs;
return lhs;
}
};
template <typename T>
struct scale {
T const scaling_factor_;
CUTLASS_HOST_DEVICE
scale(float scaling_factor) : scaling_factor_(scaling_factor) {
}
T operator()(T const &rhs) const {
T result = rhs * scaling_factor_;
return result;
}
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
/// Partial specializations needed when __CUDA_NO_HALF2_OPERATORS__ is set
template<>
struct plus<__half2> {
CUTLASS_HOST_DEVICE
__half2 operator()(__half2 lhs, __half2 const &rhs) const {
return __hadd2(lhs, rhs);
}
};
template<>
struct minus<__half2> {
CUTLASS_HOST_DEVICE
__half2 operator()(__half2 lhs, __half2 const &rhs) const {
return __hsub2(lhs, rhs);
}
};
template<>
struct multiplies<__half2> {
CUTLASS_HOST_DEVICE
__half2 operator()(__half2 lhs, __half2 const &rhs) const {
return __hmul2(lhs, rhs);
}
};
/// Partial specializations needed when __CUDA_NO_HALF_OPERATORS__ is set
template<>
struct plus<__half> {
CUTLASS_HOST_DEVICE
__half operator()(__half lhs, __half const &rhs) const {
return __hadd(lhs, rhs);
}
};
template<>
struct minus<__half> {
CUTLASS_HOST_DEVICE
__half operator()(__half lhs, __half const &rhs) const {
return __hsub(lhs, rhs);
}
};
template<>
struct multiplies<__half> {
CUTLASS_HOST_DEVICE
__half operator()(__half lhs, __half const &rhs) const {
return __hmul(lhs, rhs);
}
};
#endif // defined(__CUDA_ARCH__)
/// Squares with optional conversion
template <typename T, typename Output = T>
struct square {
CUTLASS_HOST_DEVICE
Output operator()(T lhs) const {
multiplies<Output> mul_op;
Output y = Output(lhs);
return mul_op(y, y);
}
};
/// Returns the magnitude squared of an element.
template <typename T, typename Output = T>
struct magnitude_squared {
CUTLASS_HOST_DEVICE
Output operator()(T lhs) const {
multiplies<Output> mul_op;
Output y = Output(lhs);
return mul_op(y, y);
}
};
/// Computes the square of a difference with optional conversion
template <typename T, typename Output = T>
struct square_difference {
CUTLASS_HOST_DEVICE
Output operator()(T lhs, T rhs) const {
multiplies<Output> mul_op;
Output y = Output(lhs) - Output(rhs);
return mul_op(y, y);
}
};
/// Computes the square of a difference with optional conversion
template <typename T, typename Output = T>
struct magnitude_squared_difference {
CUTLASS_HOST_DEVICE
Output operator()(T lhs, T rhs) const {
multiplies<Output> mul_op;
Output y = Output(lhs) - Output(rhs);
return mul_op(y, y);
}
};
/// Divides
template <typename T>
struct divides {
CUTLASS_HOST_DEVICE
T operator()(T lhs, T const &rhs) const {
lhs /= rhs;
return lhs;
}
};
/// Negate
template <typename T>
struct negate {
CUTLASS_HOST_DEVICE
T operator()(T lhs) const {
return -lhs;
}
};
/// Greater equal
template <typename T>
struct greater_equal {
CUTLASS_HOST_DEVICE
bool operator()(T const &lhs, T const &rhs) const {
return (lhs >= rhs);
}
};
/// Greater
template <typename T>
struct greater {
CUTLASS_HOST_DEVICE
bool operator()(T const &lhs, T const &rhs) const {
return (lhs > rhs);
}
};
/// Less equal
template <typename T>
struct less_equal {
CUTLASS_HOST_DEVICE
bool operator()(T const &lhs, T const &rhs) const {
return (lhs <= rhs);
}
};
/// Less
template <typename T>
struct less {
CUTLASS_HOST_DEVICE
bool operator()(T const &lhs, T const &rhs) const {
return (lhs < rhs);
}
};
template <typename T, bool PropogateNaN = false>
struct maximum {
CUTLASS_HOST_DEVICE
T operator()(T const &lhs, T const &rhs) const {
return (lhs < rhs ? rhs : lhs);
}
};
// Maximum with nan propogation
// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN
template <typename T>
struct maximum<T, true> {
CUTLASS_HOST_DEVICE
T operator()(T const &lhs, T const &rhs) const {
#if defined(__CUDA_ARCH__)
return lhs > rhs or isnan(lhs) ? lhs : rhs;
#else
return lhs > rhs or std::isnan(lhs) ? lhs : rhs;
#endif
}
};
template <>
struct maximum<float, false> {
CUTLASS_HOST_DEVICE
float operator()(float const &lhs, float const &rhs) const {
return fmaxf(lhs, rhs);
}
};
template <>
struct maximum<float, true> {
CUTLASS_HOST_DEVICE
float operator()(float const lhs, float const rhs) const {
float res;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs));
#elif defined(__CUDA_ARCH__)
res = lhs > rhs or isnan(lhs) ? lhs : rhs;
#else
res = lhs > rhs or std::isnan(lhs) ? lhs : rhs;
#endif
return res;
}
};
template <typename T>
using maximum_with_nan_propogation = maximum<T, true>;
template <typename T, bool PropogateNaN = false>
struct minimum{
CUTLASS_HOST_DEVICE
T operator()(T const &lhs, T const &rhs) const {
return (rhs < lhs ? rhs : lhs);
}
};
template <typename T>
struct minimum<T, true> {
CUTLASS_HOST_DEVICE
T operator()(T const &lhs, T const &rhs) const {
#if defined(__CUDA_ARCH__)
return lhs < rhs or isnan(lhs) ? lhs : rhs;
#else
return lhs < rhs or std::isnan(lhs) ? lhs : rhs;
#endif
}
};
template <>
struct minimum<float, false> {
CUTLASS_HOST_DEVICE
float operator()(float const &lhs, float const &rhs) const {
return fminf(lhs, rhs);
}
};
template <typename T, bool PropogateNaN = false>
struct maximum_absolute_value {
CUTLASS_HOST_DEVICE
float operator()(T const &lhs, T const &rhs) const {
absolute_value_op<T> abs_op;
maximum<T, PropogateNaN> max_op;
return max_op(abs_op(lhs), abs_op(rhs));
}
};
// assumes the left operand is already an absolute value
template <typename T, bool PropogateNaN = false>
struct maximum_absolute_value_reduction {
CUTLASS_HOST_DEVICE
float operator()(T const &lhs, T const &rhs) const {
absolute_value_op<T> abs_op;
maximum<T, PropogateNaN> max_op;
return max_op(lhs, abs_op(rhs));
}
};
/// Fused multiply-add
template <typename A, typename B = A, typename C = A>
struct multiply_add {
CUTLASS_HOST_DEVICE
C operator()(A const &a, B const &b, C const &c) const {
return C(a) * C(b) + c;
}
};
/// Fused multiply-add
template <typename A, typename B = A, typename C = A>
struct multiply_add_relu0 {
CUTLASS_HOST_DEVICE
C operator()(A const &a, B const &b, C const &c) const {
maximum<C> mx;
return mx(C(a) * C(b) + c, C(0));
}
};
/// Fused multiply-add
template <typename T>
struct and_add {
CUTLASS_HOST_DEVICE
T operator()(T const &a, T const &b, T const &c) const {
return ((a & b) + c);
}
};
/// Fused multiply-add
template <typename T>
struct xor_add {
CUTLASS_HOST_DEVICE
T operator()(T const &a, T const &b, T const &c) const {
return ((a ^ b) + c);
}
};
template <typename T>
struct conjugate {
CUTLASS_HOST_DEVICE
T operator()(T const &a) const {
return a;
}
};
template <typename T>
struct first {
CUTLASS_HOST_DEVICE
T operator()(T const & first, T const &...) const {
return first;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct logical_and {
CUTLASS_HOST_DEVICE
T operator()(T const &a, T const &b) const {
return ((a && b) ? T(1) : T());
}
};
template <typename T>
struct logical_or {
CUTLASS_HOST_DEVICE
T operator()(T const &a, T const &b) const {
return ((a || b) ? T(1) : T());
}
};
template <typename T>
struct logical_not {
CUTLASS_HOST_DEVICE
T operator()(T const &a) const {
return T(!(a));
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct bit_and {
CUTLASS_HOST_DEVICE
T operator()(T const &a, T const &b) const {
return a & b;
}
};
template <typename T>
struct bit_or {
CUTLASS_HOST_DEVICE
T operator()(T const &a, T const &b) const {
return a | b;
}
};
template <typename T>
struct bit_not {
CUTLASS_HOST_DEVICE
T operator()(T const &a) const {
return ~a;
}
};
template <typename T>
struct bit_xor {
CUTLASS_HOST_DEVICE
T operator()(T const &a, T const &b) const {
return a ^ b;
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////
/// Atomic reductions
template <typename T>
struct atomic_add
{
CUTLASS_DEVICE
void operator()(T *ptr, const T &data)
{
#if defined(__CUDA_ARCH__)
atomicAdd(ptr, data);
#endif
}
};
template<>
struct atomic_add<double>
{
CUTLASS_DEVICE
void operator()(double *ptr, const double &data)
{
#if !defined(__CUDA_ARCH__)
CUTLASS_UNUSED(ptr);
CUTLASS_UNUSED(data);
#elif (__CUDA_ARCH__ >= 600)
atomicAdd(ptr, data);
#else
// Use CAS loop
unsigned long long int* ptr_int = reinterpret_cast<unsigned long long int*>(ptr);
unsigned long long int old_int = *ptr_int;
unsigned long long int assumed_int;
do {
double update = data + __longlong_as_double(old_int);
assumed_int = old_int;
old_int = atomicCAS(ptr_int, assumed_int, __double_as_longlong(update));
} while (assumed_int != old_int);
#endif // (__CUDA_ARCH__ >= 600)
}
};
template<>
struct atomic_add<half2>
{
CUTLASS_DEVICE
void operator()(half2 *ptr, const half2 &data)
{
#if !defined(__CUDA_ARCH__) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600))
CUTLASS_UNUSED(ptr);
CUTLASS_UNUSED(data);
#else
// Vector-2 atomic reduction requires .target sm_60 or higher
uint32_t word = reinterpret_cast<const uint32_t&>(data);
asm volatile ("red.gpu.global.add.noftz.f16x2 [%0], %1;\n" : : "l"(ptr), "r"(word));
#endif // (__CUDA_ARCH__ >= 600)
}
};
template <typename T>
using red [[deprecated("use atomic_add instead")]] = atomic_add<T>;
template <typename T>
struct atomic_maximum {
CUTLASS_DEVICE
T operator()(T *ptr, T value) const {
#if defined(__CUDA_ARCH__)
return atomicMax(ptr, value);
#else
CUTLASS_UNUSED(ptr);
CUTLASS_UNUSED(value);
CUTLASS_NOT_IMPLEMENTED();
return 0;
#endif
}
};
template <>
struct atomic_maximum<float> {
CUTLASS_DEVICE
float operator()(float *ptr, float value) const {
#if defined(__CUDA_ARCH__)
return !signbit(value) ?
__int_as_float(atomicMax((int*)ptr, __float_as_int(value))) :
__uint_as_float(atomicMin((unsigned int*)ptr, __float_as_uint(value)));
#else
CUTLASS_UNUSED(ptr);
CUTLASS_UNUSED(value);
CUTLASS_NOT_IMPLEMENTED();
return 0;
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Partial specializations for nvcuda::wmma::fragment<Use, m, n, k, T, Layout>
//
/////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
template<typename Use, int m, int n, int k, typename T, typename Layout>
struct plus<nvcuda::wmma::fragment<Use, m, n, k, T, Layout>>
{
using Fragment = nvcuda::wmma::fragment<Use, m, n, k, T, Layout>;
using ElementType = typename Fragment::element_type;
CUTLASS_HOST_DEVICE
Fragment operator()(Fragment const &lhs, Fragment const &rhs) const
{
Fragment result;
plus<ElementType> scalar_op;
ElementType *result_elts = reinterpret_cast<ElementType*>(&result);
const ElementType *lhs_elts = reinterpret_cast<const ElementType*>(&lhs);
const ElementType *rhs_elts = reinterpret_cast<const ElementType*>(&rhs);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Fragment::num_elements; i++) {
result_elts[i] = scalar_op(lhs_elts[i], rhs_elts[i]);
}
return result;
}
};
#endif // defined(CUTLASS_ARCH_WMMA_ENABLED)
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////