-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feat] Support bitset_to_csr
#2523
base: branch-25.02
Are you sure you want to change the base?
Conversation
RAFT_EXPECTS(c_true_nnz == c_indices_d.size(), | ||
"Something wrong. The c_true_nnz != c_indices_d.size()!"); | ||
|
||
update_device(c_data_d.data(), values.data(), c_true_nnz, stream); | ||
update_device(c_indices_d.data(), indices.data(), c_true_nnz, stream); | ||
update_device(c_indptr_d.data(), indptr.data(), params.m + 1, stream); | ||
update_device(bitmap_d.data(), bitmap_h.data(), element, stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to support both bitmap and bitset inputs but it appears we're removing the bitmap support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bits
is only a naming; the code is needed to be compatible with bitset and bitmap, so I need to change bitmap
to bits
, as the compatible control point is here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not easy to understand and it’s not obvious which one is selected when true or false (future eyes are going to be confused too) . Let’s create an enum for this that we can share across benchmarks and tests. It’ll make this more straightforward for future eyes too.
* @param[in] alpha Scalar multiplier for (A * B) (default: 1.0 if std::nullopt). | ||
* @param[in] beta Scalar multiplier for the initial C (default: 0 if std::nullopt). | ||
*/ | ||
template <typename value_t, typename output_t, typename index_t, typename nnz_t, typename bitmap_t> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is including code that's running kernels. We shouldn't have this in an hpp file. That's also very misleading to users because they are going to import an hpp file thinking no kernels are being compiled. Please remove this hpp file and just have users include the cuh
file. I understand it looks like we've used this for the bitmap as well, but this was an oversight. hpp files should never include cuh files.
cpp/test/sparse/masked_matmul.cu
Outdated
@@ -87,7 +87,8 @@ bool isCuSparseVersionGreaterThan_12_0_1() | |||
template <typename value_t, | |||
typename output_t, | |||
typename index_t, | |||
typename bitmap_t = uint32_t, | |||
bool bitmap_or_bitset = true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this expecting true
when it's one or the other? This is a little confusing. Please make an enum for this so that it's more clear to future eyes who might need to update this code.
* @param[out] csr Output parameter where the resulting CSR matrix is stored. In the | ||
* bitset, each '1' bit corresponds to a non-zero element in the CSR matrix. | ||
*/ | ||
template <typename bitset_t, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please provide a brief usage example. All public API functions should have a nice copy/paste usage example for the docs. Please also make sure sure these functions (both bitset and bitmap) are being included in the docs.
This API,
bitset_to_csr,
will be utilized to implement the `bitset'- based filter for prefiltered Brute Force.