Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Add mdspan API to raft IVF functions (#810)
Browse files Browse the repository at this point in the history
* Adding mdspan calls to ivf raft

Signed-off-by: Mickael Ide <[email protected]>

* Add pylibraft

Signed-off-by: Mickael Ide <[email protected]>

* Remove pylibraft

Signed-off-by: Mickael Ide <[email protected]>

* Fix style

Signed-off-by: Mickael Ide <[email protected]>

---------

Signed-off-by: Mickael Ide <[email protected]>
  • Loading branch information
lowener authored May 5, 2023
1 parent 9727030 commit 743bb3e
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 61 deletions.
10 changes: 9 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,15 @@ list(APPEND KNOWHERE_LINKER_LIBS prometheus-cpp::core prometheus-cpp::push)
add_library(knowhere SHARED ${KNOWHERE_SRCS})
add_dependencies(knowhere ${KNOWHERE_LINKER_LIBS})
if(WITH_RAFT)
list(APPEND KNOWHERE_LINKER_LIBS raft::raft)
list(APPEND KNOWHERE_LINKER_LIBS raft::raft raft::compiled)
find_library(LIBRAFT_FOUND raft)
if (NOT LIBRAFT_FOUND)
message(WARNING "libraft not found")
else()
message(STATUS "libraft found")
list(APPEND KNOWHERE_LINKER_LIBS ${LIBRAFT_FOUND})
add_definitions(-DRAFT_COMPILED)
endif()
endif()
target_link_libraries(knowhere PUBLIC ${KNOWHERE_LINKER_LIBS})
target_include_directories(
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ $ conan install .. --build=missing -o with_ut=True -o with_raft=True -s compiler
#DISKANN SUPPORT
$ conan install .. --build=missing -o with_ut=True -o with_diskann=True -s compiler.libcxx=libstdc++11 -s build_type=Debug/Release
#build with conan
$conan build ..
$ conan build ..
#verbose
export VERBOSE=1
```
Expand Down
26 changes: 15 additions & 11 deletions cmake/libs/libraft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,14 @@ set(RAFT_FORK "rapidsai")
set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}")

function(find_and_configure_raft)
set(oneValueArgs VERSION FORK PINNED_TAG)
set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARY)
cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}"
${ARGN})

set(RAFT_COMPONENTS "")
if(PKG_COMPILE_LIBRARY)
string(APPEND RAFT_COMPONENTS " compiled")
endif()
# -----------------------------------------------------
# Invoke CPM find_package()
# -----------------------------------------------------
Expand All @@ -44,12 +48,8 @@ function(find_and_configure_raft)
${PKG_VERSION}
GLOBAL_TARGETS
raft::raft
BUILD_EXPORT_SET
faiss-exports
INSTALL_EXPORT_SET
faiss-exports
COMPONENTS
"distance nn"
${RAFT_COMPONENTS}
CPM_ARGS
GIT_REPOSITORY
https://github.com/${PKG_FORK}/raft.git
Expand All @@ -60,13 +60,17 @@ function(find_and_configure_raft)
OPTIONS
"BUILD_TESTS OFF"
"BUILD_BENCH OFF"
"RAFT_COMPILE_LIBRARIES OFF"
"RAFT_COMPILE_NN_LIBRARY OFF"
"RAFT_USE_FAISS_STATIC OFF" # Turn this on to build FAISS into your binary
"RAFT_ENABLE_NN_DEPENDENCIES OFF")
"RAFT_COMPILE_LIBRARY ${PKG_COMPILE_LIBRARY}"
"RAFT_USE_FAISS_STATIC OFF") # Turn this on to build FAISS into your binary

if(raft_ADDED)
message(VERBOSE "KNOWHERE: Using RAFT located in ${raft_SOURCE_DIR}")
else()
message(VERBOSE "KNOWHERE: Using RAFT located in ${raft_DIR}")
endif()
endfunction()

# Change pinned tag here to test a commit in CI To use a different RAFT locally,
# set the CMake variable CPM_raft_SOURCE=/path/to/local/raft
find_and_configure_raft(VERSION ${RAFT_VERSION}.00 FORK ${RAFT_FORK} PINNED_TAG
${RAFT_PINNED_TAG})
${RAFT_PINNED_TAG} COMPILE_LIBRARY OFF)
8 changes: 4 additions & 4 deletions cmake/utils/fetch_rapids.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# License for the specific language governing permissions and limitations under
# the License.

set(RAPIDS_VERSION "23.02")
set(RAPIDS_VERSION "23.04")

if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/FAISS_RAPIDS.cmake)
if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake)
file(
DOWNLOAD
https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake
${CMAKE_CURRENT_BINARY_DIR}/FAISS_RAPIDS.cmake)
${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake)
endif()
include(${CMAKE_CURRENT_BINARY_DIR}/FAISS_RAPIDS.cmake)
include(${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake)
98 changes: 54 additions & 44 deletions src/index/ivf_raft/ivf_raft.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
#include "thrust/execution_policy.h"
#include "thrust/sequence.h"

#ifdef RAFT_COMPILED
#include <raft/neighbors/specializations.cuh>
#endif

namespace knowhere {

__global__ void
Expand Down Expand Up @@ -303,18 +307,18 @@ class RaftIvfIndexNode : public IndexNode {
auto* data = reinterpret_cast<float const*>(dataset.GetTensor());

auto stream = res_->get_stream();
auto data_gpu = rmm::device_uvector<float>(rows * dim, stream);
RAFT_CUDA_TRY(cudaMemcpyAsync(data_gpu.data(), data, data_gpu.size() * sizeof(float), cudaMemcpyDefault,
stream.value()));
auto data_gpu = raft::make_device_matrix<float, std::int64_t>(*res_, rows, dim);
RAFT_CUDA_TRY(cudaMemcpyAsync(data_gpu.data_handle(), data, data_gpu.size() * sizeof(float),
cudaMemcpyDefault, stream.value()));
if constexpr (std::is_same_v<detail::raft_ivf_flat_index, T>) {
auto build_params = raft::neighbors::ivf_flat::index_params{};
build_params.metric = metric.value();
build_params.n_lists = ivf_raft_cfg.nlist;
build_params.kmeans_n_iters = ivf_raft_cfg.kmeans_n_iters;
build_params.kmeans_trainset_fraction = ivf_raft_cfg.kmeans_trainset_fraction;
build_params.adaptive_centers = ivf_raft_cfg.adaptive_centers;
gpu_index_ = raft::neighbors::ivf_flat::build<float, std::int64_t>(*res_, build_params,
data_gpu.data(), rows, dim);
gpu_index_ =
raft::neighbors::ivf_flat::build<float, std::int64_t>(*res_, build_params, data_gpu.view());
} else if constexpr (std::is_same_v<detail::raft_ivf_pq_index, T>) {
auto build_params = raft::neighbors::ivf_pq::index_params{};
build_params.metric = metric.value();
Expand All @@ -330,8 +334,8 @@ class RaftIvfIndexNode : public IndexNode {
}
build_params.codebook_kind = codebook_kind.value();
build_params.force_random_rotation = ivf_raft_cfg.force_random_rotation;
gpu_index_ = raft::neighbors::ivf_pq::build<float, std::int64_t>(*res_, build_params,
data_gpu.data(), rows, dim);
gpu_index_ =
raft::neighbors::ivf_pq::build<float, std::int64_t>(*res_, build_params, data_gpu.view());
} else {
static_assert(std::is_same_v<detail::raft_ivf_flat_index, T>);
}
Expand Down Expand Up @@ -366,19 +370,25 @@ class RaftIvfIndexNode : public IndexNode {
auto stream = res_->get_stream();
// TODO(wphicks): Clean up transfer with raft
// buffer objects when available
auto data_gpu = rmm::device_uvector<float>(rows * dim, stream);
RAFT_CUDA_TRY(cudaMemcpyAsync(data_gpu.data(), data, data_gpu.size() * sizeof(float), cudaMemcpyDefault,
stream.value()));
auto data_gpu = raft::make_device_matrix<float, std::int64_t>(*res_, rows, dim);
RAFT_CUDA_TRY(cudaMemcpyAsync(data_gpu.data_handle(), data, data_gpu.size() * sizeof(float),
cudaMemcpyDefault, stream.value()));

auto indices = rmm::device_uvector<std::int64_t>(rows, stream);
thrust::sequence(thrust::device, indices.begin(), indices.end(), gpu_index_->size());

if constexpr (std::is_same_v<detail::raft_ivf_flat_index, T>) {
raft::neighbors::ivf_flat::extend<float, std::int64_t>(*res_, *gpu_index_, data_gpu.data(),
indices.data(), rows);
raft::neighbors::ivf_flat::extend<float, std::int64_t>(
*res_, raft::make_const_mdspan(data_gpu.view()),
std::make_optional(
raft::make_device_vector_view<const std::int64_t, std::int64_t>(indices.data(), rows)),
gpu_index_.value());
} else if constexpr (std::is_same_v<detail::raft_ivf_pq_index, T>) {
raft::neighbors::ivf_pq::extend<float, std::int64_t>(*res_, *gpu_index_, data_gpu.data(),
indices.data(), rows);
raft::neighbors::ivf_pq::extend<float, std::int64_t>(
*res_, raft::make_const_mdspan(data_gpu.view()),
std::make_optional(
raft::make_device_matrix_view<const std::int64_t, std::int64_t>(indices.data(), rows, 1)),
gpu_index_.value());
} else {
static_assert(std::is_same_v<detail::raft_ivf_flat_index, T>);
}
Expand Down Expand Up @@ -410,20 +420,20 @@ class RaftIvfIndexNode : public IndexNode {
auto stream = res_->get_stream();
// TODO(wphicks): Clean up transfer with raft
// buffer objects when available
auto data_gpu = rmm::device_uvector<float>(rows * dim, stream);
RAFT_CUDA_TRY(cudaMemcpyAsync(data_gpu.data(), data, data_gpu.size() * sizeof(float), cudaMemcpyDefault,
stream.value()));
auto data_gpu = raft::make_device_matrix<float, std::int64_t>(*res_, rows, dim);
RAFT_CUDA_TRY(cudaMemcpyAsync(data_gpu.data_handle(), data, data_gpu.size() * sizeof(float),
cudaMemcpyDefault, stream.value()));

auto ids_gpu = rmm::device_uvector<std::int64_t>(output_size, stream);
auto dis_gpu = rmm::device_uvector<float>(output_size, stream);
auto ids_gpu = raft::make_device_matrix<std::int64_t, std::int64_t>(*res_, rows, ivf_raft_cfg.k);
auto dis_gpu = raft::make_device_matrix<float, std::int64_t>(*res_, rows, ivf_raft_cfg.k);

if constexpr (std::is_same_v<detail::raft_ivf_flat_index, T>) {
auto search_params = raft::neighbors::ivf_flat::search_params{};
search_params.n_probes = ivf_raft_cfg.nprobe;
if (bitset.empty()) {
raft::neighbors::ivf_flat::search<float, std::int64_t>(*res_, search_params, *gpu_index_,
data_gpu.data(), rows, ivf_raft_cfg.k,
ids_gpu.data(), dis_gpu.data());
raft::make_const_mdspan(data_gpu.view()),
ids_gpu.view(), dis_gpu.view());
} else {
auto k1 = ivf_raft_cfg.k;
auto k2 = k1;
Expand All @@ -434,21 +444,21 @@ class RaftIvfIndexNode : public IndexNode {
k2 |= k2 >> 14;
k2 += 1;
while (k2 <= 1024) {
auto ids_gpu_before = rmm::device_uvector<std::int64_t>(k2 * rows, stream);
auto dis_gpu_before = rmm::device_uvector<float>(k2 * rows, stream);
auto bs_gpu = rmm::device_uvector<uint8_t>(bitset.byte_size(), stream);
RAFT_CUDA_TRY(cudaMemcpyAsync(bs_gpu.data(), bitset.data(), bitset.byte_size(),
auto ids_gpu_before = raft::make_device_matrix<std::int64_t, std::int64_t>(*res_, rows, k2);
auto dis_gpu_before = raft::make_device_matrix<float, std::int64_t>(*res_, rows, k2);
auto bs_gpu = raft::make_device_vector<uint8_t, std::int64_t>(*res_, bitset.byte_size());
RAFT_CUDA_TRY(cudaMemcpyAsync(bs_gpu.data_handle(), bitset.data(), bitset.byte_size(),
cudaMemcpyDefault, stream.value()));

raft::neighbors::ivf_flat::search<float, std::int64_t>(
*res_, search_params, *gpu_index_, data_gpu.data(), rows, k2, ids_gpu_before.data(),
dis_gpu_before.data());
*res_, search_params, *gpu_index_, raft::make_const_mdspan(data_gpu.view()),
ids_gpu_before.view(), dis_gpu_before.view());
filter<<<dim3(1, rows), k2, k2 * sizeof(std::int64_t) + k2 * sizeof(float), stream.value()>>>(
k1, k2, rows, bs_gpu.data(), ids_gpu_before.data(), dis_gpu_before.data(), ids_gpu.data(),
dis_gpu.data());
k1, k2, rows, bs_gpu.data_handle(), ids_gpu_before.data_handle(),
dis_gpu_before.data_handle(), ids_gpu.data_handle(), dis_gpu.data_handle());

std::int64_t is_fine = 0;
RAFT_CUDA_TRY(cudaMemcpyAsync(&is_fine, ids_gpu_before.data(), sizeof(std::int64_t),
RAFT_CUDA_TRY(cudaMemcpyAsync(&is_fine, ids_gpu_before.data_handle(), sizeof(std::int64_t),
cudaMemcpyDefault, stream.value()));
stream.synchronize();
if (is_fine != -1)
Expand Down Expand Up @@ -485,8 +495,8 @@ class RaftIvfIndexNode : public IndexNode {
search_params.preferred_shmem_carveout = search_params.preferred_shmem_carveout;
if (bitset.empty()) {
raft::neighbors::ivf_pq::search<float, std::int64_t>(*res_, search_params, *gpu_index_,
data_gpu.data(), rows, ivf_raft_cfg.k,
ids_gpu.data(), dis_gpu.data());
raft::make_const_mdspan(data_gpu.view()),
ids_gpu.view(), dis_gpu.view());
} else {
auto k1 = ivf_raft_cfg.k;
auto k2 = k1;
Expand All @@ -497,22 +507,22 @@ class RaftIvfIndexNode : public IndexNode {
k2 |= k2 >> 14;
k2 += 1;
while (k2 <= 1024) {
auto ids_gpu_before = rmm::device_uvector<std::int64_t>(k2 * rows, stream);
auto dis_gpu_before = rmm::device_uvector<float>(k2 * rows, stream);
auto bs_gpu = rmm::device_uvector<uint8_t>(bitset.byte_size(), stream);
RAFT_CUDA_TRY(cudaMemcpyAsync(bs_gpu.data(), bitset.data(), bitset.byte_size(),
auto ids_gpu_before = raft::make_device_matrix<std::int64_t, std::int64_t>(*res_, rows, k2);
auto dis_gpu_before = raft::make_device_matrix<float, std::int64_t>(*res_, rows, k2);
auto bs_gpu = raft::make_device_vector<uint8_t, std::int64_t>(*res_, bitset.byte_size());
RAFT_CUDA_TRY(cudaMemcpyAsync(bs_gpu.data_handle(), bitset.data(), bitset.byte_size(),
cudaMemcpyDefault, stream.value()));

raft::neighbors::ivf_pq::search<float, std::int64_t>(
*res_, search_params, *gpu_index_, data_gpu.data(), rows, k2, ids_gpu_before.data(),
dis_gpu_before.data());
*res_, search_params, *gpu_index_, raft::make_const_mdspan(data_gpu.view()),
ids_gpu_before.view(), dis_gpu_before.view());

filter<<<dim3(1, rows), k2, k2 * sizeof(std::int64_t) + k2 * sizeof(float), stream.value()>>>(
k1, k2, rows, bs_gpu.data(), ids_gpu_before.data(), dis_gpu_before.data(), ids_gpu.data(),
dis_gpu.data());
k1, k2, rows, bs_gpu.data_handle(), ids_gpu_before.data_handle(),
dis_gpu_before.data_handle(), ids_gpu.data_handle(), dis_gpu.data_handle());

std::int64_t is_fine = 0;
RAFT_CUDA_TRY(cudaMemcpyAsync(&is_fine, ids_gpu_before.data(), sizeof(std::int64_t),
RAFT_CUDA_TRY(cudaMemcpyAsync(&is_fine, ids_gpu_before.data_handle(), sizeof(std::int64_t),
cudaMemcpyDefault, stream.value()));
stream.synchronize();
if (is_fine != -1)
Expand All @@ -524,10 +534,10 @@ class RaftIvfIndexNode : public IndexNode {
} else {
static_assert(std::is_same_v<detail::raft_ivf_flat_index, T>);
}
RAFT_CUDA_TRY(cudaMemcpyAsync(ids.get(), ids_gpu.data(), ids_gpu.size() * sizeof(std::int64_t),
RAFT_CUDA_TRY(cudaMemcpyAsync(ids.get(), ids_gpu.data_handle(), ids_gpu.size() * sizeof(std::int64_t),
cudaMemcpyDefault, stream.value()));
RAFT_CUDA_TRY(cudaMemcpyAsync(dis.get(), dis_gpu.data_handle(), dis_gpu.size() * sizeof(float),
cudaMemcpyDefault, stream.value()));
RAFT_CUDA_TRY(cudaMemcpyAsync(dis.get(), dis_gpu.data(), dis_gpu.size() * sizeof(float), cudaMemcpyDefault,
stream.value()));
stream.synchronize();
} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "RAFT inner error, " << e.what();
Expand Down

0 comments on commit 743bb3e

Please sign in to comment.