diff --git a/cpp/bench/prims/linalg/masked_matmul.cu b/cpp/bench/prims/linalg/masked_matmul.cu index eda9cb1710..b96e14a25d 100644 --- a/cpp/bench/prims/linalg/masked_matmul.cu +++ b/cpp/bench/prims/linalg/masked_matmul.cu @@ -22,7 +22,7 @@ #include #include #include -#include +#include #include #include @@ -49,11 +49,14 @@ inline auto operator<<(std::ostream& os, const MaskedMatmulBenchParams& { os << " m*k*n=" << params.m << "*" << params.k << "*" << params.n << "\tsparsity=" << params.sparsity; - if (params.sparsity == 1.0) { os << "<-inner product for comparison"; } + if (params.sparsity == 0.0) { os << "<-inner product for comparison"; } return os; } -template +template struct MaskedMatmulBench : public fixture { MaskedMatmulBench(const MaskedMatmulBenchParams& p) : fixture(true), @@ -64,15 +67,15 @@ struct MaskedMatmulBench : public fixture { c_indptr_d(0, stream), c_indices_d(0, stream), c_data_d(0, stream), - bitmap_d(0, stream), + bits_d(0, stream), c_dense_data_d(0, stream) { - index_t element = raft::ceildiv(index_t(params.m * params.n), index_t(sizeof(bitmap_t) * 8)); - std::vector bitmap_h(element); + index_t element = raft::ceildiv(index_t(params.m * params.n), index_t(sizeof(bits_t) * 8)); + std::vector bits_h(element); a_data_d.resize(params.m * params.k, stream); b_data_d.resize(params.k * params.n, stream); - bitmap_d.resize(element, stream); + bits_d.resize(element, stream); raft::random::RngState rng(2024ULL); raft::random::uniform( @@ -82,7 +85,13 @@ struct MaskedMatmulBench : public fixture { std::vector c_dense_data_h(params.m * params.n); - c_true_nnz = create_sparse_matrix(params.m, params.n, params.sparsity, bitmap_h); + if constexpr (bitmap_or_bitset) { + c_true_nnz = create_sparse_matrix(params.m, params.n, params.sparsity, bits_h); + } else { + c_true_nnz = create_sparse_matrix(1, params.n, params.sparsity, bits_h); + repeat_cpu_bitset_inplace(bits_h, params.n, params.m - 1); + c_true_nnz *= params.m; + } std::vector values(c_true_nnz); std::vector indices(c_true_nnz); @@ -93,24 +102,49 @@ struct MaskedMatmulBench : public fixture { c_indices_d.resize(c_true_nnz, stream); c_dense_data_d.resize(params.m * params.n, stream); - cpu_convert_to_csr(bitmap_h, params.m, params.n, indices, indptr); + cpu_convert_to_csr(bits_h, params.m, params.n, indices, indptr); RAFT_EXPECTS(c_true_nnz == c_indices_d.size(), "Something wrong. The c_true_nnz != c_indices_d.size()!"); update_device(c_data_d.data(), values.data(), c_true_nnz, stream); update_device(c_indices_d.data(), indices.data(), c_true_nnz, stream); update_device(c_indptr_d.data(), indptr.data(), params.m + 1, stream); - update_device(bitmap_d.data(), bitmap_h.data(), element, stream); + update_device(bits_d.data(), bits_h.data(), element, stream); + } + + void repeat_cpu_bitset_inplace(std::vector& inout, size_t input_bits, size_t repeat) + { + size_t output_bit_index = input_bits; + + for (size_t r = 0; r < repeat; ++r) { + for (size_t i = 0; i < input_bits; ++i) { + size_t input_unit_index = i / (sizeof(bits_t) * 8); + size_t input_bit_offset = i % (sizeof(bits_t) * 8); + bool bit = (inout[input_unit_index] >> input_bit_offset) & 1; + + size_t output_unit_index = output_bit_index / (sizeof(bits_t) * 8); + size_t output_bit_offset = output_bit_index % (sizeof(bits_t) * 8); + + inout[output_unit_index] |= (static_cast(bit) << output_bit_offset); + + ++output_bit_index; + } + } } - index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bitmap) + index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bits) { index_t total = static_cast(m * n); - index_t num_ones = static_cast((total * 1.0f) * sparsity); + index_t num_ones = static_cast((total * 1.0f) * (1.0f - sparsity)); index_t res = num_ones; - for (auto& item : bitmap) { - item = static_cast(0); + if (sparsity == 0.0f) { + std::fill(bits.begin(), bits.end(), 0xffffffff); + return num_ones; + } + + for (auto& item : bits) { + item = static_cast(0); } std::random_device rd; @@ -120,8 +154,8 @@ struct MaskedMatmulBench : public fixture { while (num_ones > 0) { index_t index = dis(gen); - bitmap_t& element = bitmap[index / (8 * sizeof(bitmap_t))]; - index_t bit_position = index % (8 * sizeof(bitmap_t)); + bits_t& element = bits[index / (8 * sizeof(bits_t))]; + index_t bit_position = index % (8 * sizeof(bits_t)); if (((element >> bit_position) & 1) == 0) { element |= (static_cast(1) << bit_position); @@ -131,7 +165,7 @@ struct MaskedMatmulBench : public fixture { return res; } - void cpu_convert_to_csr(std::vector& bitmap, + void cpu_convert_to_csr(std::vector& bits, index_t rows, index_t cols, std::vector& indices, @@ -142,14 +176,14 @@ struct MaskedMatmulBench : public fixture { indptr[offset_indptr++] = 0; index_t index = 0; - bitmap_t element = 0; + bits_t element = 0; index_t bit_position = 0; for (index_t i = 0; i < rows; ++i) { for (index_t j = 0; j < cols; ++j) { index = i * cols + j; - element = bitmap[index / (8 * sizeof(bitmap_t))]; - bit_position = index % (8 * sizeof(bitmap_t)); + element = bits[index / (8 * sizeof(bits_t))]; + bit_position = index % (8 * sizeof(bits_t)); if (((element >> bit_position) & 1)) { indices[offset_values] = static_cast(j); @@ -181,13 +215,17 @@ struct MaskedMatmulBench : public fixture { params.n, static_cast(c_indices_d.size())); - auto mask = - raft::core::bitmap_view(bitmap_d.data(), params.m, params.n); - auto c = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); - if (params.sparsity < 1.0) { - raft::sparse::linalg::masked_matmul(handle, a, b, mask, c); + if (params.sparsity > 0.0) { + if constexpr (bitmap_or_bitset) { + auto mask = + raft::core::bitmap_view(bits_d.data(), params.m, params.n); + raft::sparse::linalg::masked_matmul(handle, a, b, mask, c); + } else { + auto mask = raft::core::bitset_view(bits_d.data(), params.n); + raft::sparse::linalg::masked_matmul(handle, a, b, mask, c); + } } else { raft::distance::pairwise_distance(handle, a_data_d.data(), @@ -201,12 +239,16 @@ struct MaskedMatmulBench : public fixture { } resource::sync_stream(handle); - raft::sparse::linalg::masked_matmul(handle, a, b, mask, c); - resource::sync_stream(handle); - - loop_on_state(state, [this, &a, &b, &mask, &c]() { - if (params.sparsity < 1.0) { - raft::sparse::linalg::masked_matmul(handle, a, b, mask, c); + loop_on_state(state, [this, &a, &b, &c]() { + if (params.sparsity > 0.0) { + if constexpr (bitmap_or_bitset) { + auto mask = + raft::core::bitmap_view(bits_d.data(), params.m, params.n); + raft::sparse::linalg::masked_matmul(handle, a, b, mask, c); + } else { + auto mask = raft::core::bitset_view(bits_d.data(), params.n); + raft::sparse::linalg::masked_matmul(handle, a, b, mask, c); + } } else { raft::distance::pairwise_distance(handle, a_data_d.data(), @@ -228,7 +270,7 @@ struct MaskedMatmulBench : public fixture { rmm::device_uvector a_data_d; rmm::device_uvector b_data_d; - rmm::device_uvector bitmap_d; + rmm::device_uvector bits_d; rmm::device_uvector c_dense_data_d; @@ -253,7 +295,7 @@ static std::vector> getInputs() raft::util::itertools::product({size_t(10), size_t(1024)}, {size_t(128), size_t(1024)}, {size_t(1024 * 1024)}, - {0.01f, 0.1f, 0.2f, 0.5f, 1.0f}); + {0.99f, 0.9f, 0.8f, 0.5f, 0.0f}); param_vec.reserve(params_group.size()); for (TestParams params : params_group) { @@ -263,6 +305,7 @@ static std::vector> getInputs() return param_vec; } -RAFT_BENCH_REGISTER((MaskedMatmulBench), "", getInputs()); +RAFT_BENCH_REGISTER((MaskedMatmulBench), "", getInputs()); +RAFT_BENCH_REGISTER((MaskedMatmulBench), "", getInputs()); } // namespace raft::bench::linalg diff --git a/cpp/bench/prims/sparse/bitset_to_csr.cu b/cpp/bench/prims/sparse/bitset_to_csr.cu new file mode 100644 index 0000000000..fef2d44d3e --- /dev/null +++ b/cpp/bench/prims/sparse/bitset_to_csr.cu @@ -0,0 +1,178 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace raft::bench::sparse { + +template +struct bench_param { + index_t n_repeat; + index_t n_cols; + float sparsity; +}; + +template +inline auto operator<<(std::ostream& os, const bench_param& params) -> std::ostream& +{ + os << " rows*cols=" << params.n_repeat << "*" << params.n_cols + << "\tsparsity=" << params.sparsity; + return os; +} + +template +struct BitsetToCsrBench : public fixture { + BitsetToCsrBench(const bench_param& p) + : fixture(true), + params(p), + handle(stream), + bitset_d(0, stream), + nnz(0), + indptr_d(0, stream), + indices_d(0, stream), + values_d(0, stream) + { + index_t element = raft::ceildiv(1 * params.n_cols, index_t(sizeof(bitset_t) * 8)); + std::vector bitset_h(element); + nnz = create_sparse_matrix(1, params.n_cols, params.sparsity, bitset_h); + + bitset_d.resize(bitset_h.size(), stream); + indptr_d.resize(params.n_repeat + 1, stream); + indices_d.resize(nnz, stream); + values_d.resize(nnz, stream); + + update_device(bitset_d.data(), bitset_h.data(), bitset_h.size(), stream); + + resource::sync_stream(handle); + } + + index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bitset) + { + index_t total = static_cast(m * n); + index_t num_ones = static_cast((total * 1.0f) * (1.0f - sparsity)); + index_t res = num_ones; + + for (auto& item : bitset) { + item = static_cast(0); + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, total - 1); + + while (num_ones > 0) { + index_t index = dis(gen); + + bitset_t& element = bitset[index / (8 * sizeof(bitset_t))]; + index_t bit_position = index % (8 * sizeof(bitset_t)); + + if (((element >> bit_position) & 1) == 0) { + element |= (static_cast(1) << bit_position); + num_ones--; + } + } + return res; + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + auto bitset = raft::core::bitset_view(bitset_d.data(), 1 * params.n_cols); + + auto csr_view = raft::make_device_compressed_structure_view( + indptr_d.data(), indices_d.data(), params.n_repeat, params.n_cols, nnz); + auto csr = raft::make_device_csr_matrix(handle, csr_view); + + raft::sparse::convert::bitset_to_csr(handle, bitset, csr); + + resource::sync_stream(handle); + loop_on_state(state, [this, &bitset, &csr]() { + raft::sparse::convert::bitset_to_csr(handle, bitset, csr); + }); + } + + protected: + const raft::device_resources handle; + + bench_param params; + + rmm::device_uvector bitset_d; + rmm::device_uvector indptr_d; + rmm::device_uvector indices_d; + rmm::device_uvector values_d; + + index_t nnz; +}; // struct BitsetToCsrBench + +template +const std::vector> getInputs() +{ + std::vector> param_vec; + struct TestParams { + index_t m; + index_t n; + float sparsity; + }; + + const std::vector params_group = raft::util::itertools::product( + {index_t(10), index_t(1024)}, {index_t(1024 * 1024)}, {0.99f, 0.9f, 0.8f, 0.5f}); + + param_vec.reserve(params_group.size()); + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.sparsity})); + } + return param_vec; +} + +template +const std::vector> getLargeInputs() +{ + std::vector> param_vec; + struct TestParams { + index_t m; + index_t n; + float sparsity; + }; + + const std::vector params_group = raft::util::itertools::product( + {index_t(1), index_t(100)}, {index_t(100 * 1000000)}, {0.95f, 0.99f}); + + param_vec.reserve(params_group.size()); + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.sparsity})); + } + return param_vec; +} + +RAFT_BENCH_REGISTER((BitsetToCsrBench), "", getInputs()); +RAFT_BENCH_REGISTER((BitsetToCsrBench), "", getInputs()); + +RAFT_BENCH_REGISTER((BitsetToCsrBench), "", getLargeInputs()); + +} // namespace raft::bench::sparse diff --git a/cpp/include/raft/sparse/convert/csr.cuh b/cpp/include/raft/sparse/convert/csr.cuh index 081192ed44..5237edd383 100644 --- a/cpp/include/raft/sparse/convert/csr.cuh +++ b/cpp/include/raft/sparse/convert/csr.cuh @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -129,6 +130,80 @@ void bitmap_to_csr(raft::resources const& handle, detail::bitmap_to_csr(handle, bitmap, csr); } +/** + * @brief Converts a bitset matrix to a Compressed Sparse Row (CSR) format matrix. + * + * The bitset format inherently supports only a single-row matrix (rows=1). If the CSR matrix + * requires multiple rows, the data from the bitset will be repeated for each row in the output. + * + * Example usage: + * + * @code{.cpp} + * #include + * #include + * #include + * + * #include + * + * using bitset_t = uint32_t; + * using index_t = int; + * using value_t = float; + * using nnz_t = index_t; + * + * raft::resources handle; + * auto stream = resource::get_cuda_stream(handle); + * index_t n_rows = 3; + * index_t n_cols = 30; + * + * nnz_t nnz_for_bitset = 4; + * nnz_t nnz_for_csr = nnz_for_bitset * n_rows; + * + * index_t bitset_size = (n_cols + sizeof(bitset_t) * 8 - 1) / (sizeof(bitset_t) * 8); // = 1 + * + * rmm::device_uvector bitset_d(bitset_size, stream); + * std::vector bitset_h = { + * bitset_t(0b11001010), + * }; // nnz_for_bitset = 4; + * + * raft::copy(bitset_d.data(), bitset_h.data(), bitset_h.size(), stream); + * + * auto bitset_view = raft::core::bitset_view(bitset_d.data(), n_cols); + * auto csr = raft::make_device_csr_matrix(handle, n_rows, n_cols, nnz_for_csr); + * + * raft::sparse::convert::bitset_to_csr(handle, bitset_view, csr); + * resource::sync_stream(handle); + * + * // Results: + * // csr.indptr = [0, 4, 8, 12]; + * // csr.indices = [1, 3, 6, 7, + * // 1, 3, 6, 7, + * // 1, 3, 6, 7]; + * // csr.values = [1, 1, 1, 1, + * // 1, 1, 1, 1, + * // 1, 1, 1, 1]; + * @endcode + * + * @tparam bitset_t The data type of the elements in the bitset matrix. + * @tparam index_t The data type used for indexing the elements in the matrices. + * @tparam csr_matrix_t Specifies the CSR matrix type, constrained to + * raft::device_csr_matrix. + * + * @param[in] handle The RAFT handle containing the CUDA stream for operations. + * @param[in] bitset The bitset matrix view, to be converted to CSR format. + * @param[out] csr Output parameter where the resulting CSR matrix is stored. In the + * bitset, each '1' bit corresponds to a non-zero element in the CSR matrix. + */ +template >> +void bitset_to_csr(raft::resources const& handle, + raft::core::bitset_view bitset, + csr_matrix_t& csr) +{ + detail::bitset_to_csr(handle, bitset, csr); +} + }; // end NAMESPACE convert }; // end NAMESPACE sparse }; // end NAMESPACE raft diff --git a/cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh new file mode 100644 index 0000000000..72abd02f7e --- /dev/null +++ b/cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // detail::popc +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace raft { +namespace sparse { +namespace convert { +namespace detail { + +template +__global__ void repeat_csr_kernel(const index_t* indptr, + const index_t* indices, + index_t* repeated_indptr, + index_t* repeated_indices, + nnz_t nnz, + index_t repeat_count) +{ + int global_id = blockIdx.x * blockDim.x + threadIdx.x; + bool guard = global_id < nnz; + index_t* repeated_indices_addr = repeated_indices + global_id; + + for (index_t i = global_id; i < repeat_count; i += gridDim.x * blockDim.x) { + repeated_indptr[i] = (i + 2) * nnz; + } + + __syncthreads(); + + int block_offset = blockIdx.x * blockDim.x; + + index_t item; + int idx = block_offset + threadIdx.x; + item = (idx < nnz) ? indices[idx] : -1; + + __syncthreads(); + + for (index_t row = 0; row < repeat_count; ++row) { + index_t start_offset = row * nnz; + if (guard) { repeated_indices_addr[start_offset] = item; } + } +} + +template +void gpu_repeat_csr(raft::resources const& handle, + const index_t* d_indptr, + const index_t* d_indices, + nnz_t nnz, + index_t repeat_count, + index_t* d_repeated_indptr, + index_t* d_repeated_indices) +{ + auto stream = resource::get_cuda_stream(handle); + index_t repeat_csr_tpb = 256; + index_t grid = (nnz + repeat_csr_tpb - 1) / (repeat_csr_tpb); + + repeat_csr_kernel<<>>( + d_indptr, d_indices, d_repeated_indptr, d_repeated_indices, nnz, repeat_count); +} + +template >> +void bitset_to_csr(raft::resources const& handle, + raft::core::bitset_view bitset, + csr_matrix_t& csr) +{ + using row_t = typename csr_matrix_t::row_type; + using nnz_t = typename csr_matrix_t::nnz_type; + + auto csr_view = csr.structure_view(); + + if (csr_view.get_n_rows() == 0 || csr_view.get_n_cols() == 0 || csr_view.get_nnz() == 0) { + return; + } + + RAFT_EXPECTS(bitset.size() == csr_view.get_n_cols(), + "Number of size in bitset must be equal to " + "number of columns in csr"); + + auto thrust_policy = resource::get_thrust_policy(handle); + auto stream = resource::get_cuda_stream(handle); + + index_t* indptr = csr_view.get_indptr().data(); + index_t* indices = csr_view.get_indices().data(); + + RAFT_CUDA_TRY(cudaMemsetAsync(indptr, 0, (csr_view.get_n_rows() + 1) * sizeof(index_t), stream)); + + size_t sub_nnz_size = 0; + index_t bits_per_sub_col = 0; + + // Get buffer size and number of bits per each sub-columns + calc_nnz_by_rows(handle, + bitset.data(), + row_t(1), + csr_view.get_n_cols(), + static_cast(nullptr), + sub_nnz_size, + bits_per_sub_col); + + rmm::device_async_resource_ref device_memory = resource::get_workspace_resource(handle); + rmm::device_uvector sub_nnz(sub_nnz_size + 1, stream, device_memory); + + calc_nnz_by_rows(handle, + bitset.data(), + row_t(1), + csr_view.get_n_cols(), + sub_nnz.data(), + sub_nnz_size, + bits_per_sub_col); + + thrust::exclusive_scan( + thrust_policy, sub_nnz.data(), sub_nnz.data() + sub_nnz_size + 1, sub_nnz.data()); + + index_t bitset_nnz = 0; + if constexpr (is_device_csr_sparsity_owning_v) { + RAFT_CUDA_TRY(cudaMemcpyAsync( + &bitset_nnz, sub_nnz.data() + sub_nnz_size, sizeof(index_t), cudaMemcpyDeviceToHost, stream)); + resource::sync_stream(handle); + csr.initialize_sparsity(bitset_nnz * csr_view.get_n_rows()); + } else { + bitset_nnz = csr_view.get_nnz() / csr_view.get_n_rows(); + } + constexpr bool check_nnz = is_device_csr_sparsity_preserving_v; + fill_indices_by_rows(handle, + bitset.data(), + indptr, + 1, + csr_view.get_n_cols(), + csr_view.get_nnz(), + indices, + sub_nnz.data(), + bits_per_sub_col, + sub_nnz_size); + if (csr_view.get_n_rows() > 1) { + gpu_repeat_csr(handle, + indptr, + indices, + bitset_nnz, + csr_view.get_n_rows() - 1, + indptr + 2, + indices + bitset_nnz); + } + + thrust::fill_n(thrust_policy, + csr.get_elements().data(), + csr_view.get_nnz(), + typename csr_matrix_t::element_type(1)); +} + +}; // end NAMESPACE detail +}; // end NAMESPACE convert +}; // end NAMESPACE sparse +}; // end NAMESPACE raft diff --git a/cpp/include/raft/sparse/linalg/detail/masked_matmul.cuh b/cpp/include/raft/sparse/linalg/detail/masked_matmul.cuh index 276960628d..0364daff83 100644 --- a/cpp/include/raft/sparse/linalg/detail/masked_matmul.cuh +++ b/cpp/include/raft/sparse/linalg/detail/masked_matmul.cuh @@ -16,6 +16,7 @@ #pragma once #include +#include #include #include #include @@ -100,6 +101,69 @@ void masked_matmul(raft::resources const& handle, } } +template +void masked_matmul(raft::resources const& handle, + raft::device_matrix_view& A, + raft::device_matrix_view& B, + raft::core::bitset_view& mask, + raft::device_csr_matrix_view& C, + std::optional> alpha, + std::optional> beta) +{ + index_t m = A.extent(0); + index_t n = B.extent(0); + index_t dim = A.extent(1); + + auto compressed_C_view = C.structure_view(); + + RAFT_EXPECTS(A.extent(1) == B.extent(1), "The dim of A must be equal to the dim of B."); + RAFT_EXPECTS(A.extent(0) == compressed_C_view.get_n_rows(), + "Number of rows in C must match the number of rows in A."); + RAFT_EXPECTS(B.extent(0) == compressed_C_view.get_n_cols(), + "Number of columns in C must match the number of columns in B."); + + auto stream = raft::resource::get_cuda_stream(handle); + + auto C_matrix = raft::make_device_csr_matrix(handle, compressed_C_view); + + // fill C + raft::sparse::convert::bitset_to_csr(handle, mask, C_matrix); + + if (m > 10 || alpha.has_value() || beta.has_value()) { + auto C_view = raft::make_device_csr_matrix_view( + C.get_elements().data(), compressed_C_view); + + // create B col_major view + auto B_col_major = raft::make_device_matrix_view( + B.data_handle(), dim, n); + + output_t default_alpha = static_cast(1.0f); + output_t default_beta = static_cast(0.0f); + + if (!alpha.has_value()) { alpha = raft::make_host_scalar_view(&default_alpha); } + if (!beta.has_value()) { beta = raft::make_host_scalar_view(&default_beta); } + + raft::sparse::linalg::sddmm(handle, + A, + B_col_major, + C_view, + raft::linalg::Operation::NON_TRANSPOSE, + raft::linalg::Operation::NON_TRANSPOSE, + *alpha, + *beta); + } else { + raft::sparse::distance::detail::faster_dot_on_csr(handle, + C.get_elements().data(), + compressed_C_view.get_nnz(), + compressed_C_view.get_indptr().data(), + compressed_C_view.get_indices().data(), + A.data_handle(), + B.data_handle(), + compressed_C_view.get_n_rows(), + dim); + } +} + } // namespace detail } // namespace linalg } // namespace sparse diff --git a/cpp/include/raft/sparse/linalg/masked_matmul.cuh b/cpp/include/raft/sparse/linalg/masked_matmul.cuh new file mode 100644 index 0000000000..288068dae2 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/masked_matmul.cuh @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain A copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace raft { +namespace sparse { +namespace linalg { + +/** + * @defgroup masked_matmul Masked Matrix Multiplication + * @{ + */ + +/** + * @brief Performs a masked multiplication of dense matrices A and B, followed by an element-wise + * multiplication with the sparsity pattern defined by the mask, resulting in the computation + * C = alpha * ((A * B) ∘ spy(mask)) + beta * C. + * + * This function multiplies two dense matrices A and B, and then applies an element-wise + * multiplication using the sparsity pattern provided by the mask. The result is scaled by alpha + * and added to beta times the original matrix C. + * + * @tparam value_t Data type of elements in the input matrices (e.g., half, float, double) + * @tparam output_t Data type of elements in the output matrices (e.g., float, double) + * @tparam index_t Type used for matrix indices + * @tparam nnz_t Type used for the number of non-zero entries in CSR format + * @tparam bitmap_t Type of the bitmap used for the mask + * + * @param[in] handle RAFT handle for resource management + * @param[in] A Input dense matrix (device_matrix_view) with shape [m, k] + * @param[in] B Input dense matrix (device_matrix_view) with shape [n, k] + * @param[in] mask Bitmap view representing the sparsity pattern (bitmap_view) with logical shape + * [m, n]. Each bit in the mask indicates whether the corresponding element pair in A and B is + * included (1) or masked out (0). + * @param[inout] C Output sparse matrix in CSR format (device_csr_matrix_view) with dense shape [m, + * n] + * @param[in] alpha Optional scalar multiplier for the product of A and B (default: 1.0 if + * std::nullopt) + * @param[in] beta Optional scalar multiplier for the original matrix C (default: 0 if std::nullopt) + */ +template +void masked_matmul(raft::resources const& handle, + raft::device_matrix_view A, + raft::device_matrix_view B, + raft::core::bitmap_view mask, + raft::device_csr_matrix_view C, + std::optional> alpha = std::nullopt, + std::optional> beta = std::nullopt) +{ + detail::masked_matmul(handle, A, B, mask, C, alpha, beta); +} + +/** + * @brief Computes a sparse matrix product with a masked sparsity pattern and scaling. + * + * This function computes the result of: + * C = alpha * ((A * B) ∘ spy(mask)) + beta * C + * where: + * - A and B are dense input matrices. + * - "mask" defines the sparsity pattern for element-wise multiplication. + * - The result is scaled by alpha and added to beta times the original C. + * + * **Special behavior of the mask**: + * - The `bitset` mask represents a single row of data, with its bits indicating whether + * each corresponding element in (A * B) is included (1) or masked out (0). + * - If the output CSR matrix `C` has multiple rows, the `bitset` is logically repeated + * across all rows of `C`. For example, if `C` has `n_rows` rows, the same `bitset` + * pattern is applied to all rows. + * + * @tparam value_t Data type of input matrix elements (e.g., half, float, double). + * @tparam output_t Data type of output matrix elements (e.g., float, double). + * @tparam index_t Type for matrix indices. + * @tparam nnz_t Type for non-zero entries in CSR format. + * @tparam bitmap_t Type for the bitmap mask. + * + * @param[in] handle RAFT handle for managing resources. + * @param[in] A Dense input matrix [m, k] (row-major). + * @param[in] B Dense input matrix [n, k] (row-major). + * @param[in] mask Bitmap view representing a single row [1, n], where each bit + * indicates if the corresponding element in (A * B) is included (1) + * or masked out (0). The pattern is repeated for all rows of `C`. + * @param[inout] C Output sparse matrix in CSR format [m, n]. + * @param[in] alpha Scalar multiplier for (A * B) (default: 1.0 if std::nullopt). + * @param[in] beta Scalar multiplier for the initial C (default: 0 if std::nullopt). + */ +template +void masked_matmul(raft::resources const& handle, + raft::device_matrix_view A, + raft::device_matrix_view B, + raft::core::bitset_view mask, + raft::device_csr_matrix_view C, + std::optional> alpha = std::nullopt, + std::optional> beta = std::nullopt) +{ + detail::masked_matmul(handle, A, B, mask, C, alpha, beta); +} + +/** @} */ // end of masked_matmul + +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/masked_matmul.hpp b/cpp/include/raft/sparse/linalg/masked_matmul.hpp index 6cf6e834b9..32322b90f6 100644 --- a/cpp/include/raft/sparse/linalg/masked_matmul.hpp +++ b/cpp/include/raft/sparse/linalg/masked_matmul.hpp @@ -13,60 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#pragma once - -#include - -namespace raft { -namespace sparse { -namespace linalg { - /** - * @defgroup masked_matmul Masked Matrix Multiplication - * @{ + * This file is deprecated and will be removed in future release. + * Please use the cuh version instead. */ /** - * @brief Performs a masked multiplication of dense matrices A and B, followed by an element-wise - * multiplication with the sparsity pattern defined by the mask, resulting in the computation - * C = alpha * ((A * B) ∘ spy(mask)) + beta * C. - * - * This function multiplies two dense matrices A and B, and then applies an element-wise - * multiplication using the sparsity pattern provided by the mask. The result is scaled by alpha - * and added to beta times the original matrix C. - * - * @tparam value_t Data type of elements in the input matrices (e.g., half, float, double) - * @tparam output_t Data type of elements in the output matrices (e.g., float, double) - * @tparam index_t Type used for matrix indices - * @tparam nnz_t Type used for the number of non-zero entries in CSR format - * @tparam bitmap_t Type of the bitmap used for the mask - * - * @param[in] handle RAFT handle for resource management - * @param[in] A Input dense matrix (device_matrix_view) with shape [m, k] - * @param[in] B Input dense matrix (device_matrix_view) with shape [n, k] - * @param[in] mask Bitmap view representing the sparsity pattern (bitmap_view) with logical shape - * [m, n]. Each bit in the mask indicates whether the corresponding element pair in A and B is - * included (1) or masked out (0). - * @param[inout] C Output sparse matrix in CSR format (device_csr_matrix_view) with dense shape [m, - * n] - * @param[in] alpha Optional scalar multiplier for the product of A and B (default: 1.0 if - * std::nullopt) - * @param[in] beta Optional scalar multiplier for the original matrix C (default: 0 if std::nullopt) + * DISCLAIMER: this file is deprecated: use masked_matmul.cuh instead */ -template -void masked_matmul(raft::resources const& handle, - raft::device_matrix_view A, - raft::device_matrix_view B, - raft::core::bitmap_view mask, - raft::device_csr_matrix_view C, - std::optional> alpha = std::nullopt, - std::optional> beta = std::nullopt) -{ - detail::masked_matmul(handle, A, B, mask, C, alpha, beta); -} -/** @} */ // end of masked_matmul +#pragma once + +#ifndef RAFT_HIDE_DEPRECATION_WARNINGS +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use the cuh version instead.") +#endif -} // end namespace linalg -} // end namespace sparse -} // end namespace raft +#include diff --git a/cpp/test/sparse/convert_csr.cu b/cpp/test/sparse/convert_csr.cu index c1a495ea3d..eed9262a17 100644 --- a/cpp/test/sparse/convert_csr.cu +++ b/cpp/test/sparse/convert_csr.cu @@ -17,6 +17,7 @@ #include "../test_utils.cuh" #include +#include #include #include #include @@ -477,5 +478,289 @@ INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, BitmapToCSRTestLOnLargeSize, ::testing::ValuesIn(bitmaptocsr_large_inputs)); +/******************************** bitset to csr ********************************/ + +template +struct BitsetToCSRInputs { + index_t n_repeat; + index_t n_cols; + float sparsity; + bool owning; +}; + +template +class BitsetToCSRTest : public ::testing::TestWithParam> { + public: + BitsetToCSRTest() + : stream(resource::get_cuda_stream(handle)), + params(::testing::TestWithParam>::GetParam()), + bitset_d(0, stream), + indices_d(0, stream), + indptr_d(0, stream), + values_d(0, stream), + indptr_expected_d(0, stream), + indices_expected_d(0, stream), + values_expected_d(0, stream) + { + } + + protected: + void repeat_cpu_bitset(std::vector& input, + size_t input_bits, + size_t repeat, + std::vector& output) + { + const size_t output_bits = input_bits * repeat; + const size_t output_units = (output_bits + sizeof(bitset_t) * 8 - 1) / (sizeof(bitset_t) * 8); + + std::memset(output.data(), 0, output_units * sizeof(bitset_t)); + + size_t output_bit_index = 0; + + for (size_t r = 0; r < repeat; ++r) { + for (size_t i = 0; i < input_bits; ++i) { + size_t input_unit_index = i / (sizeof(bitset_t) * 8); + size_t input_bit_offset = i % (sizeof(bitset_t) * 8); + bool bit = (input[input_unit_index] >> input_bit_offset) & 1; + + size_t output_unit_index = output_bit_index / (sizeof(bitset_t) * 8); + size_t output_bit_offset = output_bit_index % (sizeof(bitset_t) * 8); + + output[output_unit_index] |= (static_cast(bit) << output_bit_offset); + + ++output_bit_index; + } + } + } + + index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bitset) + { + index_t total = static_cast(m * n); + index_t num_ones = static_cast((total * 1.0f) * sparsity); + index_t res = num_ones; + + for (auto& item : bitset) { + item = static_cast(0); + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, total - 1); + + while (num_ones > 0) { + index_t index = dis(gen); + + bitset_t& element = bitset[index / (8 * sizeof(bitset_t))]; + index_t bit_position = index % (8 * sizeof(bitset_t)); + + if (((element >> bit_position) & 1) == 0) { + element |= (static_cast(1) << bit_position); + num_ones--; + } + } + return res; + } + + void cpu_convert_to_csr(std::vector& bitset, + index_t rows, + index_t cols, + std::vector& indices, + std::vector& indptr) + { + index_t offset_indptr = 0; + index_t offset_values = 0; + indptr[offset_indptr++] = 0; + + index_t index = 0; + bitset_t element = 0; + index_t bit_position = 0; + + for (index_t i = 0; i < rows; ++i) { + for (index_t j = 0; j < cols; ++j) { + index = i * cols + j; + element = bitset[index / (8 * sizeof(bitset_t))]; + bit_position = index % (8 * sizeof(bitset_t)); + + if (((element >> bit_position) & 1)) { + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + bool csr_compare(const std::vector& row_ptrs1, + const std::vector& col_indices1, + const std::vector& row_ptrs2, + const std::vector& col_indices2) + { + if (row_ptrs1.size() != row_ptrs2.size()) { return false; } + + if (col_indices1.size() != col_indices2.size()) { return false; } + + if (!std::equal(row_ptrs1.begin(), row_ptrs1.end(), row_ptrs2.begin())) { return false; } + + for (size_t i = 0; i < row_ptrs1.size() - 1; ++i) { + size_t start_idx = row_ptrs1[i]; + size_t end_idx = row_ptrs1[i + 1]; + + std::vector cols1(col_indices1.begin() + start_idx, col_indices1.begin() + end_idx); + std::vector cols2(col_indices2.begin() + start_idx, col_indices2.begin() + end_idx); + + std::sort(cols1.begin(), cols1.end()); + std::sort(cols2.begin(), cols2.end()); + + if (cols1 != cols2) { return false; } + } + + return true; + } + + void SetUp() override + { + index_t element = raft::ceildiv(1 * params.n_cols, index_t(sizeof(bitset_t) * 8)); + std::vector bitset_h(element); + std::vector bitset_repeat_h(element * params.n_repeat); + + nnz = create_sparse_matrix(1, params.n_cols, params.sparsity, bitset_h); + + repeat_cpu_bitset(bitset_h, size_t(params.n_cols), size_t(params.n_repeat), bitset_repeat_h); + nnz *= params.n_repeat; + + std::vector indices_h(nnz); + std::vector indptr_h(params.n_repeat + 1); + + cpu_convert_to_csr(bitset_repeat_h, params.n_repeat, params.n_cols, indices_h, indptr_h); + + bitset_d.resize(bitset_h.size(), stream); + indptr_d.resize(params.n_repeat + 1, stream); + indices_d.resize(nnz, stream); + + indptr_expected_d.resize(params.n_repeat + 1, stream); + indices_expected_d.resize(nnz, stream); + values_expected_d.resize(nnz, stream); + + thrust::fill_n(resource::get_thrust_policy(handle), values_expected_d.data(), nnz, value_t{1}); + + values_d.resize(nnz, stream); + + update_device(indices_expected_d.data(), indices_h.data(), indices_h.size(), stream); + update_device(indptr_expected_d.data(), indptr_h.data(), indptr_h.size(), stream); + update_device(bitset_d.data(), bitset_h.data(), bitset_h.size(), stream); + + resource::sync_stream(handle); + } + + void Run() + { + auto bitset = raft::core::bitset_view(bitset_d.data(), params.n_cols); + + if (params.owning) { + auto csr = + raft::make_device_csr_matrix(handle, params.n_repeat, params.n_cols, nnz); + auto csr_view = csr.structure_view(); + + convert::bitset_to_csr(handle, bitset, csr); + raft::copy(indptr_d.data(), csr_view.get_indptr().data(), indptr_d.size(), stream); + raft::copy(indices_d.data(), csr_view.get_indices().data(), indices_d.size(), stream); + raft::copy(values_d.data(), csr.get_elements().data(), nnz, stream); + } else { + auto csr_view = raft::make_device_compressed_structure_view( + indptr_d.data(), indices_d.data(), params.n_repeat, params.n_cols, nnz); + auto csr = raft::make_device_csr_matrix(handle, csr_view); + + convert::bitset_to_csr(handle, bitset, csr); + raft::copy(values_d.data(), csr.get_elements().data(), nnz, stream); + } + resource::sync_stream(handle); + + std::vector indices_h(indices_expected_d.size(), 0); + std::vector indices_expected_h(indices_expected_d.size(), 0); + update_host(indices_h.data(), indices_d.data(), indices_h.size(), stream); + update_host(indices_expected_h.data(), indices_expected_d.data(), indices_h.size(), stream); + + std::vector indptr_h(indptr_expected_d.size(), 0); + std::vector indptr_expected_h(indptr_expected_d.size(), 0); + update_host(indptr_h.data(), indptr_d.data(), indptr_h.size(), stream); + update_host(indptr_expected_h.data(), indptr_expected_d.data(), indptr_h.size(), stream); + + resource::sync_stream(handle); + + ASSERT_TRUE(csr_compare(indptr_h, indices_h, indptr_expected_h, indices_expected_h)); + ASSERT_TRUE(raft::devArrMatch( + values_expected_d.data(), values_d.data(), nnz, raft::Compare(), stream)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + + BitsetToCSRInputs params; + + rmm::device_uvector bitset_d; + + index_t nnz; + + rmm::device_uvector indptr_d; + rmm::device_uvector indices_d; + rmm::device_uvector values_d; + + rmm::device_uvector indptr_expected_d; + rmm::device_uvector indices_expected_d; + rmm::device_uvector values_expected_d; +}; + +using BitsetToCSRTestI = BitsetToCSRTest; +TEST_P(BitsetToCSRTestI, Result) { Run(); } + +using BitsetToCSRTestL = BitsetToCSRTest; +TEST_P(BitsetToCSRTestL, Result) { Run(); } + +using BitsetToCSRTestLOnLargeSize = BitsetToCSRTest; +TEST_P(BitsetToCSRTestLOnLargeSize, Result) { Run(); } + +template +const std::vector> bitsettocsr_inputs = { + {0, 0, 0.2, false}, + {10, 32, 0.4, false}, + {10, 3, 0.2, false}, + {32, 1024, 0.4, false}, + {1024, 1048576, 0.01, false}, + {1024, 1024, 0.4, false}, + {64 * 1024 + 10, 2, 0.3, false}, // 64K + 10 is slightly over maximum of blockDim.y + {16, 16, 0.3, false}, // No peeling-remainder + {17, 16, 0.3, false}, // Check peeling-remainder + {18, 16, 0.3, false}, // Check peeling-remainder + {32 + 9, 33, 0.2, false}, // Check peeling-remainder + {2, 33, 0.2, false}, // Check peeling-remainder + {0, 0, 0.2, true}, + {10, 32, 0.4, true}, + {10, 3, 0.2, true}, + {32, 1024, 0.4, true}, + {1024, 1048576, 0.01, true}, + {1024, 1024, 0.4, true}, + {64 * 1024 + 10, 2, 0.3, true}, // 64K + 10 is slightly over maximum of blockDim.y + {16, 16, 0.3, true}, // No peeling-remainder + {17, 16, 0.3, true}, // Check peeling-remainder + {18, 16, 0.3, true}, // Check peeling-remainder + {32 + 9, 33, 0.2, true}, // Check peeling-remainder + {2, 33, 0.2, true}, // Check peeling-remainder +}; + +template +const std::vector> bitsettocsr_large_inputs = { + {100, 100000000, 0.01, true}, {100, 100000000, 0.05, false}, {100, 100000000 + 17, 0.05, false}}; + +INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, + BitsetToCSRTestI, + ::testing::ValuesIn(bitsettocsr_inputs)); +INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, + BitsetToCSRTestL, + ::testing::ValuesIn(bitsettocsr_inputs)); +INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, + BitsetToCSRTestLOnLargeSize, + ::testing::ValuesIn(bitsettocsr_large_inputs)); + } // namespace sparse } // namespace raft diff --git a/cpp/test/sparse/masked_matmul.cu b/cpp/test/sparse/masked_matmul.cu index f883beae32..5ee1677015 100644 --- a/cpp/test/sparse/masked_matmul.cu +++ b/cpp/test/sparse/masked_matmul.cu @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include @@ -46,6 +46,8 @@ struct MaskedMatmulInputs { unsigned long long int seed; }; +enum class BitsLayout { Bitset, Bitmap }; + template struct sum_abs_op { __host__ __device__ value_t operator()(const value_t& x, const value_t& y) const @@ -87,7 +89,8 @@ bool isCuSparseVersionGreaterThan_12_0_1() template class MaskedMatmulTest @@ -98,7 +101,7 @@ class MaskedMatmulTest stream(resource::get_cuda_stream(handle)), a_data_d(0, resource::get_cuda_stream(handle)), b_data_d(0, resource::get_cuda_stream(handle)), - bitmap_d(0, resource::get_cuda_stream(handle)), + bits_d(0, resource::get_cuda_stream(handle)), c_indptr_d(0, resource::get_cuda_stream(handle)), c_indices_d(0, resource::get_cuda_stream(handle)), c_data_d(0, resource::get_cuda_stream(handle)), @@ -107,14 +110,14 @@ class MaskedMatmulTest } protected: - index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bitmap) + index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bits) { index_t total = static_cast(m * n); index_t num_ones = static_cast((total * 1.0f) * sparsity); index_t res = num_ones; - for (auto& item : bitmap) { - item = static_cast(0); + for (auto& item : bits) { + item = static_cast(0); } std::random_device rd; @@ -124,8 +127,8 @@ class MaskedMatmulTest while (num_ones > 0) { index_t index = dis(gen); - bitmap_t& element = bitmap[index / (8 * sizeof(bitmap_t))]; - index_t bit_position = index % (8 * sizeof(bitmap_t)); + bits_t& element = bits[index / (8 * sizeof(bits_t))]; + index_t bit_position = index % (8 * sizeof(bits_t)); if (((element >> bit_position) & 1) == 0) { element |= (static_cast(1) << bit_position); @@ -135,7 +138,27 @@ class MaskedMatmulTest return res; } - void cpu_convert_to_csr(std::vector& bitmap, + void repeat_cpu_bitset_inplace(std::vector& inout, size_t input_bits, size_t repeat) + { + size_t output_bit_index = input_bits; + + for (size_t r = 0; r < repeat; ++r) { + for (size_t i = 0; i < input_bits; ++i) { + size_t input_unit_index = i / (sizeof(bits_t) * 8); + size_t input_bit_offset = i % (sizeof(bits_t) * 8); + bool bit = (inout[input_unit_index] >> input_bit_offset) & 1; + + size_t output_unit_index = output_bit_index / (sizeof(bits_t) * 8); + size_t output_bit_offset = output_bit_index % (sizeof(bits_t) * 8); + + inout[output_unit_index] |= (static_cast(bit) << output_bit_offset); + + ++output_bit_index; + } + } + } + + void cpu_convert_to_csr(std::vector& bits, index_t rows, index_t cols, std::vector& indices, @@ -146,14 +169,14 @@ class MaskedMatmulTest indptr[offset_indptr++] = 0; index_t index = 0; - bitmap_t element = 0; + bits_t element = 0; index_t bit_position = 0; for (index_t i = 0; i < rows; ++i) { for (index_t j = 0; j < cols; ++j) { index = i * cols + j; - element = bitmap[index / (8 * sizeof(bitmap_t))]; - bit_position = index % (8 * sizeof(bitmap_t)); + element = bits[index / (8 * sizeof(bits_t))]; + bit_position = index % (8 * sizeof(bits_t)); if (((element >> bit_position) & 1)) { indices[offset_values] = static_cast(j); @@ -201,15 +224,17 @@ class MaskedMatmulTest index_t b_size = params.k * params.n; index_t c_size = params.m * params.n; - index_t element = raft::ceildiv(params.m * params.n, index_t(sizeof(bitmap_t) * 8)); - std::vector bitmap_h(element); + index_t element = raft::ceildiv(params.m * params.n, index_t(sizeof(bits_t) * 8)); + std::vector bits_h(element); + + std::memset(bits_h.data(), 0, bits_h.size() * sizeof(bits_t)); std::vector a_data_h(a_size); std::vector b_data_h(b_size); a_data_d.resize(a_size, stream); b_data_d.resize(b_size, stream); - bitmap_d.resize(bitmap_h.size(), stream); + bits_d.resize(bits_h.size(), stream); auto blobs_a_b = raft::make_device_matrix(handle, 1, a_size + b_size); auto labels = raft::make_device_vector(handle, 1); @@ -262,18 +287,27 @@ class MaskedMatmulTest resource::sync_stream(handle); - index_t c_true_nnz = create_sparse_matrix(params.m, params.n, params.sparsity, bitmap_h); + index_t c_true_nnz = 0; + if constexpr (bits_layout == BitsLayout::Bitmap) { + c_true_nnz = create_sparse_matrix(params.m, params.n, params.sparsity, bits_h); + } else if constexpr (bits_layout == BitsLayout::Bitset) { + c_true_nnz = create_sparse_matrix(1, params.n, params.sparsity, bits_h); + repeat_cpu_bitset_inplace(bits_h, params.n, params.m - 1); + c_true_nnz *= params.m; + } else { + GTEST_SKIP() << "Unsupported BitsLayout!"; + } std::vector c_indptr_h(params.m + 1); std::vector c_indices_h(c_true_nnz); std::vector c_data_h(c_true_nnz); - cpu_convert_to_csr(bitmap_h, params.m, params.n, c_indices_h, c_indptr_h); + cpu_convert_to_csr(bits_h, params.m, params.n, c_indices_h, c_indptr_h); c_data_d.resize(c_data_h.size(), stream); update_device(c_data_d.data(), c_data_h.data(), c_data_h.size(), stream); - update_device(bitmap_d.data(), bitmap_h.data(), bitmap_h.size(), stream); + update_device(bits_d.data(), bits_h.data(), bits_h.size(), stream); resource::sync_stream(handle); cpu_sddmm(a_data_h, b_data_h, c_data_h, c_indices_h, c_indptr_h, true, true); @@ -304,9 +338,6 @@ class MaskedMatmulTest auto B = raft::make_device_matrix_view(b_data_d.data(), params.n, params.k); - auto mask = - raft::core::bitmap_view(bitmap_d.data(), params.m, params.n); - auto c_structure = raft::make_device_compressed_structure_view( c_indptr_d.data(), c_indices_d.data(), @@ -316,7 +347,15 @@ class MaskedMatmulTest auto C = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); - raft::sparse::linalg::masked_matmul(handle, A, B, mask, C); + if constexpr (bits_layout == BitsLayout::Bitmap) { + auto mask = raft::core::bitmap_view(bits_d.data(), params.m, params.n); + raft::sparse::linalg::masked_matmul(handle, A, B, mask, C); + } else if constexpr (bits_layout == BitsLayout::Bitset) { + auto mask = raft::core::bitset_view(bits_d.data(), params.n); + raft::sparse::linalg::masked_matmul(handle, A, B, mask, C); + } else { + GTEST_SKIP() << "Unsupported BitsLayout!"; + } resource::sync_stream(handle); @@ -344,7 +383,7 @@ class MaskedMatmulTest rmm::device_uvector a_data_d; rmm::device_uvector b_data_d; - rmm::device_uvector bitmap_d; + rmm::device_uvector bits_d; rmm::device_uvector c_indptr_d; rmm::device_uvector c_indices_d; @@ -353,14 +392,23 @@ class MaskedMatmulTest rmm::device_uvector c_expected_data_d; }; -using MaskedMatmulTestF = MaskedMatmulTest; -TEST_P(MaskedMatmulTestF, Result) { Run(); } +using MaskedMatmulOnBitmapTestF = MaskedMatmulTest; +TEST_P(MaskedMatmulOnBitmapTestF, Result) { Run(); } + +using MaskedMatmulOnBitmapTestD = MaskedMatmulTest; +TEST_P(MaskedMatmulOnBitmapTestD, Result) { Run(); } -using MaskedMatmulTestD = MaskedMatmulTest; -TEST_P(MaskedMatmulTestD, Result) { Run(); } +using MaskedMatmulOnBitmapTestH = MaskedMatmulTest; +TEST_P(MaskedMatmulOnBitmapTestH, Result) { Run(); } -using MaskedMatmulTestH = MaskedMatmulTest; -TEST_P(MaskedMatmulTestH, Result) { Run(); } +using MaskedMatmulOnBitsetTestF = MaskedMatmulTest; +TEST_P(MaskedMatmulOnBitsetTestF, Result) { Run(); } + +using MaskedMatmulOnBitsetTestD = MaskedMatmulTest; +TEST_P(MaskedMatmulOnBitsetTestD, Result) { Run(); } + +using MaskedMatmulOnBitsetTestH = MaskedMatmulTest; +TEST_P(MaskedMatmulOnBitsetTestH, Result) { Run(); } const std::vector> sddmm_inputs_f = { {0.001f, 2, 255, 1023, 0.19, 1234ULL}, @@ -419,11 +467,29 @@ const std::vector> sddmm_inputs_h = { {0.0003f, 31, 1025, 1025, 0.19, 1234ULL}, {0.001f, 1024, 1024, 1024, 0.1, 1234ULL}}; -INSTANTIATE_TEST_CASE_P(MaskedMatmulTest, MaskedMatmulTestF, ::testing::ValuesIn(sddmm_inputs_f)); +INSTANTIATE_TEST_CASE_P(MaskedMatmulTest, + MaskedMatmulOnBitmapTestF, + ::testing::ValuesIn(sddmm_inputs_f)); + +INSTANTIATE_TEST_CASE_P(MaskedMatmulTest, + MaskedMatmulOnBitmapTestD, + ::testing::ValuesIn(sddmm_inputs_d)); + +INSTANTIATE_TEST_CASE_P(MaskedMatmulTest, + MaskedMatmulOnBitmapTestH, + ::testing::ValuesIn(sddmm_inputs_h)); + +INSTANTIATE_TEST_CASE_P(MaskedMatmulTest, + MaskedMatmulOnBitsetTestF, + ::testing::ValuesIn(sddmm_inputs_f)); -INSTANTIATE_TEST_CASE_P(MaskedMatmulTest, MaskedMatmulTestD, ::testing::ValuesIn(sddmm_inputs_d)); +INSTANTIATE_TEST_CASE_P(MaskedMatmulTest, + MaskedMatmulOnBitsetTestD, + ::testing::ValuesIn(sddmm_inputs_d)); -INSTANTIATE_TEST_CASE_P(MaskedMatmulTest, MaskedMatmulTestH, ::testing::ValuesIn(sddmm_inputs_h)); +INSTANTIATE_TEST_CASE_P(MaskedMatmulTest, + MaskedMatmulOnBitsetTestH, + ::testing::ValuesIn(sddmm_inputs_h)); } // namespace sparse } // namespace raft