Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor memory ownership svm #6073

Open
wants to merge 5 commits into
base: branch-24.10
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions cpp/bench/sg/svc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,13 @@ struct SvcParams {
BlobsParams blobs;
raft::distance::kernels::KernelParams kernel;
ML::SVM::SvmParameter svm_param;
ML::SVM::SvmModel<D> model;
};

template <typename D>
class SVC : public BlobsFixture<D, D> {
public:
SVC(const std::string& name, const SvcParams<D>& p)
: BlobsFixture<D, D>(name, p.data, p.blobs),
kernel(p.kernel),
model(p.model),
svm_param(p.svm_param)
: BlobsFixture<D, D>(name, p.data, p.blobs), kernel(p.kernel), svm_param(p.svm_param)
{
std::vector<std::string> kernel_names{"linear", "poly", "rbf", "tanh"};
std::ostringstream oss;
Expand Down Expand Up @@ -101,7 +97,6 @@ std::vector<SvcParams<D>> getInputs()

// SvmParameter{C, cache_size, max_iter, nochange_steps, tol, verbosity})
p.svm_param = ML::SVM::SvmParameter{1, 200, 100, 100, 1e-3, CUML_LEVEL_INFO, 0, ML::SVM::C_SVC};
p.model = ML::SVM::SvmModel<D>{0, 0, 0, nullptr, {}, nullptr, 0, nullptr};

std::vector<Triplets> rowcols = {{50000, 2, 2}, {2048, 100000, 2}, {50000, 1000, 2}};

Expand Down
13 changes: 4 additions & 9 deletions cpp/bench/sg/svr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,13 @@ struct SvrParams {
RegressionParams regression;
raft::distance::kernels::KernelParams kernel;
ML::SVM::SvmParameter svm_param;
ML::SVM::SvmModel<D>* model;
};

template <typename D>
class SVR : public RegressionFixture<D> {
public:
SVR(const std::string& name, const SvrParams<D>& p)
: RegressionFixture<D>(name, p.data, p.regression),
kernel(p.kernel),
model(p.model),
svm_param(p.svm_param)
: RegressionFixture<D>(name, p.data, p.regression), kernel(p.kernel), svm_param(p.svm_param)
{
std::vector<std::string> kernel_names{"linear", "poly", "rbf", "tanh"};
std::ostringstream oss;
Expand All @@ -69,16 +65,16 @@ class SVR : public RegressionFixture<D> {
this->data.y.data(),
this->svm_param,
this->kernel,
*(this->model));
this->model);
this->handle->sync_stream(this->stream);
ML::SVM::svmFreeBuffers(*this->handle, *(this->model));
ML::SVM::svmFreeBuffers(*this->handle, this->model);
});
}

private:
raft::distance::kernels::KernelParams kernel;
ML::SVM::SvmParameter svm_param;
ML::SVM::SvmModel<D>* model;
ML::SVM::SvmModel<D> model;
};

template <typename D>
Expand All @@ -103,7 +99,6 @@ std::vector<SvrParams<D>> getInputs()
// epsilon, svmType})
p.svm_param =
ML::SVM::SvmParameter{1, 200, 200, 100, 1e-3, CUML_LEVEL_INFO, 0.1, ML::SVM::EPSILON_SVR};
p.model = new ML::SVM::SvmModel<D>{0, 0, 0, 0};

std::vector<Triplets> rowcols = {{50000, 2, 2}, {1024, 10000, 10}, {3000, 200, 200}};

Expand Down
21 changes: 11 additions & 10 deletions cpp/include/cuml/svm/svm_model.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,16 +15,17 @@
*/
#pragma once

#include <rmm/device_buffer.hpp>

