Skip to content

Commit

Permalink
Add device kernels and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzgoebel committed Jul 23, 2024
1 parent 298dacb commit 7ab6bf6
Show file tree
Hide file tree
Showing 8 changed files with 324 additions and 83 deletions.
94 changes: 89 additions & 5 deletions common/cuda_hip/distributed/matrix_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

#include "common/cuda_hip/base/thrust.hpp"
#include "common/cuda_hip/components/atomic.hpp"
#include "common/unified/base/kernel_launch.hpp"
#include "core/components/format_conversion_kernels.hpp"
#include "core/components/prefix_sum_kernels.hpp"


namespace gko {
Expand Down Expand Up @@ -55,8 +58,67 @@ void count_overlap_entries(
const device_matrix_data<ValueType, GlobalIndexType>& input,
const experimental::distributed::Partition<LocalIndexType, GlobalIndexType>*
row_partition,
comm_index_type local_part,
array<comm_index_type>& overlap_count) GKO_NOT_IMPLEMENTED;
comm_index_type local_part, array<comm_index_type>& overlap_count,
array<GlobalIndexType>& overlap_positions,
array<GlobalIndexType>& original_positions)
{
auto row_part_ids = row_partition->get_part_ids();
const auto* row_range_bounds = row_partition->get_range_bounds();
const auto* row_range_starting_indices =
row_partition->get_range_starting_indices();
const auto num_row_ranges = row_partition->get_num_ranges();
const auto num_input_elements = input.get_num_stored_elements();

auto policy = thrust_policy(exec);

// precompute the row and column range id of each input element
auto input_row_idxs = input.get_const_row_idxs();
array<size_type> row_range_ids{exec, num_input_elements};
thrust::upper_bound(policy, row_range_bounds + 1,
row_range_bounds + num_row_ranges + 1, input_row_idxs,
input_row_idxs + num_input_elements,
row_range_ids.get_data());

array<comm_index_type> row_part_ids_per_entry{exec, num_input_elements};
run_kernel(
exec,
[] GKO_KERNEL(auto i, auto part_id, auto part_ids, auto range_ids,
auto part_ids_per_entry, auto orig_positions) {
part_ids_per_entry[i] = part_ids[range_ids[i]];
orig_positions[i] = part_ids_per_entry[i] == part_id ? -1 : i;
},
num_input_elements, local_part, row_part_ids, row_range_ids.get_data(),
row_part_ids_per_entry.get_data(), original_positions.get_data());

thrust::stable_sort_by_key(
policy, row_part_ids_per_entry.get_data(),
row_part_ids_per_entry.get_data() + num_input_elements,
original_positions.get_data());
run_kernel(
exec,
[] GKO_KERNEL(auto i, auto orig_positions, auto overl_positions) {
overl_positions[i] = orig_positions[i] >= 0 ? 1 : 0;
},
num_input_elements, original_positions.get_const_data(),
overlap_positions.get_data());

components::prefix_sum_nonnegative(exec, overlap_positions.get_data(),
num_input_elements);
size_type num_parts = row_partition->get_num_parts();
array<comm_index_type> row_part_ptrs{exec, num_parts + 1};
row_part_ptrs.fill(0);
components::convert_idxs_to_ptrs(
exec, row_part_ids_per_entry.get_const_data(), num_input_elements,
num_parts, row_part_ptrs.get_data());

run_kernel(
exec,
[] GKO_KERNEL(auto i, auto part_id, auto part_ptrs, auto count) {
count[i] = i == part_id ? 0 : part_ptrs[i + 1] - part_ptrs[i];
},
num_parts, local_part, row_part_ptrs.get_data(),
overlap_count.get_data());
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_COUNT_OVERLAP_ENTRIES);
Expand All @@ -68,10 +130,32 @@ void fill_overlap_send_buffers(
const device_matrix_data<ValueType, GlobalIndexType>& input,
const experimental::distributed::Partition<LocalIndexType, GlobalIndexType>*
row_partition,
comm_index_type local_part, array<comm_index_type>& offsets,
comm_index_type local_part, const array<GlobalIndexType>& overlap_positions,
const array<GlobalIndexType>& original_positions,
array<GlobalIndexType>& overlap_row_idxs,
array<GlobalIndexType>& overlap_col_idxs,
array<ValueType>& overlap_values) GKO_NOT_IMPLEMENTED;
array<GlobalIndexType>& overlap_col_idxs, array<ValueType>& overlap_values)
{
auto num_entries = input.get_num_stored_elements();
auto input_row_idxs = input.get_const_row_idxs();
auto input_col_idxs = input.get_const_col_idxs();
auto input_values = input.get_const_values();

run_kernel(
exec,
[] GKO_KERNEL(auto i, auto in_rows, auto in_cols, auto in_vals,
auto in_pos, auto out_pos, auto out_rows, auto out_cols,
auto out_vals) {
if (in_pos[i] >= 0) {
out_rows[out_pos[i]] = in_rows[in_pos[i]];
out_cols[out_pos[i]] = in_cols[in_pos[i]];
out_vals[out_pos[i]] = in_vals[in_pos[i]];
}
},
num_entries, input_row_idxs, input_col_idxs, input_values,
original_positions.get_const_data(), overlap_positions.get_const_data(),
overlap_row_idxs.get_data(), overlap_col_idxs.get_data(),
overlap_values.get_data());
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_FILL_OVERLAP_SEND_BUFFERS);
Expand Down
51 changes: 24 additions & 27 deletions core/distributed/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@

