Skip to content

Commit

Permalink
add shortcuts for common reduction boilerplate
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Jan 30, 2022
1 parent 493daef commit 200ac00
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 26 deletions.
10 changes: 10 additions & 0 deletions common/unified/base/kernel_launch_reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "common/unified/base/kernel_launch.hpp"


#define GKO_KERNEL_REDUCE_SUM(ValueType) \
[] GKO_KERNEL(auto a, auto b) { return a + b; }, \
[] GKO_KERNEL(auto a) { return a; }, ValueType \
{}
#define GKO_KERNEL_REDUCE_MAX(ValueType) \
[] GKO_KERNEL(auto a, auto b) { return a > b ? a : b; }, \
[] GKO_KERNEL(auto a) { return a; }, ValueType \
{}


#if defined(GKO_COMPILING_CUDA)
#include "cuda/base/kernel_launch_reduction.cuh"
#elif defined(GKO_COMPILING_HIP)
Expand Down
3 changes: 1 addition & 2 deletions common/unified/components/reduce_array_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ void reduce_add_array(std::shared_ptr<const DefaultExecutor> exec,
[] GKO_KERNEL(auto i, auto array, auto result) {
return i == 0 ? (array[i] + result[0]) : array[i];
},
[] GKO_KERNEL(auto a, auto b) { return a + b; },
[] GKO_KERNEL(auto a) { return a; }, ValueType{}, result.get_data(),
GKO_KERNEL_REDUCE_SUM(ValueType), result.get_data(),
array.get_num_elems(), array, result);
}

