forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LossMultiMargin.cpp
340 lines (309 loc) · 8.97 KB
/
LossMultiMargin.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/AccumulateType.h>
namespace at {
namespace native {
namespace {
template <typename scalar_t>
inline scalar_t multi_margin_inner_sum_cpu(
const scalar_t* input_data,
const scalar_t* weight_data,
const int p,
const scalar_t margin,
const int64_t dim,
const int64_t target_idx) {
const scalar_t input_target = input_data[target_idx];
scalar_t sum = 0;
for (int64_t d = 0; d < dim; d++) {
if (d == target_idx) {
continue;
}
const scalar_t z = margin - input_target + input_data[d];
if (z > 0) {
scalar_t h = (p == 1) ? z : z * z;
if (weight_data != nullptr) {
h *= weight_data[target_idx];
}
sum += h;
}
}
sum /= dim;
return sum;
}
inline int64_t target_index_checked(
const int64_t* target_data,
const int64_t index,
const int64_t dim) {
const int64_t idx = target_data[index];
TORCH_CHECK(idx >= 0 && idx < dim, "target out of range");
return idx;
}
template <typename scalar_t>
static inline void multi_margin_loss_cpu_kernel(
Tensor& output,
scalar_t* input_data,
int64_t* target_data,
const int p,
scalar_t margin,
scalar_t* weight_data,
const int64_t nframe,
const int64_t dim,
const int64_t reduction) {
using accscalar_t = at::acc_type<scalar_t, false>;
// dim() != 0 check is for 1d input which produces a scalar output (that
// cannot be handled by TensorAccessor)
if (reduction == Reduction::None && output.dim() > 0) {
auto output_acc = output.accessor<scalar_t, 1>();
for (int64_t t = 0; t < nframe; t++) {
const auto idx = target_index_checked(target_data, t, dim);
auto sum = multi_margin_inner_sum_cpu(
input_data, weight_data, p, margin, dim, idx);
output_acc[t] = sum;
input_data += dim;
}
} else {
accscalar_t sum = 0;
auto output_acc = output.data_ptr<scalar_t>();
for (int64_t t = 0; t < nframe; t++) {
const auto idx = target_index_checked(target_data, t, dim);
sum += multi_margin_inner_sum_cpu(
input_data, weight_data, p, margin, dim, idx);
input_data += dim;
}
if (reduction == Reduction::Mean) {
sum /= nframe;
}
output_acc[0] = sum;
}
}
void multi_margin_loss_out_cpu_template(
Tensor& output,
const Tensor& input,
const Tensor& target,
int p,
Scalar margin,
const Tensor& weight,
int64_t reduction) {
const auto ndims = input.dim();
TORCH_CHECK(
input.numel() > 0 && ndims <= 2,
"non-empty vector or matrix expected, got size: ",
input.sizes());
TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");
int64_t nframe, dim;
if (ndims <= 1) {
nframe = 1;
dim = ndims == 0 ? 1 : input.size(0);
} else {
nframe = input.size(0);
dim = input.size(1);
}
TORCH_CHECK(
target.numel() > 0 && target.dim() <= 1 && target.numel() == nframe,
"inconsistent target size, got: ",
target.sizes());
// produce a scalar output for 1d input
if (reduction == Reduction::None && target.dim() > 0) {
output.resize_({nframe});
} else {
output.resize_({});
}
auto input_contiguous = input.contiguous();
auto target_contiguous = target.contiguous();
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "multi_margin_loss_cpu_kernel", [&] {
auto input_data = input_contiguous.data_ptr<scalar_t>();
auto target_data = target_contiguous.data_ptr<int64_t>();
auto weight_data =
weight.defined() ? weight.data_ptr<scalar_t>() : nullptr;
multi_margin_loss_cpu_kernel<scalar_t>(
output,
input_data,
target_data,
p,
margin.to<scalar_t>(),
weight_data,
nframe,
dim,
reduction);
});
}
template <typename scalar_t>
static void multi_margin_loss_backward_cpu_kernel(
scalar_t* grad_input_data,
const Tensor& grad_output,
scalar_t* input_data,
int64_t* target_data,
int p,
scalar_t margin,
scalar_t g,
scalar_t* weight_data,
int64_t nframe,
int64_t dim,
int64_t reduction) {
scalar_t* grad_input_row_data = grad_input_data;
for (int64_t t = 0; t < nframe; t++) {
int64_t target_idx = target_index_checked(target_data, t, dim);
scalar_t input_target = input_data[target_idx];
scalar_t grad_input_target = 0;
for (int64_t d = 0; d < dim; d++) {
scalar_t z = margin - input_target + input_data[d];
if (d == target_idx) {
continue;
}
if (z > 0) {
scalar_t h = (p == 1) ? g : 2 * g * z;
if (weight_data != nullptr) {
h *= weight_data[target_idx];
}
grad_input_target -= h;
grad_input_row_data[d] = h;
} else {
grad_input_row_data[d] = 0;
}
}
grad_input_row_data[target_idx] = grad_input_target;
input_data += dim;
grad_input_row_data += dim;
}
if (reduction != Reduction::None || grad_output.dim() == 0) {
assert(
reduction != Reduction::None || grad_output.dim() > 0 ||
nframe == 1); // check 1d scalar fallback-case
const auto d = *grad_output.data_ptr<scalar_t>();
for (int64_t t = 0; t < nframe * dim; t++) {
grad_input_data[t] *= d;
}
} else {
auto grad_output_acc = grad_output.accessor<scalar_t, 1>();
for (int64_t t = 0; t < nframe; t++) {
for (int64_t d = 0; d < dim; d++) {
grad_input_data[t * dim + d] *= grad_output_acc[t];
}
}
}
}
void multi_margin_loss_backward_out_cpu_template(
Tensor& grad_input,
const Tensor& grad_output,
const Tensor& input,
const Tensor& target,
int p,
Scalar margin,
const Tensor& weight,
int64_t reduction) {
const auto ndims = input.dim();
TORCH_CHECK(
input.numel() > 0 && ndims <= 2,
"non-empty vector or matrix expected, got size: ",
input.sizes());
TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");
int64_t nframe, dim;
if (ndims <= 1) {
nframe = 1;
dim = ndims == 0 ? 1 : input.size(0);
} else {
nframe = input.size(0);
dim = input.size(1);
}
TORCH_CHECK(
target.numel() > 0 && target.dim() <= 1 && target.numel() == nframe,
"inconsistent target size, got: ",
target.sizes());
grad_input.resize_as_(input);
TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
auto input_contiguous = input.contiguous();
auto target_contiguous = target.contiguous();
auto weight_contiguous = weight.contiguous();
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "multi_margin_loss_backward_cpu_kernel", [&] {
auto grad_input_data = grad_input.data_ptr<scalar_t>();
auto input_data = input_contiguous.data_ptr<scalar_t>();
auto target_data = target_contiguous.data_ptr<int64_t>();
auto weight_data = weight_contiguous.defined()
? weight_contiguous.data_ptr<scalar_t>()
: nullptr;
scalar_t g = reduction == Reduction::Mean
? static_cast<scalar_t>(1. / (nframe * dim))
: static_cast<scalar_t>(1. / dim);
multi_margin_loss_backward_cpu_kernel<scalar_t>(
grad_input_data,
grad_output,
input_data,
target_data,
p,
margin.to<scalar_t>(),
g,
weight_data,
nframe,
dim,
reduction);
});
}
} // namespace
Tensor multi_margin_loss_cpu(
const Tensor& input,
const Tensor& target,
Scalar p,
Scalar margin,
const Tensor& weight,
int64_t reduction) {
auto output = at::empty({0}, input.options());
multi_margin_loss_out_cpu_template(
output, input, target, p.toInt(), margin, weight, reduction);
return output;
}
Tensor& multi_margin_loss_cpu_out(
Tensor& output,
const Tensor& input,
const Tensor& target,
Scalar p,
Scalar margin,
const Tensor& weight,
int64_t reduction) {
multi_margin_loss_out_cpu_template(
output, input, target, p.toInt(), margin, weight, reduction);
return output;
}
Tensor multi_margin_loss_cpu_backward(
const Tensor& grad_output,
const Tensor& input,
const Tensor& target,
Scalar p,
Scalar margin,
const Tensor& weight,
int64_t reduction) {
auto grad_input = at::empty({0}, input.options());
multi_margin_loss_backward_out_cpu_template(
grad_input,
grad_output,
input,
target,
p.toInt(),
margin,
weight,
reduction);
return grad_input;
}
Tensor& multi_margin_loss_cpu_backward_out(
Tensor& grad_input,
const Tensor& grad_output,
const Tensor& input,
const Tensor& target,
Scalar p,
Scalar margin,
const Tensor& weight,
int64_t reduction) {
multi_margin_loss_backward_out_cpu_template(
grad_input,
grad_output,
input,
target,
p.toInt(),
margin,
weight,
reduction);
return grad_input;
}
} // namespace native
} // namespace at