#include "ginkgo/core/distributed/matrix.hpp"

#include <numeric>
#include <vector>

#include <ginkgo/core/base/precision_dispatch.hpp>
#include <ginkgo/core/distributed/vector.hpp>
#include <ginkgo/core/matrix/coo.hpp>
#include <ginkgo/core/matrix/csr.hpp>
#include <ginkgo/core/matrix/diagonal.hpp>

#include "core/components/prefix_sum_kernels.hpp"
#include "core/distributed/matrix_kernels.hpp"
#include "ginkgo/core/base/mtx_io.hpp"


namespace gko {
Expand Down Expand Up @@ -271,18 +268,23 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(

device_matrix_data<value_type, global_index_type> all_data{exec};
if (assembly_type == assembly::communicate) {
array<comm_index_type> overlap_count{exec, comm.size()};
size_type num_entries = data.get_num_stored_elements();
size_type num_parts = comm.size();
array<comm_index_type> overlap_count{exec, num_parts};
array<global_index_type> overlap_positions{exec, num_entries};
array<global_index_type> original_positions{exec, num_entries};
overlap_count.fill(0);
auto tmp_part = make_temporary_clone(exec, row_partition);
exec->run(matrix::make_count_overlap_entries(
data, tmp_part.get(), local_part, overlap_count));
data, tmp_part.get(), local_part, overlap_count, overlap_positions,
original_positions));

overlap_count.set_executor(exec->get_master());
std::vector<comm_index_type> overlap_send_sizes(
overlap_count.get_data(), overlap_count.get_data() + comm.size());
std::vector<comm_index_type> overlap_send_offsets(comm.size() + 1);
std::vector<comm_index_type> overlap_recv_sizes(comm.size());
std::vector<comm_index_type> overlap_recv_offsets(comm.size() + 1);
overlap_count.get_data(), overlap_count.get_data() + num_parts);
std::vector<comm_index_type> overlap_send_offsets(num_parts + 1);
std::vector<comm_index_type> overlap_recv_sizes(num_parts);
std::vector<comm_index_type> overlap_recv_offsets(num_parts + 1);

std::partial_sum(overlap_send_sizes.begin(), overlap_send_sizes.end(),
overlap_send_offsets.begin() + 1);
Expand All @@ -301,14 +303,10 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
array<global_index_type> overlap_recv_row_idxs{exec, n_recv};
array<global_index_type> overlap_recv_col_idxs{exec, n_recv};
array<value_type> overlap_recv_values{exec, n_recv};
auto offset_array =
make_const_array_view(exec->get_master(), comm.size() + 1,
overlap_send_offsets.data())
.copy_to_array();
offset_array.set_executor(exec);
exec->run(matrix::make_fill_overlap_send_buffers(
data, tmp_part.get(), local_part, offset_array,
overlap_send_row_idxs, overlap_send_col_idxs, overlap_send_values));
data, tmp_part.get(), local_part, overlap_positions,
original_positions, overlap_send_row_idxs, overlap_send_col_idxs,
overlap_send_values));

if (use_host_buffer) {
overlap_send_row_idxs.set_executor(exec->get_master());
Expand Down Expand Up @@ -339,22 +337,21 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
overlap_recv_values.set_executor(exec);
}

