forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CrossKernel.cu
92 lines (77 loc) · 3.22 KB
/
CrossKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/Cross.h>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/Dispatch.h>
#include <ATen/core/Tensor.h>
namespace at::native {
template <typename T, typename OffsetCalc, typename StrideType>
__global__ void cross_kernel(
int numel, T* out, const T* x1, const T* x2, OffsetCalc offset_calculator,
StrideType ostride, StrideType x1stride, StrideType x2stride) {
CUDA_KERNEL_LOOP(i, numel) {
const auto offsets = offset_calculator.get(i);
auto* out_row = out + offsets[0];
const auto* x1_row = x1 + offsets[1];
const auto* x2_row = x2 + offsets[2];
const T val0 = (x1_row[1 * x1stride] * x2_row[2 * x2stride] -
x1_row[2 * x1stride] * x2_row[1 * x2stride]);
const T val1 = (x1_row[2 * x1stride] * x2_row[0 * x2stride] -
x1_row[0 * x1stride] * x2_row[2 * x2stride]);
const T val2 = (x1_row[0 * x1stride] * x2_row[1 * x2stride] -
x1_row[1 * x1stride] * x2_row[0 * x2stride]);
out_row[0 * ostride] = val0;
out_row[1 * ostride] = val1;
out_row[2 * ostride] = val2;
}
}
void launch_cross_kernel(const TensorIteratorBase& iter, int64_t ostride,
int64_t x1stride, int64_t x2stride) {
const auto N = iter.numel();
auto offset_calculator = make_element_offset_calculator<3>(iter);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(N > 0 && N <= std::numeric_limits<int32_t>::max());
int64_t grid = (N + num_threads() - 1) / num_threads();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, iter.common_dtype(), "cross_cuda", [&] {
auto out = static_cast<scalar_t*>(iter.data_ptr(0));
auto x1 = static_cast<const scalar_t*>(iter.data_ptr(1));
auto x2 = static_cast<const scalar_t*>(iter.data_ptr(2));
constexpr int64_t int_max = std::numeric_limits<int>::max();
if (ostride * 2 > int_max || x1stride * 2 > int_max || x2stride * 2 > int_max) {
cross_kernel<<<grid, num_threads(), 0, stream>>>(
N, out, x1, x2, offset_calculator, ostride, x1stride, x2stride);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
cross_kernel<<<grid, num_threads(), 0, stream>>>(
N, out, x1, x2, offset_calculator,
static_cast<int>(ostride),
static_cast<int>(x1stride),
static_cast<int>(x2stride));
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
}
void cross_impl(const Tensor& result, const Tensor& x1, const Tensor& x2, int64_t dim) {
const int64_t ostride = result.stride(dim);
const int64_t x1stride = x1.stride(dim);
const int64_t x2stride = x2.stride(dim);
auto iter = TensorIteratorConfig()
.add_output(result)
.add_input(x1)
.add_input(x2)
.resize_outputs(false)
.declare_static_shape(result.sizes(), /*squash_dims=*/dim)
.build();
if (iter.numel() == 0) {
return;
}
if (iter.can_use_32bit_indexing()) {
launch_cross_kernel(iter, ostride, x1stride, x2stride);
} else {
for (auto&& sub_iter: iter.with_32bit_indexing()) {
launch_cross_kernel(sub_iter, ostride, x1stride, x2stride);
}
}
}
REGISTER_DISPATCH(cross_stub, &cross_impl);
} // namespace at::native