Skip to content

Commit

Permalink
optimization for review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Dec 20, 2024
1 parent df9faf5 commit 94d90a7
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 116 deletions.
2 changes: 1 addition & 1 deletion cpp/bench/prims/linalg/masked_matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include <raft/distance/distance.cuh>
#include <raft/distance/distance_types.hpp>
#include <raft/random/rng.cuh>
#include <raft/sparse/linalg/masked_matmul.hpp>
#include <raft/sparse/linalg/masked_matmul.cuh>
#include <raft/util/itertools.hpp>

#include <cusparse_v2.h>
Expand Down
47 changes: 47 additions & 0 deletions cpp/include/raft/sparse/convert/csr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,53 @@ void bitmap_to_csr(raft::resources const& handle,
* The bitset format inherently supports only a single-row matrix (rows=1). If the CSR matrix
* requires multiple rows, the data from the bitset will be repeated for each row in the output.
*
* Example usage:
*
* @code{.cpp}
* #include <raft/core/resource/cuda_stream.hpp>
* #include <raft/sparse/convert/csr.cuh>
* #include <rmm/device_uvector.hpp>
*
* #include <vector>
*
* using bitset_t = uint32_t;
* using index_t = int;
* using value_t = float;
* using nnz_t = index_t;
*
* raft::resources handle;
* auto stream = resource::get_cuda_stream(handle);
* index_t n_rows = 3;
* index_t n_cols = 30;
*
* nnz_t nnz_for_bitset = 4;
* nnz_t nnz_for_csr = nnz_for_bitset * n_rows;
*
* index_t bitset_size = (n_cols + sizeof(bitset_t) * 8 - 1) / (sizeof(bitset_t) * 8); // = 1
*
* rmm::device_uvector<bitset_t> bitset_d(bitset_size, stream);
* std::vector<bitset_t> bitset_h = {
* bitset_t(0b11001010),
* }; // nnz_for_bitset = 4;
*
* raft::copy(bitset_d.data(), bitset_h.data(), bitset_h.size(), stream);
*
* auto bitset_view = raft::core::bitset_view<bitset_t, index_t>(bitset_d.data(), n_cols);
* auto csr = raft::make_device_csr_matrix<value_t, index_t>(handle, n_rows, n_cols, nnz_for_csr);
*
* raft::sparse::convert::bitset_to_csr(handle, bitset_view, csr);
* resource::sync_stream(handle);
*
* // Results:
* // csr.indptr = [0, 4, 8, 12];
* // csr.indices = [1, 3, 6, 7,
* // 1, 3, 6, 7,
* // 1, 3, 6, 7];
* // csr.values = [1, 1, 1, 1,
* // 1, 1, 1, 1,
* // 1, 1, 1, 1];
* @endcode
*
* @tparam bitset_t The data type of the elements in the bitset matrix.
* @tparam index_t The data type used for indexing the elements in the matrices.
* @tparam csr_matrix_t Specifies the CSR matrix type, constrained to
Expand Down
46 changes: 37 additions & 9 deletions cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -118,24 +118,52 @@ void bitset_to_csr(raft::resources const& handle,

RAFT_CUDA_TRY(cudaMemsetAsync(indptr, 0, (csr_view.get_n_rows() + 1) * sizeof(index_t), stream));

calc_nnz_by_rows(handle, bitset.data(), row_t(1), csr_view.get_n_cols(), indptr);
thrust::exclusive_scan(thrust_policy, indptr, indptr + 2, indptr);
size_t sub_nnz_size = 0;
index_t bits_per_sub_col = 0;

// Get buffer size and number of bits per each sub-columns
calc_nnz_by_rows(handle,
bitset.data(),
row_t(1),
csr_view.get_n_cols(),
static_cast<nnz_t*>(nullptr),
sub_nnz_size,
bits_per_sub_col);

rmm::device_async_resource_ref device_memory = resource::get_workspace_resource(handle);
rmm::device_uvector<nnz_t> sub_nnz(sub_nnz_size + 1, stream, device_memory);

calc_nnz_by_rows(handle,
bitset.data(),
row_t(1),
csr_view.get_n_cols(),
sub_nnz.data(),
sub_nnz_size,
bits_per_sub_col);

thrust::exclusive_scan(
thrust_policy, sub_nnz.data(), sub_nnz.data() + sub_nnz_size + 1, sub_nnz.data());

index_t bitset_nnz = 0;

if constexpr (is_device_csr_sparsity_owning_v<csr_matrix_t>) {
RAFT_CUDA_TRY(
cudaMemcpyAsync(&bitset_nnz, indptr + 1, sizeof(index_t), cudaMemcpyDeviceToHost, stream));
RAFT_CUDA_TRY(cudaMemcpyAsync(
&bitset_nnz, sub_nnz.data() + sub_nnz_size, sizeof(index_t), cudaMemcpyDeviceToHost, stream));
resource::sync_stream(handle);
csr.initialize_sparsity(bitset_nnz * csr_view.get_n_rows());
} else {
bitset_nnz = csr_view.get_nnz() / csr_view.get_n_rows();
}

constexpr bool check_nnz = is_device_csr_sparsity_preserving_v<csr_matrix_t>;
fill_indices_by_rows<bitset_t, index_t, nnz_t, check_nnz>(
handle, bitset.data(), indptr, 1, csr_view.get_n_cols(), bitset_nnz, indices);

fill_indices_by_rows<bitset_t, index_t, nnz_t, check_nnz>(handle,
bitset.data(),
indptr,
1,
csr_view.get_n_cols(),
csr_view.get_nnz(),
indices,
sub_nnz.data(),
bits_per_sub_col,
sub_nnz_size);
if (csr_view.get_n_rows() > 1) {
gpu_repeat_csr<index_t, nnz_t>(handle,
indptr,
Expand Down
117 changes: 117 additions & 0 deletions cpp/include/raft/sparse/linalg/masked_matmul.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain A copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/sparse/linalg/detail/masked_matmul.cuh>

namespace raft {
namespace sparse {
namespace linalg {

/**
* @defgroup masked_matmul Masked Matrix Multiplication
* @{
*/

/**
* @brief Performs a masked multiplication of dense matrices A and B, followed by an element-wise
* multiplication with the sparsity pattern defined by the mask, resulting in the computation
* C = alpha * ((A * B) ∘ spy(mask)) + beta * C.
*
* This function multiplies two dense matrices A and B, and then applies an element-wise
* multiplication using the sparsity pattern provided by the mask. The result is scaled by alpha
* and added to beta times the original matrix C.
*
* @tparam value_t Data type of elements in the input matrices (e.g., half, float, double)
* @tparam output_t Data type of elements in the output matrices (e.g., float, double)
* @tparam index_t Type used for matrix indices
* @tparam nnz_t Type used for the number of non-zero entries in CSR format
* @tparam bitmap_t Type of the bitmap used for the mask
*
* @param[in] handle RAFT handle for resource management
* @param[in] A Input dense matrix (device_matrix_view) with shape [m, k]
* @param[in] B Input dense matrix (device_matrix_view) with shape [n, k]
* @param[in] mask Bitmap view representing the sparsity pattern (bitmap_view) with logical shape
* [m, n]. Each bit in the mask indicates whether the corresponding element pair in A and B is
* included (1) or masked out (0).
* @param[inout] C Output sparse matrix in CSR format (device_csr_matrix_view) with dense shape [m,
* n]
* @param[in] alpha Optional scalar multiplier for the product of A and B (default: 1.0 if
* std::nullopt)
* @param[in] beta Optional scalar multiplier for the original matrix C (default: 0 if std::nullopt)
*/
template <typename value_t, typename output_t, typename index_t, typename nnz_t, typename bitmap_t>
void masked_matmul(raft::resources const& handle,
raft::device_matrix_view<const value_t, index_t, raft::row_major> A,
raft::device_matrix_view<const value_t, index_t, raft::row_major> B,
raft::core::bitmap_view<const bitmap_t, index_t> mask,
raft::device_csr_matrix_view<output_t, index_t, index_t, nnz_t> C,
std::optional<raft::host_scalar_view<output_t>> alpha = std::nullopt,
std::optional<raft::host_scalar_view<output_t>> beta = std::nullopt)
{
detail::masked_matmul(handle, A, B, mask, C, alpha, beta);
}

/**
* @brief Computes a sparse matrix product with a masked sparsity pattern and scaling.
*
* This function computes the result of:
* C = alpha * ((A * B) ∘ spy(mask)) + beta * C
* where:
* - A and B are dense input matrices.
* - "mask" defines the sparsity pattern for element-wise multiplication.
* - The result is scaled by alpha and added to beta times the original C.
*
* **Special behavior of the mask**:
* - The `bitset` mask represents a single row of data, with its bits indicating whether
* each corresponding element in (A * B) is included (1) or masked out (0).
* - If the output CSR matrix `C` has multiple rows, the `bitset` is logically repeated
* across all rows of `C`. For example, if `C` has `n_rows` rows, the same `bitset`
* pattern is applied to all rows.
*
* @tparam value_t Data type of input matrix elements (e.g., half, float, double).
* @tparam output_t Data type of output matrix elements (e.g., float, double).
* @tparam index_t Type for matrix indices.
* @tparam nnz_t Type for non-zero entries in CSR format.
* @tparam bitmap_t Type for the bitmap mask.
*
* @param[in] handle RAFT handle for managing resources.
* @param[in] A Dense input matrix [m, k] (row-major).
* @param[in] B Dense input matrix [n, k] (row-major).
* @param[in] mask Bitmap view representing a single row [1, n], where each bit
* indicates if the corresponding element in (A * B) is included (1)
* or masked out (0). The pattern is repeated for all rows of `C`.
* @param[inout] C Output sparse matrix in CSR format [m, n].
* @param[in] alpha Scalar multiplier for (A * B) (default: 1.0 if std::nullopt).
* @param[in] beta Scalar multiplier for the initial C (default: 0 if std::nullopt).
*/
template <typename value_t, typename output_t, typename index_t, typename nnz_t, typename bitmap_t>
void masked_matmul(raft::resources const& handle,
raft::device_matrix_view<const value_t, index_t, raft::row_major> A,
raft::device_matrix_view<const value_t, index_t, raft::row_major> B,
raft::core::bitset_view<const bitmap_t, index_t> mask,
raft::device_csr_matrix_view<output_t, index_t, index_t, nnz_t> C,
std::optional<raft::host_scalar_view<output_t>> alpha = std::nullopt,
std::optional<raft::host_scalar_view<output_t>> beta = std::nullopt)
{
detail::masked_matmul(handle, A, B, mask, C, alpha, beta);
}

/** @} */ // end of masked_matmul

} // end namespace linalg
} // end namespace sparse
} // end namespace raft
104 changes: 10 additions & 94 deletions cpp/include/raft/sparse/linalg/masked_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,105 +13,21 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/sparse/linalg/detail/masked_matmul.cuh>

namespace raft {
namespace sparse {
namespace linalg {

/**
* @defgroup masked_matmul Masked Matrix Multiplication
* @{
* This file is deprecated and will be removed in future release.
* Please use the cuh version instead.
*/

/**
* @brief Performs a masked multiplication of dense matrices A and B, followed by an element-wise
* multiplication with the sparsity pattern defined by the mask, resulting in the computation
* C = alpha * ((A * B) ∘ spy(mask)) + beta * C.
*
* This function multiplies two dense matrices A and B, and then applies an element-wise
* multiplication using the sparsity pattern provided by the mask. The result is scaled by alpha
* and added to beta times the original matrix C.
*
* @tparam value_t Data type of elements in the input matrices (e.g., half, float, double)
* @tparam output_t Data type of elements in the output matrices (e.g., float, double)
* @tparam index_t Type used for matrix indices
* @tparam nnz_t Type used for the number of non-zero entries in CSR format
* @tparam bitmap_t Type of the bitmap used for the mask
*
* @param[in] handle RAFT handle for resource management
* @param[in] A Input dense matrix (device_matrix_view) with shape [m, k]
* @param[in] B Input dense matrix (device_matrix_view) with shape [n, k]
* @param[in] mask Bitmap view representing the sparsity pattern (bitmap_view) with logical shape
* [m, n]. Each bit in the mask indicates whether the corresponding element pair in A and B is
* included (1) or masked out (0).
* @param[inout] C Output sparse matrix in CSR format (device_csr_matrix_view) with dense shape [m,
* n]
* @param[in] alpha Optional scalar multiplier for the product of A and B (default: 1.0 if
* std::nullopt)
* @param[in] beta Optional scalar multiplier for the original matrix C (default: 0 if std::nullopt)
* DISCLAIMER: this file is deprecated: use masked_matmul.cuh instead
*/
template <typename value_t, typename output_t, typename index_t, typename nnz_t, typename bitmap_t>
void masked_matmul(raft::resources const& handle,
raft::device_matrix_view<const value_t, index_t, raft::row_major> A,
raft::device_matrix_view<const value_t, index_t, raft::row_major> B,
raft::core::bitmap_view<const bitmap_t, index_t> mask,
raft::device_csr_matrix_view<output_t, index_t, index_t, nnz_t> C,
std::optional<raft::host_scalar_view<output_t>> alpha = std::nullopt,
std::optional<raft::host_scalar_view<output_t>> beta = std::nullopt)
{
detail::masked_matmul(handle, A, B, mask, C, alpha, beta);
}

/**
* @brief Computes a sparse matrix product with a masked sparsity pattern and scaling.
*
* This function computes the result of:
* C = alpha * ((A * B) ∘ spy(mask)) + beta * C
* where:
* - A and B are dense input matrices.
* - "mask" defines the sparsity pattern for element-wise multiplication.
* - The result is scaled by alpha and added to beta times the original C.
*
* **Special behavior of the mask**:
* - The `bitset` mask represents a single row of data, with its bits indicating whether
* each corresponding element in (A * B) is included (1) or masked out (0).
* - If the output CSR matrix `C` has multiple rows, the `bitset` is logically repeated
* across all rows of `C`. For example, if `C` has `n_rows` rows, the same `bitset`
* pattern is applied to all rows.
*
* @tparam value_t Data type of input matrix elements (e.g., half, float, double).
* @tparam output_t Data type of output matrix elements (e.g., float, double).
* @tparam index_t Type for matrix indices.
* @tparam nnz_t Type for non-zero entries in CSR format.
* @tparam bitmap_t Type for the bitmap mask.
*
* @param[in] handle RAFT handle for managing resources.
* @param[in] A Dense input matrix [m, k] (row-major).
* @param[in] B Dense input matrix [n, k] (row-major).
* @param[in] mask Bitmap view representing a single row [1, n], where each bit
* indicates if the corresponding element in (A * B) is included (1)
* or masked out (0). The pattern is repeated for all rows of `C`.
* @param[inout] C Output sparse matrix in CSR format [m, n].
* @param[in] alpha Scalar multiplier for (A * B) (default: 1.0 if std::nullopt).
* @param[in] beta Scalar multiplier for the initial C (default: 0 if std::nullopt).
*/
template <typename value_t, typename output_t, typename index_t, typename nnz_t, typename bitmap_t>
void masked_matmul(raft::resources const& handle,
raft::device_matrix_view<const value_t, index_t, raft::row_major> A,
raft::device_matrix_view<const value_t, index_t, raft::row_major> B,
raft::core::bitset_view<const bitmap_t, index_t> mask,
raft::device_csr_matrix_view<output_t, index_t, index_t, nnz_t> C,
std::optional<raft::host_scalar_view<output_t>> alpha = std::nullopt,
std::optional<raft::host_scalar_view<output_t>> beta = std::nullopt)
{
detail::masked_matmul(handle, A, B, mask, C, alpha, beta);
}
#pragma once

/** @} */ // end of masked_matmul
#ifndef RAFT_HIDE_DEPRECATION_WARNINGS
#pragma message(__FILE__ \
" is deprecated and will be removed in a future release." \
" Please use the cuh version instead.")
#endif

} // end namespace linalg
} // end namespace sparse
} // end namespace raft
#include <raft/sparse/linalg/masked_matmul.cuh>
Loading

0 comments on commit 94d90a7

Please sign in to comment.