-
Notifications
You must be signed in to change notification settings - Fork 194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[REVIEW] Add tfidf bm25 #2353
base: branch-24.12
Are you sure you want to change the base?
[REVIEW] Add tfidf bm25 #2353
Changes from all commits
a6677ca
ad2d7d7
e3c9344
3b0a6d2
309ea1a
3740998
e987ec8
0b55c32
229b9f8
0eded98
3e5a625
ad50a7f
ed2c529
aae5e34
87a7d16
1de93ba
31ae597
08abc72
c6e6ce8
f7d2335
c16fa56
9a716b7
60936ba
a655c9a
9a66f42
69dce2d
1467154
7d1057e
dc800d6
520e12c
f626bf1
c931b61
af1515d
9147c90
59ae9d6
7dd2f6d
5797ef5
e588d7b
51f52c1
afdddfb
e9f9aa8
599651e
9e2d627
1143113
698d6c7
e0d40e5
fa44bcc
41938c4
63a506d
427ea26
ffbfbc7
2d82aca
dc01bc1
6f4745d
987ff5e
c46008c
81bb89d
ff1991f
c593f4e
0febb55
6477cd4
c836ba8
ce8253e
442cd7a
b1720c7
3365ec3
06b6df2
034d2c5
04bb007
3747291
281a029
3d66d4b
2b70436
84ffc8b
63607bd
0f462a9
1fc27f3
82cfb1f
30d0352
dd404d7
185da16
1155609
6302957
05f4af2
a1e3a48
44f3e1c
e25e2de
187e148
ec4e4a2
e6d2c1c
5120c97
81e2a41
90373ab
87a729c
63576b0
c123acb
29f14d9
226c82e
397042a
0ca6e10
b000065
3507771
766ff24
8e172af
123f3c8
81b074b
b6b5cd2
85d97d0
c59bdf9
c871023
04041c2
b022e6e
52dd0d9
8b53c8d
d0e8750
f9c9a0b
9db7cd9
a70619e
2b0202a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <raft/core/device_csr_matrix.hpp> | ||
#include <raft/core/device_mdspan.hpp> | ||
#include <raft/core/resource/cuda_stream.hpp> | ||
#include <raft/core/resources.hpp> | ||
#include <raft/sparse/matrix/detail/preprocessing.cuh> | ||
|
||
#include <optional> | ||
|
||
namespace raft::sparse::matrix { | ||
|
||
/** | ||
* @brief Use BM25 algorithm to encode features in COO sparse matrix | ||
* @param handle: raft resource handle | ||
* @param coo_in: Input COO matrix | ||
* @param values_out: Output values array | ||
* @param k_param: K value to use for BM25 algorithm | ||
* @param b_param: B value to use for BM25 algorithm | ||
*/ | ||
template <typename T1, typename T2, typename IdxT> | ||
void encode_bm25(raft::resources& handle, | ||
raft::device_coo_matrix_view<T2, T1, T1, T1> coo_in, | ||
raft::device_vector_view<T2, IdxT> values_out, | ||
float k_param = 1.6f, | ||
float b_param = 0.75) | ||
{ | ||
return matrix::detail::encode_bm25<T1, T2, IdxT>(handle, coo_in, values_out, k_param, b_param); | ||
} | ||
|
||
/** | ||
* @brief Use BM25 algorithm to encode features in CSR sparse matrix | ||
* @param handle: raft resource handle | ||
* @param csr_in: Input CSR matrix | ||
* @param values_out: Output values array | ||
* @param k_param: K value to use for BM25 algorithm | ||
* @param b_param: B value to use for BM25 algorithm | ||
*/ | ||
template <typename T1, typename T2, typename IdxT> | ||
void encode_bm25(raft::resources& handle, | ||
raft::device_csr_matrix_view<T2, T1, T1, T1> csr_in, | ||
raft::device_vector_view<T2, IdxT> values_out, | ||
float k_param = 1.6f, | ||
float b_param = 0.75) | ||
{ | ||
return matrix::detail::encode_bm25<T1, T2, IdxT>(handle, csr_in, values_out, k_param, b_param); | ||
} | ||
|
||
/** | ||
* @brief Use TFIDF algorithm to encode features in COO sparse matrix | ||
* @param handle: raft resource handle | ||
* @param coo_in: Input COO matrix | ||
* @param values_out: Output COO values array | ||
*/ | ||
template <typename T1, typename T2, typename IdxT> | ||
void encode_tfidf(raft::resources& handle, | ||
raft::device_coo_matrix_view<T2, T1, T1, T1> coo_in, | ||
raft::device_vector_view<T2, IdxT> values_out) | ||
{ | ||
return matrix::detail::encode_tfidf<T1, T2, IdxT>(handle, coo_in, values_out); | ||
} | ||
|
||
/** | ||
* @brief Use TFIDF algorithm to encode features in CSR sparse matrix | ||
* @param handle: raft resource handle | ||
* @param csr_in: Input CSR matrix | ||
* @param values_out: Output values array | ||
*/ | ||
template <typename T1, typename T2, typename IdxT> | ||
void encode_tfidf(raft::resources& handle, | ||
raft::device_csr_matrix_view<T2, T1, T1, T1> csr_in, | ||
raft::device_vector_view<T2, IdxT> values_out) | ||
{ | ||
return matrix::detail::encode_tfidf<T1, T2, IdxT>(handle, csr_in, values_out); | ||
} | ||
|
||
} // namespace raft::sparse::matrix |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,8 +30,11 @@ | |
" Please use the sparse/spatial version instead.") | ||
#endif | ||
|
||
#include <raft/core/device_coo_matrix.hpp> | ||
#include <raft/core/resource/cuda_stream.hpp> | ||
#include <raft/sparse/convert/csr.cuh> | ||
#include <raft/sparse/neighbors/brute_force.cuh> | ||
#include <raft/sparse/op/sort.cuh> | ||
|
||
namespace raft::sparse::neighbors { | ||
|
||
|
@@ -59,7 +62,7 @@ namespace raft::sparse::neighbors { | |
* @param[in] metric distance metric/measure to use | ||
* @param[in] metricArg potential argument for metric (currently unused) | ||
*/ | ||
template <typename value_idx = int, typename value_t = float, int TPB_X = 32> | ||
template <typename value_idx = int, typename value_t = float> | ||
void brute_force_knn(const value_idx* idxIndptr, | ||
const value_idx* idxIndices, | ||
const value_t* idxData, | ||
|
@@ -103,4 +106,171 @@ void brute_force_knn(const value_idx* idxIndptr, | |
metricArg); | ||
} | ||
|
||
/** | ||
* Search the sparse kNN for the k-nearest neighbors of a set of sparse query vectors | ||
* using some distance implementation | ||
* @param[in] csr_idx index csr matrix | ||
* @param[in] csr_query query csr matrix | ||
* @param[out] output_indices dense matrix for output indices (size n_query_rows * k) | ||
* @param[out] output_dists dense matrix for output distances (size n_query_rows * k) | ||
* @param[in] k the number of neighbors to query | ||
* @param[in] handle CUDA resource::get_cuda_stream(handle) to order operations with respect to | ||
* @param[in] batch_size_index maximum number of rows to use from index matrix per batch | ||
* @param[in] batch_size_query maximum number of rows to use from query matrix per batch | ||
* @param[in] metric distance metric/measure to use | ||
* @param[in] metricArg potential argument for metric (currently unused) | ||
*/ | ||
template <typename value_idx = int, typename value_t = float> | ||
void brute_force_knn(raft::device_csr_matrix<value_t, | ||
value_idx, | ||
value_idx, | ||
value_idx, | ||
raft::device_uvector_policy, | ||
raft::PRESERVING> csr_idx, | ||
raft::device_csr_matrix<value_t, | ||
value_idx, | ||
value_idx, | ||
value_idx, | ||
raft::device_uvector_policy, | ||
raft::PRESERVING> csr_query, | ||
device_vector_view<value_idx> output_indices, | ||
device_vector_view<value_t> output_dists, | ||
int k, | ||
raft::resources const& handle, | ||
size_t batch_size_index = 2 << 14, // approx 1M | ||
size_t batch_size_query = 2 << 14, | ||
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, | ||
float metricArg = 0) | ||
{ | ||
auto idxIndptr = csr_idx.structure_view().get_indptr(); | ||
auto idxIndices = csr_idx.structure_view().get_indices(); | ||
auto idxData = csr_idx.view().get_elements(); | ||
|
||
auto queryIndptr = csr_query.structure_view().get_indptr(); | ||
auto queryIndices = csr_query.structure_view().get_indices(); | ||
auto queryData = csr_query.view().get_elements(); | ||
|
||
brute_force::knn<value_idx, value_t>(idxIndptr.data(), | ||
idxIndices.data(), | ||
idxData.data(), | ||
idxIndices.size(), | ||
idxIndptr.size() - 1, | ||
csr_idx.structure_view().get_n_cols(), | ||
queryIndptr.data(), | ||
queryIndices.data(), | ||
queryData.data(), | ||
queryIndices.size(), | ||
queryIndptr.size() - 1, | ||
csr_query.structure_view().get_n_cols(), | ||
output_indices.data_handle(), | ||
output_dists.data_handle(), | ||
k, | ||
handle, | ||
batch_size_index, | ||
batch_size_query, | ||
metric, | ||
metricArg); | ||
} | ||
|
||
/** | ||
* Search the sparse kNN for the k-nearest neighbors of a set of sparse query vectors | ||
* using some distance implementation | ||
* @param[in] coo_idx index coo matrix | ||
* @param[in] coo_query query coo matrix | ||
* @param[out] output_indices dense matrix for output indices (size n_query_rows * k) | ||
* @param[out] output_dists dense matrix for output distances (size n_query_rows * k) | ||
* @param[in] k the number of neighbors to query | ||
* @param[in] handle CUDA resource::get_cuda_stream(handle) to order operations with respect to | ||
* @param[in] batch_size_index maximum number of rows to use from index matrix per batch | ||
* @param[in] batch_size_query maximum number of rows to use from query matrix per batch | ||
* @param[in] metric distance metric/measure to use | ||
* @param[in] metricArg potential argument for metric (currently unused) | ||
*/ | ||
template <typename value_idx = int, typename value_t = float> | ||
void brute_force_knn(raft::device_coo_matrix<value_t, | ||
value_idx, | ||
value_idx, | ||
value_idx, | ||
raft::device_uvector_policy, | ||
raft::PRESERVING> coo_idx, | ||
raft::device_coo_matrix<value_t, | ||
value_idx, | ||
value_idx, | ||
value_idx, | ||
raft::device_uvector_policy, | ||
raft::PRESERVING> coo_query, | ||
device_vector_view<value_idx> output_indices, | ||
device_vector_view<value_t> output_dists, | ||
int k, | ||
raft::resources const& handle, | ||
size_t batch_size_index = 2 << 14, // approx 1M | ||
size_t batch_size_query = 2 << 14, | ||
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, | ||
float metricArg = 0) | ||
{ | ||
cudaStream_t stream = raft::resource::get_cuda_stream(handle); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we could add a judgment for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rhdong do you think I should raise and error or just return before performing bfknn? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should depend on the logic: to return directly (keeping no change on the outputs), if it is normal to have zero-size input, or you could use |
||
auto idxRows = coo_idx.structure_view().get_rows(); | ||
auto idxCols = coo_idx.structure_view().get_cols(); | ||
auto idxData = coo_idx.view().get_elements(); | ||
|
||
auto queryRows = coo_query.structure_view().get_rows(); | ||
auto queryCols = coo_query.structure_view().get_cols(); | ||
auto queryData = coo_query.view().get_elements(); | ||
|
||
raft::sparse::op::coo_sort(int(idxRows.size()), | ||
int(idxCols.size()), | ||
int(idxData.size()), | ||
idxRows.data(), | ||
idxCols.data(), | ||
idxRows.data(), | ||
stream); | ||
|
||
raft::sparse::op::coo_sort(int(queryRows.size()), | ||
int(queryCols.size()), | ||
int(queryData.size()), | ||
queryRows.data(), | ||
queryCols.data(), | ||
queryData.data(), | ||
stream); | ||
// + 1 is to account for the 0 at the beginning of the csr representation | ||
auto idxRowsCsr = raft::make_device_vector<value_idx, int64_t>( | ||
handle, coo_query.structure_view().get_n_rows() + 1); | ||
auto queryRowsCsr = raft::make_device_vector<value_idx, int64_t>( | ||
handle, coo_query.structure_view().get_n_rows() + 1); | ||
|
||
raft::sparse::convert::sorted_coo_to_csr(idxRows.data(), | ||
int(idxRows.size()), | ||
idxRowsCsr.data_handle(), | ||
coo_idx.structure_view().get_n_rows() + 1, | ||
stream); | ||
|
||
raft::sparse::convert::sorted_coo_to_csr(queryRows.data(), | ||
int(queryRows.size()), | ||
queryRowsCsr.data_handle(), | ||
coo_query.structure_view().get_n_rows() + 1, | ||
stream); | ||
|
||
brute_force::knn<value_idx, value_t>(idxRowsCsr.data_handle(), | ||
idxCols.data(), | ||
idxData.data(), | ||
idxCols.size(), | ||
idxRowsCsr.size() - 1, | ||
coo_idx.structure_view().get_n_cols(), | ||
queryRowsCsr.data_handle(), | ||
queryCols.data(), | ||
queryData.data(), | ||
queryCols.size(), | ||
queryRowsCsr.size() - 1, | ||
coo_query.structure_view().get_n_cols(), | ||
output_indices.data_handle(), | ||
output_dists.data_handle(), | ||
k, | ||
handle, | ||
batch_size_index, | ||
batch_size_query, | ||
metric, | ||
metricArg); | ||
} | ||
|
||
}; // namespace raft::sparse::neighbors |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -319,6 +319,8 @@ if(BUILD_TESTS) | |
sparse/spgemmi.cu | ||
sparse/spmm.cu | ||
sparse/symmetrize.cu | ||
sparse/preprocess_csr.cu | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I only needed to add these two files but it seems like the formatter wanted me to make the other changes? |
||
sparse/preprocess_coo.cu | ||
) | ||
|
||
ConfigureTest( | ||
|
@@ -327,8 +329,16 @@ if(BUILD_TESTS) | |
) | ||
|
||
ConfigureTest( | ||
NAME SPARSE_NEIGHBORS_TEST PATH sparse/neighbors/cross_component_nn.cu | ||
sparse/neighbors/brute_force.cu sparse/neighbors/knn_graph.cu LIB EXPLICIT_INSTANTIATE_ONLY | ||
NAME | ||
SPARSE_NEIGHBORS_TEST | ||
PATH | ||
sparse/neighbors/cross_component_nn.cu | ||
sparse/neighbors/brute_force.cu | ||
sparse/neighbors/brute_force_coo.cu | ||
sparse/neighbors/brute_force_csr.cu | ||
sparse/neighbors/knn_graph.cu | ||
LIB | ||
EXPLICIT_INSTANTIATE_ONLY | ||
) | ||
|
||
ConfigureTest( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should add the comments for the template parameters.