-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feat] Support bitset_to_csr
#2523
Open
rhdong
wants to merge
6
commits into
rapidsai:branch-25.02
Choose a base branch
from
rhdong:rhdong/bitset-to-csr
base: branch-25.02
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 4 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
06f6f29
[Feat] Support `bitset_to_csr`
rhdong 5e259e9
Merge branch 'branch-25.02' into rhdong/bitset-to-csr
rhdong ab6d71e
`masked_matmul` supports bitset mask
rhdong 162a741
Merge branch 'branch-25.02' into rhdong/bitset-to-csr
cjnolet df9faf5
Merge branch 'branch-25.02' into rhdong/bitset-to-csr
rhdong 94d90a7
optimization for review comments
rhdong File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include <common/benchmark.hpp> | ||
|
||
#include <raft/core/device_resources.hpp> | ||
#include <raft/core/resource/cuda_stream.hpp> | ||
#include <raft/core/resources.hpp> | ||
#include <raft/sparse/convert/csr.cuh> | ||
#include <raft/util/itertools.hpp> | ||
|
||
#include <rmm/device_uvector.hpp> | ||
|
||
#include <sstream> | ||
#include <vector> | ||
|
||
namespace raft::bench::sparse { | ||
|
||
template <typename index_t> | ||
struct bench_param { | ||
index_t n_repeat; | ||
index_t n_cols; | ||
float sparsity; | ||
}; | ||
|
||
template <typename index_t> | ||
inline auto operator<<(std::ostream& os, const bench_param<index_t>& params) -> std::ostream& | ||
{ | ||
os << " rows*cols=" << params.n_repeat << "*" << params.n_cols | ||
<< "\tsparsity=" << params.sparsity; | ||
return os; | ||
} | ||
|
||
template <typename bitset_t, typename index_t, typename value_t = float> | ||
struct BitsetToCsrBench : public fixture { | ||
BitsetToCsrBench(const bench_param<index_t>& p) | ||
: fixture(true), | ||
params(p), | ||
handle(stream), | ||
bitset_d(0, stream), | ||
nnz(0), | ||
indptr_d(0, stream), | ||
indices_d(0, stream), | ||
values_d(0, stream) | ||
{ | ||
index_t element = raft::ceildiv(1 * params.n_cols, index_t(sizeof(bitset_t) * 8)); | ||
std::vector<bitset_t> bitset_h(element); | ||
nnz = create_sparse_matrix(1, params.n_cols, params.sparsity, bitset_h); | ||
|
||
bitset_d.resize(bitset_h.size(), stream); | ||
indptr_d.resize(params.n_repeat + 1, stream); | ||
indices_d.resize(nnz, stream); | ||
values_d.resize(nnz, stream); | ||
|
||
update_device(bitset_d.data(), bitset_h.data(), bitset_h.size(), stream); | ||
|
||
resource::sync_stream(handle); | ||
} | ||
|
||
index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector<bitset_t>& bitset) | ||
{ | ||
index_t total = static_cast<index_t>(m * n); | ||
index_t num_ones = static_cast<index_t>((total * 1.0f) * (1.0f - sparsity)); | ||
index_t res = num_ones; | ||
|
||
for (auto& item : bitset) { | ||
item = static_cast<bitset_t>(0); | ||
} | ||
|
||
std::random_device rd; | ||
std::mt19937 gen(rd()); | ||
std::uniform_int_distribution<index_t> dis(0, total - 1); | ||
|
||
while (num_ones > 0) { | ||
index_t index = dis(gen); | ||
|
||
bitset_t& element = bitset[index / (8 * sizeof(bitset_t))]; | ||
index_t bit_position = index % (8 * sizeof(bitset_t)); | ||
|
||
if (((element >> bit_position) & 1) == 0) { | ||
element |= (static_cast<index_t>(1) << bit_position); | ||
num_ones--; | ||
} | ||
} | ||
return res; | ||
} | ||
|
||
void run_benchmark(::benchmark::State& state) override | ||
{ | ||
std::ostringstream label_stream; | ||
label_stream << params; | ||
state.SetLabel(label_stream.str()); | ||
|
||
auto bitset = raft::core::bitset_view<bitset_t, index_t>(bitset_d.data(), 1 * params.n_cols); | ||
|
||
auto csr_view = raft::make_device_compressed_structure_view<index_t, index_t, index_t>( | ||
indptr_d.data(), indices_d.data(), params.n_repeat, params.n_cols, nnz); | ||
auto csr = raft::make_device_csr_matrix<value_t, index_t>(handle, csr_view); | ||
|
||
raft::sparse::convert::bitset_to_csr<bitset_t, index_t>(handle, bitset, csr); | ||
|
||
resource::sync_stream(handle); | ||
loop_on_state(state, [this, &bitset, &csr]() { | ||
raft::sparse::convert::bitset_to_csr<bitset_t, index_t>(handle, bitset, csr); | ||
}); | ||
} | ||
|
||
protected: | ||
const raft::device_resources handle; | ||
|
||
bench_param<index_t> params; | ||
|
||
rmm::device_uvector<bitset_t> bitset_d; | ||
rmm::device_uvector<index_t> indptr_d; | ||
rmm::device_uvector<index_t> indices_d; | ||
rmm::device_uvector<value_t> values_d; | ||
|
||
index_t nnz; | ||
}; // struct BitsetToCsrBench | ||
|
||
template <typename index_t> | ||
const std::vector<bench_param<index_t>> getInputs() | ||
{ | ||
std::vector<bench_param<index_t>> param_vec; | ||
struct TestParams { | ||
index_t m; | ||
index_t n; | ||
float sparsity; | ||
}; | ||
|
||
const std::vector<TestParams> params_group = raft::util::itertools::product<TestParams>( | ||
{index_t(10), index_t(1024)}, {index_t(1024 * 1024)}, {0.99f, 0.9f, 0.8f, 0.5f}); | ||
|
||
param_vec.reserve(params_group.size()); | ||
for (TestParams params : params_group) { | ||
param_vec.push_back(bench_param<index_t>({params.m, params.n, params.sparsity})); | ||
} | ||
return param_vec; | ||
} | ||
|
||
template <typename index_t = int64_t> | ||
const std::vector<bench_param<index_t>> getLargeInputs() | ||
{ | ||
std::vector<bench_param<index_t>> param_vec; | ||
struct TestParams { | ||
index_t m; | ||
index_t n; | ||
float sparsity; | ||
}; | ||
|
||
const std::vector<TestParams> params_group = raft::util::itertools::product<TestParams>( | ||
{index_t(1), index_t(100)}, {index_t(100 * 1000000)}, {0.95f, 0.99f}); | ||
|
||
param_vec.reserve(params_group.size()); | ||
for (TestParams params : params_group) { | ||
param_vec.push_back(bench_param<index_t>({params.m, params.n, params.sparsity})); | ||
} | ||
return param_vec; | ||
} | ||
|
||
RAFT_BENCH_REGISTER((BitsetToCsrBench<uint32_t, int, float>), "", getInputs<int>()); | ||
RAFT_BENCH_REGISTER((BitsetToCsrBench<uint64_t, int, double>), "", getInputs<int>()); | ||
|
||
RAFT_BENCH_REGISTER((BitsetToCsrBench<uint32_t, int64_t, float>), "", getLargeInputs<int64_t>()); | ||
|
||
} // namespace raft::bench::sparse |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to support both bitmap and bitset inputs but it appears we're removing the bitmap support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
bits
is only a naming; the code is needed to be compatible with bitset and bitmap, so I need to changebitmap
tobits
, as the compatible control point is hereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not easy to understand and it’s not obvious which one is selected when true or false (future eyes are going to be confused too) . Let’s create an enum for this that we can share across benchmarks and tests. It’ll make this more straightforward for future eyes too.