Skip to content

Commit

Permalink
fix all tests csr and coo
Browse files Browse the repository at this point in the history
  • Loading branch information
jperez999 committed Dec 20, 2024
1 parent 15f53c3 commit 80b527e
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 239 deletions.
159 changes: 51 additions & 108 deletions cpp/test/preprocess_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <raft/random/rmat_rectangular_generator.cuh>
#include <raft/random/rng.cuh>
#include <raft/sparse/convert/coo.cuh>
#include <raft/sparse/convert/dense.cuh>
#include <raft/sparse/matrix/preprocessing.cuh>
#include <raft/sparse/neighbors/cross_component_nn.cuh>
#include <raft/sparse/op/filter.cuh>
Expand All @@ -40,34 +41,25 @@ struct check_zeroes {
};

template <typename T1, typename T2>
void preproc_coo(raft::resources& handle,
raft::host_vector_view<T1> h_rows,
raft::host_vector_view<T1> h_cols,
raft::host_vector_view<T2> h_elems,
raft::device_vector_view<T2> results,
int num_rows,
int num_cols,
bool tf_idf)
void preproc(raft::resources& handle,
raft::device_vector_view<T2> dense_values,
raft::device_vector_view<T2> results,
int num_rows,
int num_cols,
bool tf_idf)
{
cudaStream_t stream = raft::resource::get_cuda_stream(handle);
int rows_size = h_rows.size();
int cols_size = h_cols.size();
int elements_size = h_elems.size();
auto device_matrix = raft::make_device_matrix<T2, int64_t>(handle, num_rows, num_cols);
raft::matrix::fill<T2>(handle, device_matrix.view(), 0.0f);
auto host_matrix = raft::make_host_matrix<T2, int64_t>(handle, num_rows, num_cols);
raft::copy(host_matrix.data_handle(), device_matrix.data_handle(), device_matrix.size(), stream);

raft::resource::sync_stream(handle, stream);
auto host_dense_vals = raft::make_host_vector<T2, int64_t>(handle, dense_values.size());
raft::copy(
host_dense_vals.data_handle(), dense_values.data_handle(), dense_values.size(), stream);

for (int i = 0; i < elements_size; i++) {
int row = h_rows(i);
int col = h_cols(i);
float element = h_elems(i);
host_matrix(row, col) = element;
}
auto host_matrix =
raft::make_host_matrix_view<T2, int64_t>(host_dense_vals.data_handle(), num_rows, num_cols);
auto device_matrix = raft::make_device_matrix<T2, int64_t>(handle, num_rows, num_cols);

raft::copy(device_matrix.data_handle(), host_matrix.data_handle(), host_matrix.size(), stream);

auto output_cols_lengths = raft::make_device_matrix<T2, int64_t>(handle, 1, num_cols);
raft::linalg::reduce(output_cols_lengths.data_handle(),
device_matrix.data_handle(),
Expand Down Expand Up @@ -96,6 +88,7 @@ void preproc_coo(raft::resources& handle,
output_cols_length_sum.data_handle(),
output_cols_length_sum.size(),
stream);

T2 avg_col_length = T2(h_output_cols_length_sum(0)) / num_cols;

auto output_rows_freq = raft::make_device_matrix<T2, int64_t>(handle, 1, num_rows);
Expand Down Expand Up @@ -153,98 +146,48 @@ void preproc_coo(raft::resources& handle,
}
}
}
raft::copy(results.data_handle(), out_host_vector.data_handle(), out_host_vector.size(), stream);
}

