From 80b527ead25359aeea49fcc770054daa7bb35b35 Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Fri, 20 Dec 2024 14:29:03 -0500 Subject: [PATCH] fix all tests csr and coo --- cpp/test/preprocess_utils.cu | 159 ++++++++++-------------------- cpp/test/sparse/preprocess_coo.cu | 110 +++++++++------------ cpp/test/sparse/preprocess_csr.cu | 111 ++++++++------------- 3 files changed, 141 insertions(+), 239 deletions(-) diff --git a/cpp/test/preprocess_utils.cu b/cpp/test/preprocess_utils.cu index 6f0fcbdf17..a240e2e4e2 100644 --- a/cpp/test/preprocess_utils.cu +++ b/cpp/test/preprocess_utils.cu @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -40,34 +41,25 @@ struct check_zeroes { }; template -void preproc_coo(raft::resources& handle, - raft::host_vector_view h_rows, - raft::host_vector_view h_cols, - raft::host_vector_view h_elems, - raft::device_vector_view results, - int num_rows, - int num_cols, - bool tf_idf) +void preproc(raft::resources& handle, + raft::device_vector_view dense_values, + raft::device_vector_view results, + int num_rows, + int num_cols, + bool tf_idf) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); - int rows_size = h_rows.size(); - int cols_size = h_cols.size(); - int elements_size = h_elems.size(); - auto device_matrix = raft::make_device_matrix(handle, num_rows, num_cols); - raft::matrix::fill(handle, device_matrix.view(), 0.0f); - auto host_matrix = raft::make_host_matrix(handle, num_rows, num_cols); - raft::copy(host_matrix.data_handle(), device_matrix.data_handle(), device_matrix.size(), stream); - raft::resource::sync_stream(handle, stream); + auto host_dense_vals = raft::make_host_vector(handle, dense_values.size()); + raft::copy( + host_dense_vals.data_handle(), dense_values.data_handle(), dense_values.size(), stream); - for (int i = 0; i < elements_size; i++) { - int row = h_rows(i); - int col = h_cols(i); - float element = h_elems(i); - host_matrix(row, col) = element; - } + auto host_matrix = + raft::make_host_matrix_view(host_dense_vals.data_handle(), num_rows, num_cols); + auto device_matrix = raft::make_device_matrix(handle, num_rows, num_cols); raft::copy(device_matrix.data_handle(), host_matrix.data_handle(), host_matrix.size(), stream); + auto output_cols_lengths = raft::make_device_matrix(handle, 1, num_cols); raft::linalg::reduce(output_cols_lengths.data_handle(), device_matrix.data_handle(), @@ -96,6 +88,7 @@ void preproc_coo(raft::resources& handle, output_cols_length_sum.data_handle(), output_cols_length_sum.size(), stream); + T2 avg_col_length = T2(h_output_cols_length_sum(0)) / num_cols; auto output_rows_freq = raft::make_device_matrix(handle, 1, num_rows); @@ -153,98 +146,48 @@ void preproc_coo(raft::resources& handle, } } } - raft::copy(results.data_handle(), out_host_vector.data_handle(), out_host_vector.size(), stream); -} - -template -int get_dupe_mask_count(raft::resources& handle, - raft::device_vector_view rows, - raft::device_vector_view columns, - raft::device_vector_view values, - const raft::device_vector_view& mask) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - raft::sparse::op::coo_sort(int(rows.size()), - int(columns.size()), - int(values.size()), - rows.data_handle(), - columns.data_handle(), - values.data_handle(), - stream); - - raft::sparse::op::compute_duplicates_mask( - mask.data_handle(), rows.data_handle(), columns.data_handle(), rows.size(), stream); - int col_nnz_count = thrust::reduce(raft::resource::get_thrust_policy(handle), - mask.data_handle(), - mask.data_handle() + mask.size()); - return col_nnz_count; + raft::copy(results.data_handle(), out_host_vector.data_handle(), out_host_vector.size(), stream); } template -void remove_dupes(raft::resources& handle, - raft::device_vector_view rows, - raft::device_vector_view columns, - raft::device_vector_view values, - raft::device_vector_view mask, - const raft::device_vector_view& out_rows, - const raft::device_vector_view& out_cols, - const raft::device_vector_view& out_vals, - int num_rows = 128) +void calc_tfidf_bm25(raft::resources& handle, + raft::device_csr_matrix_view csr_in, + raft::device_vector_view results, + bool tf_idf = false) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - auto col_counts = raft::make_device_vector(handle, columns.size()); - - thrust::fill(raft::resource::get_thrust_policy(handle), - col_counts.data_handle(), - col_counts.data_handle() + col_counts.size(), - 1.0f); - - auto keys_out = raft::make_device_vector(handle, num_rows); - auto counts_out = raft::make_device_vector(handle, num_rows); - - thrust::reduce_by_key(raft::resource::get_thrust_policy(handle), - rows.data_handle(), - rows.data_handle() + rows.size(), - col_counts.data_handle(), - keys_out.data_handle(), - counts_out.data_handle()); - - auto mask_out = raft::make_device_vector(handle, rows.size()); - - raft::linalg::map(handle, mask_out.view(), raft::cast_op{}, raft::make_const_mdspan(mask)); - - auto values_c = raft::make_device_vector(handle, values.size()); - raft::linalg::map(handle, - values_c.view(), - raft::mul_op{}, - raft::make_const_mdspan(values), - raft::make_const_mdspan(mask_out.view())); - - auto keys_nnz_out = raft::make_device_vector(handle, num_rows); - auto counts_nnz_out = raft::make_device_vector(handle, num_rows); - - thrust::reduce_by_key(raft::resource::get_thrust_policy(handle), - rows.data_handle(), - rows.data_handle() + rows.size(), - mask.data_handle(), - keys_nnz_out.data_handle(), - counts_nnz_out.data_handle()); - - raft::sparse::op::coo_remove_scalar(rows.data_handle(), - columns.data_handle(), - values_c.data_handle(), - values_c.size(), - out_rows.data_handle(), - out_cols.data_handle(), - out_vals.data_handle(), - counts_nnz_out.data_handle(), - counts_out.data_handle(), - 0, - num_rows, - stream); + int num_rows = csr_in.structure_view().get_n_rows(); + int num_cols = csr_in.structure_view().get_n_cols(); + int rows_size = csr_in.structure_view().get_indptr().size(); + int cols_size = csr_in.structure_view().get_indices().size(); + int elements_size = csr_in.get_elements().size(); + + auto indptr = raft::make_device_vector_view( + csr_in.structure_view().get_indptr().data(), rows_size); + auto indices = raft::make_device_vector_view( + csr_in.structure_view().get_indices().data(), cols_size); + auto values = + raft::make_device_vector_view(csr_in.get_elements().data(), elements_size); + auto dense_values = raft::make_device_vector(handle, num_rows * num_cols); + + cusparseHandle_t cu_handle; + RAFT_CUSPARSE_TRY(cusparseCreate(&cu_handle)); + + raft::sparse::convert::csr_to_dense(cu_handle, + num_rows, + num_cols, + elements_size, + indptr.data_handle(), + indices.data_handle(), + values.data_handle(), + num_rows, + dense_values.data_handle(), + stream, + true); + + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + preproc(handle, dense_values.view(), results, num_rows, num_cols, tf_idf); } template diff --git a/cpp/test/sparse/preprocess_coo.cu b/cpp/test/sparse/preprocess_coo.cu index 48cf9ae64c..c8ddfc920b 100644 --- a/cpp/test/sparse/preprocess_coo.cu +++ b/cpp/test/sparse/preprocess_coo.cu @@ -21,7 +21,6 @@ #include #include #include -#include #include #include @@ -33,37 +32,6 @@ namespace raft { namespace sparse { -template -void calc_tfidf_bm25(raft::resources& handle, - raft::device_coo_matrix_view coo_in, - raft::device_vector_view results, - bool tf_idf = false) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - int num_rows = coo_in.structure_view().get_n_rows(); - int num_cols = coo_in.structure_view().get_n_cols(); - int rows_size = coo_in.structure_view().get_cols().size(); - int cols_size = coo_in.structure_view().get_rows().size(); - int elements_size = coo_in.get_elements().size(); - - auto h_rows = raft::make_host_vector(handle, rows_size); - auto h_cols = raft::make_host_vector(handle, cols_size); - auto h_elems = raft::make_host_vector(handle, elements_size); - - raft::copy(h_rows.data_handle(), - coo_in.structure_view().get_rows().data(), - coo_in.structure_view().get_rows().size(), - stream); - raft::copy(h_cols.data_handle(), - coo_in.structure_view().get_cols().data(), - coo_in.structure_view().get_cols().size(), - stream); - raft::copy( - h_elems.data_handle(), coo_in.get_elements().data(), coo_in.get_elements().size(), stream); - raft::util::preproc_coo( - handle, h_rows.view(), h_cols.view(), h_elems.view(), results, num_rows, num_cols, tf_idf); -} - template struct SparsePreprocessInputs { int n_rows; @@ -94,51 +62,71 @@ class SparsePreprocessCoo auto rows = raft::make_device_vector(handle, params.nnz_edges); auto columns = raft::make_device_vector(handle, params.nnz_edges); auto values = raft::make_device_vector(handle, params.nnz_edges); - auto mask = raft::make_device_vector(handle, params.nnz_edges); + + rmm::device_uvector rows_uvec(rows.size(), stream); + rmm::device_uvector cols_uvec(rows.size(), stream); + rmm::device_uvector vals_uvec(rows.size(), stream); raft::util::create_dataset( handle, rows.view(), columns.view(), values.view(), 5, params.n_rows, params.n_cols); - int non_dupe_nnz_count = raft::util::get_dupe_mask_count( - handle, rows.view(), columns.view(), values.view(), mask.view()); - - auto rows_nnz = raft::make_device_vector(handle, non_dupe_nnz_count); - auto columns_nnz = raft::make_device_vector(handle, non_dupe_nnz_count); - auto values_nnz = raft::make_device_vector(handle, non_dupe_nnz_count); - raft::util::remove_dupes(handle, - rows.view(), - columns.view(), - values.view(), - mask.view(), - rows_nnz.view(), - columns_nnz.view(), - values_nnz.view(), - num_rows); - - auto coo_struct_view = raft::make_device_coordinate_structure_view(rows_nnz.data_handle(), - columns_nnz.data_handle(), - num_rows, - num_cols, - int(values_nnz.size())); + + raft::sparse::op::coo_sort(int(rows.size()), + int(columns.size()), + int(values.size()), + rows.data_handle(), + columns.data_handle(), + values.data_handle(), + stream); + + raft::copy(rows_uvec.data(), rows.data_handle(), rows.size(), stream); + raft::copy(cols_uvec.data(), columns.data_handle(), columns.size(), stream); + raft::copy(vals_uvec.data(), values.data_handle(), values.size(), stream); + + raft::sparse::COO coo(stream); + raft::sparse::op::max_duplicates(handle, + coo, + rows_uvec.data(), + cols_uvec.data(), + vals_uvec.data(), + params.nnz_edges, + num_rows, + num_cols); + + auto rows_csr = raft::make_device_vector(handle, num_rows + 1); + + raft::sparse::convert::sorted_coo_to_csr( + coo.rows(), coo.nnz, rows_csr.data_handle(), num_rows + 1, stream); + + auto csr_struct_view = raft::make_device_compressed_structure_view( + rows_csr.data_handle(), coo.cols(), num_rows, num_cols, coo.nnz); + + auto csr_matrix = + raft::make_device_csr_matrix(handle, csr_struct_view); + raft::update_device( + csr_matrix.view().get_elements().data(), coo.vals(), coo.nnz, stream); + + auto coo_struct_view = raft::make_device_coordinate_structure_view( + coo.rows(), coo.cols(), num_rows, num_cols, int(coo.nnz)); auto c_matrix = raft::make_device_coo_matrix(handle, coo_struct_view); - raft::update_device( - c_matrix.view().get_elements().data(), values_nnz.data_handle(), values_nnz.size(), stream); + raft::update_device(c_matrix.view().get_elements().data(), coo.vals(), coo.nnz, stream); - auto result = raft::make_device_vector(handle, values_nnz.size()); - auto bm25_vals = raft::make_device_vector(handle, values_nnz.size()); - auto tfidf_vals = raft::make_device_vector(handle, values_nnz.size()); + auto result = raft::make_device_vector(handle, coo.nnz); if (bm25_on) { + auto bm25_vals = raft::make_device_vector(handle, coo.nnz); sparse::matrix::encode_bm25(handle, c_matrix.view(), result.view()); - calc_tfidf_bm25(handle, c_matrix.view(), bm25_vals.view()); + raft::util::calc_tfidf_bm25(handle, csr_matrix.view(), bm25_vals.view()); ASSERT_TRUE(raft::devArrMatch(bm25_vals.data_handle(), result.data_handle(), result.size(), raft::CompareApprox(2e-5), stream)); } else { + auto tfidf_vals = raft::make_device_vector(handle, coo.nnz); sparse::matrix::encode_tfidf(handle, c_matrix.view(), result.view()); - calc_tfidf_bm25(handle, c_matrix.view(), tfidf_vals.view(), true); + raft::util::calc_tfidf_bm25( + handle, csr_matrix.view(), tfidf_vals.view(), true); ASSERT_TRUE(raft::devArrMatch(tfidf_vals.data_handle(), result.data_handle(), result.size(), diff --git a/cpp/test/sparse/preprocess_csr.cu b/cpp/test/sparse/preprocess_csr.cu index c5751f44cb..ff45051f67 100644 --- a/cpp/test/sparse/preprocess_csr.cu +++ b/cpp/test/sparse/preprocess_csr.cu @@ -21,7 +21,6 @@ #include #include #include -#include #include #include @@ -32,41 +31,6 @@ namespace raft { namespace sparse { -template -void calc_tfidf_bm25(raft::resources& handle, - raft::device_csr_matrix_view csr_in, - raft::device_vector_view results, - bool tf_idf = false) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - int num_rows = csr_in.structure_view().get_n_rows(); - int num_cols = csr_in.structure_view().get_n_cols(); - int rows_size = csr_in.structure_view().get_indptr().size(); - int cols_size = csr_in.structure_view().get_indices().size(); - int elements_size = csr_in.get_elements().size(); - - auto h_rows = raft::make_host_vector(handle, rows_size); - auto h_cols = raft::make_host_vector(handle, cols_size); - auto h_elems = raft::make_host_vector(handle, elements_size); - - auto indptr = raft::make_device_vector_view( - csr_in.structure_view().get_indptr().data(), csr_in.structure_view().get_indptr().size()); - auto indices = raft::make_device_vector_view( - csr_in.structure_view().get_indices().data(), csr_in.structure_view().get_indices().size()); - auto values = raft::make_device_vector_view(csr_in.get_elements().data(), - csr_in.get_elements().size()); - auto rows = raft::make_device_vector(handle, values.size()); - - raft::sparse::convert::csr_to_coo( - indptr.data_handle(), num_rows, rows.data_handle(), rows.size(), stream); - - raft::copy(h_rows.data_handle(), rows.data_handle(), rows.size(), stream); - raft::copy(h_cols.data_handle(), indices.data_handle(), cols_size, stream); - raft::copy(h_elems.data_handle(), values.data_handle(), values.size(), stream); - raft::util::preproc_coo( - handle, h_rows.view(), h_cols.view(), h_elems.view(), results, num_rows, num_cols, tf_idf); -} - template struct SparsePreprocessInputs { int n_rows; @@ -97,47 +61,54 @@ class SparsePreprocessCSR auto rows = raft::make_device_vector(handle, params.nnz_edges); auto columns = raft::make_device_vector(handle, params.nnz_edges); auto values = raft::make_device_vector(handle, params.nnz_edges); - auto mask = raft::make_device_vector(handle, params.nnz_edges); + + rmm::device_uvector rows_uvec(rows.size(), stream); + rmm::device_uvector cols_uvec(rows.size(), stream); + rmm::device_uvector vals_uvec(rows.size(), stream); raft::util::create_dataset( handle, rows.view(), columns.view(), values.view(), 5, params.n_rows, params.n_cols); - int non_dupe_nnz_count = raft::util::get_dupe_mask_count( - handle, rows.view(), columns.view(), values.view(), mask.view()); - - auto rows_nnz = raft::make_device_vector(handle, non_dupe_nnz_count); - auto columns_nnz = raft::make_device_vector(handle, non_dupe_nnz_count); - auto values_nnz = raft::make_device_vector(handle, non_dupe_nnz_count); - raft::util::remove_dupes(handle, - rows.view(), - columns.view(), - values.view(), - mask.view(), - rows_nnz.view(), - columns_nnz.view(), - values_nnz.view(), - num_rows); - auto rows_csr = raft::make_device_vector(handle, non_dupe_nnz_count); - raft::sparse::convert::sorted_coo_to_csr( - rows_nnz.data_handle(), non_dupe_nnz_count, rows_csr.data_handle(), num_rows, stream); - auto csr_struct_view = raft::make_device_compressed_structure_view(rows_csr.data_handle(), - columns_nnz.data_handle(), - num_rows, - num_cols, - int(values_nnz.size())); + raft::sparse::op::coo_sort(int(rows.size()), + int(columns.size()), + int(values.size()), + rows.data_handle(), + columns.data_handle(), + values.data_handle(), + stream); + + raft::copy(rows_uvec.data(), rows.data_handle(), rows.size(), stream); + raft::copy(cols_uvec.data(), columns.data_handle(), columns.size(), stream); + raft::copy(vals_uvec.data(), values.data_handle(), values.size(), stream); + + raft::sparse::COO coo(stream); + raft::sparse::op::max_duplicates(handle, + coo, + rows_uvec.data(), + cols_uvec.data(), + vals_uvec.data(), + params.nnz_edges, + num_rows, + num_cols); + + auto rows_csr = raft::make_device_vector(handle, num_rows + 1); + + raft::sparse::convert::sorted_coo_to_csr( + coo.rows(), coo.nnz, rows_csr.data_handle(), num_rows + 1, stream); + auto csr_struct_view = raft::make_device_compressed_structure_view( + rows_csr.data_handle(), coo.cols(), num_rows, num_cols, coo.nnz); auto c_matrix = raft::make_device_csr_matrix(handle, csr_struct_view); - raft::update_device( - c_matrix.view().get_elements().data(), values_nnz.data_handle(), values_nnz.size(), stream); + raft::update_device(c_matrix.view().get_elements().data(), coo.vals(), coo.nnz, stream); - auto result = raft::make_device_vector(handle, values_nnz.size()); - auto bm25_vals = raft::make_device_vector(handle, values_nnz.size()); - auto tfidf_vals = raft::make_device_vector(handle, values_nnz.size()); + auto result = raft::make_device_vector(handle, coo.nnz); + auto bm25_vals = raft::make_device_vector(handle, coo.nnz); + auto tfidf_vals = raft::make_device_vector(handle, coo.nnz); if (bm25_on) { sparse::matrix::encode_bm25(handle, c_matrix.view(), result.view()); - calc_tfidf_bm25(handle, c_matrix.view(), bm25_vals.view()); + raft::util::calc_tfidf_bm25(handle, c_matrix.view(), bm25_vals.view()); ASSERT_TRUE(raft::devArrMatch(bm25_vals.data_handle(), result.data_handle(), result.size(), @@ -145,7 +116,7 @@ class SparsePreprocessCSR stream)); } else { sparse::matrix::encode_tfidf(handle, c_matrix.view(), result.view()); - calc_tfidf_bm25(handle, c_matrix.view(), tfidf_vals.view(), true); + raft::util::calc_tfidf_bm25(handle, c_matrix.view(), tfidf_vals.view(), true); ASSERT_TRUE(raft::devArrMatch(tfidf_vals.data_handle(), result.data_handle(), result.size(), @@ -185,9 +156,9 @@ const std::vector> sparse_preprocess_inputs = const std::vector> sparse_preprocess_inputs_big = { { - 12, // n_rows_factor - 12, // n_cols_factor - 100000 // nnz_edges + 12, // n_rows_factor + 12, // n_cols_factor + 1000000 // nnz_edges - 6475 }, };