Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a crash in FAISS benchmark wrapper introduced in #2021 #2062

Merged
merged 1 commit into from
Dec 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ class OmpSingleThreadScope {

namespace raft::bench::ann {

struct copyable_event {
copyable_event() { RAFT_CUDA_TRY(cudaEventCreate(&value_, cudaEventDisableTiming)); }
~copyable_event() { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(value_)); }
copyable_event(copyable_event&&) = default;
copyable_event& operator=(copyable_event&&) = default;
copyable_event(const copyable_event& res) : copyable_event{} {}
copyable_event& operator=(const copyable_event& other) = delete;
operator cudaEvent_t() const noexcept { return value_; }

private:
cudaEvent_t value_{nullptr};
};

template <typename T>
class FaissGpu : public ANN<T> {
public:
Expand All @@ -97,18 +110,15 @@ class FaissGpu : public ANN<T> {

FaissGpu(Metric metric, int dim, const BuildParam& param)
: ANN<T>(metric, dim),
gpu_resource_{std::make_shared<faiss::gpu::StandardGpuResources>()},
metric_type_(parse_metric_type(metric)),
nlist_{param.nlist},
training_sample_fraction_{1.0 / double(param.ratio)}
{
static_assert(std::is_same_v<T, float>, "faiss support only float type");
RAFT_CUDA_TRY(cudaGetDevice(&device_));
RAFT_CUDA_TRY(cudaEventCreate(&sync_, cudaEventDisableTiming));
faiss_default_stream_ = gpu_resource_.getDefaultStream(device_);
}

virtual ~FaissGpu() noexcept { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(sync_)); }

void build(const T* dataset, size_t nrow, cudaStream_t stream = 0) final;

virtual void set_search_param(const FaissGpu<T>::AnnSearchParam& param) {}
Expand Down Expand Up @@ -142,7 +152,7 @@ class FaissGpu : public ANN<T> {

void stream_wait(cudaStream_t stream) const
{
RAFT_CUDA_TRY(cudaEventRecord(sync_, faiss_default_stream_));
RAFT_CUDA_TRY(cudaEventRecord(sync_, gpu_resource_->getDefaultStream(device_)));
RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, sync_));
}

Expand All @@ -162,14 +172,13 @@ class FaissGpu : public ANN<T> {
* faiss::gpu::StandardGpuResources are thread-safe.
*
*/
mutable faiss::gpu::StandardGpuResources gpu_resource_;
mutable std::shared_ptr<faiss::gpu::StandardGpuResources> gpu_resource_;
std::shared_ptr<faiss::gpu::GpuIndex> index_;
std::shared_ptr<faiss::IndexRefineFlat> index_refine_{nullptr};
faiss::MetricType metric_type_;
int nlist_;
int device_;
cudaEvent_t sync_{nullptr};
cudaStream_t faiss_default_stream_{nullptr};
copyable_event sync_{};
double training_sample_fraction_;
std::shared_ptr<faiss::SearchParameters> search_params_;
const T* dataset_;
Expand Down Expand Up @@ -278,7 +287,7 @@ class FaissGpuIVFFlat : public FaissGpu<T> {
faiss::gpu::GpuIndexIVFFlatConfig config;
config.device = this->device_;
this->index_ = std::make_shared<faiss::gpu::GpuIndexIVFFlat>(
&(this->gpu_resource_), dim, param.nlist, this->metric_type_, config);
this->gpu_resource_.get(), dim, param.nlist, this->metric_type_, config);
}

void set_search_param(const typename FaissGpu<T>::AnnSearchParam& param) override
Expand Down Expand Up @@ -321,7 +330,7 @@ class FaissGpuIVFPQ : public FaissGpu<T> {
config.device = this->device_;

this->index_ =
std::make_shared<faiss::gpu::GpuIndexIVFPQ>(&(this->gpu_resource_),
std::make_shared<faiss::gpu::GpuIndexIVFPQ>(this->gpu_resource_.get(),
dim,
param.nlist,
param.M,
Expand Down Expand Up @@ -383,7 +392,7 @@ class FaissGpuIVFSQ : public FaissGpu<T> {
faiss::gpu::GpuIndexIVFScalarQuantizerConfig config;
config.device = this->device_;
this->index_ = std::make_shared<faiss::gpu::GpuIndexIVFScalarQuantizer>(
&(this->gpu_resource_), dim, param.nlist, qtype, this->metric_type_, true, config);
this->gpu_resource_.get(), dim, param.nlist, qtype, this->metric_type_, true, config);
}

void set_search_param(const typename FaissGpu<T>::AnnSearchParam& param) override
Expand Down Expand Up @@ -426,7 +435,7 @@ class FaissGpuFlat : public FaissGpu<T> {
faiss::gpu::GpuIndexFlatConfig config;
config.device = this->device_;
this->index_ = std::make_shared<faiss::gpu::GpuIndexFlat>(
&(this->gpu_resource_), dim, this->metric_type_, config);
this->gpu_resource_.get(), dim, this->metric_type_, config);
}
void set_search_param(const typename FaissGpu<T>::AnnSearchParam& param) override
{
Expand Down