Skip to content

Commit

Permalink
Fix a crash in FAISS benchmark wrapper introduced in #2021 (#2062)
Browse files Browse the repository at this point in the history
With the changes introduced by #2021, the copied FAISS benchmark wrapper contains a cuda event that is used for synchronizing between streams during search. The lifetime of the event is the same as of the wrapper, but the event handle itself is copied between the wrappers; this leads to illegal memory accesses and crashes.
This PR fixes the bug by creating a new cuda event on each wrapper copy, so that the wrappers do not share their synchronization events.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2062
  • Loading branch information
achirkin authored Dec 13, 2023
1 parent d9a7290 commit 6c95f9c
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ class OmpSingleThreadScope {

namespace raft::bench::ann {

struct copyable_event {
copyable_event() { RAFT_CUDA_TRY(cudaEventCreate(&value_, cudaEventDisableTiming)); }
~copyable_event() { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(value_)); }
copyable_event(copyable_event&&) = default;
copyable_event& operator=(copyable_event&&) = default;
copyable_event(const copyable_event& res) : copyable_event{} {}
copyable_event& operator=(const copyable_event& other) = delete;
operator cudaEvent_t() const noexcept { return value_; }

private:
cudaEvent_t value_{nullptr};
};

template <typename T>
class FaissGpu : public ANN<T> {
public:
Expand All @@ -97,18 +110,15 @@ class FaissGpu : public ANN<T> {

FaissGpu(Metric metric, int dim, const BuildParam& param)
: ANN<T>(metric, dim),
gpu_resource_{std::make_shared<faiss::gpu::StandardGpuResources>()},
metric_type_(parse_metric_type(metric)),
nlist_{param.nlist},
training_sample_fraction_{1.0 / double(param.ratio)}
{
static_assert(std::is_same_v<T, float>, "faiss support only float type");
RAFT_CUDA_TRY(cudaGetDevice(&device_));
RAFT_CUDA_TRY(cudaEventCreate(&sync_, cudaEventDisableTiming));
faiss_default_stream_ = gpu_resource_.getDefaultStream(device_);
}

virtual ~FaissGpu() noexcept { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(sync_)); }

void build(const T* dataset, size_t nrow, cudaStream_t stream = 0) final;

virtual void set_search_param(const FaissGpu<T>::AnnSearchParam& param) {}
Expand Down Expand Up @@ -142,7 +152,7 @@ class FaissGpu : public ANN<T> {

void stream_wait(cudaStream_t stream) const
{
RAFT_CUDA_TRY(cudaEventRecord(sync_, faiss_default_stream_));
RAFT_CUDA_TRY(cudaEventRecord(sync_, gpu_resource_->getDefaultStream(device_)));
RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, sync_));
}

Expand All @@ -162,14 +172,13 @@ class FaissGpu : public ANN<T> {
* faiss::gpu::StandardGpuResources are thread-safe.
*
*/
mutable faiss::gpu::StandardGpuResources gpu_resource_;
mutable std::shared_ptr<faiss::gpu::StandardGpuResources> gpu_resource_;
std::shared_ptr<faiss::gpu::GpuIndex> index_;
std::shared_ptr<faiss::IndexRefineFlat> index_refine_{nullptr};
faiss::MetricType metric_type_;
int nlist_;
int device_;
cudaEvent_t sync_{nullptr};
cudaStream_t faiss_default_stream_{nullptr};
copyable_event sync_{};
double training_sample_fraction_;
std::shared_ptr<faiss::SearchParameters> search_params_;
const T* dataset_;
Expand Down Expand Up @@ -278,7 +287,7 @@ class FaissGpuIVFFlat : public FaissGpu<T> {
faiss::gpu::GpuIndexIVFFlatConfig config;
config.device = this->device_;
this->index_ = std::make_shared<faiss::gpu::GpuIndexIVFFlat>(
&(this->gpu_resource_), dim, param.nlist, this->metric_type_, config);
this->gpu_resource_.get(), dim, param.nlist, this->metric_type_, config);
}

void set_search_param(const typename FaissGpu<T>::AnnSearchParam& param) override
Expand Down Expand Up @@ -321,7 +330,7 @@ class FaissGpuIVFPQ : public FaissGpu<T> {
config.device = this->device_;

this->index_ =
std::make_shared<faiss::gpu::GpuIndexIVFPQ>(&(this->gpu_resource_),
std::make_shared<faiss::gpu::GpuIndexIVFPQ>(this->gpu_resource_.get(),
dim,
param.nlist,
param.M,
Expand Down Expand Up @@ -383,7 +392,7 @@ class FaissGpuIVFSQ : public FaissGpu<T> {
faiss::gpu::GpuIndexIVFScalarQuantizerConfig config;
config.device = this->device_;
this->index_ = std::make_shared<faiss::gpu::GpuIndexIVFScalarQuantizer>(
&(this->gpu_resource_), dim, param.nlist, qtype, this->metric_type_, true, config);
this->gpu_resource_.get(), dim, param.nlist, qtype, this->metric_type_, true, config);
}

void set_search_param(const typename FaissGpu<T>::AnnSearchParam& param) override
Expand Down Expand Up @@ -426,7 +435,7 @@ class FaissGpuFlat : public FaissGpu<T> {
faiss::gpu::GpuIndexFlatConfig config;
config.device = this->device_;
this->index_ = std::make_shared<faiss::gpu::GpuIndexFlat>(
&(this->gpu_resource_), dim, this->metric_type_, config);
this->gpu_resource_.get(), dim, this->metric_type_, config);
}
void set_search_param(const typename FaissGpu<T>::AnnSearchParam& param) override
{
Expand Down

0 comments on commit 6c95f9c

Please sign in to comment.