size_type n_nnz = data.get_num_stored_elements();
array<global_index_type> all_row_idxs{exec, n_nnz + n_recv};
array<global_index_type> all_col_idxs{exec, n_nnz + n_recv};
array<value_type> all_values{exec, n_nnz + n_recv};
exec->copy_from(exec, n_nnz, data.get_const_row_idxs(),
array<global_index_type> all_row_idxs{exec, num_entries + n_recv};
array<global_index_type> all_col_idxs{exec, num_entries + n_recv};
array<value_type> all_values{exec, num_entries + n_recv};
exec->copy_from(exec, num_entries, data.get_const_row_idxs(),
all_row_idxs.get_data());
exec->copy_from(exec, n_recv, overlap_recv_row_idxs.get_data(),
all_row_idxs.get_data() + n_nnz);
exec->copy_from(exec, n_nnz, data.get_const_col_idxs(),
all_row_idxs.get_data() + num_entries);
exec->copy_from(exec, num_entries, data.get_const_col_idxs(),
all_col_idxs.get_data());
exec->copy_from(exec, n_recv, overlap_recv_col_idxs.get_data(),
all_col_idxs.get_data() + n_nnz);
exec->copy_from(exec, n_nnz, data.get_const_values(),
all_col_idxs.get_data() + num_entries);
exec->copy_from(exec, num_entries, data.get_const_values(),
all_values.get_data());
exec->copy_from(exec, n_recv, overlap_recv_values.get_data(),
all_values.get_data() + n_nnz);
all_values.get_data() + num_entries);
all_data = device_matrix_data<value_type, global_index_type>{
exec, global_dim, all_row_idxs, all_col_idxs, all_values};
all_data.sum_duplicates();
Expand Down
22 changes: 13 additions & 9 deletions core/distributed/matrix_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ namespace gko {
namespace kernels {


#define GKO_DECLARE_COUNT_OVERLAP_ENTRIES(ValueType, LocalIndexType, \
GlobalIndexType) \
void count_overlap_entries( \
std::shared_ptr<const DefaultExecutor> exec, \
const device_matrix_data<ValueType, GlobalIndexType>& input, \
const experimental::distributed::Partition< \
LocalIndexType, GlobalIndexType>* row_partition, \
comm_index_type local_part, array<comm_index_type>& overlap_count)
#define GKO_DECLARE_COUNT_OVERLAP_ENTRIES(ValueType, LocalIndexType, \
GlobalIndexType) \
void count_overlap_entries( \
std::shared_ptr<const DefaultExecutor> exec, \
const device_matrix_data<ValueType, GlobalIndexType>& input, \
const experimental::distributed::Partition< \
LocalIndexType, GlobalIndexType>* row_partition, \
comm_index_type local_part, array<comm_index_type>& overlap_count, \
array<GlobalIndexType>& overlap_positions, \
array<GlobalIndexType>& original_positions)


#define GKO_DECLARE_FILL_OVERLAP_SEND_BUFFERS(ValueType, LocalIndexType, \
Expand All @@ -36,7 +38,9 @@ namespace kernels {
const device_matrix_data<ValueType, GlobalIndexType>& input, \
const experimental::distributed::Partition< \
LocalIndexType, GlobalIndexType>* row_partition, \
comm_index_type local_part, array<comm_index_type>& offsets, \
comm_index_type local_part, \
const array<GlobalIndexType>& overlap_positions, \
const array<GlobalIndexType>& original_positions, \
array<GlobalIndexType>& overlap_row_idxs, \
array<GlobalIndexType>& overlap_col_idxs, \
array<ValueType>& overlap_values)
Expand Down
8 changes: 5 additions & 3 deletions dpcpp/distributed/matrix_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ void count_overlap_entries(
const device_matrix_data<ValueType, GlobalIndexType>& input,
const experimental::distributed::Partition<LocalIndexType, GlobalIndexType>*
row_partition,
comm_index_type local_part,
array<comm_index_type>& overlap_count) GKO_NOT_IMPLEMENTED;
comm_index_type local_part, array<comm_index_type>& overlap_count,
array<GlobalIndexType>& overlap_positions,
array<GlobalIndexType>& original_positions) GKO_NOT_IMPLEMENTED;

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_COUNT_OVERLAP_ENTRIES);
Expand All @@ -32,7 +33,8 @@ void fill_overlap_send_buffers(
const device_matrix_data<ValueType, GlobalIndexType>& input,
const experimental::distributed::Partition<LocalIndexType, GlobalIndexType>*
row_partition,
comm_index_type local_part, array<comm_index_type>& offsets,
comm_index_type local_part, const array<GlobalIndexType>& overlap_positions,
const array<GlobalIndexType>& original_positions,
array<GlobalIndexType>& overlap_row_idxs,
array<GlobalIndexType>& overlap_col_idxs,
array<ValueType>& overlap_values) GKO_NOT_IMPLEMENTED;
Expand Down
70 changes: 65 additions & 5 deletions omp/distributed/matrix_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "core/distributed/matrix_kernels.hpp"