template <typename T1, typename T2>
int get_dupe_mask_count(raft::resources& handle,
raft::device_vector_view<T1> rows,
raft::device_vector_view<T1> columns,
raft::device_vector_view<T2> values,
const raft::device_vector_view<T1>& mask)
{
cudaStream_t stream = raft::resource::get_cuda_stream(handle);

raft::sparse::op::coo_sort(int(rows.size()),
int(columns.size()),
int(values.size()),
rows.data_handle(),
columns.data_handle(),
values.data_handle(),
stream);

raft::sparse::op::compute_duplicates_mask<T1>(
mask.data_handle(), rows.data_handle(), columns.data_handle(), rows.size(), stream);

int col_nnz_count = thrust::reduce(raft::resource::get_thrust_policy(handle),
mask.data_handle(),
mask.data_handle() + mask.size());
return col_nnz_count;
raft::copy(results.data_handle(), out_host_vector.data_handle(), out_host_vector.size(), stream);
}

template <typename T1, typename T2>
void remove_dupes(raft::resources& handle,
raft::device_vector_view<T1> rows,
raft::device_vector_view<T1> columns,
raft::device_vector_view<T2> values,
raft::device_vector_view<T1> mask,
const raft::device_vector_view<T1>& out_rows,
const raft::device_vector_view<T1>& out_cols,
const raft::device_vector_view<T2>& out_vals,
int num_rows = 128)
void calc_tfidf_bm25(raft::resources& handle,
raft::device_csr_matrix_view<T2, T1, T1, T1> csr_in,
raft::device_vector_view<T2> results,
bool tf_idf = false)
{
cudaStream_t stream = raft::resource::get_cuda_stream(handle);

auto col_counts = raft::make_device_vector<T1, int64_t>(handle, columns.size());

thrust::fill(raft::resource::get_thrust_policy(handle),
col_counts.data_handle(),
col_counts.data_handle() + col_counts.size(),
1.0f);

auto keys_out = raft::make_device_vector<T1, int64_t>(handle, num_rows);
auto counts_out = raft::make_device_vector<T1, int64_t>(handle, num_rows);

thrust::reduce_by_key(raft::resource::get_thrust_policy(handle),
rows.data_handle(),
rows.data_handle() + rows.size(),
col_counts.data_handle(),
keys_out.data_handle(),
counts_out.data_handle());

auto mask_out = raft::make_device_vector<T2, int64_t>(handle, rows.size());

raft::linalg::map(handle, mask_out.view(), raft::cast_op<T2>{}, raft::make_const_mdspan(mask));

auto values_c = raft::make_device_vector<T2, int64_t>(handle, values.size());
raft::linalg::map(handle,
values_c.view(),
raft::mul_op{},
raft::make_const_mdspan(values),
raft::make_const_mdspan(mask_out.view()));

auto keys_nnz_out = raft::make_device_vector<T1, int64_t>(handle, num_rows);
auto counts_nnz_out = raft::make_device_vector<T1, int64_t>(handle, num_rows);

thrust::reduce_by_key(raft::resource::get_thrust_policy(handle),
rows.data_handle(),
rows.data_handle() + rows.size(),
mask.data_handle(),
keys_nnz_out.data_handle(),
counts_nnz_out.data_handle());

raft::sparse::op::coo_remove_scalar<T2>(rows.data_handle(),
columns.data_handle(),
values_c.data_handle(),
values_c.size(),
out_rows.data_handle(),
out_cols.data_handle(),
out_vals.data_handle(),
counts_nnz_out.data_handle(),
counts_out.data_handle(),
0,
num_rows,
stream);
int num_rows = csr_in.structure_view().get_n_rows();
int num_cols = csr_in.structure_view().get_n_cols();
int rows_size = csr_in.structure_view().get_indptr().size();
int cols_size = csr_in.structure_view().get_indices().size();
int elements_size = csr_in.get_elements().size();

auto indptr = raft::make_device_vector_view<T1, int64_t>(
csr_in.structure_view().get_indptr().data(), rows_size);
auto indices = raft::make_device_vector_view<T1, int64_t>(
csr_in.structure_view().get_indices().data(), cols_size);
auto values =
raft::make_device_vector_view<T2, int64_t>(csr_in.get_elements().data(), elements_size);
auto dense_values = raft::make_device_vector<T2, int64_t>(handle, num_rows * num_cols);

cusparseHandle_t cu_handle;
RAFT_CUSPARSE_TRY(cusparseCreate(&cu_handle));

raft::sparse::convert::csr_to_dense(cu_handle,
num_rows,
num_cols,
elements_size,
indptr.data_handle(),
indices.data_handle(),
values.data_handle(),
num_rows,
dense_values.data_handle(),
stream,
true);

RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
preproc<T1, T2>(handle, dense_values.view(), results, num_rows, num_cols, tf_idf);
}

