forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LinearAlgebra.cpp
1742 lines (1536 loc) · 64.2 KB
/
LinearAlgebra.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <ATen/ATen.h>
#include <ATen/ExpandUtils.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/ReduceOpsUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/TensorUtils.h>
#include <ATen/Parallel.h>
#include <ATen/LegacyTHFunctionsCPU.h>
#include <ATen/core/grad_mode.h>
#include <functional>
#include <numeric>
#include <vector>
#include <limits>
#include <ATen/NamedTensorUtils.h>
namespace at {
namespace native {
// Helper function for det methods.
// For pivoted LU factorization A = P * L * U. Since we always have det(L) = 1,
// det(P) = \pm 1, this method returns a 3-tuple:
// (det(P), diag(U), info),
// where info helps us identify singular matrices.
static inline std::tuple<Tensor, Tensor> _lu_det_P_diag_U(const Tensor& self) {
Tensor pivs, lu, infos;
std::tie(lu, pivs, infos) = at::_lu_with_info(self, /*pivot=*/true, /*check_errors=*/false);
TORCH_CHECK(infos.ge(0).all().item<uint8_t>(), "Invalid argument passed to lu");
auto n = self.size(-1);
auto num_exchanges = (at::arange(1, n + 1, pivs.options()) != pivs).sum(-1, /*keepdim=*/false, /*dtype=*/self.scalar_type()).fmod_(2);
// NB: the `.contiguous()` call is added due to the bug in `.prod()` as reported in
// issue #https://github.com/pytorch/pytorch/issues/34061
auto u_diagonal = lu.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).contiguous();
return std::tuple<Tensor, Tensor>(num_exchanges.mul_(-2).add_(1), u_diagonal);
}
// torch.linalg.det, alias for torch.det
Tensor linalg_det(const Tensor& self) {
return self.det();
}
Tensor det(const Tensor& self) {
squareCheckInputs(self);
TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())),
"Expected a floating point tensor as input");
Tensor det_P, diag_U;
std::tie(det_P, diag_U) = _lu_det_P_diag_U(self);
// complete_det is 0 when U is singular (U(i, i) = 0 for some i in [1, self.size(-1)]).
// The product accumulation takes care of this case, and hence no special case handling is required.
auto complete_det = diag_U.prod(-1).mul_(det_P);
return complete_det;
}
Tensor logdet(const Tensor& self) {
squareCheckInputs(self);
TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())),
"Expected a floating point tensor as input");
Tensor det_P, diag_U;
std::tie(det_P, diag_U) = _lu_det_P_diag_U(self);
Tensor det_sign = diag_U.sign().prod(-1).mul_(det_P);
// If det_sign > 0, diag_U.abs_().log_().sum(-1) gives logdet (this means U is not singular).
// If det_sign <= 0, then we get proper nan (when det < 0, i.e., det_sign) or -inf (when det = 0, i.e., U is singular).
// U is singular when U(i, i) = 0 for some i in [1, self.size(-1)].
Tensor logdet_vals = diag_U.abs_().log_().sum(-1);
if (self.dim() > 2) {
logdet_vals.index_put_((det_sign < 0).nonzero_numpy(), at::full({}, NAN, self.options()));
} else if (det_sign.item<double>() < 0) {
logdet_vals.fill_(NAN);
}
return logdet_vals;
}
std::tuple<Tensor, Tensor> slogdet(const Tensor& self) {
squareCheckInputs(self);
TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())),
"Expected a floating point tensor as input");
Tensor det_P, diag_U;
std::tie(det_P, diag_U) = _lu_det_P_diag_U(self);
auto det_sign = diag_U.sign().prod(-1).mul_(det_P);
// abslogdet_val is -inf if U is singular, in which case diag_U.abs_().log_().sum(-1) will return -inf.
// U is singular when U(i, i) = 0 for some i in [1, self.size(-1)].
// Since abslogdet_val cannot take nan, no special case handling is required.
auto abslogdet_val = diag_U.abs_().log_().sum(-1);
return std::make_tuple(det_sign, abslogdet_val);
}
Tensor pinverse(const Tensor& self, double rcond) {
TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())) && self.dim() >= 2,
"pinverse(", self.scalar_type(), "{", self.sizes(), "}): expected a tensor with 2 or more dimensions "
"of floating types");
if (self.numel() == 0) {
// Match NumPy
auto self_sizes = self.sizes().vec();
std::swap(self_sizes[self.dim() - 1], self_sizes[self.dim() - 2]);
return at::empty(self_sizes, self.options());
}
Tensor U, S, V;
std::tie(U, S, V) = self.svd();
Tensor max_val = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1);
Tensor S_pseudoinv = at::where(S > rcond * max_val, S.reciprocal(), at::zeros({}, S.options())).to(self.dtype());
// computes V.conj() @ diag(S_pseudoinv) @ U.T.conj()
return at::matmul(V.conj() * S_pseudoinv.unsqueeze(-2), U.transpose(-2, -1).conj());
}
static inline Tensor _matrix_rank_helper(const Tensor& self, bool symmetric) {
Tensor S;
if (!symmetric) {
Tensor U, V;
std::tie(U, S, V) = self.svd(/*some=*/true, /*compute_uv=*/false);
} else {
Tensor eigvecs;
std::tie(S, eigvecs) = self.symeig(/*eigenvectors=*/false);
S = S.abs();
}
return S;
}
Tensor matrix_rank(const Tensor& self, double tol, bool symmetric) {
TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())) && self.dim() == 2,
"matrix_rank(", self.scalar_type(), "{", self.sizes(), "}): expected a 2D tensor "
"of floating types");
Tensor S = _matrix_rank_helper(self, symmetric);
return (S > tol).sum();
}
Tensor matrix_rank(const Tensor& self, bool symmetric) {
TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())) && self.dim() == 2,
"matrix_rank(", self.scalar_type(), "{", self.sizes(), "}): expected a 2D tensor "
"of floating types");
Tensor S = _matrix_rank_helper(self, symmetric);
double tol = _get_epsilon(self.scalar_type()) * std::max(self.size(0), self.size(1));
return (S > S.max().mul_(tol)).sum();
}
static void check_1d(const Tensor& t, const char* arg, const char* fn) {
TORCH_CHECK(t.dim() == 1, fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D");
}
Tensor addr(const Tensor& self, const Tensor& vec1, const Tensor& vec2, Scalar beta, Scalar alpha) {
TORCH_WARN(
"torch.addr is deprecated and may be removed in a future PyTorch release. "
"This function can be implemented using torch.outer as "
"alpha * torch.outer(vec1, vec2) + beta * input when beta is not zero, "
"alpha * torch.outer(vec1, vec2) when beta is zero.");
Tensor outer_result = at::outer(vec1, vec2) * alpha;
if (beta.to<double>() == 0.0) {
return outer_result;
}
return outer_result + (self * beta);
}
Tensor& addr_(Tensor& self, const Tensor& vec1, const Tensor& vec2, Scalar beta, Scalar alpha) {
return at::addr_out(self, self, vec1, vec2, beta, alpha);
}
Tensor& addr_out(Tensor &result, const Tensor& self, const Tensor& vec1, const Tensor& vec2, Scalar beta, Scalar alpha) {
auto addr_result = at::addr(self, vec1, vec2, beta, alpha);
// Validates safe casting
const auto result_dtype = addr_result.scalar_type();
TORCH_CHECK(canCast(result_dtype, result.scalar_type()),
"result type ", result_dtype,
" can't be cast to the desired output type ", result.scalar_type());
at::native::resize_output(result, addr_result.sizes().vec());
result.copy_(addr_result);
return result;
}
// torch.ger, alias for torch.outer
Tensor& ger_out(Tensor &result, const Tensor& self, const Tensor& vec2) {
TORCH_WARN("torch.ger is deprecated and will be removed in a future PyTorch release. "
"Use torch.outer instead.");
return at::outer_out(result, self, vec2);
}
Tensor ger(const Tensor& self, const Tensor& vec2) {
return self.outer(vec2);
}
Tensor& outer_out(Tensor &result, const Tensor& self, const Tensor& vec2) {
check_1d(self, "self", "outer");
check_1d(vec2, "vec2", "outer");
// torch.outer is implemented as a composite op using reshape and mul
at::mul_out(result, self.reshape({self.size(0), 1}), vec2);
return result;
}
Tensor outer(const Tensor& self, const Tensor& vec2) {
check_1d(self, "self", "outer");
check_1d(vec2, "vec2", "outer");
return self.reshape({self.size(0), 1}) * vec2;
}
static void addmm_impl_cpu_(
Tensor &result, const Tensor &self, Tensor m1, Tensor m2, Scalar beta, Scalar alpha) {
TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2);
// Array access is faster than .size(n) and .stride(n)
const auto self_sizes = self.sizes();
auto m1_strides = m1.strides();
auto m1_sizes = m1.sizes();
auto m2_strides = m2.strides();
auto m2_sizes = m2.sizes();
TORCH_CHECK(
m1_sizes[1] == m2_sizes[0], "mat1 and mat2 shapes cannot be multiplied (",
m1_sizes[0], "x", m1_sizes[1], " and ", m2_sizes[0], "x", m2_sizes[1], ")");
TORCH_CHECK(
self_sizes[0] == m1_sizes[0] && self_sizes[1] == m2_sizes[1],
"input shape is incompatible with matrix multiplication (",
m1_sizes[0], "x", m1_sizes[1], " @ ", m2_sizes[0], "x", m2_sizes[1], " != ",
self_sizes[0], "x", self_sizes[1], ")");
native::resize_(result, self_sizes);
const auto result_strides = result.strides();
const auto result_sizes = result.sizes();
if (result.numel() == 0) {
return;
}
if (beta.toComplexDouble() != 0.0 && !self.is_same(result)) {
result.copy_(self);
}
bool transpose_c = false;
Tensor c;
// Cast result as matrix a
if (result_strides[0] == 1 &&
(result_sizes[1] == 1 || result_strides[1] >= std::max(int64_t{1}, result_sizes[0]))) {
transpose_c = false;
c = result;
} else if (result_strides[1] == 1 &&
(result_sizes[0] == 1 || result_strides[0] >= std::max(int64_t{1}, result_sizes[1]))) {
std::swap(m1, m2);
std::swap(m1_sizes, m2_sizes);
std::swap(m1_strides, m2_strides);
transpose_c = true;
c = result;
} else {
transpose_c = false;
// make c FORTRAN contiguous
c = result.transpose(0, 1).contiguous().transpose_(0, 1);
}
const int64_t m = result_sizes[transpose_c ? 1 : 0];
const int64_t n = result_sizes[transpose_c ? 0 : 1];
const int64_t k = m1_sizes[transpose_c ? 0 : 1];
// Cast m1 as matrix a
bool transpose_a = false;
Tensor a;
/* Need lda >= max(1, (transpose_a ? k : m)) */
if (m1_strides[transpose_c ? 1 : 0] == 1 &&
m1_strides[transpose_c ? 0 : 1] >= std::max(int64_t{1}, m)) {
transpose_a = false;
a = m1;
} else if (m1_strides[transpose_c ? 0 : 1] == 1 &&
m1_strides[transpose_c ? 1 : 0] >= std::max(int64_t{1}, k)) {
transpose_a = true;
a = m1;
} else {
transpose_a = !transpose_c;
a = m1.clone(at::MemoryFormat::Contiguous);
}
// Cast m2 as matrix b
bool transpose_b = false;
Tensor b;
/* Need ldm2_ >= max(1, (transpose_m2 == 'n' ? k : n)) */
if (m2_strides[transpose_c ? 1 : 0] == 1 &&
m2_strides[transpose_c ? 0 : 1] >= std::max(int64_t{1}, k)) {
transpose_b = false;
b = m2;
} else if (m2_strides[transpose_c ? 0 : 1] == 1 &&
m2_strides[transpose_c ? 1 : 0] >= std::max(int64_t{1}, n)) {
transpose_b = true;
b = m2;
} else {
transpose_b = !transpose_c;
b = m2.clone(at::MemoryFormat::Contiguous);
}
const int64_t lda = a.strides()[(transpose_a == transpose_c) ? 1 : 0];
const int64_t ldb = b.strides()[(transpose_b == transpose_c) ? 1 : 0];
const int64_t ldc = c.strides()[transpose_c ? 0 : 1];
// Apply BLAS routine
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16,
result.scalar_type(), "addmm_impl_cpu_",
[&]{
at::native::cpublas::gemm(
transpose_a ? cpublas::Transpose : cpublas::NoTranspose,
transpose_b ? cpublas::Transpose : cpublas::NoTranspose,
m, n, k,
alpha.to<scalar_t>(),
a.data_ptr<scalar_t>(), lda,
b.data_ptr<scalar_t>(), ldb,
beta.to<scalar_t>(),
c.data_ptr<scalar_t>(), ldc);
});
if (!c.is_same(result)) {
result.copy_(c);
}
}
static void addbmm_impl_cpu_(
Tensor &result, const Tensor &self, const Tensor &batch1, const Tensor &batch2, Scalar beta, Scalar alpha) {
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
TORCH_CHECK(batch1.size(0) == batch2.size(0),
"batch1 and batch2 must have same number of batches, got ",
batch1.size(0), " and ", batch2.size(0));
TORCH_CHECK(batch1.size(2) == batch2.size(1),
"Incompatible matrix sizes for bmm (",
batch1.size(1), "x", batch1.size(2), " and ",
batch2.size(1), "x", batch2.size(2), ")");
const int64_t dim1 = batch1.size(1);
const int64_t dim2 = batch2.size(2);
TORCH_CHECK(self.size(0) == dim1 && self.size(1) == dim2,
"self tensor does not match matmul output shape");
result.resize_as_(self);
if (beta.to<double>() != 0.0 && !self.is_same(result)) {
result.copy_(self);
}
const int64_t num_batches = batch1.size(0);
for (int64_t batch = 0; batch < num_batches; ++batch) {
addmm_impl_cpu_(result, result, batch1[batch], batch2[batch], beta, alpha);
beta = 1; // accumulate output once
}
}
Tensor& addbmm_cpu_out(Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
Tensor b_self = std::get<0>(expand_size(self, {batch1.size(1), batch2.size(2)}, "addbmm_out"));
{
at::NoNamesGuard guard;
addbmm_impl_cpu_(result, b_self, batch1, batch2, beta, alpha);
}
at::namedinference::propagate_names_for_addmm(result, batch1, batch2, self);
return result;
}
Tensor &addbmm_cpu_(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
return addbmm_cpu_out(self, self, batch1, batch2, beta, alpha);
}
Tensor addbmm_cpu(const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
Tensor result = at::empty({0}, self.options());
return addbmm_cpu_out(result, self, batch1, batch2, beta, alpha);
}
Tensor& addmm_cpu_out(Tensor &result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) {
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
Tensor b_self = std::get<0>(expand_size(self, {mat1.sizes()[0], mat2.sizes()[1]}, "addmm_out"));
{
at::NoNamesGuard guard;
addmm_impl_cpu_(result, b_self, mat1, mat2, beta, alpha);
}
at::namedinference::propagate_names_for_addmm(result, mat1, mat2, self);
return result;
}
Tensor addmm_cpu(const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) {
Tensor result = at::empty({0}, self.options());
return addmm_cpu_out(result, self, mat1, mat2, beta, alpha);
}
Tensor &addmm_cpu_(Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) {
return addmm_cpu_out(self, self, mat1, mat2, beta, alpha);
}
Tensor& mm_cpu_out(Tensor & result, const Tensor & self, const Tensor & mat2) {
TORCH_CHECK(self.dim() == 2, "self must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
native::resize_(result, {self.sizes()[0], mat2.sizes()[1]});
return addmm_cpu_out(result, result, self, mat2, 0, 1);
}
Tensor mm_cpu(const Tensor & self, const Tensor & mat2) {
TORCH_CHECK(self.dim() == 2, "self must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
Tensor result = at::empty({self.sizes()[0], mat2.sizes()[1]}, self.options());
return addmm_cpu_out(result, result, self, mat2, 0, 1);
}
template <typename scalar_t, bool is_bmm>
inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const Tensor& mat2, Scalar beta_, Scalar alpha_) {
int64_t bs = result.size(0);
int64_t is = result.size(1);
int64_t js = result.size(2);
int64_t ks = self.size(2);
scalar_t alpha = alpha_.to<scalar_t>();
scalar_t beta = beta_.to<scalar_t>();
auto r0 = result.accessor<scalar_t, 3>();
auto s0 = self.accessor<scalar_t, 3>();
auto m0 = mat2.accessor<scalar_t, 3>();
int64_t grain_size = std::min(internal::GRAIN_SIZE / (is * js * ks), (int64_t)1);
parallel_for(0, bs, grain_size, [&](int64_t b_begin, int64_t b_end) {
for (int64_t b = b_begin; b < b_end; b++) {
auto r1 = r0[b];
auto s1 = s0[b];
auto m1 = m0[b];
for (int64_t i = 0; i < is; i++) {
auto r2 = r1[i];
auto s2 = s1[i];
for (int64_t j = 0; j < js; j++) {
scalar_t &r = r2[j];
if (is_bmm) {
r = 0;
for (int64_t k = 0; k < ks; k++) {
r += s2[k] * m1[k][j];
}
} else {
r *= beta;
for (int64_t k = 0; k < ks; k++) {
r += alpha * s2[k] * m1[k][j];
}
}
}
}
}
});
}
// This tries to apply some optimizations to bmm/baddbmm:
// - When the operand size is small, computation are parallelized over the batch
// dimension using OMP and naive matrix multiplication is applied.
// - When the operand size is larger than the threshold, if compiled with MKL, MKL's batch gemm is used.
// - Otherwise, we use a series of matrix multiplications.
// The threshold of 400 for the first has not been thoroughly benchmarked yet and may have room for further
// optimization, it likely depends on the characteristics of the CPU, MKL will be different from non-MKL etc.,
// but this seems to be a first starting point.
static inline Tensor& bmm_out_or_baddbmm_(Tensor& self_or_result, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha, bool is_bmm_out) {
// is_bmm_out: true for bmm_out, false for baddbmm_
// self_or_result is "self" for baddbmm_ and "result" for bmm_out
CheckedFrom c = (is_bmm_out ? "bmm" : "baddbmm");
TensorArg self_arg(self_or_result, is_bmm_out ? "self" : "result", 0);
TensorArg b1_arg(batch1, "batch1", 1);
TensorArg b2_arg(batch2, "batch2", 2);
checkBackend(c, {self_or_result, batch1, batch2}, Backend::CPU);
checkDim(c, b1_arg, 3);
checkDim(c, b2_arg, 3);
int64_t bs = batch1.size(0);
checkSize(c, b2_arg, 0, bs);
int64_t contraction_size = batch1.size(2);
int64_t res_rows = batch1.size(1);
int64_t res_cols = batch2.size(2);
checkSize(c, b2_arg, 1, contraction_size);
if (is_bmm_out) {
self_or_result.resize_({bs, res_rows, res_cols});
} else {
checkSize(c, self_arg, 0, bs);
checkSize(c, self_arg, 1, res_rows);
checkSize(c, self_arg, 2, res_cols);
}
// handle pathological cases that blas may not like
if (self_or_result.numel() == 0) {
return self_or_result;
} else if (contraction_size == 0) {
if (is_bmm_out) {
return self_or_result.zero_();
} else {
return self_or_result.mul_(beta);
}
}
auto batch_items_contiguous_or_transposed = [&](const Tensor& t) {
return (t.stride(2) == 1 && t.stride(1) >= t.size(2))
|| (t.stride(1) == 1 && t.stride(2) >= t.size(1));
};
if (contraction_size * res_rows * res_cols < 400) {
if (is_bmm_out) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(batch1.scalar_type(), "bmm", [&] {
baddbmm_cpu_kernel<scalar_t, true>(self_or_result, batch1, batch2, beta, alpha);
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(batch1.scalar_type(), "baddbmm", [&] {
baddbmm_cpu_kernel<scalar_t, false>(self_or_result, batch1, batch2, beta, alpha);
});
}
} else if (at::hasMKL() && (at::native::is_floating_point(self_or_result) ||
at::native::is_complex(self_or_result))
&& batch_items_contiguous_or_transposed(batch1)
&& batch_items_contiguous_or_transposed(batch2)
&& self_or_result.is_contiguous()) {
at::native::_baddbmm_mkl_(self_or_result, batch1, batch2, beta, alpha);
} else { // split along batch dimension
if (is_bmm_out) {
for (int64_t b = 0; b < bs; b++) {
auto r = self_or_result.select(0, b);
native::mm_cpu_out(r, batch1.select(0, b), batch2.select(0, b));
}
} else {
for (int64_t b = 0; b < bs; b++) {
self_or_result.select(0, b).addmm_(batch1.select(0, b), batch2.select(0, b), beta, alpha);
}
}
}
return self_or_result;
}
Tensor baddbmm_cpu(const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
Tensor result = at::empty({0}, self.options());
return at::native::baddbmm_out_cpu(result, self, batch1, batch2, beta, alpha);
}
Tensor& baddbmm_out_cpu(Tensor &result, const Tensor& self_, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
Tensor self;
std::tie(self) = expand_size(self_, {batch1.size(0), batch1.size(1), batch2.size(2)}, "baddbmm");
result.resize_(self.sizes());
result.copy_(self);
return at::native::baddbmm__cpu(result, batch1, batch2, beta, alpha);
}
Tensor& baddbmm__cpu(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
return bmm_out_or_baddbmm_(self, batch1, batch2, beta, alpha, false);
}
Tensor bmm_cpu(const Tensor& self, const Tensor& mat2) {
Tensor result = at::empty({0}, self.options());
return at::native::bmm_out_cpu(result, self, mat2);
}
Tensor& bmm_out_cpu(Tensor &result, const Tensor& batch1, const Tensor& batch2) {
Scalar beta(0.0);
Scalar alpha(1.0);
{
NoNamesGuard guard;
bmm_out_or_baddbmm_(result, batch1, batch2, beta, alpha, true);
}
namedinference::propagate_names_if_nonempty(
result,
namedinference::compute_bmm_outnames(result, batch1, batch2));
return result;
}
Tensor& dot_out(Tensor& result, const Tensor& self, const Tensor& tensor) {
at::native::resize_output(result, {});
TORCH_CHECK(result.scalar_type() == self.scalar_type(),
"result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type());
return result.fill_(self.dot(tensor));
}
Tensor& vdot_out(Tensor& result, const Tensor& self, const Tensor& other) {
at::native::resize_output(result, {});
TORCH_CHECK(result.scalar_type() == self.scalar_type(),
"result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type());
return result.fill_(self.vdot(other));
}
/*
Matrix product of two Tensors.
The behavior depends on the dimensionality of the Tensors as follows:
- If both Tensors are 1-dimensional, the dot product (scalar) is returned.
- If both arguments are 2-dimensional, the matrix-matrix product is returned.
- If the first argument is 1-dimensional and the second argument is 2-dimensional,
a 1 is prepended to its dimension for the purpose of the matrix multiply.
After the matrix multiply, the prepended dimension is removed.
- If the first argument is 2-dimensional and the second argument is 1-dimensional,
the matrix-vector product is returned.
- If both arguments are at least 1-dimensional and at least one argument is
N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first
argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the
batched matrix multiply and removed after. If the second argument is 1-dimensional, a
1 is appended to its dimension for the purpose of the batched matrix multiple and removed after.
The non-matrix (i.e. batch) dimensions are broadcasted (and thus
must be broadcastable). For example, if tensor1 is a (j x 1 x n x m) Tensor
and tensor2 is a (k x m x p) Tensor, the returned tensor will be an (j x k x n x p) Tensor.
*/
Tensor matmul(
c10::optional<Tensor> out_opt,
const Tensor& tensor1,
const Tensor& tensor2) {
NoNamesGuard guard;
auto dim_tensor1 = tensor1.dim();
auto dim_tensor2 = tensor2.dim();
auto has_out = out_opt.has_value();
Tensor out = out_opt.value_or(Tensor());
if (dim_tensor1 == 1 && dim_tensor2 == 1) {
return has_out ? at::native::dot_out(out, tensor1, tensor2) : tensor1.dot(tensor2);
} else if (dim_tensor1 == 2 && dim_tensor2 == 1) {
return has_out ? at::mv_out(out, tensor1, tensor2) : tensor1.mv(tensor2);
} else if (dim_tensor1 == 1 && dim_tensor2 == 2) {
return has_out ? at::mm_out(out, tensor1.unsqueeze(0), tensor2).squeeze_(0)
: tensor1.unsqueeze(0).mm(tensor2).squeeze_(0);
} else if (dim_tensor1 == 2 && dim_tensor2 == 2) {
return has_out ? at::mm_out(out, tensor1, tensor2) : tensor1.mm(tensor2);
} else if (dim_tensor1 >= 3 && (dim_tensor2 == 1 || dim_tensor2 == 2)) {
// optimization: use mm instead of bmm by folding tensor1's batch into
// its leading matrix dimension.
Tensor t2 = dim_tensor2 == 1 ? tensor2.unsqueeze(-1) : tensor2;
auto size1 = tensor1.sizes();
auto size2 = t2.sizes();
std::vector<int64_t> output_size;
output_size.insert(output_size.end(), size1.begin(), size1.end() - 1);
if (dim_tensor2 > 1) {
output_size.push_back(size2[dim_tensor2 - 1]);
}
// fold the batch into the first dimension
Tensor t1 = tensor1.contiguous().view({-1, size1[size1.size() - 1]});
Tensor output = has_out ? at::_unsafe_view(at::mm_out(out, t1, t2), output_size)
: at::_unsafe_view(t1.mm(t2), output_size);
return has_out ? out.set_(output) : output;
} else if ((dim_tensor1 == 1 || dim_tensor1 == 2) && dim_tensor2 >= 3) {
// optimization: transpose the inner dimensions of the arguments, call
// matmul on the swapped arguments, then transpose the inner dimensions
// of the result.
const int64_t n = dim_tensor1 == 2 ? tensor1.size(-2) : 1;
const int64_t m = tensor1.size(-1);
const int64_t p = tensor2.size(-1);
const Tensor t2_T = tensor2.transpose(-1, -2);
const Tensor t1_T = dim_tensor1 == 2 ? tensor1.t() : tensor1.reshape({n, m}).t();
const Tensor res_T = matmul(out_opt, t2_T, t1_T);
if (dim_tensor1 == 2) {
Tensor res = res_T.transpose(-1, -2).contiguous();
return has_out ? out.set_(res) : res;
}
else {
std::vector<int64_t> shape = tensor2.sizes().slice(0, dim_tensor2 - 2).vec();
shape.push_back(p);
Tensor res = res_T.reshape(shape).contiguous();
return has_out ? out.set_(res) : res;
}
} else if ((dim_tensor1 >= 1 && dim_tensor2 >= 1) && (dim_tensor1 >= 3 || dim_tensor2 >= 3)) {
// We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
// we track m1 vs m2 separately even though they must match for nicer error messages
int64_t n = dim_tensor1 > 1 ? tensor1.size(-2) : 1;
int64_t m1 = tensor1.size(-1);
IntArrayRef batch_tensor1(tensor1.sizes().data(), std::max<int64_t>(dim_tensor1 - 2, 0));
int64_t m2 = dim_tensor2 > 1 ? tensor2.size(-2) : 1;
int64_t p = tensor2.size(-1);
IntArrayRef batch_tensor2(tensor2.sizes().data(), std::max<int64_t>(dim_tensor2 - 2, 0));
// expand the batch portion (i.e. cut off matrix dimensions and expand rest)
std::vector<int64_t> expand_batch_portion = infer_size(batch_tensor1, batch_tensor2);
std::vector<int64_t> tensor1_expand_size(expand_batch_portion);
tensor1_expand_size.insert(tensor1_expand_size.end(), {n, m1});
std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
tensor2_expand_size.insert(tensor2_expand_size.end(), {m2, p});
int expand_batch_product = std::accumulate(expand_batch_portion.begin(), expand_batch_portion.end(),
1, std::multiplies<int64_t>());
std::vector<int64_t> tensor1_bmm_view({expand_batch_product});
tensor1_bmm_view.insert(tensor1_bmm_view.end(), {n, m1});
std::vector<int64_t> tensor2_bmm_view({expand_batch_product});
tensor2_bmm_view.insert(tensor2_bmm_view.end(), {m2, p});
// flatten expanded batches
Tensor tensor1_expanded = tensor1.expand(tensor1_expand_size).contiguous().view(tensor1_bmm_view);
Tensor tensor2_expanded = tensor2.expand(tensor2_expand_size).contiguous().view(tensor2_bmm_view);
// reshape batches back into result
std::vector<int64_t> output_shape(expand_batch_portion);
if (dim_tensor1 > 1) {
output_shape.push_back(n);
}
if (dim_tensor2 > 1) {
output_shape.push_back(p);
}
Tensor output = has_out ? at::_unsafe_view(at::bmm_out(out, tensor1_expanded, tensor2_expanded), output_shape)
: at::_unsafe_view(tensor1_expanded.bmm(tensor2_expanded), output_shape);
return has_out ? out.set_(output) : output;
}
AT_ERROR("both arguments to matmul need to be at least 1D, but they are ",
dim_tensor1, "D and ", dim_tensor2, "D");
}
Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
auto maybe_outnames = namedinference::compute_matmul_outnames(tensor1, tensor2);
auto result = at::native::matmul(c10::nullopt, tensor1, tensor2);
namedinference::propagate_names_if_nonempty(result, maybe_outnames);
return result;
}
Tensor& matmul_out(Tensor &result, const Tensor & tensor1, const Tensor & tensor2) {
auto maybe_outnames = namedinference::compute_matmul_outnames(tensor1, tensor2);
at::native::matmul(c10::optional<Tensor>(result), tensor1, tensor2);
namedinference::propagate_names_if_nonempty(result, maybe_outnames);
return result;
}
// helper methods for matrix_exp
namespace {
template <typename scalar_t, int ROW, int COL>
using array2d = std::array<std::array<scalar_t, COL>, ROW>;
// we consider 6 Taylor expansions of degree
// 1, 2, 4, 8, 12, 18
constexpr int total_n_degs = 6;
Tensor operator_1_norm(const Tensor& tensor) {
return std::get<0>(tensor.abs().sum(-2).max(-1));
}
// Allocates a buffers of uninitialized or zero values
// of shape [n_copies, a.size()]
Tensor _allocate_buffer(const Tensor& a, int n_copies, bool is_zero = false) {
auto res = at::empty(
{n_copies, a.size(0), a.size(1), a.size(2)},
a.options().memory_format(at::MemoryFormat::Contiguous)
);
if (is_zero) {
res.zero_();
}
return res;
}
// Makes `buffer` to store `num_matrices` number of matrices needed for
// compute the matrix exponentials of different orders, i.e.
// first `num_matrices` matrices from the list l := {I, A, A^2, A^3, A^6}
// in a contiguous block of memory such that
// buffer[0, ...] = l[0], // I
// buffer[1, ...] = l[1], // A
// ...
// buffer[num_matrices - 1, ...] = l[num_matries - 1]
void _fill_matrix_powers(Tensor& buffer, const Tensor& a, int num_matrices) {
auto a_sizes_minus_last = a.sizes().vec();
a_sizes_minus_last.pop_back();
// fill I
buffer.select(0, 0).copy_(
at::diag_embed(
at::ones({1}, buffer.options())
.expand(a_sizes_minus_last)
)
);
// fill a
buffer.select(0, 1).copy_(a);
// fill a^2
if (2 <= num_matrices - 1) {
at::native::matmul(
buffer.select(0, 2), // out for a^2
buffer.select(0, 1),
buffer.select(0, 1)
);
}
// fill a^3
if (3 <= num_matrices - 1) {
at::native::matmul(
buffer.select(0, 3), // out for a^3
buffer.select(0, 1),
buffer.select(0, 2)
);
}
// fill a^6
if (4 <= num_matrices - 1) {
at::native::matmul(
buffer.select(0, 4),
buffer.select(0, 3),
buffer.select(0, 3)
);
}
}
inline Tensor _move_memory_if_cuda_input(
const Tensor& mem,
const Tensor& in
) {
return (in.device().type() == at::kCUDA)
? mem.to(at::device_of(in).value())
: mem;
}
// convert a 1D blob to a 2D Tensor of size [1, blob.size()]
// such that blob.device() == in.device())
// designed to be used with _compute_linear_combination
template <typename scalar_t>
inline Tensor _blob_to_Tensor(
std::initializer_list<scalar_t> blob,
const Tensor& in
) {
// we convert to void* expecitly because begin() returns
// a pointer to a constant.
// Blob is assumed to be a 1D array, that is why
// we also insert a fake dimension so that the result could directly
// be used in _compute_linear_combination
auto tensor = at::from_blob((void*)blob.begin(), blob.size(), in.dtype())
.unsqueeze(0);
return _move_memory_if_cuda_input(tensor, in);
}
// I + A
Tensor compute_T1(const Tensor& A) {
// 2 for {I, A}
auto As = _allocate_buffer(A, 2);
_fill_matrix_powers(As, A, 2);
return As.sum(0);
}
// I + A + A^2 / 2
Tensor compute_T2(const Tensor& A) {
auto As = _allocate_buffer(A, 3);
// 3 for {I, A, A^2}
_fill_matrix_powers(As, A, 3);
As.select(0, 2).div_(2.0);
return As.sum(0);
}
// I + A + A^2 * (I / 2 + A / 6 + A^2 / 24)
template <typename scalar_t>
Tensor compute_T4(const Tensor& A) {
auto As = _allocate_buffer(A, 4);
// 3 for {I, A, A^2}
_fill_matrix_powers(As, A, 3);
at::native::matmul(
// output for A^2 * (I / 2 + A / 6 + A^2 / 24)
As.select(0, 3),
// contains A^2
As.select(0, 2),
// computes (I / 2 + A / 6 + A^2 / 24)
at::native::_compute_linear_combination(
As.narrow(0, 0, 3),
_blob_to_Tensor<scalar_t>({1 / 2.0, 1 / 6.0, 1 / 24.0}, A)
)
);
// I + A + A^2 * (I / 2 + A / 6 + A^2 / 24)
return at::native::_compute_linear_combination(
As, _blob_to_Tensor<scalar_t>({1.0, 1.0, 0.0, 1.0}, A)
);
}
template <typename scalar_t>
Tensor compute_T8(const Tensor& A) {
constexpr scalar_t sqrt_177 = 0.1330413469565007072504e+2;
constexpr scalar_t x3 = 2. / 3.;
constexpr scalar_t x1 = x3 * ((1. + sqrt_177) / 88.);
constexpr scalar_t x2 = x3 * ((1. + sqrt_177) / 352.);
constexpr scalar_t x4 = (-271. + 29. * sqrt_177) / (315. * x3);
constexpr scalar_t x5 = (-11. + 11. * sqrt_177) / (1260. * x3);
constexpr scalar_t x6 = (-99. + 11. * sqrt_177) / (5040. * x3);
constexpr scalar_t x7 = (89. - sqrt_177) / (5040. * x3);
constexpr scalar_t y2 = (857. - 58. * sqrt_177) / 630.;
auto As = _allocate_buffer(A, 5);
// 3 for {I, A, A^2}
_fill_matrix_powers(As, A, 3);
// A4 = A2 * (x1 * A + x2 * A2)
at::native::matmul(
// output for A4
As.select(0, 3),
// As.select(0, 2) = A^2
As.select(0, 2),
at::native::_compute_linear_combination(
// extract {A, A^2} from As
As.narrow(0, 1, 2),
_blob_to_Tensor<scalar_t>({x1, x2}, A)
)
);
// A8 = (x3 * A2 + A4) * (x4 * I + x5 * A + x6 * A2 + x7 * A4)
at::native::matmul(
// output for A8
As.select(0, 4),
// x3 * A2 + A4
at::native::_compute_linear_combination(
As.narrow(0, 2, 2),
_blob_to_Tensor<scalar_t>({x3, 1.0}, A)
),
at::native::_compute_linear_combination(
As.narrow(0, 0, 4),
_blob_to_Tensor<scalar_t>({x4, x5, x6, x7}, A)
)
);
// return I + A + y2 * A2 + A8;
return at::native::_compute_linear_combination(
As,
_blob_to_Tensor<scalar_t>({1.0, 1.0, y2, 0.0, 1.0}, A)
);
}
template <typename scalar_t>
Tensor compute_T12(const Tensor& A) {
constexpr int num_prods = 4;
array2d<scalar_t, num_prods, num_prods> b = {{
{
9.0198e-16,
0.46932117595418237389,
-0.20099424927047284052,
-0.04623946134063071740
},
{
5.31597895759871264183,
1.19926790417132231573,
0.01179296240992997031,
0.01108844528519167989
},
{
0.18188869982170434744,
0.05502798439925399070,
0.09351590770535414968,
0.00610700528898058230
},
{
-2.0861320e-13,
-0.13181061013830184015,
-0.02027855540589259079,
-0.00675951846863086359
}
}};
// gather coefficients `b` from above into a tensor,
// and move them to device `device_of(A)`
auto bs = at::from_blob(
reinterpret_cast<void*>(&b),
{num_prods, num_prods},
{num_prods, 1},
A.dtype()
);
bs = _move_memory_if_cuda_input(bs, A);
auto As = _allocate_buffer(A, num_prods);
_fill_matrix_powers(As, A, num_prods);
auto Bs = at::native::_compute_linear_combination(As, bs);
// compute A6
Bs.select(0, 2).add_(at::native::matmul(
// tmp buffer for this matrix product
As.select(0, 0),
Bs.select(0, 3),
Bs.select(0, 3)
));
return Bs.select(0,0).add_(at::native::matmul(
// tmp buffer for this matrix product
As.select(0, 0),
Bs.select(0, 1).add_(Bs.select(0, 2)),
Bs.select(0, 2)
));
}
template <typename scalar_t>
Tensor compute_T18(const Tensor& A) {
constexpr int num_prods = 5;
array2d<scalar_t, num_prods, num_prods> b = {{
{
0.,
-1.00365581030144618291e-01,
-8.02924648241156932449e-03,
-8.92138498045729985177e-04,
0.
},
{
0.,
3.97849749499645077844e-01,
1.36783778460411720168e+00,
4.98289622525382669416e-01,
-6.37898194594723280150e-04