Skip to content

Commit

Permalink
enhance: knowhere support data view index node
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 committed Jan 14, 2025
1 parent 2a2f5ef commit dbbc39e
Show file tree
Hide file tree
Showing 21 changed files with 2,017 additions and 104 deletions.
3 changes: 2 additions & 1 deletion include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ constexpr const char* INDEX_FAISS_IVFFLAT = "IVF_FLAT";
constexpr const char* INDEX_FAISS_IVFFLAT_CC = "IVF_FLAT_CC";
constexpr const char* INDEX_FAISS_IVFPQ = "IVF_PQ";
constexpr const char* INDEX_FAISS_SCANN = "SCANN";
constexpr const char* INDEX_FAISS_SCANN_DVR = "SCANN_DVR";
constexpr const char* INDEX_FAISS_IVFSQ8 = "IVF_SQ8";
constexpr const char* INDEX_FAISS_IVFSQ_CC = "IVF_SQ_CC";

Expand Down Expand Up @@ -119,7 +120,7 @@ constexpr const char* WITH_RAW_DATA = "with_raw_data";
constexpr const char* ENSURE_TOPK_FULL = "ensure_topk_full";
constexpr const char* CODE_SIZE = "code_size";
constexpr const char* RAW_DATA_STORE_PREFIX = "raw_data_store_prefix";