namespace ML {
namespace SVM {

// Contains array(s) for matrix storage
template <typename math_t>
struct SupportStorage {
int nnz = -1;
int* indptr = nullptr;
int* indices = nullptr;
math_t* data = nullptr;
int nnz = -1;
rmm::device_buffer indptr;
rmm::device_buffer indices;
rmm::device_buffer data;
};

/**
Expand All @@ -39,17 +40,17 @@ struct SvmModel {

//! Non-zero dual coefficients ( dual_coef[i] = \f$ y_i \alpha_i \f$).
//! Size [n_support].
math_t* dual_coefs;
rmm::device_buffer dual_coefs;

//! Support vector storage - can contain either CSR or dense
SupportStorage<math_t> support_matrix;
SupportStorage support_matrix;

//! Indices (from the training set) of the support vectors, size [n_support].
int* support_idx;
rmm::device_buffer support_idx;

int n_classes; //!< Number of classes found in the input labels
//! Device pointer for the unique classes. Size [n_classes]
math_t* unique_labels;
rmm::device_buffer unique_labels;
};

}; // namespace SVM
Expand Down
70 changes: 34 additions & 36 deletions cpp/src/svm/results.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -117,56 +117,53 @@ class Results {
*/
void Get(const math_t* alpha,
const math_t* f,
math_t** dual_coefs,
rmm::device_buffer& dual_coefs,
int* n_support,
int** idx,
SupportStorage<math_t>* support_matrix,
rmm::device_buffer& idx,
SupportStorage& support_matrix,
math_t* b)
{
CombineCoefs(alpha, val_tmp.data());
GetDualCoefs(val_tmp.data(), dual_coefs, n_support);
*b = CalcB(alpha, f, *n_support);
if (*n_support > 0) {
*idx = GetSupportVectorIndices(val_tmp.data(), *n_support);
*support_matrix = CollectSupportVectorMatrix(*idx, *n_support);
} else {
*dual_coefs = nullptr;
*idx = nullptr;
*support_matrix = {};
}
GetSupportVectorIndices(idx, val_tmp.data(), *n_support);
CollectSupportVectorMatrix(support_matrix, idx, *n_support);
// Make sure that all pending GPU calculations finished before we return
handle.sync_stream(stream);
}

/**
* Collect support vectors into a matrix storage
*
* @param [out] support_matrix containing the support vectors, size [n_suppor*n_cols]
* @param [in] idx indices of support vectors, size [n_support]
* @param [in] n_support number of support vectors
* @return pointer to a newly allocated device buffer that stores the support
* vectors, size [n_suppor*n_cols]
*/
SupportStorage<math_t> CollectSupportVectorMatrix(const int* idx, int n_support)
void CollectSupportVectorMatrix(SupportStorage& support_matrix,
rmm::device_buffer& idx,
int n_support)
{
SupportStorage<math_t> support_matrix;
// allow ~1GB dense support matrix
if (isDenseType<MatrixViewType>() ||
((size_t)n_support * n_cols * sizeof(math_t) < (1 << 30))) {
support_matrix.data = (math_t*)rmm_alloc.allocate_async(
n_support * n_cols * sizeof(math_t), rmm::CUDA_ALLOCATION_ALIGNMENT, stream);
ML::SVM::extractRows<math_t>(matrix, support_matrix.data, idx, n_support, handle);
support_matrix.nnz = -1;
support_matrix.indptr.resize(0, stream);
support_matrix.indices.resize(0, stream);
support_matrix.data.resize(n_support * n_cols * sizeof(math_t), stream);
if (n_support > 0) {
ML::SVM::extractRows<math_t>(
matrix, (math_t*)support_matrix.data.data(), (int*)idx.data(), n_support, handle);
}
} else {
ML::SVM::extractRows<math_t>(matrix,
&(support_matrix.indptr),
&(support_matrix.indices),
&(support_matrix.data),
support_matrix.indptr,
support_matrix.indices,
support_matrix.data,
&(support_matrix.nnz),
idx,
(int*)idx.data(),
mfoerste4 marked this conversation as resolved.
Show resolved Hide resolved
n_support,
handle);
}

return support_matrix;
}

/**
Expand Down Expand Up @@ -205,33 +202,34 @@ class Results {
* unallocated on entry, on exit size [n_support]
* @param [out] n_support number of support vectors
*/
void GetDualCoefs(const math_t* val_tmp, math_t** dual_coefs, int* n_support)
void GetDualCoefs(const math_t* val_tmp, rmm::device_buffer& dual_coefs, int* n_support)
{
// Return only the non-zero coefficients
auto select_op = [] __device__(math_t a) { return 0 != a; };
*n_support = SelectByCoef(val_tmp, n_rows, val_tmp, select_op, val_selected.data());
*dual_coefs = (math_t*)rmm_alloc.allocate_async(
*n_support * sizeof(math_t), rmm::CUDA_ALLOCATION_ALIGNMENT, stream);
raft::copy(*dual_coefs, val_selected.data(), *n_support, stream);
dual_coefs.resize(*n_support * sizeof(math_t), stream);
raft::copy((math_t*)dual_coefs.data(), val_selected.data(), *n_support, stream);
handle.sync_stream(stream);
}

/**
* Flag support vectors and also collect their indices.
* Support vectors are the vectors where alpha > 0.
*
* @param [out] idx the training set indices of the support vectors, size [n_support]
* @param [in] coef dual coefficients, size [n_rows]
* @param [in] n_support number of support vectors
* @return indices of the support vectors, size [n_support]
*/
int* GetSupportVectorIndices(const math_t* coef, int n_support)
void GetSupportVectorIndices(rmm::device_buffer& idx, const math_t* coef, int n_support)
{
auto select_op = [] __device__(math_t a) -> bool { return 0 != a; };
SelectByCoef(coef, n_rows, f_idx.data(), select_op, idx_selected.data());
int* idx = (int*)rmm_alloc.allocate_async(
n_support * sizeof(int), rmm::CUDA_ALLOCATION_ALIGNMENT, stream);
raft::copy(idx, idx_selected.data(), n_support, stream);
return idx;
if (n_support > 0) {
auto select_op = [] __device__(math_t a) -> bool { return 0 != a; };
SelectByCoef(coef, n_rows, f_idx.data(), select_op, idx_selected.data());
idx.resize(n_support * sizeof(int), stream);
raft::copy((int*)idx.data(), idx_selected.data(), n_support, stream);
} else {
idx.resize(0, stream);
}
}

/**
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/svm/smosolver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ void SmoSolver<math_t>::Solve(MatrixViewType matrix,
int n_cols,
math_t* y,
const math_t* sample_weight,
math_t** dual_coefs,
rmm::device_buffer* dual_coefs,
mfoerste4 marked this conversation as resolved.
Show resolved Hide resolved
int* n_support,
SupportStorage<math_t>* support_matrix,
int** idx,
SupportStorage* support_matrix,
rmm::device_buffer* idx,
math_t* b,
int max_outer_iter,
int max_inner_iter)
Expand Down Expand Up @@ -210,7 +210,7 @@ void SmoSolver<math_t>::Solve(MatrixViewType matrix,
diff_prev);

Results<math_t, MatrixViewType> res(handle, matrix, n_rows, n_cols, y, C_vec.data(), svmType);
res.Get(alpha.data(), f.data(), dual_coefs, n_support, idx, support_matrix, b);
res.Get(alpha.data(), f.data(), *dual_coefs, n_support, *idx, *support_matrix, b);

ReleaseBuffers();
}
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/svm/smosolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ class SmoSolver {
int n_cols,
math_t* y,
const math_t* sample_weight,
math_t** dual_coefs,
rmm::device_buffer* dual_coefs,
int* n_support,
SupportStorage<math_t>* support_matrix,
int** idx,
SupportStorage* support_matrix,
rmm::device_buffer* idx,
math_t* b,
int max_outer_iter = -1,
int max_inner_iter = 10000);
Expand Down
34 changes: 19 additions & 15 deletions cpp/src/svm/sparse_util.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -717,9 +717,9 @@ void extractRows(raft::device_csr_matrix_view<math_t, int, int, int> matrix_in,
*/
template <typename math_t, typename LayoutPolicyIn>
void extractRows(raft::device_matrix_view<math_t, int, LayoutPolicyIn> matrix_in,
int** indptr_out,
int** indices_out,
math_t** data_out,
rmm::device_buffer& indptr_out,
rmm::device_buffer& indices_out,
rmm::device_buffer& data_out,
int* nnz,
const int* row_indices,
int num_indices,
Expand All @@ -734,8 +734,6 @@ void extractRows(raft::device_matrix_view<math_t, int, LayoutPolicyIn> matrix_in
* This is the specialized version for
* 'CSR -> CSR (raw pointers)'
*
* Warning: this specialization will allocate the the required arrays in device memory.
*
* @param [in] matrix_in matrix input in CSR [i, j]
* @param [out] indptr_out row index pointer of CSR output [num_indices + 1]
* @param [out] indices_out column indices of CSR output [nnz = indptr_out[num_indices + 1]]
Expand All @@ -747,9 +745,9 @@ void extractRows(raft::device_matrix_view<math_t, int, LayoutPolicyIn> matrix_in
*/
template <typename math_t>
void extractRows(raft::device_csr_matrix_view<math_t, int, int, int> matrix_in,
int** indptr_out,
int** indices_out,
math_t** data_out,
rmm::device_buffer& indptr_out,
rmm::device_buffer& indices_out,
rmm::device_buffer& data_out,
int* nnz,
const int* row_indices,
int num_indices,
Expand All @@ -762,20 +760,26 @@ void extractRows(raft::device_csr_matrix_view<math_t, int, int, int> matrix_in,
math_t* data_in = matrix_in.get_elements().data();

// allocate indptr
auto* rmm_alloc = rmm::mr::get_current_device_resource();
*indptr_out = (int*)rmm_alloc->allocate((num_indices + 1) * sizeof(int), stream);
indptr_out.resize((num_indices + 1) * sizeof(int), stream);

*nnz = computeIndptrForSubset(indptr_in, *indptr_out, row_indices, num_indices, stream);
*nnz =
computeIndptrForSubset(indptr_in, (int*)indptr_out.data(), row_indices, num_indices, stream);

// allocate indices, data
*indices_out = (int*)rmm_alloc->allocate(*nnz * sizeof(int), stream);
*data_out = (math_t*)rmm_alloc->allocate(*nnz * sizeof(math_t), stream);
indices_out.resize(*nnz * sizeof(int), stream);
data_out.resize(*nnz * sizeof(math_t), stream);

// copy with 1 warp per row for now, blocksize 256
const dim3 bs(32, 8, 1);
const dim3 gs(raft::ceildiv(num_indices, (int)bs.y), 1, 1);
extractCSRRowsFromCSR<math_t><<<gs, bs, 0, stream>>>(
*indptr_out, *indices_out, *data_out, indptr_in, indices_in, data_in, row_indices, num_indices);
extractCSRRowsFromCSR<math_t><<<gs, bs, 0, stream>>>((int*)indptr_out.data(),
(int*)indices_out.data(),
(math_t*)data_out.data(),
indptr_in,
indices_in,
data_in,
row_indices,
num_indices);

RAFT_CUDA_TRY(cudaPeekAtLastError());
}
Expand Down
8 changes: 2 additions & 6 deletions cpp/src/svm/svc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,7 @@ SVC<math_t>::SVC(raft::handle_t& handle,
param(SvmParameter{C, cache_size, max_iter, nochange_steps, tol, verbosity}),
kernel_params(kernel_params)
{
model.n_support = 0;
model.dual_coefs = nullptr;
model.support_matrix = {};
model.support_idx = nullptr;
model.unique_labels = nullptr;
svmFreeBuffers(handle, model);
}

template <typename math_t>
Expand All @@ -162,7 +158,7 @@ void SVC<math_t>::fit(
math_t* input, int n_rows, int n_cols, math_t* labels, const math_t* sample_weight)
{
model.n_cols = n_cols;
if (model.dual_coefs) svmFreeBuffers(handle, model);
svmFreeBuffers(handle, model);
svcFit(handle, input, n_rows, n_cols, labels, param, kernel_params, model, sample_weight);
}

Expand Down
Loading
Loading