forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CUDAJitLoops.cuh
297 lines (257 loc) · 10.6 KB
/
CUDAJitLoops.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
#pragma once
#include <ATen/jit_macros.h>
// Jiterator functions are guarded behind this macro
#if AT_USE_JITERATOR()
#include <ATen/OpMathType.h>
#include <ATen/TensorIterator.h>
#include <ATen/core/Array.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <ATen/native/cuda/thread_constants.h>
#include <ATen/native/cuda/Loops.cuh>
#include <c10/macros/Macros.h>
#include <c10/core/ScalarType.h>
#include <c10/util/SmallBuffer.h>
#include <c10/util/C++17.h>
#include <initializer_list>
#include <type_traits>
#include <tuple>
#include <mutex>
namespace at {
namespace native {
template <typename Tuple, std::size_t... I>
constexpr auto tuple_to_array_helper(Tuple& t, std::index_sequence<I...> seq) {
constexpr auto size = seq.size();
(void)t; // warning : unused parameter when tuple is empty.
return std::array<void*, size>{static_cast<void*>(&std::get<I>(t))...};
}
// Helper function convert tuple to std::array<void*, N>
// for passing the arguments to CUDA Kernel
// NOTE: We capture tuple by reference,
// so the pointers in returned array are only valid
// till tuple is alive.
template <typename ...Args>
constexpr auto tuple_to_array(std::tuple<Args...>& extra_args) {
constexpr auto tuple_size = sizeof...(Args);
return tuple_to_array_helper(extra_args, std::make_index_sequence<tuple_size>{});
}
struct JittedVecKernelCache {
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
at::cuda::jit::NvrtcFunction vec1;
at::cuda::jit::NvrtcFunction vec2;
at::cuda::jit::NvrtcFunction vec4;
};
struct JittedKernelVariantCache {
JittedVecKernelCache vec;
at::cuda::jit::NvrtcFunction noncontiguous;
at::cuda::jit::NvrtcFunction dynamic_contiguous;
at::cuda::jit::NvrtcFunction dynamic_noncontiguous;
};
inline c10::SmallBuffer<void*, 64> pack_kernel_args(
std::initializer_list<void*> args,
c10::ArrayRef<void*> extra_args) {
c10::SmallBuffer<void*, 64> ret(args.size() + extra_args.size());
std::copy(args.begin(), args.end(), ret.data());
std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size());
return ret;
}
template<typename array_t,
typename inp_calc_t,
typename out_calc_t,
typename loader_t,
typename storer_t>
void launch_jitted_unrolled_kernel(
std::mutex &jiterator_mutex,
at::cuda::jit::NvrtcFunction &fn_cache,
const at::cuda::jit::KernelDescriptor &desc,
int64_t N,
array_t data,
inp_calc_t ic,
out_calc_t oc,
loader_t l,
storer_t s,
bool contiguous,
at::cuda::jit::BinaryFuncVariant scalar_pos,
void* scalar_val,
c10::ArrayRef<void*> extra_args) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
//casting result to int is always safe, intermediate is int64 and won't overflow
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
if (!fn_cache.function) {
const std::lock_guard<std::mutex> lock{jiterator_mutex};
if (!fn_cache.function) {
constexpr bool dynamic_casting = !std::is_same<decltype(l), memory::LoadWithoutCast>() ||
!std::is_same<decltype(s), memory::StoreWithoutCast>();
auto code = at::cuda::jit::generate_code(
desc, contiguous, dynamic_casting, scalar_pos);
fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name);
}
}
auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u},
{num_threads(), 1u, 1u});
}
template<int arity, typename array_t>
void launch_jitted_vectorized_kernel(
std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache,
const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data,
at::cuda::jit::BinaryFuncVariant scalar_pos,
void *scalar_val, c10::ArrayRef<void*> extra_args) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
// N is still int64_t for the computation, but it's always safe to cast result to int
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
const int vec_size = at::cuda::jit::can_vectorize_up_to(
desc, c10::ArrayRef<char*>(data.data, data.size()));
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
// fn_ptr is set to the appropriate function based on the vec size and GPU used
at::cuda::jit::NvrtcFunction* fn_ptr;
if (vec_size == 4) {
fn_ptr = &fn_cache.vec4;
} else if (vec_size == 2) {
fn_ptr = &fn_cache.vec2;
} else if (vec_size ==1) {
fn_ptr = &fn_cache.vec1;
} else {
TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel");
}
bool vectorized = vec_size > 1;
if (!fn_ptr->function) {
const std::lock_guard<std::mutex> lock{jiterator_mutex};
if (!fn_ptr->function) { // cache miss!
// Generates program
auto code = at::cuda::jit::generate_code(
desc, /*contiguous=*/true, /*dynamic_casting=*/false,
scalar_pos, vectorized, vec_size);
std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name;
// Acquires the program
*fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
}
}
if (vectorized) {
auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args);
at::cuda::jit::launch_jitted_pwise_function(
*fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
} else {
// NVCC complains about unused variables l and s.
// It should be false positive in most cases, so we suppress the warnings.
#pragma nv_diagnostic push
#pragma nv_diag_suppress 177
auto ic = TrivialOffsetCalculator<arity>();
auto oc = TrivialOffsetCalculator<1>();
auto l = memory::LoadWithoutCast();
auto s = memory::StoreWithoutCast();
auto args = pack_kernel_args(
{&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
at::cuda::jit::launch_jitted_pwise_function(
*fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
#pragma nv_diagnostic pop
}
}
template <int arity>
void jitted_gpu_kernel_generic(
std::mutex &jiterator_mutex,
JittedKernelVariantCache &cache,
const at::cuda::jit::KernelDescriptor &desc,
at::cuda::jit::BinaryFuncVariant scalar_pos,
c10::ArrayRef<void*> extra_args,
TensorIteratorBase& iter,
const bool dynamic_casting,
void *scalar_val) {
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
TORCH_INTERNAL_ASSERT(iter.ninputs() == arity);
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
constexpr int ntensors = arity + 1;
at::detail::Array<char*, ntensors> data;
for (auto i : c10::irange(ntensors)) {
data[i] = (char*)iter.data_ptr(i);
}
int64_t numel = iter.numel();
bool contiguous = iter.is_contiguous();
// Decides which of 4 kernel types to launch
// Variations are:
// - Case 1: no dynamic casting and contiguous
// - Case 2: no dynamic casting and noncontiguous
// - Case 3: dynamic casting and contiguous
// - Case 4: dynamic casting and noncontiguous
// These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
if (!dynamic_casting) {
if (contiguous) {
// Case 1: no dynamic casting and contiguous
launch_jitted_vectorized_kernel<arity>(
jiterator_mutex, cache.vec, desc,
numel, data, scalar_pos, scalar_val, extra_args);
return;
}
// Case 2: no dynamic casting and noncontiguous
auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
auto output_offset_calculator = make_output_offset_calculator(iter);
auto loader = memory::LoadWithoutCast();
auto storer = memory::StoreWithoutCast();
launch_jitted_unrolled_kernel(
jiterator_mutex, cache.noncontiguous, desc, numel, data,
input_offset_calculator, output_offset_calculator, loader,
storer, contiguous, scalar_pos, scalar_val, extra_args);
return;
}
// Cases 3 and 4 are handled below
// Both require construction of a storer (this asserts 1 output) and one or more loaders
// Creates store cast to output (the zeroth tensor in TensorIterator)
auto storer = memory::StoreWithCast<1>(iter);
// Creates load casts from inputs (note offset indexing into the iterators 1...n tensors)
auto loader = memory::LoadWithCast<arity>(iter);
if (contiguous) {
// Case 3: dynamic casting and contiguous
auto input_offset_calculator = TrivialOffsetCalculator<arity>();
auto output_offset_calculator = TrivialOffsetCalculator<1>();
launch_jitted_unrolled_kernel(
jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator,
output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
return;
}
// Case 4: dynamic casting and noncontiguous
auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
auto output_offset_calculator = make_output_offset_calculator(iter);
launch_jitted_unrolled_kernel(
jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator,
output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
}
// NOTE: static to reduce chances of name collision.
template <
char const* name,
typename result_type,
typename f_inputs_type,
int arity,
at::cuda::jit::BinaryFuncVariant scalar_pos =
at::cuda::jit::BinaryFuncVariant::NoScalar,
typename... ExtraArgs>
static void jitted_gpu_kernel_impl(
TensorIteratorBase& iter,
const std::string &f,
const bool dynamic_casting,
at::opmath_type<f_inputs_type> scalar_val,
std::tuple<ExtraArgs...> extra_args) {
// TODO: Memory use can probably be optimized by re-using kernels across GPUs with
// the same compute capability
static std::mutex jiterator_mutex;
static std::vector<JittedKernelVariantCache> device_caches(c10::cuda::device_count());
constexpr int nInputs = arity;
constexpr int nOutputs = 1; // TODO: Support more than 1 output
static const auto desc = at::cuda::jit::make_kernel_descriptor<
result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs);
auto &cache = device_caches[iter.device().index()];
auto extra_args_array = tuple_to_array(extra_args);
return jitted_gpu_kernel_generic<arity>(
jiterator_mutex,
cache,
desc,
scalar_pos,
extra_args_array,
iter,
dynamic_casting,
&scalar_val
);
}
}} // at::native
#endif // AT_USE_JITERATOR()