#include <algorithm>

#include <omp.h>

#include <ginkgo/core/base/exception_helpers.hpp>
Expand All @@ -26,8 +28,50 @@ void count_overlap_entries(
const device_matrix_data<ValueType, GlobalIndexType>& input,
const experimental::distributed::Partition<LocalIndexType, GlobalIndexType>*
row_partition,
comm_index_type local_part,
array<comm_index_type>& overlap_count) GKO_NOT_IMPLEMENTED;
comm_index_type local_part, array<comm_index_type>& overlap_count,
array<GlobalIndexType>& overlap_positions,
array<GlobalIndexType>& original_positions)
{
auto num_input_elements = input.get_num_stored_elements();
auto input_row_idxs = input.get_const_row_idxs();
auto row_part_ids = row_partition->get_part_ids();
array<comm_index_type> row_part_ids_per_entry{exec, num_input_elements};

size_type row_range_id = 0;
#pragma omp parallel for firstprivate(row_range_id)
for (size_type i = 0; i < input.get_num_stored_elements(); ++i) {
auto global_row = input_row_idxs[i];
row_range_id = find_range(global_row, row_partition, row_range_id);
auto row_part_id = row_part_ids[row_range_id];
row_part_ids_per_entry.get_data()[i] = row_part_id;
if (row_part_id != local_part) {
#pragma omp atomic
overlap_count.get_data()[row_part_id]++;
original_positions.get_data()[i] = i;
} else {
original_positions.get_data()[i] = -1;
}
}

auto comp = [row_part_ids_per_entry, local_part](auto i, auto j) {
comm_index_type a =
i == -1 ? local_part : row_part_ids_per_entry.get_const_data()[i];
comm_index_type b =
j == -1 ? local_part : row_part_ids_per_entry.get_const_data()[j];
return a < b;
};
std::stable_sort(original_positions.get_data(),
original_positions.get_data() + num_input_elements, comp);

#pragma omp parallel for
for (size_type i = 0; i < num_input_elements; i++) {
overlap_positions.get_data()[i] =
original_positions.get_const_data()[i] == -1 ? 0 : 1;
}

components::prefix_sum_nonnegative(exec, overlap_positions.get_data(),
num_input_elements);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_COUNT_OVERLAP_ENTRIES);
Expand All @@ -39,10 +83,26 @@ void fill_overlap_send_buffers(
const device_matrix_data<ValueType, GlobalIndexType>& input,
const experimental::distributed::Partition<LocalIndexType, GlobalIndexType>*
row_partition,
comm_index_type local_part, array<comm_index_type>& offsets,
comm_index_type local_part, const array<GlobalIndexType>& overlap_positions,
const array<GlobalIndexType>& original_positions,
array<GlobalIndexType>& overlap_row_idxs,
array<GlobalIndexType>& overlap_col_idxs,
array<ValueType>& overlap_values) GKO_NOT_IMPLEMENTED;
array<GlobalIndexType>& overlap_col_idxs, array<ValueType>& overlap_values)
{
auto input_row_idxs = input.get_const_row_idxs();
auto input_col_idxs = input.get_const_col_idxs();
auto input_vals = input.get_const_values();

#pragma omp parallel for
for (size_type i = 0; i < input.get_num_stored_elements(); ++i) {
auto in_pos = original_positions.get_const_data()[i];
if (in_pos >= 0) {
auto out_pos = overlap_positions.get_const_data()[i];
overlap_row_idxs.get_data()[out_pos] = input_row_idxs[in_pos];
overlap_col_idxs.get_data()[out_pos] = input_col_idxs[in_pos];
overlap_values.get_data()[out_pos] = input_vals[in_pos];
}
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_FILL_OVERLAP_SEND_BUFFERS);
Expand Down
Loading

0 comments on commit 7ab6bf6

Please sign in to comment.