template <typename T1, typename T2>
Expand Down
110 changes: 49 additions & 61 deletions cpp/test/sparse/preprocess_coo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/sparse/matrix/preprocessing.cuh>
#include <raft/sparse/selection/knn.cuh>
#include <raft/util/cudart_utils.hpp>

#include <gtest/gtest.h>
Expand All @@ -33,37 +32,6 @@
namespace raft {
namespace sparse {

template <typename T1, typename T2>
void calc_tfidf_bm25(raft::resources& handle,
raft::device_coo_matrix_view<T2, T1, T1, T1> coo_in,
raft::device_vector_view<T2> results,
bool tf_idf = false)
{
cudaStream_t stream = raft::resource::get_cuda_stream(handle);
int num_rows = coo_in.structure_view().get_n_rows();
int num_cols = coo_in.structure_view().get_n_cols();
int rows_size = coo_in.structure_view().get_cols().size();
int cols_size = coo_in.structure_view().get_rows().size();
int elements_size = coo_in.get_elements().size();

auto h_rows = raft::make_host_vector<T1, int64_t>(handle, rows_size);
auto h_cols = raft::make_host_vector<T1, int64_t>(handle, cols_size);
auto h_elems = raft::make_host_vector<T2, int64_t>(handle, elements_size);

raft::copy(h_rows.data_handle(),
coo_in.structure_view().get_rows().data(),
coo_in.structure_view().get_rows().size(),
stream);
raft::copy(h_cols.data_handle(),
coo_in.structure_view().get_cols().data(),
coo_in.structure_view().get_cols().size(),
stream);
raft::copy(
h_elems.data_handle(), coo_in.get_elements().data(), coo_in.get_elements().size(), stream);
raft::util::preproc_coo<T1, T2>(
handle, h_rows.view(), h_cols.view(), h_elems.view(), results, num_rows, num_cols, tf_idf);
}

template <typename Type_f, typename Index_>
struct SparsePreprocessInputs {
int n_rows;
Expand Down Expand Up @@ -94,51 +62,71 @@ class SparsePreprocessCoo
auto rows = raft::make_device_vector<Index_, int64_t>(handle, params.nnz_edges);
auto columns = raft::make_device_vector<Index_, int64_t>(handle, params.nnz_edges);
auto values = raft::make_device_vector<Type_f, int64_t>(handle, params.nnz_edges);
auto mask = raft::make_device_vector<Index_, int64_t>(handle, params.nnz_edges);

rmm::device_uvector<Index_> rows_uvec(rows.size(), stream);
rmm::device_uvector<Index_> cols_uvec(rows.size(), stream);
rmm::device_uvector<Type_f> vals_uvec(rows.size(), stream);

raft::util::create_dataset<Index_, Type_f>(
handle, rows.view(), columns.view(), values.view(), 5, params.n_rows, params.n_cols);
int non_dupe_nnz_count = raft::util::get_dupe_mask_count<Index_, Type_f>(
handle, rows.view(), columns.view(), values.view(), mask.view());

auto rows_nnz = raft::make_device_vector<Index_, int64_t>(handle, non_dupe_nnz_count);
auto columns_nnz = raft::make_device_vector<Index_, int64_t>(handle, non_dupe_nnz_count);
auto values_nnz = raft::make_device_vector<Type_f, int64_t>(handle, non_dupe_nnz_count);
raft::util::remove_dupes<Index_, Type_f>(handle,
rows.view(),
columns.view(),
values.view(),
mask.view(),
rows_nnz.view(),
columns_nnz.view(),
values_nnz.view(),
num_rows);

auto coo_struct_view = raft::make_device_coordinate_structure_view(rows_nnz.data_handle(),
columns_nnz.data_handle(),
num_rows,
num_cols,
int(values_nnz.size()));

raft::sparse::op::coo_sort(int(rows.size()),
int(columns.size()),
int(values.size()),
rows.data_handle(),
columns.data_handle(),
values.data_handle(),
stream);

raft::copy(rows_uvec.data(), rows.data_handle(), rows.size(), stream);
raft::copy(cols_uvec.data(), columns.data_handle(), columns.size(), stream);
raft::copy(vals_uvec.data(), values.data_handle(), values.size(), stream);

raft::sparse::COO<Type_f, Index_> coo(stream);
raft::sparse::op::max_duplicates(handle,
coo,
rows_uvec.data(),
cols_uvec.data(),
vals_uvec.data(),
params.nnz_edges,
num_rows,
num_cols);

auto rows_csr = raft::make_device_vector<Index_, int64_t>(handle, num_rows + 1);

raft::sparse::convert::sorted_coo_to_csr(
coo.rows(), coo.nnz, rows_csr.data_handle(), num_rows + 1, stream);

auto csr_struct_view = raft::make_device_compressed_structure_view(
rows_csr.data_handle(), coo.cols(), num_rows, num_cols, coo.nnz);

auto csr_matrix =
raft::make_device_csr_matrix<Type_f, Index_, Index_, Index_>(handle, csr_struct_view);
raft::update_device<Type_f>(
csr_matrix.view().get_elements().data(), coo.vals(), coo.nnz, stream);

auto coo_struct_view = raft::make_device_coordinate_structure_view(
coo.rows(), coo.cols(), num_rows, num_cols, int(coo.nnz));
auto c_matrix =
raft::make_device_coo_matrix<Type_f, Index_, Index_, Index_>(handle, coo_struct_view);
raft::update_device<Type_f>(
c_matrix.view().get_elements().data(), values_nnz.data_handle(), values_nnz.size(), stream);
raft::update_device<Type_f>(c_matrix.view().get_elements().data(), coo.vals(), coo.nnz, stream);

auto result = raft::make_device_vector<Type_f, int64_t>(handle, values_nnz.size());
auto bm25_vals = raft::make_device_vector<Type_f, int64_t>(handle, values_nnz.size());
auto tfidf_vals = raft::make_device_vector<Type_f, int64_t>(handle, values_nnz.size());
auto result = raft::make_device_vector<Type_f, int64_t>(handle, coo.nnz);

if (bm25_on) {
auto bm25_vals = raft::make_device_vector<Type_f, int64_t>(handle, coo.nnz);
sparse::matrix::encode_bm25<Index_, Type_f>(handle, c_matrix.view(), result.view());
calc_tfidf_bm25<Index_, Type_f>(handle, c_matrix.view(), bm25_vals.view());
raft::util::calc_tfidf_bm25<Index_, Type_f>(handle, csr_matrix.view(), bm25_vals.view());
ASSERT_TRUE(raft::devArrMatch<Type_f>(bm25_vals.data_handle(),
result.data_handle(),
result.size(),
raft::CompareApprox<Type_f>(2e-5),
stream));
} else {
auto tfidf_vals = raft::make_device_vector<Type_f, int64_t>(handle, coo.nnz);
sparse::matrix::encode_tfidf<Index_, Type_f>(handle, c_matrix.view(), result.view());
calc_tfidf_bm25<Index_, Type_f>(handle, c_matrix.view(), tfidf_vals.view(), true);
raft::util::calc_tfidf_bm25<Index_, Type_f>(
handle, csr_matrix.view(), tfidf_vals.view(), true);
ASSERT_TRUE(raft::devArrMatch<Type_f>(tfidf_vals.data_handle(),
result.data_handle(),
result.size(),
Expand Down
Loading

0 comments on commit 80b527e

Please sign in to comment.