Expand Down
3 changes: 1 addition & 2 deletions common/unified/distributed/partition_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ void count_ranges(std::shared_ptr<const DefaultExecutor> exec,
auto prev_part = i == 0 ? comm_index_type{-1} : mapping[i - 1];
return cur_part != prev_part ? 1 : 0;
},
[] GKO_KERNEL(auto a, auto b) { return a + b; },
[] GKO_KERNEL(auto a) { return a; }, size_type{}, result.get_data(),
GKO_KERNEL_REDUCE_SUM(size_type), result.get_data(),
mapping.get_num_elems(), mapping);
num_ranges = exec->copy_val_to_host(result.get_const_data());
}
Expand Down
25 changes: 9 additions & 16 deletions common/unified/matrix/dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,8 @@ void compute_dot(std::shared_ptr<const DefaultExecutor> exec,
[] GKO_KERNEL(auto i, auto j, auto x, auto y) {
return x(i, j) * y(i, j);
},
[] GKO_KERNEL(auto a, auto b) { return a + b; },
[] GKO_KERNEL(auto a) { return a; }, ValueType{}, result->get_values(),
x->get_size(), x, y);
GKO_KERNEL_REDUCE_SUM(ValueType), result->get_values(), x->get_size(),
x, y);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_DOT_KERNEL);
Expand All @@ -271,9 +270,8 @@ void compute_conj_dot(std::shared_ptr<const DefaultExecutor> exec,
[] GKO_KERNEL(auto i, auto j, auto x, auto y) {
return conj(x(i, j)) * y(i, j);
},
[] GKO_KERNEL(auto a, auto b) { return a + b; },
[] GKO_KERNEL(auto a) { return a; }, ValueType{}, result->get_values(),
x->get_size(), x, y);
GKO_KERNEL_REDUCE_SUM(ValueType), result->get_values(), x->get_size(),
x, y);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_CONJ_DOT_KERNEL);
Expand Down Expand Up @@ -301,9 +299,8 @@ void compute_norm1(std::shared_ptr<const DefaultExecutor> exec,
{
run_kernel_col_reduction(
exec, [] GKO_KERNEL(auto i, auto j, auto x) { return abs(x(i, j)); },
[] GKO_KERNEL(auto a, auto b) { return a + b; },
[] GKO_KERNEL(auto a) { return a; }, remove_complex<ValueType>{},
result->get_values(), x->get_size(), x);
GKO_KERNEL_REDUCE_SUM(remove_complex<ValueType>), result->get_values(),
x->get_size(), x);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_NORM1_KERNEL);
Expand All @@ -318,8 +315,7 @@ void compute_max_nnz_per_row(std::shared_ptr<const DefaultExecutor> exec,
count_nonzeros_per_row(exec, source, partial.get_data());
run_kernel_reduction(
exec, [] GKO_KERNEL(auto i, auto partial) { return partial[i]; },
[] GKO_KERNEL(auto a, auto b) { return a > b ? a : b; },
[] GKO_KERNEL(auto a) { return a; }, size_type{},
GKO_KERNEL_REDUCE_MAX(size_type),
partial.get_data() + source->get_size()[0], source->get_size()[0],
partial);
result = exec->copy_val_to_host(partial.get_const_data() +
Expand Down Expand Up @@ -351,8 +347,7 @@ void compute_slice_sets(std::shared_ptr<const DefaultExecutor> exec,
stride_factor)
: size_type{};
},
[] GKO_KERNEL(auto a, auto b) { return a > b ? a : b; },
[] GKO_KERNEL(auto a) { return a; }, size_type{}, slice_lengths, 1,
GKO_KERNEL_REDUCE_MAX(size_type), slice_lengths, 1,
gko::dim<2>{num_slices, slice_size}, row_nnz, slice_size, stride_factor,
num_rows);
exec->copy(num_slices, slice_lengths, slice_sets);
Expand All @@ -373,9 +368,7 @@ void count_nonzeros_per_row(std::shared_ptr<const DefaultExecutor> exec,
[] GKO_KERNEL(auto i, auto j, auto mtx) {
return is_nonzero(mtx(i, j)) ? 1 : 0;
},
[] GKO_KERNEL(auto a, auto b) { return a + b; },
[] GKO_KERNEL(auto a) { return a; }, IndexType{}, result, 1,
mtx->get_size(), mtx);
GKO_KERNEL_REDUCE_SUM(IndexType), result, 1, mtx->get_size(), mtx);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
Expand Down
6 changes: 2 additions & 4 deletions common/unified/matrix/ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ void compute_max_row_nnz(std::shared_ptr<const DefaultExecutor> exec,
[] GKO_KERNEL(auto i, auto row_ptrs) {
return row_ptrs[i + 1] - row_ptrs[i];
},
[] GKO_KERNEL(auto a, auto b) { return a > b ? a : b; },
[] GKO_KERNEL(auto a) { return a; }, size_type{}, result.get_data(),
GKO_KERNEL_REDUCE_MAX(size_type), result.get_data(),
row_ptrs.get_num_elems() - 1, row_ptrs);
max_nnz = exec->copy_val_to_host(result.get_const_data());
}
Expand Down Expand Up @@ -169,8 +168,7 @@ void count_nonzeros_per_row(std::shared_ptr<const DefaultExecutor> exec,
const auto ell_idx = ell_col * ell_stride + row;
return is_nonzero(in_vals[ell_idx]) ? 1 : 0;
},
[] GKO_KERNEL(auto a, auto b) { return a + b; },
[] GKO_KERNEL(auto a) { return a; }, IndexType{}, result,
GKO_KERNEL_REDUCE_SUM(IndexType), result,
dim<2>{source->get_num_stored_elements_per_row(),
source->get_size()[0]},
static_cast<int64>(source->get_stride()), source->get_const_values());
Expand Down
3 changes: 1 addition & 2 deletions common/unified/matrix/sellp_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ void compute_slice_sets(std::shared_ptr<const DefaultExecutor> exec,
stride_factor)
: size_type{};
},
[] GKO_KERNEL(auto a, auto b) { return a > b ? a : b; },
[] GKO_KERNEL(auto a) { return a; }, size_type{}, slice_lengths, 1,
GKO_KERNEL_REDUCE_MAX(size_type), slice_lengths, 1,
gko::dim<2>{num_slices, slice_size}, row_ptrs, slice_size,
stride_factor, num_rows);
exec->copy(num_slices, slice_lengths, slice_sets);
Expand Down

0 comments on commit 200ac00

Please sign in to comment.