forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SparseBinaryOpIntersectionCommon.h
483 lines (427 loc) · 19.6 KB
/
SparseBinaryOpIntersectionCommon.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
#pragma once
#include <ATen/Tensor.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/Dispatch.h>
#include <ATen/native/sparse/Macros.h>
#include <ATen/ExpandUtils.h>
#include <ATen/native/SparseTensorUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/arange.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
#include <ATen/ops/result_type.h>
#endif
#ifdef GPUCC
#define NAME "sparse_binary_op_intersection_cuda"
#else
#define NAME "sparse_binary_op_intersection_cpu"
#endif
namespace at::native {
namespace {
using at::sparse::get_sparse_impl;
// ForwardIt: only legacy random access iterator is supported.
template<class ForwardIt, class T, bool is_lower = true>
static FUNCAPI INLINE
ForwardIt find_bound(ForwardIt first, ForwardIt last, const T& value) {
ForwardIt RESTRICT it;
typename std::iterator_traits<ForwardIt>::difference_type count, step;
// NOTE: std::distance(first, last) compiles but produces wrong results on CUDA,
// so only legacy random access iterators are safe in this code.
count = last - first;
while (count > 0) {
it = first;
step = count / 2;
// avoiding std::advance(it, step),
// although it does work unlike std::distance on CUDA.
it += step;
// The decision which separates finding a lower bound vs an upper bound.
// Note that a lower bound is a value at *it with the smallest index
// such that *it >= value if such value exists, or last if does not.
// Similarly, an upper bound is a value at *it with the smallest index
// such that *it > value if such value exists, or last if does not.
// Let is_lower = true and *it < value, then we know that *it and values
// preceeding *it cannot contain a lower bound, so we adjust initial iterator range
// from [first, first + count] to [first + step + 1, first + count - (step + 1)],
// where +1 skips the element at which we have just evaluated *it < value.
// Samilar logic holds when is_lower = false.
if (is_lower ? *it < value : value >= *it) {
first = ++it;
count -= step + 1;
}
else {
count = step;
}
}
return first;
}
template <template <typename func_t> class kernel_t>
struct KernelLauncher {
template <typename func_t>
static void launch(TensorIteratorBase& iter, const func_t& f) {
kernel_t<func_t>::launch(iter, f);
}
};
TensorIterator make_value_selection_intersection_iter(
const Tensor& lhs_values,
const Tensor& lhs_select_idx,
const Tensor& rhs_values,
const Tensor& rhs_select_idx,
const Tensor& intersection_counts) {
const auto res_values_sizes = [&]() -> std::vector<int64_t> {
auto sizes = infer_size(
// keep nnz dim
lhs_values.sizes(),
// remove nnz dim for smooth broadcasting
rhs_values.sizes().slice(1));
// update nnz dim to be the length of an index
sizes[0] = lhs_select_idx.numel();
return sizes;
}();
auto res_values = at::empty(res_values_sizes, lhs_values.options());
const auto restride_idx = [&res_values](const Tensor& idx) -> Tensor {
auto idx_sizes = std::vector<int64_t>(res_values.dim(), 1);
auto idx_strides = std::vector<int64_t>(res_values.dim(), 0);
idx_sizes[0] = idx.numel();
idx_strides[0] = 1;
return idx.as_strided(idx_sizes, idx_strides);
};
const auto restride_values = [&lhs_select_idx](const Tensor& values) -> Tensor {
auto values_sizes = at::DimVector(values.sizes());
auto values_strides = at::DimVector(values.strides());
values_sizes[0] = lhs_select_idx.numel();
values_strides[0] = 0;
return values.as_strided(values_sizes, values_strides);
};
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.add_owned_output(res_values)
.add_owned_input(restride_values(lhs_values))
.add_owned_input(restride_idx(lhs_select_idx))
.add_owned_input(restride_values(rhs_values))
.add_owned_input(restride_idx(rhs_select_idx))
.add_owned_input(restride_idx(intersection_counts))
.build();
return iter;
}
template <
template <typename func_t> class kernel_t,
typename value_selection_intersection_kernel_t,
typename index_t = int64_t,
int64_t max_static_len = 0>
void _sparse_binary_op_intersection_kernel_impl(
Tensor& res,
const Tensor& x_,
const Tensor& y_,
const std::vector<int64_t>& broadcasted_shape,
const std::optional<Tensor>& x_hash_opt_ = c10::nullopt,
const std::optional<Tensor>& y_hash_opt_ = c10::nullopt,
const bool accumulate_matches = true,
const bool distributive_with_sum = true
) {
// The common dtype check is relevant when op is done in-place.
// This is because binary_of_t produces new values and it could be that
// new_values.dtype != res.dtype. In such a case we should error out
// as soon as possible to avoid redundant kernel runs.
const auto common_dtype = at::result_type(x_, y_);
TORCH_CHECK(canCast(common_dtype, res.scalar_type()),
"Can't convert result type ", common_dtype,
" to output ", res.scalar_type());
using KernelLauncher = KernelLauncher<kernel_t>;
using OptTensor = std::optional<Tensor>;
// If the op and sum are not distributive, coalesce is required.
const auto coalesce_if_not_distributive = [distributive_with_sum](const Tensor& t, const OptTensor& t_hash_opt) -> auto {
// No need to coalesce in such a case.
if (distributive_with_sum) {
return std::make_tuple(t, t_hash_opt);
} else {
// Otherwise coalesce and force hash recompute.
return std::make_tuple(t.coalesce(), static_cast<OptTensor>(c10::nullopt));
}
};
Tensor x, y;
OptTensor x_hash_opt, y_hash_opt;
std::tie(x, x_hash_opt) = coalesce_if_not_distributive(x_, x_hash_opt_);
std::tie(y, y_hash_opt) = coalesce_if_not_distributive(y_, y_hash_opt_);
// Given sparse tensors x and y we decide which one is source, and which one
// is probably_coalesced. The indices of both source and probably_coalesced are
// hashed and then the hash values of the source's indices are binary-searched
// into the hash values of the probably_coalesced's indices.
// If probably_coalesce is coalesced, by the property of the hashing method
// (see below), the hash values are already sorted and we can avoid any
// explicit sorting routines.
Tensor probably_coalesced, source;
OptTensor probably_coalesced_indices_hash_opt, source_indices_hash_opt;
std::tie(probably_coalesced, probably_coalesced_indices_hash_opt, source, source_indices_hash_opt) = [&]() -> auto {
// Case 1: either x or y is coalesced.
if ((x.is_coalesced() ^ y.is_coalesced())) {
return x.is_coalesced()
? std::make_tuple(x, x_hash_opt, y, y_hash_opt)
: std::make_tuple(y, y_hash_opt, x, x_hash_opt);
}
// Case 2: Both x and y are either coalesced or non-coalesced.
// If both are coalesced, search into the larger tensor is faster.
// Same holds when both are non-coalesced.
else {
Tensor larger, smaller;
OptTensor larger_hash_opt, smaller_hash_opt;
std::tie(larger, larger_hash_opt, smaller, smaller_hash_opt) = [&]() -> auto {
return x._nnz() >= y._nnz()
? std::make_tuple(x, x_hash_opt, y, y_hash_opt)
: std::make_tuple(y, y_hash_opt, x, x_hash_opt);
}();
// If under a uniform distribution it is likely to hit many elements in larger,
// it is best to coalesce it for better performance.
const auto larger_sizes = larger.sizes();
const auto sparse_dim_numel = std::accumulate(
larger_sizes.begin(),
larger_sizes.begin() + larger.sparse_dim(),
1,
std::multiplies<int64_t>());
// If nnz > prod(larger.shape[:sparse_dim]), by the pidgeonhole principle,
// there is at least one bucket with nnz / prod(larger.shape[:sparse_dim]) elements.
// It provides a lower bound for the max count in the intersection.
// This condition is very conservative as we do not check whether such an event
// actually occurred, although it is very likely under a uniform distribution,
// the distribution with the highest uncertainty (maximizes entropy).
const auto max_count_lower_bound = larger._nnz() / sparse_dim_numel;
constexpr int64_t MAX_COPIES_PER_THREAD = 50;
return max_count_lower_bound > MAX_COPIES_PER_THREAD
// coalesce invalidates hash values, so force-recompute
? std::make_tuple(larger.coalesce(), static_cast<OptTensor>(c10::nullopt), smaller, smaller_hash_opt)
: std::make_tuple(larger, larger_hash_opt, smaller, smaller_hash_opt);
}
}();
// The employed hash function maps a d-dim index to a linear offset
// into a contiguous memory that is sufficient to fit a dense tensor
// of shape broadcasted_shape(x.shape, y.shape), i.e.
// idx -> \sum_{i = 0}^d idx[i] * hash_coeffs[i], where
// hash_coeffs are the strides of a contiguous tensor of shape
// broadcasted_shape(x.shape, y.shape).
// Assuming the following order on the dimensions, i.e. the right-most dim is the
// fastest-changing dim, and the left-most is the slowest-changing dim,
// which is implicit in the definition of hash_coeffs,
// it could be shown that the hash function is actually bijective and, hence,
// is a perfect hash function (no collisions ever).
// Need owning storage in case of the Tensor class.
const auto hash_coeffs_storage = [&]() -> auto {
const auto broadcasted_sparse_dim_shape = std::vector<int64_t>(
broadcasted_shape.begin(),
broadcasted_shape.begin() + probably_coalesced.sparse_dim()
);
auto strides = c10::contiguous_strides(broadcasted_sparse_dim_shape);
return at::sparse::TensorGeometryHolder<max_static_len>(strides, strides, probably_coalesced.options());
}();
const auto hash_coeffs = std::get<0>(*hash_coeffs_storage);
const auto nnz_arange = at::arange(
std::max(probably_coalesced._nnz(), source._nnz()),
source._indices().options());
const auto probably_coalesced_nnz_arange = nnz_arange.narrow(-1, 0, probably_coalesced._nnz());
// non-const because of gcc-5/clang-5 issues
auto sparse_dim = probably_coalesced.sparse_dim();
// Apply the hash function to probably_coalesced.indices
const auto probably_coalesced_indices_hash = [&]() -> Tensor {
// probably_coalesced is coalesced and hash provided? Reuse it!
if (probably_coalesced_indices_hash_opt.has_value()) {
return (*probably_coalesced_indices_hash_opt).contiguous();
}
const auto indices = probably_coalesced._indices();
// non-const because of gcc-5/clang-5 issues
auto indices_dim_stride = indices.stride(0);
auto indices_nnz_stride = indices.stride(1);
auto hash = at::empty({probably_coalesced._nnz()}, indices.options().dtype(kLong));
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(hash)
.add_input(probably_coalesced_nnz_arange)
.build();
{
const auto* RESTRICT ptr_indices = indices.const_data_ptr<index_t>();
KernelLauncher::launch(iter,
// NOTE: capture by value required by CUDA
[=] FUNCAPI (index_t nnz_idx) -> int64_t {
int64_t hash = 0;
if (!ptr_indices) {
return hash;
}
const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride;
for (int64_t dim = 0; dim < sparse_dim; ++dim) {
const auto dim_hash_coeff = hash_coeffs[dim];
const auto dim_index = ptr_indices_dim[dim * indices_dim_stride];
hash += dim_index * dim_hash_coeff;
}
return hash;
});
}
return hash;
}();
// Now that we have hash values of probably_coalesced.indices,
// we need to decide whether they need to get sorted.
// The sort is not requires if probably_coalesced is coalesced.
Tensor sorted_hash, argsort_hash;
std::tie(sorted_hash, argsort_hash) = [&]() -> std::tuple<Tensor, Tensor> {
if (probably_coalesced.is_coalesced()) {
// NOTE: argsort.dtype == nnz_arange.dtype
const auto argsort = nnz_arange.narrow(-1, 0, probably_coalesced._nnz());
return std::make_tuple(probably_coalesced_indices_hash, argsort);
} else {
// NOTE: we want argsort.dtype == nnz_arange.dtype,
// but sort() produces indices of type int64_t,
// so we convert to nnz_arange.dtype to avoid issues
// with pointer types in the kernels below.
Tensor sorted, argsort;
std::tie(sorted, argsort) = probably_coalesced_indices_hash.sort();
return std::make_tuple(sorted, argsort.to(nnz_arange.scalar_type()));
}
}();
// Perform hash intersection.
// Let s_hash = hash(source.indices),
// pc_hash = hash(probably_coalesced.indices), then
// for i = 0, ..., len(s_hash) - 1:
// lb = <index of a value in pc_hash[argsort_hash] which is a lower bound for s_hash[i]>,
// up = <index of a value in pc_hash[argsort_hash] which is an upper bound for s_hash[i]>,
// intersection_count[i] = up - lb
// intersection_first_idx[i] = lb.
//
// intersection_count and intersection_first_idx are used to form indices at which
// intersection values are selected.
Tensor intersection_count, intersection_first_idx;
std::tie(intersection_count, intersection_first_idx) = [&]() -> std::tuple<Tensor, Tensor> {
const auto source_nnz = source._nnz();
auto intersection_buffer = at::empty({2, source_nnz}, sorted_hash.options());
auto intersection_count = intersection_buffer.select(0, 0);
auto intersection_first_idx = intersection_buffer.select(0, 1);
const auto source_indices = source._indices();
const auto source_arange = nnz_arange.narrow(-1, 0, source_nnz);
// non-const because of gcc-5/clang-5 issues
auto indices_dim_stride = source_indices.stride(0);
auto indices_nnz_stride = source_indices.stride(1);
auto dummy = at::empty({1}, source_arange.options());
auto hash = source_indices_hash_opt.has_value()
? (*source_indices_hash_opt).contiguous()
: at::empty({0}, probably_coalesced._indices().options().dtype(kLong));
const auto* RESTRICT hash_ptr = source_indices_hash_opt.has_value()
? hash.data_ptr<int64_t>()
: nullptr;
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.add_owned_output(dummy.expand_as(source_arange))
.add_input(source_arange)
.build();
{
const auto* RESTRICT ptr_indices = source_indices.const_data_ptr<index_t>();
const auto* RESTRICT ptr_sorted_hash = sorted_hash.const_data_ptr<int64_t>();
const auto sorted_hash_len = sorted_hash.numel();
auto* RESTRICT ptr_intersection_count = intersection_count.data_ptr<int64_t>();
auto* RESTRICT ptr_intersection_first_idx = intersection_first_idx.data_ptr<int64_t>();
// Fusing hash computation with hash intersection.
KernelLauncher::launch(iter,
// NOTE: capture by value required by CUDA
[=] FUNCAPI (index_t nnz_idx) -> index_t {
int64_t hash = 0;
if (hash_ptr) {
hash = hash_ptr[nnz_idx];
} else if (sparse_dim) {
// Compute hash value
const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride;
for (int64_t dim = 0; dim < sparse_dim; ++dim) {
const auto dim_hash_coeff = hash_coeffs[dim];
const auto dim_index = ptr_indices_dim[dim * indices_dim_stride];
hash += dim_index * dim_hash_coeff;
}
}
// Perform hash values intersection
const auto* RESTRICT lb = find_bound<const int64_t*, int64_t, /*is_lower=*/true>(
ptr_sorted_hash,
ptr_sorted_hash + sorted_hash_len,
hash
);
const auto* RESTRICT ub = find_bound<const int64_t*, int64_t, /*is_lower=*/false>(
ptr_sorted_hash,
ptr_sorted_hash + sorted_hash_len,
hash
);
ptr_intersection_count[nnz_idx] = ub - lb;
ptr_intersection_first_idx[nnz_idx] = lb - ptr_sorted_hash;
return 0;
});
}
return std::make_tuple(intersection_count, intersection_first_idx);
}();
const auto res_indices = source._indices().clone();
const auto binary_op_res_dtype = at::result_type(source._values(), probably_coalesced._values());
const auto res_values = value_selection_intersection_kernel_t::apply(
source._values().to(binary_op_res_dtype),
nnz_arange.narrow(-1, 0, source._nnz()),
probably_coalesced._values().to(binary_op_res_dtype),
intersection_first_idx.to(nnz_arange.scalar_type()),
intersection_count,
argsort_hash,
accumulate_matches).to(res.scalar_type());
const auto res_sparse_dim = source.sparse_dim();
const auto res_dense_dim = source.dense_dim();
const auto& res_shape = broadcasted_shape;
const auto res_nnz = source._nnz();
auto* res_sparse_impl = get_sparse_impl(res);
res_sparse_impl->raw_resize_(res_sparse_dim, res_dense_dim, res_shape);
res_sparse_impl->set_indices_and_values_unsafe(res_indices, res_values);
res_sparse_impl->set_nnz_and_narrow(res_nnz);
res._coalesced_(source.is_coalesced());
}
template <
template <typename func_t> class kernel_t,
typename value_selection_intersection_kernel_t>
void _sparse_binary_op_intersection_kernel_out(
Tensor& res,
const Tensor& x,
const Tensor& y,
const std::optional<Tensor>& x_hash_opt = c10::nullopt,
const std::optional<Tensor>& y_hash_opt = c10::nullopt,
// If op distributes with the sum, the arguments are processed as is,
// without the calls to coalesce().
const bool distributive_with_sum = true
) {
TORCH_CHECK(
(x.is_sparse() && y.is_sparse())
&& (x.dim() == y.dim()) && (x.sparse_dim() == y.sparse_dim())
&& (x.sizes().slice(0, x.sparse_dim()) == y.sizes().slice(0, y.sparse_dim())),
NAME, "(): expects sparse inputs with equal dimensionality, ",
"number of sparse dimensions, and shape of sparse dimensions");
TORCH_CHECK(
x._indices().scalar_type() == y._indices().scalar_type(),
NAME, "(): expects inputs' indices to be of the same dtype (i.e. long or int)");
const auto check_hash_validity = [](const Tensor& t, const std::optional<Tensor>& t_hash_opt) {
if (!t_hash_opt.has_value()) {
return;
}
const auto &t_hash = *t_hash_opt;
TORCH_INTERNAL_ASSERT(
t_hash.dim() == 1 && t_hash.scalar_type() == kLong && t_hash.size(-1) == t._indices().size(-1),
NAME, "(): explicit hash values need to be a 1-dim Long tensor with the ",
"NSE matching that of the corresponding sparse tensor.");
};
check_hash_validity(x, x_hash_opt);
check_hash_validity(y, y_hash_opt);
const auto broadcasted_shape = infer_size(x.sizes(), y.sizes());
// 8 sparse dims should be more than enough?
constexpr int64_t max_sparse_dims = 8;
// COO indices are only 64-bit integers for now.
using index_t = int64_t;
if (max_sparse_dims > x.sparse_dim()) {
_sparse_binary_op_intersection_kernel_impl<
// For some reason MSVC complaints about passing constexpr max_sparse_dims
// as a template parameter claiming as if it is not know at compile time.
kernel_t, value_selection_intersection_kernel_t, index_t, 8>(
res, x, y, broadcasted_shape, x_hash_opt, y_hash_opt, distributive_with_sum);
} else {
_sparse_binary_op_intersection_kernel_impl<
kernel_t, value_selection_intersection_kernel_t, index_t>(
res, x, y, broadcasted_shape, x_hash_opt, y_hash_opt, distributive_with_sum);
}
}
} // anonymous namespace
} // at::native