constexpr const char* SUB_DIM = "sub_dim";
// RAFT Params
constexpr const char* REFINE_RATIO = "refine_ratio";
constexpr const char* CACHE_DATASET_ON_DEVICE = "cache_dataset_on_device";
Expand Down
92 changes: 92 additions & 0 deletions include/knowhere/comp/rw_lock.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#ifndef KNOWHERE_RW_LOCK_H
#define KNOWHERE_RW_LOCK_H
#include <atomic>
#include <condition_variable>
#include <mutex>
#include <queue>
/*
FairRWLock is a fair MultiRead-SingleWrite lock
*/
namespace knowhere {
class FairRWLock {
public:
void
LockRead() {
std::unique_lock lock(mtx_);
auto cur_id = task_id_counter_++;
task_id_q.push(cur_id);
task_cv_.wait(lock, [&] { return task_id_q.front() == cur_id && !have_writer_task_; });
reader_task_counter_++;
task_id_q.pop();
}
void
UnLockRead() {
std::unique_lock lock(mtx_);
if (--reader_task_counter_ == 0) {
task_cv_.notify_all();
}
}
void
LockWrite() {
std::unique_lock lock(mtx_);
auto cur_id = task_id_counter_++;
task_id_q.push(cur_id);
task_cv_.wait(lock,
[&] { return task_id_q.front() == cur_id && !have_writer_task_ && reader_task_counter_ == 0; });
have_writer_task_ = true;
task_id_q.pop();
}
void
UnLockWrite() {
std::unique_lock lock(mtx_);
have_writer_task_ = false;
task_cv_.notify_all();
}

private:
uint64_t task_id_counter_ = 0;
uint64_t reader_task_counter_ = 0;
bool have_writer_task_ = false;
std::mutex mtx_;
std::condition_variable task_cv_;
std::queue<uint64_t> task_id_q;
};

class FairReadLockGuard {
public:
explicit FairReadLockGuard(FairRWLock& lock) : lock_(lock) {
lock_.LockRead();
}

~FairReadLockGuard() {
lock_.UnLockRead();
}

private:
FairRWLock& lock_;
};

class FairWriteLockGuard {
public:
explicit FairWriteLockGuard(FairRWLock& lock) : lock_(lock) {
lock_.LockWrite();
}

~FairWriteLockGuard() {
lock_.UnLockWrite();
}

private:
FairRWLock& lock_;
};
} // namespace knowhere
#endif
26 changes: 14 additions & 12 deletions include/knowhere/index/index_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,8 @@ class IndexIterator : public IndexNode::iterator {
: refine_ratio_(refine_ratio),
refine_(refine_ratio != 0.0f),
retain_iterator_order_(retain_iterator_order),
use_knowhere_search_pool_(use_knowhere_search_pool),
sign_(larger_is_closer ? -1 : 1) {
sign_(larger_is_closer ? -1 : 1),
use_knowhere_search_pool_(use_knowhere_search_pool) {
}

std::pair<int64_t, float>
Expand Down Expand Up @@ -524,8 +524,16 @@ class IndexIterator : public IndexNode::iterator {
}

protected:
inline size_t
min_refine_size() const {
// TODO: maybe make this configurable
return std::max((size_t)20, (size_t)(res_.size() * refine_ratio_));
}

virtual void
next_batch(std::function<void(const std::vector<DistId>&)> batch_handler) = 0;
next_batch(std::function<void(const std::vector<DistId>&)> batch_handler) {
throw std::runtime_error("next_batch not implemented");
}
// will be called only if refine_ratio_ is not 0.
virtual float
raw_distance(int64_t) {
Expand All @@ -537,18 +545,15 @@ class IndexIterator : public IndexNode::iterator {

const float refine_ratio_;
const bool refine_;
bool initialized_ = false;
bool retain_iterator_order_ = false;
const int64_t sign_;

std::priority_queue<DistId, std::vector<DistId>, std::greater<DistId>> res_;
// unused if refine_ is false
std::priority_queue<DistId, std::vector<DistId>, std::greater<DistId>> refined_res_;

private:
inline size_t
min_refine_size() const {
// TODO: maybe make this configurable
return std::max((size_t)20, (size_t)(res_.size() * refine_ratio_));
}

void
UpdateNext() {
auto batch_handler = [this](const std::vector<DistId>& batch) {
Expand All @@ -569,10 +574,7 @@ class IndexIterator : public IndexNode::iterator {
next_batch(batch_handler);
}

bool initialized_ = false;
bool retain_iterator_order_ = false;
bool use_knowhere_search_pool_ = true;
const int64_t sign_;
};

// An iterator implementation that accepts a function to get distances and ids list and returns them in order.
Expand Down
10 changes: 8 additions & 2 deletions include/knowhere/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <atomic>
#include <cassert>
#include <functional>
#include <iostream>
#include <memory>

Expand Down Expand Up @@ -73,10 +74,15 @@ class Object {
mutable std::atomic_uint32_t ref_counts_ = 1;
};

using ViewDataOp = std::function<const void*(size_t)>;

template <typename T>
class Pack : public Object {
static_assert(std::is_same_v<T, std::shared_ptr<knowhere::FileManager>>,
"IndexPack only support std::shared_ptr<knowhere::FileManager> by far.");
// Currently, DataViewIndex and DiskIndex are mutually exclusive, they can share one object.
// todo: pack can hold more object
static_assert(std::is_same_v<T, std::shared_ptr<knowhere::FileManager>> || std::is_same_v<T, knowhere::ViewDataOp>,
"IndexPack only support std::shared_ptr<knowhere::FileManager> or ViewDataOp == std::function<const "
"void*(size_t)> by far.");

public:
Pack() {
Expand Down
31 changes: 31 additions & 0 deletions include/knowhere/operands.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,5 +196,36 @@ template <>
struct MockData<knowhere::int8> {
using type = knowhere::fp32;
};

//
enum class DataFormatEnum { fp32, fp16, bf16, int8, bin1 };

template <typename T>
struct DataType2EnumHelper {};

template <>
struct DataType2EnumHelper<knowhere::fp32> {
static constexpr DataFormatEnum value = DataFormatEnum::fp32;
};
template <>
struct DataType2EnumHelper<knowhere::fp16> {
static constexpr DataFormatEnum value = DataFormatEnum::fp16;
};
template <>
struct DataType2EnumHelper<knowhere::bf16> {
static constexpr DataFormatEnum value = DataFormatEnum::bf16;
};
template <>
struct DataType2EnumHelper<knowhere::int8> {
static constexpr DataFormatEnum value = DataFormatEnum::int8;
};
template <>
struct DataType2EnumHelper<knowhere::bin1> {
static constexpr DataFormatEnum value = DataFormatEnum::bin1;
};

template <typename T>
static constexpr DataFormatEnum datatype_v = DataType2EnumHelper<T>::value;

} // namespace knowhere
#endif /* OPERANDS_H */
44 changes: 32 additions & 12 deletions include/knowhere/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ IsFlatIndex(const knowhere::IndexType& index_type) {
return std::find(flat_index_list.begin(), flat_index_list.end(), index_type) != flat_index_list.end();
}

template <typename DataType>
float
GetL2Norm(const DataType* x, int32_t d);

template <typename DataType>
std::vector<float>
GetL2Norms(const DataType* x, int32_t d, int32_t n);

template <typename DataType>
extern float
NormalizeVec(DataType* x, int32_t d);
Expand All @@ -52,6 +60,10 @@ template <typename DataType>
extern void
NormalizeDataset(const DataSetPtr dataset);

template <typename DataType>
extern std::tuple<DataSetPtr, std::vector<float>>
CopyAndNormalizeDataset(const DataSetPtr dataset);

constexpr inline uint64_t seed = 0xc70f6907UL;

inline uint64_t
Expand Down Expand Up @@ -112,8 +124,10 @@ GetKey(const std::string& name) {
template <typename InType, typename OutType>
inline DataSetPtr
data_type_conversion(const DataSet& src, const std::optional<int64_t> start = std::nullopt,
const std::optional<int64_t> count = std::nullopt) {
auto dim = src.GetDim();
const std::optional<int64_t> count = std::nullopt,
const std::optional<int64_t> count_dim = std::nullopt) {
auto in_dim = src.GetDim();
auto out_dim = count_dim.value_or(in_dim);
auto rows = src.GetRows();

// check the acceptable range
Expand All @@ -128,17 +142,21 @@ data_type_conversion(const DataSet& src, const std::optional<int64_t> start = st
}

// map
auto* des_data = new OutType[dim * count_rows];
auto* des_data = new OutType[out_dim * count_rows];
std::memset(des_data, 0, sizeof(OutType) * out_dim * count_rows);
auto* src_data = (const InType*)src.GetTensor();
for (auto i = 0; i < dim * count_rows; i++) {
des_data[i] = (OutType)src_data[i + start_row * dim];
for (auto i = 0; i < count_rows; i++) {
for (auto d = 0; d < in_dim; d++) {
des_data[i * out_dim + d] = (OutType)src_data[(start_row + i) * in_dim + d];
}
}

auto des = std::make_shared<DataSet>();
des->SetRows(count_rows);
des->SetDim(dim);
des->SetDim(out_dim);
des->SetTensor(des_data);
des->SetIsOwner(true);
des->SetTensorBeginId(src.GetTensorBeginId() + start_row);
return des;
}

Expand All @@ -152,28 +170,30 @@ data_type_conversion(const DataSet& src, const std::optional<int64_t> start = st
template <typename DataType>
inline DataSetPtr
ConvertFromDataTypeIfNeeded(const DataSetPtr& ds, const std::optional<int64_t> start = std::nullopt,
const std::optional<int64_t> count = std::nullopt) {
const std::optional<int64_t> count = std::nullopt,
const std::optional<int64_t> count_dim = std::nullopt) {
if constexpr (std::is_same_v<DataType, typename MockData<DataType>::type>) {
if (!start.has_value() && !count.has_value()) {
if (!start.has_value() && !count.has_value() && (!count_dim.has_value() || ds->GetDim() == count_dim.value())) {
return ds;
}
}

return data_type_conversion<DataType, typename MockData<DataType>::type>(*ds, start, count);
return data_type_conversion<DataType, typename MockData<DataType>::type>(*ds, start, count, count_dim);
}

// Convert DataSet from float to DataType
template <typename DataType>
inline DataSetPtr
ConvertToDataTypeIfNeeded(const DataSetPtr& ds, const std::optional<int64_t> start = std::nullopt,
const std::optional<int64_t> count = std::nullopt) {
const std::optional<int64_t> count = std::nullopt,
const std::optional<int64_t> count_dim = std::nullopt) {
if constexpr (std::is_same_v<DataType, typename MockData<DataType>::type>) {
if (!start.has_value() && !count.has_value()) {
if (!start.has_value() && !count.has_value() && (!count_dim.has_value() || ds->GetDim() == count_dim.value())) {
return ds;
}
}

return data_type_conversion<typename MockData<DataType>::type, DataType>(*ds, start, count);
return data_type_conversion<typename MockData<DataType>::type, DataType>(*ds, start, count, count_dim);
}

template <typename T>
Expand Down
Loading

0 comments on commit dbbc39e

Please sign in to comment.