Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Add tfidf bm25 #2353

Open
wants to merge 133 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 122 commits
Commits
Show all changes
133 commits
Select commit Hold shift + click to select a range
a6677ca
update master references
ajschmidt8 Jul 14, 2020
ad2d7d7
REL DOC Updates for main branch switch
mike-wendt Jul 16, 2020
e3c9344
Merge pull request #272 from rapidsai/branch-21.06
ajschmidt8 Jun 10, 2021
3b0a6d2
Merge pull request #321 from rapidsai/branch-21.08
ajschmidt8 Sep 16, 2021
309ea1a
REL v21.08.00 release
GPUtester Apr 6, 2022
3740998
Merge pull request #612 from rapidsai/branch-22.04
raydouglass Apr 6, 2022
e987ec8
REL v22.04.00 release
GPUtester Apr 6, 2022
0b55c32
Add `conda` compilers (#702)
ajschmidt8 Jun 7, 2022
229b9f8
update changelog
raydouglass Jun 7, 2022
0eded98
Merge pull request #708 from rapidsai/branch-22.06
raydouglass Jun 7, 2022
3e5a625
FIX update-version.sh
raydouglass Jun 7, 2022
ad50a7f
Merge pull request #709 from rapidsai/branch-22.06
raydouglass Jun 7, 2022
ed2c529
REL v22.06.00 release
GPUtester Jun 7, 2022
aae5e34
Merge pull request #782 from rapidsai/branch-22.08
raydouglass Aug 17, 2022
87a7d16
REL v22.08.00 release
GPUtester Aug 17, 2022
1de93ba
Merge pull request #908 from rapidsai/branch-22.10
raydouglass Oct 12, 2022
31ae597
REL v22.10.00 release
GPUtester Oct 12, 2022
08abc72
[HOTFIX] Update cuda-python dependency to 11.7.1 (#963)
cjnolet Nov 4, 2022
c6e6ce8
Merge pull request #988 from rapidsai/branch-22.10
raydouglass Nov 4, 2022
f7d2335
REL v22.10.01 release
GPUtester Nov 4, 2022
c16fa56
Merge pull request #1063 from rapidsai/branch-22.12
raydouglass Dec 8, 2022
9a716b7
REL v22.12.00 release
GPUtester Dec 8, 2022
60936ba
Merge pull request #1101 from rapidsai/branch-22.12
raydouglass Dec 14, 2022
a655c9a
REL v22.12.01 release
GPUtester Dec 14, 2022
9a66f42
Merge pull request #1250 from rapidsai/branch-23.02
raydouglass Feb 9, 2023
69dce2d
REL v23.02.00 release
raydouglass Feb 9, 2023
1467154
Merge pull request #1405 from rapidsai/branch-23.04
raydouglass Apr 12, 2023
7d1057e
REL v23.04.00 release
raydouglass Apr 12, 2023
dc800d6
REL v23.04.01 release
raydouglass Apr 21, 2023
520e12c
REL Merge pull request #1486 from rapidsai/branch-23.04
raydouglass May 3, 2023
f626bf1
Merge pull request #1549 from rapidsai/branch-23.06
raydouglass Jun 7, 2023
c931b61
REL v23.06.00 release
raydouglass Jun 7, 2023
af1515d
Merge pull request #1589 from rapidsai/branch-23.06
raydouglass Jun 12, 2023
9147c90
REL v23.06.01 release
raydouglass Jun 12, 2023
59ae9d6
Merge pull request #1636 from rapidsai/branch-23.06
raydouglass Jul 5, 2023
7dd2f6d
REL v23.06.02 release
raydouglass Jul 5, 2023
5797ef5
Merge pull request #1692 from rapidsai/branch-23.08
raydouglass Aug 9, 2023
e588d7b
REL v23.08.00 release
raydouglass Aug 9, 2023
51f52c1
Merge pull request #1863 from rapidsai/branch-23.10
raydouglass Oct 11, 2023
afdddfb
REL v23.10.00 release
raydouglass Oct 11, 2023
e9f9aa8
Merge pull request #2020 from rapidsai/branch-23.12
raydouglass Dec 6, 2023
599651e
REL v23.12.00 release
raydouglass Dec 6, 2023
9e2d627
REL Revert update-version.sh changes for release
raydouglass Dec 6, 2023
1143113
Merge pull request #2134 from rapidsai/branch-24.02
raydouglass Feb 12, 2024
698d6c7
REL v24.02.00 release
raydouglass Feb 12, 2024
e0d40e5
Merge pull request #2240 from rapidsai/branch-24.04
raydouglass Apr 10, 2024
fa44bcc
REL v24.04.00 release
raydouglass Apr 10, 2024
41938c4
Merge pull request #2341 from rapidsai/branch-24.06
raydouglass Jun 5, 2024
63a506d
REL v24.06.00 release
raydouglass Jun 5, 2024
427ea26
add in support for preprocessing with bm25 and tfidf
jperez999 Jun 5, 2024
ffbfbc7
add in test cases and header file
jperez999 Jun 6, 2024
2d82aca
add tfidf coo support
jperez999 Jun 25, 2024
dc01bc1
add in header for coo tfidf
jperez999 Jun 25, 2024
6f4745d
add bm25 test support coo in and refactor tfidf support
jperez999 Jun 26, 2024
987ff5e
add in long test for coo to csr convert test
jperez999 Jun 28, 2024
c46008c
remove unneeded print statement
jperez999 Jun 28, 2024
81bb89d
remove unneeded test
jperez999 Jun 28, 2024
ff1991f
add csr and coo matrix bfknn apis
jperez999 Jul 3, 2024
c593f4e
add knn to preprocess tests
jperez999 Jul 3, 2024
0febb55
all tests in place and refactor code
jperez999 Jul 4, 2024
6477cd4
add in cmake for test files
jperez999 Jul 4, 2024
c836ba8
adjust tests, coo now passes all checks
jperez999 Jul 4, 2024
ce8253e
csr and coo tests passing, refactor feature preprocessing
jperez999 Jul 6, 2024
442cd7a
refactor names to make more generic
jperez999 Jul 7, 2024
b1720c7
further refactor to feature and id variable names
jperez999 Jul 7, 2024
3365ec3
add documentation and refactor to use num rows and num cols from matrix
jperez999 Jul 8, 2024
06b6df2
update tests to reflect values given refactor
jperez999 Jul 8, 2024
034d2c5
add documentation
jperez999 Jul 8, 2024
04bb007
removed unnecessary imports and variables
jperez999 Jul 8, 2024
3747291
fix function docs to reflect behavior more correctly
jperez999 Jul 9, 2024
281a029
Merge branch 'branch-24.08' into add-tfidf-bm25
jperez999 Jul 10, 2024
3d66d4b
Update docs/source/contributing.md
jperez999 Jul 10, 2024
2b70436
Update .github/PULL_REQUEST_TEMPLATE.md
jperez999 Jul 10, 2024
84ffc8b
Update .github/PULL_REQUEST_TEMPLATE.md
jperez999 Jul 10, 2024
63607bd
Merge branch 'branch-24.08' into add-tfidf-bm25
jperez999 Jul 12, 2024
0f462a9
Merge branch 'branch-24.08' into add-tfidf-bm25
jperez999 Jul 31, 2024
1fc27f3
Merge branch 'branch-24.10' into add-tfidf-bm25
jperez999 Jul 31, 2024
82cfb1f
Merge branch 'branch-24.10' into add-tfidf-bm25
jperez999 Aug 7, 2024
30d0352
Merge pull request #2399 from rapidsai/branch-24.08
raydouglass Aug 7, 2024
dd404d7
REL v24.08.00 release
raydouglass Aug 7, 2024
185da16
REL v24.08.01 release
raydouglass Aug 7, 2024
1155609
Merge branch 'branch-24.10' into add-tfidf-bm25
jperez999 Aug 14, 2024
6302957
Merge branch 'branch-24.10' into add-tfidf-bm25
cjnolet Aug 29, 2024
05f4af2
fix preprocessing and make tests run on r random at generation
jperez999 Sep 11, 2024
a1e3a48
remove unnecessary imports
jperez999 Sep 11, 2024
44f3e1c
remove log for tf
jperez999 Sep 11, 2024
e25e2de
added more template changes
jperez999 Sep 11, 2024
187e148
Merge branch 'branch-24.10' into add-tfidf-bm25
jperez999 Sep 11, 2024
ec4e4a2
Merge branch 'branch-24.10' into add-tfidf-bm25
jperez999 Sep 18, 2024
e6d2c1c
remove excess thrust calls
jperez999 Sep 24, 2024
5120c97
add better comment on inputs for tests
jperez999 Sep 24, 2024
81e2a41
Merge branch 'add-tfidf-bm25' of https://github.com/jperez999/raft in…
jperez999 Sep 24, 2024
90373ab
Merge branch 'branch-24.10' into add-tfidf-bm25
jperez999 Sep 24, 2024
87a729c
fixed scale errors
jperez999 Sep 26, 2024
63576b0
remove vector based public apis
jperez999 Sep 26, 2024
c123acb
add in bfknn tests for csr and coo sparse matrices
jperez999 Sep 27, 2024
29f14d9
Merge branch 'branch-24.12' into add-tfidf-bm25
rhdong Sep 27, 2024
226c82e
Merge pull request #2460 from rapidsai/branch-24.10
raydouglass Oct 9, 2024
397042a
REL v24.10.00 release
raydouglass Oct 9, 2024
0ca6e10
Merge branch 'branch-24.12' into add-tfidf-bm25
jperez999 Oct 16, 2024
b000065
remove unused functions
jperez999 Oct 17, 2024
3507771
Merge branch 'add-tfidf-bm25' of https://github.com/jperez999/raft in…
jperez999 Oct 17, 2024
766ff24
Merge branch 'branch-24.12' into add-tfidf-bm25
jperez999 Oct 22, 2024
8e172af
Merge branch 'branch-24.12' into add-tfidf-bm25
benfred Oct 25, 2024
123f3c8
Merge branch 'branch-24.12' into add-tfidf-bm25
jperez999 Oct 27, 2024
81b074b
add comment for thrust call replacement and merge main
jperez999 Oct 30, 2024
b6b5cd2
Merge branch 'add-tfidf-bm25' of https://github.com/jperez999/raft in…
jperez999 Oct 30, 2024
85d97d0
getting rid of changes to pull request template, not part of this PR
jperez999 Oct 30, 2024
c59bdf9
revert contributing md changes
jperez999 Oct 30, 2024
c871023
remove change to pre-commit-config.yaml
jperez999 Oct 30, 2024
04041c2
remove all changes to conda env files
jperez999 Oct 30, 2024
b022e6e
revert changes to python pyproject files
jperez999 Oct 30, 2024
52dd0d9
remove extra comment symbol in file
jperez999 Oct 30, 2024
8b53c8d
complete reversion of file to main
jperez999 Oct 30, 2024
d0e8750
revert dependencies file from merge
jperez999 Oct 30, 2024
f9c9a0b
file revert
jperez999 Oct 30, 2024
9db7cd9
revert contributing md
jperez999 Oct 30, 2024
a70619e
revert contributing md
jperez999 Oct 30, 2024
2b0202a
add in review comments
jperez999 Oct 31, 2024
c2473a9
change return to float
jperez999 Nov 19, 2024
eb733ae
Merge branch 'branch-24.12' into add-tfidf-bm25
jperez999 Nov 26, 2024
15f53c3
adding in comment fixes, failing big csr tests still
jperez999 Dec 4, 2024
80b527e
fix all tests csr and coo
jperez999 Dec 20, 2024
1c15944
Merge branch 'branch-25.02' into add-tfidf-bm25
jperez999 Jan 15, 2025
c14307d
overhaul sparse encoding
jperez999 Jan 23, 2025
7442b10
Merge branch 'add-tfidf-bm25' of https://github.com/jperez999/raft in…
jperez999 Jan 23, 2025
5e90c5f
migrate class out of detail
jperez999 Jan 23, 2025
d1365ec
add basic hash function to ensure stability no matter size
jperez999 Jan 23, 2025
7eb6f76
add documentation for details functions.
jperez999 Jan 24, 2025
47c8288
full working sparse encoder, batchable
jperez999 Jan 26, 2025
a013006
merge in 25.02
jperez999 Jan 27, 2025
70f5666
remove build sh from template folder
jperez999 Jan 27, 2025
2ff4c8d
add comments spareencoder class
jperez999 Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
539 changes: 539 additions & 0 deletions cpp/include/raft/sparse/matrix/detail/preprocessing.cuh

Large diffs are not rendered by default.

107 changes: 107 additions & 0 deletions cpp/include/raft/sparse/matrix/preprocessing.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/sparse/matrix/detail/preprocessing.cuh>

#include <optional>

namespace raft::sparse::matrix {

/**
* @brief Use BM25 algorithm to encode features in COO sparse matrix
* @tparam IndexType is the type of the edges index in the coo matrix
* @tparam ValueType is the type of the values array in the coo matrix
* @tparam IdxT is the type of the indices of arrays in matrix
* @param handle: raft resource handle
* @param coo_in: Input COO matrix
* @param values_out: Output values array
* @param k_param: K value to use for BM25 algorithm
* @param b_param: B value to use for BM25 algorithm
*/
template <typename IndexType, typename ValueType, typename IdxT>
void encode_bm25(raft::resources& handle,
raft::device_coo_matrix_view<ValueType, IndexType, IndexType, IndexType> coo_in,
raft::device_vector_view<ValueType, IdxT> values_out,
float k_param = 1.6f,
float b_param = 0.75)
{
return matrix::detail::encode_bm25<IndexType, ValueType, IdxT>(
handle, coo_in, values_out, k_param, b_param);
}

/**
* @brief Use BM25 algorithm to encode features in CSR sparse matrix
* @param handle: raft resource handle
* @tparam IndexType is the type of the edges index in the csr matrix
* @tparam ValueType is the type of the values array in the csr matrix
* @tparam IdxT is the type of the indices of arrays in matrix
* @param csr_in: Input CSR matrix
* @param values_out: Output values array
* @param k_param: K value to use for BM25 algorithm
* @param b_param: B value to use for BM25 algorithm
*/
template <typename IndexType, typename ValueType, typename IdxT>
void encode_bm25(raft::resources& handle,
raft::device_csr_matrix_view<ValueType, IndexType, IndexType, IndexType> csr_in,
raft::device_vector_view<ValueType, IdxT> values_out,
float k_param = 1.6f,
float b_param = 0.75)
{
return matrix::detail::encode_bm25<IndexType, ValueType, IdxT>(
handle, csr_in, values_out, k_param, b_param);
}

/**
* @brief Use TFIDF algorithm to encode features in COO sparse matrix
* @tparam IndexType is the type of the edges index in the coo matrix
* @tparam ValueType is the type of the values array in the coo matrix
* @tparam IdxT is the type of the indices of arrays in matrix
* @param handle: raft resource handle
* @param coo_in: Input COO matrix
* @param values_out: Output COO values array
*/
template <typename IndexType, typename ValueType, typename IdxT>
void encode_tfidf(raft::resources& handle,
raft::device_coo_matrix_view<ValueType, IndexType, IndexType, IndexType> coo_in,
raft::device_vector_view<ValueType, IdxT> values_out)
{
return matrix::detail::encode_tfidf<IndexType, ValueType, IdxT>(handle, coo_in, values_out);
}

/**
* @brief Use TFIDF algorithm to encode features in CSR sparse matrix
* @tparam IndexType is the type of the edges index in the csr matrix
* @tparam ValueType is the type of the values array in the csr matrix
* @tparam IdxT is the type of the indices of arrays in matrix
* @param handle: raft resource handle
* @param csr_in: Input CSR matrix
* @param values_out: Output values array
*/
template <typename IndexType, typename ValueType, typename IdxT>
void encode_tfidf(raft::resources& handle,
raft::device_csr_matrix_view<ValueType, IndexType, IndexType, IndexType> csr_in,
raft::device_vector_view<ValueType, IdxT> values_out)
{
return matrix::detail::encode_tfidf<IndexType, ValueType, IdxT>(handle, csr_in, values_out);
}

} // namespace raft::sparse::matrix
4 changes: 3 additions & 1 deletion cpp/include/raft/sparse/neighbors/brute_force.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,6 +25,8 @@ namespace raft::sparse::neighbors::brute_force {
/**
* Search the sparse kNN for the k-nearest neighbors of a set of sparse query vectors
* using some distance implementation
* template parameter value_idx is the type of the Indptr and Indices arrays.
* template parameter value_t is the type of the Data array.
* @param[in] idxIndptr csr indptr of the index matrix (size n_idx_rows + 1)
* @param[in] idxIndices csr column indices array of the index matrix (size n_idx_nnz)
* @param[in] idxData csr data array of the index matrix (size idxNNZ)
Expand Down
184 changes: 183 additions & 1 deletion cpp/include/raft/sparse/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@
" Please use the sparse/spatial version instead.")
#endif

#include <raft/core/device_coo_matrix.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/sparse/convert/csr.cuh>
#include <raft/sparse/neighbors/brute_force.cuh>
#include <raft/sparse/op/sort.cuh>

namespace raft::sparse::neighbors {

Expand Down Expand Up @@ -59,7 +62,7 @@ namespace raft::sparse::neighbors {
* @param[in] metric distance metric/measure to use
* @param[in] metricArg potential argument for metric (currently unused)
*/
template <typename value_idx = int, typename value_t = float, int TPB_X = 32>
template <typename value_idx = int, typename value_t = float>
void brute_force_knn(const value_idx* idxIndptr,
const value_idx* idxIndices,
const value_t* idxData,
Expand Down Expand Up @@ -103,4 +106,183 @@ void brute_force_knn(const value_idx* idxIndptr,
metricArg);
}

/**
* Search the sparse kNN for the k-nearest neighbors of a set of sparse query vectors
* using some distance implementation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add the comments for the template parameters.

* @tparam value_idx is the type of the edges index in the csr matrix
* @tparam value_t is the type of the values array in the csr matrix
* @param[in] csr_idx index csr matrix
* @param[in] csr_query query csr matrix
* @param[out] output_indices dense matrix for output indices (size n_query_rows * k)
* @param[out] output_dists dense matrix for output distances (size n_query_rows * k)
* @param[in] k the number of neighbors to query
* @param[in] handle CUDA resource::get_cuda_stream(handle) to order operations with respect to
* @param[in] batch_size_index maximum number of rows to use from index matrix per batch
* @param[in] batch_size_query maximum number of rows to use from query matrix per batch
* @param[in] metric distance metric/measure to use
* @param[in] metricArg potential argument for metric (currently unused)
*/
template <typename value_idx = int, typename value_t = float>
void brute_force_knn(raft::device_csr_matrix<value_t,
value_idx,
value_idx,
value_idx,
raft::device_uvector_policy,
raft::PRESERVING> csr_idx,
raft::device_csr_matrix<value_t,
value_idx,
value_idx,
value_idx,
raft::device_uvector_policy,
raft::PRESERVING> csr_query,
device_vector_view<value_idx> output_indices,
device_vector_view<value_t> output_dists,
int k,
raft::resources const& handle,
size_t batch_size_index = 2 << 14, // approx 1M
size_t batch_size_query = 2 << 14,
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded,
float metricArg = 0)
{
auto idxIndptr = csr_idx.structure_view().get_indptr();
auto idxIndices = csr_idx.structure_view().get_indices();
auto idxData = csr_idx.view().get_elements();

RAFT_EXPECTS(idxData.size() > 0, "No Values were detected in the Index CSR Matrix.");

auto queryIndptr = csr_query.structure_view().get_indptr();
auto queryIndices = csr_query.structure_view().get_indices();
auto queryData = csr_query.view().get_elements();

RAFT_EXPECTS(queryData.size() > 0, "No Values were detected in the Query CSR Matrix.");

brute_force::knn<value_idx, value_t>(idxIndptr.data(),
idxIndices.data(),
idxData.data(),
idxIndices.size(),
idxIndptr.size() - 1,
csr_idx.structure_view().get_n_cols(),
queryIndptr.data(),
queryIndices.data(),
queryData.data(),
queryIndices.size(),
queryIndptr.size() - 1,
csr_query.structure_view().get_n_cols(),
output_indices.data_handle(),
output_dists.data_handle(),
k,
handle,
batch_size_index,
batch_size_query,
metric,
metricArg);
}

/**
* Search the sparse kNN for the k-nearest neighbors of a set of sparse query vectors
* using some distance implementation
* @tparam value_idx is the type of the edges index in the coo matrix
* @tparam value_t is the type of the values array in the coo matrix
* @param[in] coo_idx index coo matrix
* @param[in] coo_query query coo matrix
* @param[out] output_indices dense matrix for output indices (size n_query_rows * k)
* @param[out] output_dists dense matrix for output distances (size n_query_rows * k)
* @param[in] k the number of neighbors to query
* @param[in] handle CUDA resource::get_cuda_stream(handle) to order operations with respect to
* @param[in] batch_size_index maximum number of rows to use from index matrix per batch
* @param[in] batch_size_query maximum number of rows to use from query matrix per batch
* @param[in] metric distance metric/measure to use
* @param[in] metricArg potential argument for metric (currently unused)
*/
template <typename value_idx = int, typename value_t = float>
void brute_force_knn(raft::device_coo_matrix<value_t,
value_idx,
value_idx,
value_idx,
raft::device_uvector_policy,
raft::PRESERVING> coo_idx,
raft::device_coo_matrix<value_t,
value_idx,
value_idx,
value_idx,
raft::device_uvector_policy,
raft::PRESERVING> coo_query,
device_vector_view<value_idx> output_indices,
device_vector_view<value_t> output_dists,
int k,
raft::resources const& handle,
size_t batch_size_index = 2 << 14, // approx 1M
size_t batch_size_query = 2 << 14,
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded,
float metricArg = 0)
{
cudaStream_t stream = raft::resource::get_cuda_stream(handle);

Copy link
Member

@rhdong rhdong Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could add a judgment for 0 size data for idx and query, though it should happen rarely. (Considering the following code includes the logic of size() - 1)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rhdong do you think I should raise and error or just return before performing bfknn?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should depend on the logic: to return directly (keeping no change on the outputs), if it is normal to have zero-size input, or you could use RAFT_EXPECTS to notify the caller.

auto idxRows = coo_idx.structure_view().get_rows();
auto idxCols = coo_idx.structure_view().get_cols();
auto idxData = coo_idx.view().get_elements();

RAFT_EXPECTS(idxData.size() > 0, "No Values were detected in the Index COO Matrix.");

auto queryRows = coo_query.structure_view().get_rows();
auto queryCols = coo_query.structure_view().get_cols();
auto queryData = coo_query.view().get_elements();

RAFT_EXPECTS(queryData.size() > 0, "No Values were detected in the Query COO Matrix.");

raft::sparse::op::coo_sort(int(idxRows.size()),
int(idxCols.size()),
int(idxData.size()),
idxRows.data(),
idxCols.data(),
idxRows.data(),
stream);

raft::sparse::op::coo_sort(int(queryRows.size()),
int(queryCols.size()),
int(queryData.size()),
queryRows.data(),
queryCols.data(),
queryData.data(),
stream);
// + 1 is to account for the 0 at the beginning of the csr representation
auto idxRowsCsr = raft::make_device_vector<value_idx, int64_t>(
handle, coo_query.structure_view().get_n_rows() + 1);
auto queryRowsCsr = raft::make_device_vector<value_idx, int64_t>(
handle, coo_query.structure_view().get_n_rows() + 1);

raft::sparse::convert::sorted_coo_to_csr(idxRows.data(),
int(idxRows.size()),
idxRowsCsr.data_handle(),
coo_idx.structure_view().get_n_rows() + 1,
stream);

raft::sparse::convert::sorted_coo_to_csr(queryRows.data(),
int(queryRows.size()),
queryRowsCsr.data_handle(),
coo_query.structure_view().get_n_rows() + 1,
stream);

brute_force::knn<value_idx, value_t>(idxRowsCsr.data_handle(),
idxCols.data(),
idxData.data(),
idxCols.size(),
idxRowsCsr.size() - 1,
coo_idx.structure_view().get_n_cols(),
queryRowsCsr.data_handle(),
queryCols.data(),
queryData.data(),
queryCols.size(),
queryRowsCsr.size() - 1,
coo_query.structure_view().get_n_cols(),
output_indices.data_handle(),
output_dists.data_handle(),
k,
handle,
batch_size_index,
batch_size_query,
metric,
metricArg);
}

}; // namespace raft::sparse::neighbors
41 changes: 41 additions & 0 deletions cpp/template/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/bin/bash

# Copyright (c) 2023-2024, NVIDIA CORPORATION.

# raft empty project template build script

# Abort script on first error
set -e

PARALLEL_LEVEL=${PARALLEL_LEVEL:=`nproc`}

BUILD_TYPE=Release
BUILD_DIR=build/

RAFT_REPO_REL=""
EXTRA_CMAKE_ARGS=""
set -e


if [[ ${RAFT_REPO_REL} != "" ]]; then
RAFT_REPO_PATH="`readlink -f \"${RAFT_REPO_REL}\"`"
EXTRA_CMAKE_ARGS="${EXTRA_CMAKE_ARGS} -DCPM_raft_SOURCE=${RAFT_REPO_PATH}"
fi

if [ "$1" == "clean" ]; then
rm -rf build
exit 0
fi

mkdir -p $BUILD_DIR
cd $BUILD_DIR

cmake \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DRAFT_NVTX=OFF \
-DCMAKE_CUDA_ARCHITECTURES="RAPIDS" \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
${EXTRA_CMAKE_ARGS} \
../

cmake --build . -j${PARALLEL_LEVEL}
2 changes: 2 additions & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ if(BUILD_TESTS)
sparse/spgemmi.cu
sparse/spmm.cu
sparse/symmetrize.cu
sparse/preprocess_csr.cu
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only needed to add these two files but it seems like the formatter wanted me to make the other changes?

sparse/preprocess_coo.cu
)

ConfigureTest(
Expand Down
Loading
Loading