-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathXTensor2.cu
4097 lines (3544 loc) · 134 KB
/
XTensor2.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/* NiuTrans - an open-source MT toolkit
* Copyright (C) 2017, Natural Language Processing Lab. All rights reserved.
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* You should have received a copy of the GNU General Public
* License along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
*/
/*
* $Id:
* XTensor; XTensor.h
* implementation of tensors used in this work. It it is the basis of XMatrix and
* XVector
*
* $Version:
* 0.1.0
*
* $Created by:
* XIAO Tong (email: [email protected]) 2017-07-31
* $Update by:
* LI Yinqiao (email: [email protected]) 2017-11-18 bug fixes
*
*/
#define WORKERSNUM 256
#include "XTensor.h"
#include "XTensor.cuh"
#include "XDevice.h"
#include "XHeap.h"
#ifdef USE_CUDA
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
/************************************************************
* basic kernels
*/
/*
matrix multiplication via cuda version BLAS
*/
void CudaBLASMatrixMUL(cublasHandle_t * handle,
void * a, MATRIX_TRANS_TYPE transposedA, MATRIX_DATA_TYPE dataTypeA,
void * b, MATRIX_TRANS_TYPE transposedB, MATRIX_DATA_TYPE dataTypeB,
void * c, MATRIX_DATA_TYPE dataTypeC,
int na, int ma, int nb, int mb, int nc, int mc,
DTYPE alpha, DTYPE beta)
{
/*
matrxi-matrix multiplication
For row-major matrices (as in c/c++), the trick used here is (AB)^T = B^T * A^T
*/
if (dataTypeA == X_DOUBLE && dataTypeB == X_DOUBLE && dataTypeC == X_DOUBLE) {
double alpha2 = (double)alpha;
double beta2 = (double)beta;
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasDgemm(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, &alpha2, (const double*)b, mb, (const double*)a, ma, &beta2, (double*)c, mc);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasDgemm(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, &alpha2, (const double*)b, mb, (const double*)a, ma, &beta2, (double*)c, mc);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasDgemm(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, &alpha2, (const double*)b, mb, (const double*)a, ma, &beta2, (double*)c, mc);
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasDgemm(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, &alpha2, (const double*)b, nb, (const double*)a, ma, &beta2, (double*)c, mc);
}
else if (dataTypeA == X_FLOAT && dataTypeB == X_FLOAT && dataTypeC == X_FLOAT) {
float alpha2 = (float)alpha;
float beta2 = (float)beta;
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasSgemm(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, &alpha2, (const float*)b, mb, (const float*)a, ma, &beta2, (float*)c, mc);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasSgemm(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, &alpha2, (const float*)b, mb, (const float*)a, ma, &beta2, (float*)c, mc);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasSgemm(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, &alpha2, (const float*)b, mb, (const float*)a, ma, &beta2, (float*)c, mc);
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasSgemm(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, &alpha2, (const float*)b, mb, (const float*)a, ma, &beta2, (float*)c, mc);
}
else if (dataTypeA == X_FLOAT16 && dataTypeB == X_FLOAT16 && dataTypeC == X_FLOAT16) {
unsigned short alpha2 = FloatToFloat16(alpha);
unsigned short beta2 = FloatToFloat16(beta);
__half * alpha3 = (__half*)&alpha2;
__half * beta3 = (__half*)&beta2;
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasHgemm(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, alpha3, (const __half*)b, mb, (const __half*)a, ma, beta3, (__half*)c, mc);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasHgemm(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, alpha3, (const __half*)b, mb, (const __half*)a, ma, beta3, (__half*)c, mc);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasHgemm(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, alpha3, (const __half*)b, mb, (const __half*)a, ma, beta3, (__half*)c, mc);
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasHgemm(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, alpha3, (const __half*)b, mb, (const __half*)a, ma, beta3, (__half*)c, mc);
}
else {
ShowNiuTransErrors("Unsupported data type!");
}
}
/*
matrix multiplication via cuda version BLAS
*/
void CudaBLASMatrixMULBatched(cublasHandle_t * handle,
const void ** a, MATRIX_TRANS_TYPE transposedA, MATRIX_DATA_TYPE dataTypeA,
const void ** b, MATRIX_TRANS_TYPE transposedB, MATRIX_DATA_TYPE dataTypeB,
void ** c, MATRIX_DATA_TYPE dataTypeC,
int count, int na, int ma, int nb, int mb, int nc, int mc,
DTYPE alpha, DTYPE beta)
{
/*
matrxi-matrix multiplication
For row-major matrices (as in c/c++), the trick used here is (AB)^T = B^T * A^T
*/
if (dataTypeA == X_DOUBLE && dataTypeB == X_DOUBLE && dataTypeC == X_DOUBLE) {
double alpha2 = (double)alpha;
double beta2 = (double)beta;
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasDgemmBatched(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, &alpha2, (const double**)b, mb, (const double**)a, ma, &beta2, (double**)c, mc, count);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasDgemmBatched(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, &alpha2, (const double**)b, mb, (const double**)a, ma, &beta2, (double**)c, mc, count);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasDgemmBatched(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, &alpha2, (const double**)b, mb, (const double**)a, ma, &beta2, (double**)c, mc, count);
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasDgemmBatched(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, &alpha2, (const double**)b, nb, (const double**)a, ma, &beta2, (double**)c, mc, count);
}
else if (dataTypeA == X_FLOAT && dataTypeB == X_FLOAT && dataTypeC == X_FLOAT) {
float alpha2 = (float)alpha;
float beta2 = (float)beta;
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasSgemmBatched(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, &alpha2, (const float**)b, mb, (const float**)a, ma, &beta2, (float**)c, mc, count);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasSgemmBatched(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, &alpha2, (const float**)b, mb, (const float**)a, ma, &beta2, (float**)c, mc, count);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasSgemmBatched(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, &alpha2, (const float**)b, mb, (const float**)a, ma, &beta2, (float**)c, mc, count);
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasSgemmBatched(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, &alpha2, (const float**)b, mb, (const float**)a, ma, &beta2, (float**)c, mc, count);
}
else if (dataTypeA == X_FLOAT16 && dataTypeB == X_FLOAT16 && dataTypeC == X_FLOAT16) {
unsigned short alpha2 = FloatToFloat16(alpha);
unsigned short beta2 = FloatToFloat16(beta);
__half * alpha3 = (__half*)&alpha2;
__half * beta3 = (__half*)&beta2;
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasHgemmBatched(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, alpha3, (const __half**)b, mb, (const __half**)a, ma, beta3, (__half**)c, mc, count);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasHgemmBatched(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, alpha3, (const __half**)b, mb, (const __half**)a, ma, beta3, (__half**)c, mc, count);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasHgemmBatched(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, alpha3, (const __half**)b, mb, (const __half**)a, ma, beta3, (__half**)c, mc, count);
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasHgemmBatched(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, alpha3, (const __half**)b, mb, (const __half**)a, ma, beta3, (__half**)c, mc, count);
}
else {
ShowNiuTransErrors("Unsupported data type!");
}
}
/* matrix multiplication in batch and strided mode via cuda version BLAS */
extern "C"
void CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle,
const void * a, MATRIX_TRANS_TYPE transposedA, MATRIX_DATA_TYPE dataTypeA, long long int strideA,
const void * b, MATRIX_TRANS_TYPE transposedB, MATRIX_DATA_TYPE dataTypeB, long long int strideB,
void * c, MATRIX_DATA_TYPE dataTypeC, long long int strideC,
int count, int na, int ma, int nb, int mb, int nc, int mc,
DTYPE alpha, DTYPE beta)
{
/*
matrxi-matrix multiplication
For row-major matrices (as in c/c++), the trick used here is (AB)^T = B^T * A^T
*/
if (dataTypeA == X_DOUBLE && dataTypeB == X_DOUBLE && dataTypeC == X_DOUBLE) {
double alpha2 = (double)alpha;
double beta2 = (double)beta;
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasDgemmStridedBatched(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, &alpha2, (const double*)b, mb, strideB, (const double*)a, ma, strideA, &beta2, (double*)c, mc, strideC, count);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasDgemmStridedBatched(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, &alpha2, (const double*)b, mb, strideB, (const double*)a, ma, strideA, &beta2, (double*)c, mc, strideC, count);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasDgemmStridedBatched(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, &alpha2, (const double*)b, mb, strideB, (const double*)a, ma, strideA, &beta2, (double*)c, mc, strideC, count);
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasDgemmStridedBatched(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, &alpha2, (const double*)b, nb, strideB, (const double*)a, ma, strideA, &beta2, (double*)c, mc, strideC, count);
}
else if (dataTypeA == X_FLOAT && dataTypeB == X_FLOAT && dataTypeC == X_FLOAT) {
float alpha2 = (float)alpha;
float beta2 = (float)beta;
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasSgemmStridedBatched(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, &alpha2, (const float*)b, mb, strideB, (const float*)a, ma, strideA, &beta2, (float*)c, mc, strideC, count);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasSgemmStridedBatched(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, &alpha2, (const float*)b, mb, strideB, (const float*)a, ma, strideA, &beta2, (float*)c, mc, strideC, count);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasSgemmStridedBatched(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, &alpha2, (const float*)b, mb, strideB, (const float*)a, ma, strideA, &beta2, (float*)c, mc, strideC, count);
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasSgemmStridedBatched(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, &alpha2, (const float*)b, mb, strideB, (const float*)a, ma, strideA, &beta2, (float*)c, mc, strideC, count);
}
else if (dataTypeA == X_FLOAT16 && dataTypeB == X_FLOAT16 && dataTypeC == X_FLOAT16) {
unsigned short alpha2 = FloatToFloat16(alpha);
unsigned short beta2 = FloatToFloat16(beta);
__half * alpha3 = (__half*)&alpha2;
__half * beta3 = (__half*)&beta2;
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasHgemmStridedBatched(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, (const __half*)alpha3, (const __half*)b, mb, strideB, (const __half*)a, ma, strideA, (const __half*)beta3, (__half*)c, mc, strideC, count);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasHgemmStridedBatched(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, (const __half*)alpha3, (const __half*)b, mb, strideB, (const __half*)a, ma, strideA, (const __half*)beta3, (__half*)c, mc, strideC, count);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasHgemmStridedBatched(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, (const __half*)alpha3, (const __half*)b, mb, strideB, (const __half*)a, ma, strideA, (const __half*)beta3, (__half*)c, mc, strideC, count);
else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasHgemmStridedBatched(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, (const __half*)alpha3, (const __half*)b, mb, strideB, (const __half*)a, ma, strideA, (const __half*)beta3, (__half*)c, mc, strideC, count);
}
else {
ShowNiuTransErrors("Unsupported data type!");
}
}
/*
matrix multiplication via cuda version BLAS
*/
void CudaBLASMatrixMULList(cublasHandle_t * handle,
List * a, MATRIX_TRANS_TYPE transposedA,
List * b, MATRIX_TRANS_TYPE transposedB,
List * c,
int count, DTYPE alpha, DTYPE beta)
{
CheckNiuTransErrors((a && b && c), "Empty input lists!");
CheckNiuTransErrors((a->count == b->count && a->count == c->count), "Input lists must be of the same size!");
if (a->count == 0)
return;
bool isUniform = true;
bool isStrided = true;
MTYPEINT strideA = MAX_INT;
MTYPEINT strideB = MAX_INT;
MTYPEINT strideC = MAX_INT;
for (int i = 1; i < a->count; i++) {
XTensor * aim = (XTensor*)a->GetItem(i - 1);
XTensor * bim = (XTensor*)b->GetItem(i - 1);
XTensor * cim = (XTensor*)c->GetItem(i - 1);
XTensor * ai = (XTensor*)a->GetItem(i);
XTensor * bi = (XTensor*)b->GetItem(i);
XTensor * ci = (XTensor*)c->GetItem(i);
if (!XTensor::IsIdentical(aim, ai) ||
!XTensor::IsIdentical(bim, bi) ||
!XTensor::IsIdentical(cim, ci))
{
isUniform = false;
break;
}
if (isStrided) {
MTYPEINT gapA = MTYPEINT(ai->data) - MTYPEINT(aim->data);
MTYPEINT gapB = MTYPEINT(bi->data) - MTYPEINT(bim->data);
MTYPEINT gapC = MTYPEINT(ci->data) - MTYPEINT(cim->data);
if (strideA == MAX_INT)
strideA = gapA;
if (strideB == MAX_INT)
strideB = gapB;
if (strideC == MAX_INT)
strideC = gapC;
if (strideA != gapA || strideB != gapB || strideC != gapC)
isStrided = false;
}
}
XTensor * a0 = (XTensor*)a->GetItem(0);
XTensor * b0 = (XTensor*)b->GetItem(0);
XTensor * c0 = (XTensor*)c->GetItem(0);
if (isUniform) {
XMem * mem = a0->mem;
if (isStrided) {
CudaBLASMatrixMULBatchedStrided(handle,
a0->data, transposedA, a0->dataType, strideA / a0->unitSize,
b0->data, transposedB, b0->dataType, strideB / b0->unitSize,
c0->data, c0->dataType, strideC / c0->unitSize, a->count,
a0->dimSize[1], a0->dimSize[0],
b0->dimSize[1], b0->dimSize[0],
c0->dimSize[1], c0->dimSize[0], alpha, beta);
}
else {
DTYPE ** ap = new DTYPE*[a->count];
DTYPE ** bp = new DTYPE*[b->count];
DTYPE ** cp = new DTYPE*[c->count];
for (int i = 0; i < a->count; i++) {
XTensor * ai = (XTensor*)a->GetItem(i);
XTensor * bi = (XTensor*)b->GetItem(i);
XTensor * ci = (XTensor*)c->GetItem(i);
ap[i] = (DTYPE*)ai->data;
bp[i] = (DTYPE*)bi->data;
cp[i] = (DTYPE*)ci->data;
}
mem->SetPinBuf();
DTYPE ** apGPU = (DTYPE**)mem->AllocBuf(mem->devID, sizeof(DTYPE*) * a->count, 256);
DTYPE ** bpGPU = (DTYPE**)mem->AllocBuf(mem->devID, sizeof(DTYPE*) * a->count, 256);
DTYPE ** cpGPU = (DTYPE**)mem->AllocBuf(mem->devID, sizeof(DTYPE*) * a->count, 256);
cudaMemcpy(apGPU, ap, sizeof(DTYPE*) * a->count, cudaMemcpyHostToDevice);
cudaMemcpy(bpGPU, bp, sizeof(DTYPE*) * b->count, cudaMemcpyHostToDevice);
cudaMemcpy(cpGPU, cp, sizeof(DTYPE*) * c->count, cudaMemcpyHostToDevice);
CudaBLASMatrixMULBatched(handle,
(const void**)apGPU, transposedA, a0->dataType,
(const void**)bpGPU, transposedB, b0->dataType,
(void**)cpGPU, c0->dataType, a->count,
a0->dimSize[1], a0->dimSize[0],
b0->dimSize[1], b0->dimSize[0],
c0->dimSize[1], c0->dimSize[0], alpha, beta);
delete[] ap;
delete[] bp;
delete[] cp;
mem->BackToPinBuf();
}
}
else {
for (int i = 0; i < a->count; i++) {
XTensor * ai = (XTensor*)a->GetItem(i);
XTensor * bi = (XTensor*)b->GetItem(i);
XTensor * ci = (XTensor*)c->GetItem(i);
CudaBLASMatrixMUL(handle,
ai->data, transposedA, ai->dataType,
bi->data, transposedB, bi->dataType,
ci->data, ci->dataType,
ai->dimSize[1], ai->dimSize[0],
bi->dimSize[1], bi->dimSize[0],
ci->dimSize[1], ci->dimSize[0], alpha, beta);
}
}
}
__global__
void KernelFloatToFloat16(float * s, __half * t, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size) {
t[i] = __float2half(s[i]);
}
}
__global__
void KernelFloat16ToFloat(__half * s, float * t, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size) {
t[i] = __half2float(s[i]);
}
}
/*
data conversion (cuda code)
>> devID - device id
>> s - source data array
>> typeS - source data type
>> t - target data array
>> typeT - target data type
>> size - number of the items in s (and t)
*/
void CudaConvertDataType(int devID, void * s, MATRIX_DATA_TYPE typeS, void * t, MATRIX_DATA_TYPE typeT, int size)
{
CheckNiuTransErrors((devID >= 0), "This code must be run on GPUs!");
if (typeS == typeT)
return;
int gridSize[3];
int blockSize[3];
GDevs->GetGridAndBlockSize(devID, size, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
if (typeS == X_FLOAT && typeT == X_FLOAT16)
KernelFloatToFloat16 << <blocks, threads >> >((float*)s, (__half*)t, size);
else if (typeS == X_FLOAT16 && typeT == X_FLOAT)
KernelFloat16ToFloat << <blocks, threads >> >((__half*)s, (float*)t, size);
else {
ShowNiuTransErrors("Unsupported data types for conversion!");
}
}
/************************************************************
* cuda means
*/
/*
copy a range of elements from a source vector to a target vector
>> s - source matrix
>> t - target matrix
>> stream - the stream for creating the job pipeline
<< return - succeed or not
*/
bool CudaCopyValues(XTensor * s, XTensor * t, XStream * stream)
{
if (s == NULL || t == NULL)
return false;
CheckNiuTransErrors(s->dataType == t->dataType, "Unmatched data type!");
CheckNiuTransErrors((s->unitSize == t->unitSize), "Incompatible vectors in value copy.");
CheckNiuTransErrors((s->denseRatio <= s->denseRatio), "Incompatible vectors in value copy.");
/* dense -> dense */
if (!s->isSparse && !t->isSparse) {
if (stream == NULL)
XMemCopy(t->data, t->devID, s->data, s->devID, s->unitSize * s->unitNum);
else
XMemCopyAsync(t->data, t->devID, s->data, s->devID, s->unitSize * s->unitNum, stream->stream, stream->devID);
}
/* dense -> sparse */
else if (!s->isSparse && t->isSparse &&
s->dataType == DTYPE_IN_MATRIX &&
t->dataType == DTYPE_IN_MATRIX)
{
ShowNiuTransErrors("TODO!");
}
/* sparse -> dense */
else if (s->isSparse && !t->isSparse &&
s->dataType == DTYPE_IN_MATRIX &&
t->dataType == DTYPE_IN_MATRIX)
{
ShowNiuTransErrors("TODO!");
}
/* sparse -> sparse */
else if (s->isSparse && t->isSparse &&
s->dataType == DTYPE_IN_MATRIX &&
t->dataType == DTYPE_IN_MATRIX)
{
int num = s->GetNonzeroSize();
int size = sizeof(int) + num * (s->unitSize + sizeof(int));
if (stream == NULL)
XMemCopy(t->data, t->devID, s->data, s->devID, size);
else
XMemCopyAsync(t->data, t->devID, s->data, s->devID, size, stream->stream, stream->devID);
t->unitNumNonZero = num;
}
else {
ShowNiuTransErrors("TODO!");
}
return true;
}
/*
flush a list of XTensor to GPU memory
>> mList - list of the tensors
>> GPUMem - memory pool for the GPU
*/
void CudaCPUToGPUFlush(List * mList, XMem * GPUMem)
{
if (mList == NULL || mList->count == 0)
return;
#ifdef USE_CUDA
int size = 0, p = 0;
int reqiredSize = 0;
/* compute the requried memory size */
for (int i = 0; i < mList->count; i++) {
XTensor * m = (XTensor*)mList->GetItem(i);
CheckNiuTransErrors((m->devID < 0), "Cannot do gpu-flush on matrices that are already on GPUs.");
if (m->isSparse)
reqiredSize = sizeof(int) + (sizeof(int) + m->unitSize) * m->unitNumNonZero;
else
reqiredSize = m->unitSize * m->unitNum;
//reqiredSize = (int)GPUMem->GetPitch(GPUMem->devID, (MTYPE)GPUMem->GetAddress() + size, reqiredSize);
size += reqiredSize;
}
char * data = new char[size];
char * GPUData = (char*)GPUMem->Alloc(GPUMem->devID, size);
int pSize = 0;
/* place the data in a memory block */
for (int i = 0; i < mList->count; i++) {
XTensor * m = (XTensor*)mList->GetItem(i);
if (m->isSparse)
pSize = sizeof(int) + (sizeof(int) + m->unitSize) * m->unitNumNonZero;
else
pSize = m->unitSize * m->unitNum;
//reqiredSize = (int)GPUMem->GetPitch(GPUMem->devID, (MTYPE)GPUMem->GetAddress() + p, pSize);
reqiredSize = pSize;
memcpy(data + p, m->data, pSize);
if (m->dataHost != NULL)
delete[](char*)m->dataHost;
m->dataHost = NULL;
m->data = GPUData + p;
m->devID = GPUMem->devID;
m->mem = GPUMem;
p += reqiredSize;
}
/* copy from CPU memory to GPU memory */
cudaMemcpy(GPUData, data, size, cudaMemcpyHostToDevice);
delete[] data;
#endif
}
/* copy the data from GPU memory to CPU memory */
void CudaGPUToCPUFlush(XTensor * tensor)
{
CheckNiuTransErrors((sizeof(DTYPE) == tensor->unitSize), "Unsupported data type.");
if (tensor->dataHost != NULL)
delete[](char*)tensor->dataHost;
if (tensor->isSparse) {
int num = int(tensor->unitNum * tensor->denseRatio + 1);
cudaMemcpy(&num, (DTYPE*)tensor->data, sizeof(int), cudaMemcpyDeviceToHost);
int tupleSize = sizeof(int) + sizeof(DTYPE);
int size = sizeof(int) + tupleSize*(num);
CheckNiuTransErrors((size >= 0), "Illegal data size in the sparse matrix!");
tensor->dataHost = new char[size];
cudaMemcpy(tensor->dataHost, tensor->data, size, cudaMemcpyDeviceToHost);
}
else {
tensor->dataHost = new char[tensor->unitNum * tensor->unitSize];
if (tensor->data != NULL)
cudaMemcpy(tensor->dataHost, tensor->data, tensor->unitNum * tensor->unitSize, cudaMemcpyDeviceToHost);
else
memset(tensor->dataHost, 0, tensor->unitNum * tensor->unitSize);
}
}
/*
set the cell to the ascending order along a given dimension (kernel code)
>> data - the data array
>> stride - how many items we go ove when move to the next item along the dimension
>> strideNum - size of the given dimension
>> blockNum - block number
*/
__global__
void KernelSetAscendingOrder(int * data, int stride, int strideNum, int blockNum)
{
__shared__ int iBlock[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ int iOffset[MAX_CUDA_THREAD_NUM_PER_BLOCK];
/* index along the "stride" dimension */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* index along the leading dimension */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if (i >= stride * blockNum || j >= strideNum)
return;
if (threadIdx.y == 0) {
iBlock[threadIdx.x] = i / stride;
iOffset[threadIdx.x] = i % stride;
}
__syncthreads();
int * d = (int*)data + (iBlock[threadIdx.x] * strideNum + j) * stride + iOffset[threadIdx.x];
*d = j;
}
/*
set the cell to the ascending order along a given dimension
>> a - the tensor
>> dim - the dimension
*/
extern "C"
void CudaSetAscendingOrder(XTensor * a, int dim)
{
CheckNiuTransErrors((a->dataType == X_INT), "TODO!");
int stride = 1;
int strideNum = a->dimSize[dim];
for (int i = 0; i < dim; i++)
stride *= a->dimSize[i];
int blockNum = 1;
for (int i = dim + 1; i < a->order; i++)
blockNum *= a->dimSize[i];
int gridSize[3];
int blockSize[3];
GDevs->GetGridAndBlockSize2D(a->devID, strideNum, stride * blockNum, MAX_INT, gridSize, blockSize);
KernelSetAscendingOrder << <dim3(gridSize[1], gridSize[0]), dim3(blockSize[1], blockSize[0]) >> >
((int*)a->data, stride, strideNum, blockNum);
}
/*
set each entry to its negtive value (CUDA Kernel)
>> d - pointer to the data array
>> size - size of the data array
*/
__global__
void KernelNegate(DTYPE * d, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
d[i] = -d[i];
}
/*
set each entry to its negtive value (CUDA Kernel)
This is for float16 computation
>> d - pointer to the data array
>> size - size of the data array
*/
__global__
void KernelNegate(__half * d, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
if (i < size)
d[i] = __hsub(__float2half(0), d[i]);
#else
if (i < size)
d[i] = __float2half(-__half2float(d[i]));
#endif
}
/*
set each entry to its negtive value
>> a - the tensor
*/
extern "C"
void CudaNegate(XTensor * a)
{
CheckNiuTransErrors((a->isSparse == false), "TODO!");
int gridSize[3];
int blockSize[3];
GDevs->GetGridAndBlockSize(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
if (a->dataType == DTYPE_IN_MATRIX) {
KernelNegate << <blocks, threads >> >((DTYPE*)a->data, a->unitNum);
}
else if (a->dataType == X_FLOAT16) {
KernelNegate << <blocks, threads >> >((__half*)a->data, a->unitNum);
}
else {
ShowNiuTransErrors("TODO!");
}
}
/*
set all entries to its root (CUDA Kernel)
>> d - data array
>> size - size of the data array
*/
__global__
void KernelSqrtV2(DTYPE * d, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
d[i] = sqrt(d[i]);
}
/*
set all entries to its root (CUDA Kernel)
>> d - data array
>> size - size of the data array
*/
__global__
void KernelSqrtV2(__half * d, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
if (i < size)
d[i] = hsqrt(d[i]);
#else
if (i < size)
d[i] = __float2half(sqrt(__half2float(d[i])));
#endif
}
/*
get power(d[i], p)
>> d - data array
>> p - power
>> size - size of the data array
*/
__global__
void KernelPower(DTYPE * d, DTYPE p, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
d[i] = pow(d[i], p);
}
/*
get power(d[i], p)
>> d - data array
>> p - power
>> size - size of the data array
*/
__global__
void KernelPower(__half * d, __half p, int size)
{
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
//int i = blockDim.x * blockIdx.x + threadIdx.x;
//if (i < size)
// d[i] = hpow(d[i], p);
#else
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
d[i] = __float2half(pow(__half2float(d[i]), __half2float(p)));
#endif
}
/* get the power of the entries */
extern "C"
void CudaPower(XTensor * a, DTYPE p)
{
int gridSize[3];
int blockSize[3];
GDevs->GetGridAndBlockSize(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
if (a->dataType == DTYPE_IN_MATRIX) {
if (p == (DTYPE)0.5) {
KernelSqrtV2 << <blocks, threads >> >((DTYPE*)a->data, a->unitNum);
}
else if (p != (DTYPE)1.0) {
KernelPower << <blocks, threads >> >((DTYPE*)a->data, p, a->unitNum);
}
}
else if (a->dataType == X_FLOAT16) {
if (p == (DTYPE)0.5) {
KernelSqrtV2 << <blocks, threads >> >((__half*)a->data, a->unitNum);
}
else if (p != (DTYPE)1.0) {
ShowNiuTransErrors("TODO!");
//unsigned short p2 = FloatToFloat16(p);
//__half * pp = (__half*)&p2;
//KernelPower<<<blocks, threads>>>((__half*)a->data, *pp, a->unitNum);
}
}
else {
ShowNiuTransErrors("TODO!");
}
}
/*
scale and shift all matrix entires p = p * scale + shift (CUDA Kernel)
>> d - the data array
>> size - the size of d
>> scale - how much we want to scale it
>> shift - how much we want to shift it
*/
__global__
void KernelScaleAndShift(DTYPE * d, int size, DTYPE scale, DTYPE shift)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
d[i] = d[i] * scale + shift;
}
/*
scale and shift all matrix entires p = p * scale + shift (CUDA Kernel)
This is for float16 computation
>> d - the data array
>> size - the size of d
>> scale - how much we want to scale it
>> shift - how much we want to shift it
*/
__global__
void KernelScaleAndShift(__half * d, int size, __half scale, __half shift)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
if (i < size)
d[i] = __hadd(__hmul(d[i], scale), shift);
#else
if (i < size)
d[i] = __float2half(__half2float(d[i]) * __half2float(scale) + __half2float(shift));
#endif
}
/*
scale and shift all matrix entires
p = p * scale + shift
>> a - the matrix
>> scale - the scaler factor
>> shift - the shift factor
*/
void CudaScaleAndShift(XTensor * a, DTYPE scale, DTYPE shift)
{
/* sparse matrix */
if (a->isSparse) {
// TODO
}
/* dense matrix */
else {
int gridSize[3];
int blockSize[3];
GDevs->GetGridAndBlockSize(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
if (a->dataType == DTYPE_IN_MATRIX) {
KernelScaleAndShift << <blocks, threads >> >((DTYPE*)a->data, a->unitNum, scale, shift);
}
else if (a->dataType == X_FLOAT16) {
unsigned short scale2 = FloatToFloat16(scale);
unsigned short shift2 = FloatToFloat16(shift);
__half * scaleft16p = (__half*)&scale2;
__half * shiftft16p = (__half*)&shift2;
KernelScaleAndShift << <blocks, threads >> >((__half*)a->data, a->unitNum, *scaleft16p, *shiftft16p);
}
else {
ShowNiuTransErrors("TODO!");
}
}
}
/*
copy a number of blocks to target positions
NOTE that this version makes more use of the 2d threads in cuda
>> source - data array (head of the blocks) to copy from
>> blockSize - size of block
>> blockNum - number of blocks
>> target - target data array
>> targetBlocks - target positions of the copy
*/
template<int miniBlockSize>
__global__
void KernelCopyBlocks(DTYPE * source, int blockSize, int blockNum, DTYPE * target, int * targetBlocks)
{
/* entry index in the block */
int i = (blockDim.x * blockIdx.x + threadIdx.x) * miniBlockSize;
/* block index */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if (j >= blockNum)
return;
/* target position */
int k = targetBlocks[j];
DTYPE * s = source + blockSize * j;
DTYPE * t = target + blockSize * k;
if (i < blockSize) {
if (miniBlockSize == 4) {
t[i] = s[i];
t[i + 1] = s[i + 1];
t[i + 2] = s[i + 2];
t[i + 3] = s[i + 3];
}
else if (miniBlockSize <= 1) {
t[i] = s[i];
}
else {
printf("something wrong!");
}
}
}
/*
copy a number of blocks to target positions (cuda version)
>> source - data array (head of the blocks) to copy from
>> blockSize - size of block
>> blockNum - number of blocks
>> target - target data array
>> targetBlocks - target positions of the copy (on the device)
>> myMem - memory pool
*/
void CudaCopyBlocks(void * source, int blockSize, int blockNum, void * target, int * targetBlocks, XMem * myMem)
{
CheckNiuTransErrors((myMem != NULL), "No memory pool!");
CheckNiuTransErrors((myMem->devID >= 0), "Wrong device to run!");
CheckNiuTransErrors((blockSize % sizeof(DTYPE) == 0), "Unsupported block size!");
int cudaGrids[3];
int cudaBlocks[3];
int bSize = blockSize / sizeof(DTYPE);
if (bSize % 4 == 0) {
GDevs->GetGridAndBlockSize2D(myMem->devID, bSize / 4, blockNum, MAX_INT, cudaGrids, cudaBlocks);
KernelCopyBlocks<4> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
((DTYPE*)source, bSize, blockNum, (DTYPE*)target, targetBlocks);
}
else {
GDevs->GetGridAndBlockSize2D(myMem->devID, bSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
KernelCopyBlocks<1> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
((DTYPE*)source, bSize, blockNum, (DTYPE*)target, targetBlocks);
}
}
/*
copy a number of blocks from source positions to target positions
>> source - data array (head of the blocks) to copy from
>> blockSize - size of block
>> sourceBlocks - source positions of the copy
>> blockNum - number of blocks
>> target - target data array
>> targetBlocks - target positions of the copy
*/
__global__
void KernelCopyBlocksSelected(DTYPE * source, int blockSize, int * sourceBlocks, int blockNum, DTYPE * target, int * targetBlocks)
{
/* block index */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* entry index in the block */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if (j >= blockNum)
return;
/* target position */
int srcIndex = sourceBlocks[j];
int tgtIndex = targetBlocks[j];
DTYPE * s = source + blockSize * srcIndex;
DTYPE * t = target + blockSize * tgtIndex;
if (i < blockSize)
t[i] = s[i];
}
/*
copy a number of blocks from source positions to target positions (cuda version)
>> source - data array (head of the blocks) to copy from
>> blockSize - size of block
>> sourceBlocks - source positions of the copy
>> blockNum - number of blocks
>> target - target data array
>> targetBlocks - target positions of the copy
>> myMem - memory pool
*/
void CudaCopyBlocksSelected(void * source, int blockSize, int * sourceBlocks, int blockNum, void * target, int * targetBlocks, XMem * myMem)
{
CheckNiuTransErrors((myMem != NULL), "No memory pool!");
CheckNiuTransErrors((myMem->devID >= 0), "Wrong device to run!");
CheckNiuTransErrors((blockSize % sizeof(DTYPE) == 0), "Unsupported block size!");
/* copy the index to the GPU memory */
int * sourceBlocksTMP = (int*)myMem->AllocBuf(myMem->devID, blockNum * sizeof(int));
int * targetBlocksTMP = (int*)myMem->AllocBuf(myMem->devID, blockNum * sizeof(int));
XMemCopy(sourceBlocksTMP, myMem->devID, sourceBlocks, -1, blockNum * sizeof(int));
XMemCopy(targetBlocksTMP, myMem->devID, targetBlocks, -1, blockNum * sizeof(int));
int cudaGrids[3];
int cudaBlocks[3];
GDevs->GetGridAndBlockSize2D(myMem->devID, blockSize / sizeof(DTYPE), blockNum, MAX_INT, cudaGrids, cudaBlocks);
KernelCopyBlocksSelected << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
((DTYPE*)source, blockSize / sizeof(DTYPE), sourceBlocksTMP, blockNum, (DTYPE*)target, targetBlocksTMP);
myMem->ReleaseBuf(myMem->devID, blockNum * sizeof(int));
myMem->ReleaseBuf(myMem->devID, blockNum * sizeof(int));
}
/*
set target data block index for the data movement in split (device code)
>> blockIndex - block index