From 7ab6bf62b69d0849c694c139a6d276928e13d4cf Mon Sep 17 00:00:00 2001 From: Fritz Goebel Date: Fri, 19 Jul 2024 12:20:14 +0000 Subject: [PATCH] Add device kernels and tests --- .../cuda_hip/distributed/matrix_kernels.cpp | 94 ++++++++++++++++++- core/distributed/matrix.cpp | 51 +++++----- core/distributed/matrix_kernels.hpp | 22 +++-- dpcpp/distributed/matrix_kernels.dp.cpp | 8 +- omp/distributed/matrix_kernels.cpp | 70 +++++++++++++- reference/distributed/matrix_kernels.cpp | 54 ++++++++--- reference/test/distributed/matrix_kernels.cpp | 48 ++++++---- test/distributed/matrix_kernels.cpp | 60 +++++++++++- 8 files changed, 324 insertions(+), 83 deletions(-) diff --git a/common/cuda_hip/distributed/matrix_kernels.cpp b/common/cuda_hip/distributed/matrix_kernels.cpp index 70824bf03d9..ab3ec9da8b1 100644 --- a/common/cuda_hip/distributed/matrix_kernels.cpp +++ b/common/cuda_hip/distributed/matrix_kernels.cpp @@ -20,6 +20,9 @@ #include "common/cuda_hip/base/thrust.hpp" #include "common/cuda_hip/components/atomic.hpp" +#include "common/unified/base/kernel_launch.hpp" +#include "core/components/format_conversion_kernels.hpp" +#include "core/components/prefix_sum_kernels.hpp" namespace gko { @@ -55,8 +58,67 @@ void count_overlap_entries( const device_matrix_data& input, const experimental::distributed::Partition* row_partition, - comm_index_type local_part, - array& overlap_count) GKO_NOT_IMPLEMENTED; + comm_index_type local_part, array& overlap_count, + array& overlap_positions, + array& original_positions) +{ + auto row_part_ids = row_partition->get_part_ids(); + const auto* row_range_bounds = row_partition->get_range_bounds(); + const auto* row_range_starting_indices = + row_partition->get_range_starting_indices(); + const auto num_row_ranges = row_partition->get_num_ranges(); + const auto num_input_elements = input.get_num_stored_elements(); + + auto policy = thrust_policy(exec); + + // precompute the row and column range id of each input element + auto input_row_idxs = input.get_const_row_idxs(); + array row_range_ids{exec, num_input_elements}; + thrust::upper_bound(policy, row_range_bounds + 1, + row_range_bounds + num_row_ranges + 1, input_row_idxs, + input_row_idxs + num_input_elements, + row_range_ids.get_data()); + + array row_part_ids_per_entry{exec, num_input_elements}; + run_kernel( + exec, + [] GKO_KERNEL(auto i, auto part_id, auto part_ids, auto range_ids, + auto part_ids_per_entry, auto orig_positions) { + part_ids_per_entry[i] = part_ids[range_ids[i]]; + orig_positions[i] = part_ids_per_entry[i] == part_id ? -1 : i; + }, + num_input_elements, local_part, row_part_ids, row_range_ids.get_data(), + row_part_ids_per_entry.get_data(), original_positions.get_data()); + + thrust::stable_sort_by_key( + policy, row_part_ids_per_entry.get_data(), + row_part_ids_per_entry.get_data() + num_input_elements, + original_positions.get_data()); + run_kernel( + exec, + [] GKO_KERNEL(auto i, auto orig_positions, auto overl_positions) { + overl_positions[i] = orig_positions[i] >= 0 ? 1 : 0; + }, + num_input_elements, original_positions.get_const_data(), + overlap_positions.get_data()); + + components::prefix_sum_nonnegative(exec, overlap_positions.get_data(), + num_input_elements); + size_type num_parts = row_partition->get_num_parts(); + array row_part_ptrs{exec, num_parts + 1}; + row_part_ptrs.fill(0); + components::convert_idxs_to_ptrs( + exec, row_part_ids_per_entry.get_const_data(), num_input_elements, + num_parts, row_part_ptrs.get_data()); + + run_kernel( + exec, + [] GKO_KERNEL(auto i, auto part_id, auto part_ptrs, auto count) { + count[i] = i == part_id ? 0 : part_ptrs[i + 1] - part_ptrs[i]; + }, + num_parts, local_part, row_part_ptrs.get_data(), + overlap_count.get_data()); +} GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE( GKO_DECLARE_COUNT_OVERLAP_ENTRIES); @@ -68,10 +130,32 @@ void fill_overlap_send_buffers( const device_matrix_data& input, const experimental::distributed::Partition* row_partition, - comm_index_type local_part, array& offsets, + comm_index_type local_part, const array& overlap_positions, + const array& original_positions, array& overlap_row_idxs, - array& overlap_col_idxs, - array& overlap_values) GKO_NOT_IMPLEMENTED; + array& overlap_col_idxs, array& overlap_values) +{ + auto num_entries = input.get_num_stored_elements(); + auto input_row_idxs = input.get_const_row_idxs(); + auto input_col_idxs = input.get_const_col_idxs(); + auto input_values = input.get_const_values(); + + run_kernel( + exec, + [] GKO_KERNEL(auto i, auto in_rows, auto in_cols, auto in_vals, + auto in_pos, auto out_pos, auto out_rows, auto out_cols, + auto out_vals) { + if (in_pos[i] >= 0) { + out_rows[out_pos[i]] = in_rows[in_pos[i]]; + out_cols[out_pos[i]] = in_cols[in_pos[i]]; + out_vals[out_pos[i]] = in_vals[in_pos[i]]; + } + }, + num_entries, input_row_idxs, input_col_idxs, input_values, + original_positions.get_const_data(), overlap_positions.get_const_data(), + overlap_row_idxs.get_data(), overlap_col_idxs.get_data(), + overlap_values.get_data()); +} GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE( GKO_DECLARE_FILL_OVERLAP_SEND_BUFFERS); diff --git a/core/distributed/matrix.cpp b/core/distributed/matrix.cpp index a64a07619b3..3799895abf3 100644 --- a/core/distributed/matrix.cpp +++ b/core/distributed/matrix.cpp @@ -4,17 +4,14 @@ #include "ginkgo/core/distributed/matrix.hpp" -#include -#include - #include #include #include #include #include +#include "core/components/prefix_sum_kernels.hpp" #include "core/distributed/matrix_kernels.hpp" -#include "ginkgo/core/base/mtx_io.hpp" namespace gko { @@ -271,18 +268,23 @@ void Matrix::read_distributed( device_matrix_data all_data{exec}; if (assembly_type == assembly::communicate) { - array overlap_count{exec, comm.size()}; + size_type num_entries = data.get_num_stored_elements(); + size_type num_parts = comm.size(); + array overlap_count{exec, num_parts}; + array overlap_positions{exec, num_entries}; + array original_positions{exec, num_entries}; overlap_count.fill(0); auto tmp_part = make_temporary_clone(exec, row_partition); exec->run(matrix::make_count_overlap_entries( - data, tmp_part.get(), local_part, overlap_count)); + data, tmp_part.get(), local_part, overlap_count, overlap_positions, + original_positions)); overlap_count.set_executor(exec->get_master()); std::vector overlap_send_sizes( - overlap_count.get_data(), overlap_count.get_data() + comm.size()); - std::vector overlap_send_offsets(comm.size() + 1); - std::vector overlap_recv_sizes(comm.size()); - std::vector overlap_recv_offsets(comm.size() + 1); + overlap_count.get_data(), overlap_count.get_data() + num_parts); + std::vector overlap_send_offsets(num_parts + 1); + std::vector overlap_recv_sizes(num_parts); + std::vector overlap_recv_offsets(num_parts + 1); std::partial_sum(overlap_send_sizes.begin(), overlap_send_sizes.end(), overlap_send_offsets.begin() + 1); @@ -301,14 +303,10 @@ void Matrix::read_distributed( array overlap_recv_row_idxs{exec, n_recv}; array overlap_recv_col_idxs{exec, n_recv}; array overlap_recv_values{exec, n_recv}; - auto offset_array = - make_const_array_view(exec->get_master(), comm.size() + 1, - overlap_send_offsets.data()) - .copy_to_array(); - offset_array.set_executor(exec); exec->run(matrix::make_fill_overlap_send_buffers( - data, tmp_part.get(), local_part, offset_array, - overlap_send_row_idxs, overlap_send_col_idxs, overlap_send_values)); + data, tmp_part.get(), local_part, overlap_positions, + original_positions, overlap_send_row_idxs, overlap_send_col_idxs, + overlap_send_values)); if (use_host_buffer) { overlap_send_row_idxs.set_executor(exec->get_master()); @@ -339,22 +337,21 @@ void Matrix::read_distributed( overlap_recv_values.set_executor(exec); } - size_type n_nnz = data.get_num_stored_elements(); - array all_row_idxs{exec, n_nnz + n_recv}; - array all_col_idxs{exec, n_nnz + n_recv}; - array all_values{exec, n_nnz + n_recv}; - exec->copy_from(exec, n_nnz, data.get_const_row_idxs(), + array all_row_idxs{exec, num_entries + n_recv}; + array all_col_idxs{exec, num_entries + n_recv}; + array all_values{exec, num_entries + n_recv}; + exec->copy_from(exec, num_entries, data.get_const_row_idxs(), all_row_idxs.get_data()); exec->copy_from(exec, n_recv, overlap_recv_row_idxs.get_data(), - all_row_idxs.get_data() + n_nnz); - exec->copy_from(exec, n_nnz, data.get_const_col_idxs(), + all_row_idxs.get_data() + num_entries); + exec->copy_from(exec, num_entries, data.get_const_col_idxs(), all_col_idxs.get_data()); exec->copy_from(exec, n_recv, overlap_recv_col_idxs.get_data(), - all_col_idxs.get_data() + n_nnz); - exec->copy_from(exec, n_nnz, data.get_const_values(), + all_col_idxs.get_data() + num_entries); + exec->copy_from(exec, num_entries, data.get_const_values(), all_values.get_data()); exec->copy_from(exec, n_recv, overlap_recv_values.get_data(), - all_values.get_data() + n_nnz); + all_values.get_data() + num_entries); all_data = device_matrix_data{ exec, global_dim, all_row_idxs, all_col_idxs, all_values}; all_data.sum_duplicates(); diff --git a/core/distributed/matrix_kernels.hpp b/core/distributed/matrix_kernels.hpp index fa7e891eb4e..4cdaf3e17fe 100644 --- a/core/distributed/matrix_kernels.hpp +++ b/core/distributed/matrix_kernels.hpp @@ -19,14 +19,16 @@ namespace gko { namespace kernels { -#define GKO_DECLARE_COUNT_OVERLAP_ENTRIES(ValueType, LocalIndexType, \ - GlobalIndexType) \ - void count_overlap_entries( \ - std::shared_ptr exec, \ - const device_matrix_data& input, \ - const experimental::distributed::Partition< \ - LocalIndexType, GlobalIndexType>* row_partition, \ - comm_index_type local_part, array& overlap_count) +#define GKO_DECLARE_COUNT_OVERLAP_ENTRIES(ValueType, LocalIndexType, \ + GlobalIndexType) \ + void count_overlap_entries( \ + std::shared_ptr exec, \ + const device_matrix_data& input, \ + const experimental::distributed::Partition< \ + LocalIndexType, GlobalIndexType>* row_partition, \ + comm_index_type local_part, array& overlap_count, \ + array& overlap_positions, \ + array& original_positions) #define GKO_DECLARE_FILL_OVERLAP_SEND_BUFFERS(ValueType, LocalIndexType, \ @@ -36,7 +38,9 @@ namespace kernels { const device_matrix_data& input, \ const experimental::distributed::Partition< \ LocalIndexType, GlobalIndexType>* row_partition, \ - comm_index_type local_part, array& offsets, \ + comm_index_type local_part, \ + const array& overlap_positions, \ + const array& original_positions, \ array& overlap_row_idxs, \ array& overlap_col_idxs, \ array& overlap_values) diff --git a/dpcpp/distributed/matrix_kernels.dp.cpp b/dpcpp/distributed/matrix_kernels.dp.cpp index 9225e58ad14..60fc0686473 100644 --- a/dpcpp/distributed/matrix_kernels.dp.cpp +++ b/dpcpp/distributed/matrix_kernels.dp.cpp @@ -19,8 +19,9 @@ void count_overlap_entries( const device_matrix_data& input, const experimental::distributed::Partition* row_partition, - comm_index_type local_part, - array& overlap_count) GKO_NOT_IMPLEMENTED; + comm_index_type local_part, array& overlap_count, + array& overlap_positions, + array& original_positions) GKO_NOT_IMPLEMENTED; GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE( GKO_DECLARE_COUNT_OVERLAP_ENTRIES); @@ -32,7 +33,8 @@ void fill_overlap_send_buffers( const device_matrix_data& input, const experimental::distributed::Partition* row_partition, - comm_index_type local_part, array& offsets, + comm_index_type local_part, const array& overlap_positions, + const array& original_positions, array& overlap_row_idxs, array& overlap_col_idxs, array& overlap_values) GKO_NOT_IMPLEMENTED; diff --git a/omp/distributed/matrix_kernels.cpp b/omp/distributed/matrix_kernels.cpp index a3e8cb60868..55ee5524116 100644 --- a/omp/distributed/matrix_kernels.cpp +++ b/omp/distributed/matrix_kernels.cpp @@ -4,6 +4,8 @@ #include "core/distributed/matrix_kernels.hpp" +#include + #include #include @@ -26,8 +28,50 @@ void count_overlap_entries( const device_matrix_data& input, const experimental::distributed::Partition* row_partition, - comm_index_type local_part, - array& overlap_count) GKO_NOT_IMPLEMENTED; + comm_index_type local_part, array& overlap_count, + array& overlap_positions, + array& original_positions) +{ + auto num_input_elements = input.get_num_stored_elements(); + auto input_row_idxs = input.get_const_row_idxs(); + auto row_part_ids = row_partition->get_part_ids(); + array row_part_ids_per_entry{exec, num_input_elements}; + + size_type row_range_id = 0; +#pragma omp parallel for firstprivate(row_range_id) + for (size_type i = 0; i < input.get_num_stored_elements(); ++i) { + auto global_row = input_row_idxs[i]; + row_range_id = find_range(global_row, row_partition, row_range_id); + auto row_part_id = row_part_ids[row_range_id]; + row_part_ids_per_entry.get_data()[i] = row_part_id; + if (row_part_id != local_part) { +#pragma omp atomic + overlap_count.get_data()[row_part_id]++; + original_positions.get_data()[i] = i; + } else { + original_positions.get_data()[i] = -1; + } + } + + auto comp = [row_part_ids_per_entry, local_part](auto i, auto j) { + comm_index_type a = + i == -1 ? local_part : row_part_ids_per_entry.get_const_data()[i]; + comm_index_type b = + j == -1 ? local_part : row_part_ids_per_entry.get_const_data()[j]; + return a < b; + }; + std::stable_sort(original_positions.get_data(), + original_positions.get_data() + num_input_elements, comp); + +#pragma omp parallel for + for (size_type i = 0; i < num_input_elements; i++) { + overlap_positions.get_data()[i] = + original_positions.get_const_data()[i] == -1 ? 0 : 1; + } + + components::prefix_sum_nonnegative(exec, overlap_positions.get_data(), + num_input_elements); +} GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE( GKO_DECLARE_COUNT_OVERLAP_ENTRIES); @@ -39,10 +83,26 @@ void fill_overlap_send_buffers( const device_matrix_data& input, const experimental::distributed::Partition* row_partition, - comm_index_type local_part, array& offsets, + comm_index_type local_part, const array& overlap_positions, + const array& original_positions, array& overlap_row_idxs, - array& overlap_col_idxs, - array& overlap_values) GKO_NOT_IMPLEMENTED; + array& overlap_col_idxs, array& overlap_values) +{ + auto input_row_idxs = input.get_const_row_idxs(); + auto input_col_idxs = input.get_const_col_idxs(); + auto input_vals = input.get_const_values(); + +#pragma omp parallel for + for (size_type i = 0; i < input.get_num_stored_elements(); ++i) { + auto in_pos = original_positions.get_const_data()[i]; + if (in_pos >= 0) { + auto out_pos = overlap_positions.get_const_data()[i]; + overlap_row_idxs.get_data()[out_pos] = input_row_idxs[in_pos]; + overlap_col_idxs.get_data()[out_pos] = input_col_idxs[in_pos]; + overlap_values.get_data()[out_pos] = input_vals[in_pos]; + } + } +} GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE( GKO_DECLARE_FILL_OVERLAP_SEND_BUFFERS); diff --git a/reference/distributed/matrix_kernels.cpp b/reference/distributed/matrix_kernels.cpp index d8b0f9e1d4f..6a57a64e075 100644 --- a/reference/distributed/matrix_kernels.cpp +++ b/reference/distributed/matrix_kernels.cpp @@ -4,10 +4,12 @@ #include "core/distributed/matrix_kernels.hpp" +#include +#include + #include "core/base/allocator.hpp" #include "core/base/device_matrix_data_kernels.hpp" #include "core/base/iterator_factory.hpp" -#include "ginkgo/core/distributed/partition.hpp" #include "reference/distributed/partition_helpers.hpp" @@ -23,21 +25,47 @@ void count_overlap_entries( const device_matrix_data& input, const experimental::distributed::Partition* row_partition, - comm_index_type local_part, array& overlap_count) + comm_index_type local_part, array& overlap_count, + array& overlap_positions, + array& original_positions) { + auto num_input_elements = input.get_num_stored_elements(); auto input_row_idxs = input.get_const_row_idxs(); auto row_part_ids = row_partition->get_part_ids(); + array row_part_ids_per_entry{exec, num_input_elements}; size_type row_range_id = 0; for (size_type i = 0; i < input.get_num_stored_elements(); ++i) { auto global_row = input_row_idxs[i]; row_range_id = find_range(global_row, row_partition, row_range_id); - row_range_id = find_range(global_row, row_partition, row_range_id); auto row_part_id = row_part_ids[row_range_id]; + row_part_ids_per_entry.get_data()[i] = row_part_id; if (row_part_id != local_part) { overlap_count.get_data()[row_part_id]++; + original_positions.get_data()[i] = i; + } else { + original_positions.get_data()[i] = -1; } } + + auto comp = [row_part_ids_per_entry, local_part](auto i, auto j) { + comm_index_type a = + i == -1 ? local_part : row_part_ids_per_entry.get_const_data()[i]; + comm_index_type b = + j == -1 ? local_part : row_part_ids_per_entry.get_const_data()[j]; + return a < b; + }; + + std::stable_sort(original_positions.get_data(), + original_positions.get_data() + num_input_elements, comp); + for (size_type i = 0; i < num_input_elements; i++) { + overlap_positions.get_data()[i] = + original_positions.get_const_data()[i] == -1 ? 0 : 1; + } + + std::exclusive_scan(overlap_positions.get_data(), + overlap_positions.get_data() + num_input_elements, + overlap_positions.get_data(), 0); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE( @@ -50,26 +78,22 @@ void fill_overlap_send_buffers( const device_matrix_data& input, const experimental::distributed::Partition* row_partition, - comm_index_type local_part, array& offsets, + comm_index_type local_part, const array& overlap_positions, + const array& original_positions, array& overlap_row_idxs, array& overlap_col_idxs, array& overlap_values) { auto input_row_idxs = input.get_const_row_idxs(); auto input_col_idxs = input.get_const_col_idxs(); auto input_vals = input.get_const_values(); - auto row_part_ids = row_partition->get_part_ids(); - size_type row_range_id = 0; for (size_type i = 0; i < input.get_num_stored_elements(); ++i) { - auto global_row = input_row_idxs[i]; - row_range_id = find_range(global_row, row_partition, row_range_id); - row_range_id = find_range(global_row, row_partition, row_range_id); - auto row_part_id = row_part_ids[row_range_id]; - if (row_part_id != local_part) { - auto idx = offsets.get_data()[row_part_id]++; - overlap_row_idxs.get_data()[idx] = global_row; - overlap_col_idxs.get_data()[idx] = input_col_idxs[i]; - overlap_values.get_data()[idx] = input_vals[i]; + auto in_pos = original_positions.get_const_data()[i]; + if (in_pos >= 0) { + auto out_pos = overlap_positions.get_const_data()[i]; + overlap_row_idxs.get_data()[out_pos] = input_row_idxs[in_pos]; + overlap_col_idxs.get_data()[out_pos] = input_col_idxs[in_pos]; + overlap_values.get_data()[out_pos] = input_vals[in_pos]; } } } diff --git a/reference/test/distributed/matrix_kernels.cpp b/reference/test/distributed/matrix_kernels.cpp index 00d3fcd8895..80fc8eb3330 100644 --- a/reference/test/distributed/matrix_kernels.cpp +++ b/reference/test/distributed/matrix_kernels.cpp @@ -17,8 +17,6 @@ #include #include "core/test/utils.hpp" -#include "ginkgo/core/base/array.hpp" -#include "ginkgo/core/base/types.hpp" namespace { @@ -194,24 +192,37 @@ TYPED_TEST(Matrix, CountOverlapEntries) using git = typename TestFixture::global_index_type; using vt = typename TestFixture::value_type; using ca = gko::array; + using ga = gko::array; this->mapping = {this->ref, {1, 0, 2, 2, 0, 1, 1}}; std::vector overlap_count_ref{ ca{this->ref, I{0, 5, 3}}, ca{this->ref, I{4, 0, 3}}, ca{this->ref, I{4, 5, 0}}}; + std::vector overlap_pos_ref{ + ga{this->ref, I{0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7}}, + ga{this->ref, I{0, 1, 2, 3, 4, 4, 4, 4, 4, 4, 5, 6}}, + ga{this->ref, I{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9}}}; + std::vector original_pos_ref{ + ga{this->ref, I{-1, -1, -1, -1, 0, 1, 9, 10, 11, 4, 5, 6}}, + ga{this->ref, I{2, 3, 7, 8, -1, -1, -1, -1, -1, 4, 5, 6}}, + ga{this->ref, I{2, 3, 7, 8, 0, 1, 9, 10, 11, -1, -1, -1}}}; comm_index_type num_parts = 3; auto partition = gko::experimental::distributed::Partition::build_from_mapping( this->ref, this->mapping, num_parts); auto input = this->create_input_full_rank(); - gko::array overlap_count{ - this->ref, static_cast(num_parts)}; + ca overlap_count{this->ref, static_cast(num_parts)}; + ga overlap_positions{this->ref, input.get_num_stored_elements()}; + ga original_positions{this->ref, input.get_num_stored_elements()}; for (gko::size_type i = 0; i < num_parts; i++) { overlap_count.fill(0); gko::kernels::reference::distributed_matrix::count_overlap_entries( - this->ref, input, partition.get(), i, overlap_count); + this->ref, input, partition.get(), i, overlap_count, + overlap_positions, original_positions); GKO_ASSERT_ARRAY_EQ(overlap_count, overlap_count_ref[i]); + GKO_ASSERT_ARRAY_EQ(overlap_positions, overlap_pos_ref[i]); + GKO_ASSERT_ARRAY_EQ(original_positions, original_pos_ref[i]); } } @@ -225,10 +236,14 @@ TYPED_TEST(Matrix, FillOverlapSendBuffers) using ga = gko::array; using va = gko::array; this->mapping = {this->ref, {1, 0, 2, 2, 0, 1, 1}}; - std::vector overlap_offsets{ - ca{this->ref, I{0, 0, 5, 8}}, - ca{this->ref, I{0, 4, 4, 7}}, - ca{this->ref, I{0, 4, 9, 9}}}; + std::vector overlap_positions{ + ga{this->ref, I{0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7}}, + ga{this->ref, I{0, 1, 2, 3, 4, 4, 4, 4, 4, 4, 5, 6}}, + ga{this->ref, I{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9}}}; + std::vector original_positions{ + ga{this->ref, I{-1, -1, -1, -1, 0, 1, 9, 10, 11, 4, 5, 6}}, + ga{this->ref, I{2, 3, 7, 8, -1, -1, -1, -1, -1, 4, 5, 6}}, + ga{this->ref, I{2, 3, 7, 8, 0, 1, 9, 10, 11, -1, -1, -1}}}; std::vector overlap_row_idxs_ref{ ga{this->ref, I{0, 0, 5, 5, 6, 2, 3, 3}}, ga{this->ref, I{1, 1, 4, 4, 2, 3, 3}}, @@ -251,15 +266,14 @@ TYPED_TEST(Matrix, FillOverlapSendBuffers) gko::array overlap_col_idxs{this->ref}; gko::array overlap_values{this->ref}; for (gko::size_type i = 0; i < num_parts; i++) { - overlap_row_idxs.resize_and_reset( - overlap_offsets[i].get_data()[num_parts]); - overlap_col_idxs.resize_and_reset( - overlap_offsets[i].get_data()[num_parts]); - overlap_values.resize_and_reset( - overlap_offsets[i].get_data()[num_parts]); + auto num_entries = overlap_row_idxs_ref[i].get_size(); + overlap_row_idxs.resize_and_reset(num_entries); + overlap_col_idxs.resize_and_reset(num_entries); + overlap_values.resize_and_reset(num_entries); gko::kernels::reference::distributed_matrix::fill_overlap_send_buffers( - this->ref, input, partition.get(), i, overlap_offsets[i], - overlap_row_idxs, overlap_col_idxs, overlap_values); + this->ref, input, partition.get(), i, overlap_positions[i], + original_positions[i], overlap_row_idxs, overlap_col_idxs, + overlap_values); GKO_ASSERT_ARRAY_EQ(overlap_row_idxs, overlap_row_idxs_ref[i]); GKO_ASSERT_ARRAY_EQ(overlap_col_idxs, overlap_col_idxs_ref[i]); GKO_ASSERT_ARRAY_EQ(overlap_values, overlap_values_ref[i]); diff --git a/test/distributed/matrix_kernels.cpp b/test/distributed/matrix_kernels.cpp index ad91d699496..6de772e8006 100644 --- a/test/distributed/matrix_kernels.cpp +++ b/test/distributed/matrix_kernels.cpp @@ -48,8 +48,9 @@ class Matrix : public CommonTestFixture { { gko::device_matrix_data d_input{exec, input}; - for (comm_index_type part = 0; part < row_partition->get_num_parts(); - ++part) { + gko::size_type num_parts = row_partition->get_num_parts(); + gko::size_type num_entries = input.get_num_stored_elements(); + for (comm_index_type part = 0; part < num_parts; ++part) { gko::array local_row_idxs{ref}; gko::array local_col_idxs{ref}; gko::array local_values{ref}; @@ -62,6 +63,55 @@ class Matrix : public CommonTestFixture { gko::array d_non_local_row_idxs{exec}; gko::array d_non_local_col_idxs{exec}; gko::array d_non_local_values{exec}; + gko::array overlap_count{ref, num_parts}; + overlap_count.fill(0); + gko::array d_overlap_count{exec, num_parts}; + d_overlap_count.fill(0); + gko::array overlap_positions{ref, num_entries}; + gko::array d_overlap_positions{exec, + num_entries}; + gko::array original_positions{ref, num_entries}; + gko::array d_original_positions{exec, + num_entries}; + + gko::kernels::reference::distributed_matrix::count_overlap_entries( + ref, input, row_partition.get(), part, overlap_count, + overlap_positions, original_positions); + gko::kernels::GKO_DEVICE_NAMESPACE::distributed_matrix:: + count_overlap_entries( + exec, d_input, d_row_partition.get(), part, d_overlap_count, + d_overlap_positions, d_original_positions); + + gko::array overlap_offsets{ref, num_parts + 1}; + std::partial_sum(overlap_count.get_data(), + overlap_count.get_data() + num_parts, + overlap_offsets.get_data() + 1); + overlap_offsets.get_data()[0] = 0; + gko::array d_overlap_offsets{exec, + overlap_offsets}; + gko::size_type num_overlap_entries = + overlap_offsets.get_data()[num_parts]; + gko::array overlap_row_idxs{ref, + num_overlap_entries}; + gko::array overlap_col_idxs{ref, + num_overlap_entries}; + gko::array overlap_values{ref, num_overlap_entries}; + gko::array d_overlap_row_idxs{ + exec, num_overlap_entries}; + gko::array d_overlap_col_idxs{ + exec, num_overlap_entries}; + gko::array d_overlap_values{exec, num_overlap_entries}; + + gko::kernels::reference::distributed_matrix:: + fill_overlap_send_buffers(ref, input, row_partition.get(), part, + overlap_positions, original_positions, + overlap_row_idxs, overlap_col_idxs, + overlap_values); + gko::kernels::GKO_DEVICE_NAMESPACE::distributed_matrix:: + fill_overlap_send_buffers( + exec, d_input, d_row_partition.get(), part, + d_overlap_positions, d_original_positions, + d_overlap_row_idxs, d_overlap_col_idxs, d_overlap_values); gko::kernels::reference::distributed_matrix:: separate_local_nonlocal( @@ -75,6 +125,12 @@ class Matrix : public CommonTestFixture { d_non_local_row_idxs, d_non_local_col_idxs, d_non_local_values); + GKO_ASSERT_ARRAY_EQ(overlap_positions, d_overlap_positions); + GKO_ASSERT_ARRAY_EQ(original_positions, d_original_positions); + GKO_ASSERT_ARRAY_EQ(overlap_count, d_overlap_count); + GKO_ASSERT_ARRAY_EQ(overlap_row_idxs, d_overlap_row_idxs); + GKO_ASSERT_ARRAY_EQ(overlap_col_idxs, d_overlap_col_idxs); + GKO_ASSERT_ARRAY_EQ(overlap_values, d_overlap_values); GKO_ASSERT_ARRAY_EQ(local_row_idxs, d_local_row_idxs); GKO_ASSERT_ARRAY_EQ(local_col_idxs, d_local_col_idxs); GKO_ASSERT_ARRAY_EQ(local_values, d_local_values);