forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ReflectionPad.cu
453 lines (367 loc) · 13.6 KB
/
ReflectionPad.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/NativeFunctions.h>
#include <ATen/TensorUtils.h>
#include <ATen/Utils.h>
// keeping THC headers for atomicAdd
#include <THC/THCAtomics.cuh>
#include <thrust/pair.h>
namespace at {
namespace native {
namespace {
using at::cuda::detail::canUse32BitIndexMath;
__device__
inline thrust::pair<int64_t, int64_t> get_index_mapping1d(
int64_t input_w, int64_t output_w,
int64_t output_x,
int64_t pad_l) {
// 3D grid of 1D blocks
auto input_offset =
(blockIdx.y + blockIdx.z * gridDim.y) * input_w;
auto output_offset =
(blockIdx.y + blockIdx.z * gridDim.y) * output_w;
auto i_start_x = ::max(int64_t(0), -pad_l);
auto o_start_x = ::max(int64_t(0), pad_l);
int64_t input_x = ::abs(output_x - pad_l)
- ::abs(output_x - (input_w + pad_l - 1))
- output_x
+ 2 * pad_l + input_w - 1
- o_start_x + i_start_x;
return thrust::make_pair<int64_t, int64_t>(
input_offset + input_x, output_offset + output_x);
}
__device__
inline thrust::pair<int64_t, int64_t> get_index_mapping2d(
int64_t input_dim_x, int64_t input_dim_y,
int64_t output_dim_x, int64_t output_dim_y,
int64_t pad_l, int64_t pad_t,
int64_t output_xy) {
// 3D grid of 1D blocks
auto input_offset =
(blockIdx.y + blockIdx.z * gridDim.y) * input_dim_x * input_dim_y;
auto output_offset =
(blockIdx.y + blockIdx.z * gridDim.y) * output_dim_x * output_dim_y;
auto output_x = output_xy % output_dim_x;
auto output_y = output_xy / output_dim_x;
auto i_start_x = ::max(int64_t(0), -pad_l);
auto i_start_y = ::max(int64_t(0), -pad_t);
auto o_start_x = ::max(int64_t(0), pad_l);
auto o_start_y = ::max(int64_t(0), pad_t);
auto input_x = ::abs(output_x - pad_l)
- ::abs(output_x - (input_dim_x + pad_l - 1))
- output_x
+ 2 * pad_l + input_dim_x - 1
- o_start_x + i_start_x;
auto input_y = ::abs(output_y - pad_t)
- ::abs(output_y - (input_dim_y + pad_t - 1))
- output_y
+ 2 * pad_t + input_dim_y - 1
- o_start_y + i_start_y;
return thrust::make_pair<int64_t, int64_t>(
input_offset + input_y * input_dim_x + input_x,
output_offset + output_y * output_dim_x + output_x);
}
template<typename scalar_t>
__global__ void reflection_pad1d_out_kernel(
scalar_t * input, scalar_t * output,
int64_t input_w,
int64_t pad_l, int64_t pad_r) {
auto output_x = threadIdx.x + blockIdx.x * blockDim.x;
auto output_w = input_w + pad_l + pad_r;
if (output_x < output_w) {
auto index_pair = get_index_mapping1d(input_w, output_w, output_x, pad_l);
output[index_pair.second] = input[index_pair.first];
}
}
template <typename scalar_t>
__global__ void reflection_pad1d_backward_out_kernel(
scalar_t * grad_input, scalar_t * grad_output,
int64_t input_w,
int64_t pad_l, int64_t pad_r) {
auto output_x = threadIdx.x + blockIdx.x * blockDim.x;
auto output_w = input_w + pad_l + pad_r;
if (output_x < output_w) {
auto index_pair = get_index_mapping1d(input_w, output_w, output_x, pad_l);
atomicAdd(
&grad_input[index_pair.first], grad_output[index_pair.second]);
}
}
template<typename scalar_t>
__global__ void reflection_pad2d_out_kernel(
scalar_t * input, scalar_t * output,
int64_t input_dim_x, int64_t input_dim_y,
int pad_t, int pad_b, int pad_l, int pad_r) {
auto output_xy = threadIdx.x + blockIdx.x * blockDim.x;
auto output_dim_x = input_dim_x + pad_l + pad_r;
auto output_dim_y = input_dim_y + pad_t + pad_b;
if (output_xy < output_dim_x * output_dim_y) {
auto index_pair = get_index_mapping2d(
input_dim_x, input_dim_y,
output_dim_x, output_dim_y,
pad_l, pad_t,
output_xy);
output[index_pair.second] = input[index_pair.first];
}
}
template <typename scalar_t>
__global__ void reflection_pad2d_backward_out_kernel(
scalar_t * grad_input, scalar_t * grad_output,
int64_t input_dim_x, int64_t input_dim_y,
int pad_t, int pad_b, int pad_l, int pad_r) {
auto output_xy = threadIdx.x + blockIdx.x * blockDim.x;
auto output_dim_x = input_dim_x + pad_l + pad_r;
auto output_dim_y = input_dim_y + pad_t + pad_b;
if (output_xy < output_dim_x * output_dim_y) {
auto index_pair = get_index_mapping2d(
input_dim_x, input_dim_y,
output_dim_x, output_dim_y,
pad_l, pad_t,
output_xy);
atomicAdd(&grad_input[index_pair.first], grad_output[index_pair.second]);
}
}
void reflection_pad1d_out_template(
Tensor &output, const Tensor &input_, IntArrayRef padding) {
AT_CHECK(canUse32BitIndexMath(input_),
"input tensor must fit into 32-bit index math");
int64_t dim_plane = 0;
int64_t dim_w = 1;
int64_t nbatch = 1;
AT_CHECK(input_.numel() > 0 &&
(input_.ndimension() == 2 || input_.ndimension() == 3), "non-empty 2D "
"or 3D (batch mode) tensor expected for input, but got: ", input_);
if (input_.ndimension() == 3) {
nbatch = input_.size(0);
dim_plane++;
dim_w++;
}
int64_t pad_l = padding[0];
int64_t pad_r = padding[1];
int64_t nplane = input_.size(dim_plane);
int64_t input_w = input_.size(dim_w);
int64_t output_w = input_w + pad_l + pad_r;
AT_CHECK(pad_l < input_w && pad_r < input_w, "Padding size should be less "
"than the corresponding input dimension, but got: padding (", pad_l, ", ",
pad_r, ") at dimension ", dim_w, " of input ", input_);
AT_CHECK(output_w >= 1,
"input (W: ", input_w, ")is too small. Calculated output W: ", output_w);
if (input_.ndimension() == 2) {
output.resize_({nplane, output_w});
} else {
output.resize_({nbatch, nplane, output_w});
}
dim3 block_size(output_w > 256 ? 256 : output_w);
dim3 grid_size((int) ::ceil(output_w / 256.0), nplane, nbatch);
Tensor input = input_.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "reflection_pad1d_out_template", [&] {
reflection_pad1d_out_kernel<<<
grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
input.data<scalar_t>(), output.data<scalar_t>(),
input_w, pad_l, pad_r);
}
);
AT_CUDA_CHECK(cudaGetLastError());
}
void reflection_pad1d_backward_out_template(
Tensor & grad_input, const Tensor & grad_output_,
const Tensor & input, IntArrayRef padding) {
AT_CHECK(canUse32BitIndexMath(input),
"input tensor must fit into 32-bit index math");
AT_CHECK(canUse32BitIndexMath(grad_output_),
"input tensor must fit into 32-bit index math");
int64_t dim_plane = 0;
int64_t dim_w = 1;
int64_t nbatch = 1;
if (input.ndimension() == 3) {
nbatch = input.size(0);
dim_plane++;
dim_w++;
}
int64_t pad_l = padding[0];
int64_t pad_r = padding[1];
int64_t nplane = input.size(dim_plane);
int64_t input_w = input.size(dim_w);
int64_t output_w = input_w + pad_l + pad_r;
Tensor grad_output = grad_output_.contiguous();
AT_CHECK(output_w == grad_output.size(dim_w),
"gradOutput width unexpected. Expected: ", output_w, ", Got: ",
grad_output.size(dim_w));
dim3 block_size(output_w > 256 ? 256 : output_w);
dim3 grid_size((int) ::ceil(output_w / 256.0), nplane, nbatch);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_input.scalar_type(), "reflection_pad1d_backward_out_template", [&] {
reflection_pad1d_backward_out_kernel<<<
grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input.data<scalar_t>(), grad_output.data<scalar_t>(),
input_w, pad_l, pad_r);
}
);
AT_CUDA_CHECK(cudaGetLastError());
}
void reflection_pad2d_out_template(
Tensor &output, const Tensor &input_, IntArrayRef padding) {
AT_CHECK(canUse32BitIndexMath(input_),
"input tensor must fit into 32-bit index math");
int plane_dim = 0;
int dim_h = 1;
int dim_w = 2;
int nbatch = 1;
AT_CHECK(input_.numel() > 0 &&
(input_.ndimension() == 3 || input_.ndimension() == 4), "non-empty 3D or "
"4D (batch mode) tensor expected for input, but got: ", input_);
if (input_.ndimension() == 4) {
nbatch = input_.size(0);
plane_dim++;
dim_h++;
dim_w++;
}
int64_t pad_l = padding[0];
int64_t pad_r = padding[1];
int64_t pad_t = padding[2];
int64_t pad_b = padding[3];
int nplane = input_.size(plane_dim);
int input_h = input_.size(dim_h);
int input_w = input_.size(dim_w);
AT_CHECK(pad_l < input_w && pad_r < input_w,
"Padding size should be less than the corresponding input dimension, but "
"got: padding (", pad_l, ", ", pad_r, ") at dimension ", dim_w,
" of input ", input_.sizes());
AT_CHECK(pad_t < input_h && pad_b < input_h,
"Padding size should be less than the corresponding input dimension, but "
"got: padding (", pad_t, ", ", pad_b, ") at dimension ", dim_h,
" of input ", input_.sizes());
int output_h = input_h + pad_t + pad_b;
int output_w = input_w + pad_l + pad_r;
AT_CHECK(output_w >= 1 || output_h >= 1,
"input (H: ", input_h, ", W: ", input_w, ")is too small. Calculated "
"output H: ", output_h, " W: ", output_w);
if (input_.ndimension() == 3) {
output.resize_({nplane, output_h, output_w});
} else {
output.resize_({nbatch, nplane, output_h, output_w});
}
Tensor input = input_.contiguous();
int output_plane_size = output_h * output_w;
dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
dim3 grid_size(
(int) std::ceil(output_plane_size/256.0), nplane, nbatch);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "reflection_pad2d_out_template", [&] {
reflection_pad2d_out_kernel<<<
grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
input.data<scalar_t>(), output.data<scalar_t>(),
input_w, input_h,
pad_t, pad_b, pad_l, pad_r);
}
);
AT_CUDA_CHECK(cudaGetLastError());
}
void reflection_pad2d_backward_out_template(
Tensor &grad_input, const Tensor &grad_output_,
const Tensor &input, IntArrayRef padding) {
AT_CHECK(canUse32BitIndexMath(input),
"input tensor must fit into 32-bit index math");
AT_CHECK(canUse32BitIndexMath(grad_output_),
"output gradient tensor must fit into 32-bit index math");
int plane_dim = 0;
int dim_h = 1;
int dim_w = 2;
int nbatch = 1;
if (input.ndimension() == 4) {
nbatch = input.size(0);
plane_dim++;
dim_h++;
dim_w++;
}
int64_t pad_l = padding[0];
int64_t pad_r = padding[1];
int64_t pad_t = padding[2];
int64_t pad_b = padding[3];
int nplane = input.size(plane_dim);
int input_h = input.size(dim_h);
int input_w = input.size(dim_w);
int output_h = input_h + pad_t + pad_b;
int output_w = input_w + pad_l + pad_r;
AT_CHECK(output_w == grad_output_.size(dim_w), "grad_output width "
"unexpected. Expected: ", output_w, ", Got: ", grad_output_.size(dim_w));
AT_CHECK(output_h == grad_output_.size(dim_h), "grad_output height "
"unexpected. Expected: ", output_h, ", Got: ", grad_output_.size(dim_h));
Tensor grad_output = grad_output_.contiguous();
int output_plane_size = output_h * output_w;
dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
dim3 grid_size(
(int) std::ceil(output_plane_size/256.0), nplane, nbatch);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "reflection_pad2d_backward_out_template", [&] {
reflection_pad2d_backward_out_kernel<<<
grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input.data<scalar_t>(), grad_output.data<scalar_t>(),
input_w, input_h,
pad_t, pad_b, pad_l, pad_r);
}
);
AT_CUDA_CHECK(cudaGetLastError());
}
} // namespace
Tensor& reflection_pad1d_out_cuda(
Tensor& output, const Tensor& input, IntArrayRef padding) {
reflection_pad1d_out_template(output, input, padding);
return output;
}
Tensor reflection_pad1d_cuda(const Tensor& input, IntArrayRef padding) {
auto output = at::empty({0}, input.options());
reflection_pad1d_out_template(output, input, padding);
return output;
}
Tensor& reflection_pad1d_backward_out_cuda(
Tensor& grad_input, const Tensor& grad_output,
const Tensor& input,
IntArrayRef padding) {
grad_input.resize_as_(input);
grad_input.zero_();
reflection_pad1d_backward_out_template(
grad_input, grad_output, input, padding);
return grad_input;
}
Tensor reflection_pad1d_backward_cuda(
const Tensor& grad_output,
const Tensor& input,
IntArrayRef padding) {
auto grad_input = at::zeros_like(input);
reflection_pad1d_backward_out_template(
grad_input, grad_output, input, padding);
return grad_input;
}
Tensor& reflection_pad2d_out_cuda(
Tensor& output, const Tensor& input, IntArrayRef padding) {
reflection_pad2d_out_template(output, input, padding);
return output;
}
Tensor reflection_pad2d_cuda(const Tensor& input, IntArrayRef padding) {
auto output = at::empty({0}, input.options());
reflection_pad2d_out_template(output, input, padding);
return output;
}
Tensor& reflection_pad2d_backward_out_cuda(
Tensor& grad_input, const Tensor& grad_output,
const Tensor& input,
IntArrayRef padding) {
grad_input.resize_as_(input);
grad_input.zero_();
reflection_pad2d_backward_out_template(
grad_input, grad_output, input, padding);
return grad_input;
}
Tensor reflection_pad2d_backward_cuda(
const Tensor& grad_output,
const Tensor& input,
IntArrayRef padding) {
auto grad_input = at::zeros_like(input);
reflection_pad2d_backward_out_template(
grad_input, grad_output, input, padding);
return grad_input;
}
} // namespace native
} // namespace at