diff --git a/src/models/llm/rag/wiki_index_builder.cpp b/src/models/llm/rag/wiki_index_builder.cpp index 0da3294..da95399 100644 --- a/src/models/llm/rag/wiki_index_builder.cpp +++ b/src/models/llm/rag/wiki_index_builder.cpp @@ -121,28 +121,28 @@ class WikiIndexBuilder::Impl { /*** * - * @param index_f_path + * @param index_file_dir * @return */ - StatusCode load_index(const std::string& index_f_path); + StatusCode load_index(const std::string& index_file_dir); /*** * - * @param corpus_segment_path + * @param corpus_segment_dir * @return */ - StatusCode load_corpus_segment(const std::string& corpus_segment_path); + StatusCode load_corpus_segment(const std::string& corpus_segment_dir); /*** * * @param input_prompt - * @param referenced_corpus + * @param out_referenced_corpus * @param top_k * @param apply_chat_template * @return */ StatusCode search( - const std::string& input_prompt, std::string& referenced_corpus, int top_k=1, bool apply_chat_template=true); + const std::string& input_prompt, std::string& out_referenced_corpus, int top_k=1, bool apply_chat_template=true); /*** * @@ -455,55 +455,97 @@ StatusCode WikiIndexBuilder::Impl::build_index(const std::string &source_wiki_co /*** * - * @param index_f_path + * @param index_file_dir * @return */ -StatusCode WikiIndexBuilder::Impl::load_index(const std::string &index_f_path) { - if (!FilePathUtil::is_file_exist(index_f_path)) { - LOG(ERROR) << fmt::format("index file: {} not exist", index_f_path); +StatusCode WikiIndexBuilder::Impl::load_index(const std::string &index_file_dir) { + // gather all *.index file + std::vector index_file_paths; + cv::glob(fmt::format("{}/*.index", index_file_dir), index_file_paths, true); + if (index_file_paths.empty()) { + LOG(ERROR) << fmt::format("no *.index index file found in {}", index_file_dir); return StatusCode::RAG_LOAD_INDEX_FAILED; } - auto* index = dynamic_cast(faiss::read_index(index_f_path.c_str(), 0)); - if (index == nullptr) { - LOG(ERROR) << fmt::format("load index file: {} failed", index_f_path); - return StatusCode::RAG_LOAD_INDEX_FAILED; + // sort index file + std::sort(index_file_paths.begin(), index_file_paths.end(), [](const std::string& path_a, const std::string& path_b) -> bool { + auto name_a = FilePathUtil::get_file_name(path_a); + auto prefix_id_a = name_a.substr(0, name_a.find(".index")); + prefix_id_a = prefix_id_a.substr(prefix_id_a.rfind('_') + 1); + auto name_b = FilePathUtil::get_file_name(path_b); + auto prefix_id_b = name_b.substr(0, name_b.find(".index")); + prefix_id_b = prefix_id_b.substr(prefix_id_b.rfind('_') + 1); + return std::stoi(prefix_id_a) < std::stoi(prefix_id_b); + }); + + // load and merge index file + std::vector merged_vectors; + int vec_dims = 0; + for (auto& f_path : index_file_paths) { + auto* index = dynamic_cast(faiss::read_index(f_path.c_str(), 0)); + if (index == nullptr) { + LOG(ERROR) << fmt::format("read index file failed: {}", f_path); + return StatusCode::RAG_LOAD_INDEX_FAILED; + } + const float* raw_data = index->get_xb(); + vec_dims = index->d; + for (size_t i = 0; i < index->ntotal * index->d; ++i) { + merged_vectors.push_back(raw_data[i]); + } } - if (_m_index == nullptr) { - _m_index = std::unique_ptr(index); - } else { - _m_index.reset(index); + + if (_m_index != nullptr) { + _m_index.reset(); } + _m_index = std::make_unique(vec_dims); + auto merged_index_ntotal = static_cast(merged_vectors.size() / vec_dims); + _m_index->add(merged_index_ntotal, merged_vectors.data()); return StatusCode::OK; } /*** * - * @param corpus_segment_path + * @param corpus_segment_dir * @return */ -StatusCode WikiIndexBuilder::Impl::load_corpus_segment(const std::string &corpus_segment_path) { - if (!FilePathUtil::is_file_exist(corpus_segment_path)) { - LOG(ERROR) << fmt::format("segment corpus file: {} not exist", corpus_segment_path); - return StatusCode::RAG_LOAD_SEGMENT_CORPUS_FAILED; +StatusCode WikiIndexBuilder::Impl::load_corpus_segment(const std::string &corpus_segment_dir) { + // gather all *.jsonl file + std::vector corpus_file_paths; + cv::glob(fmt::format("{}/*.jsonl", corpus_segment_dir), corpus_file_paths, true); + if (corpus_file_paths.empty()) { + LOG(ERROR) << fmt::format("no *.jsonl corpus file found in {}", corpus_segment_dir); + return StatusCode::RAG_LOAD_INDEX_FAILED; } - std::ifstream f_in(corpus_segment_path, std::ios::in); - if (!f_in.is_open() || f_in.bad()) { - LOG(ERROR) << fmt::format("read segment corpus file: {} failed", corpus_segment_path); - return StatusCode::RAG_LOAD_SEGMENT_CORPUS_FAILED; - } + // sort index file + std::sort(corpus_file_paths.begin(), corpus_file_paths.end(), [](const std::string& path_a, const std::string& path_b) -> bool { + auto name_a = FilePathUtil::get_file_name(path_a); + auto prefix_id_a = name_a.substr(0, name_a.find(".jsonl")); + prefix_id_a = prefix_id_a.substr(prefix_id_a.rfind('_') + 1); + auto name_b = FilePathUtil::get_file_name(path_b); + auto prefix_id_b = name_b.substr(0, name_b.find(".jsonl")); + prefix_id_b = prefix_id_b.substr(prefix_id_b.rfind('_') + 1); + return std::stoi(prefix_id_a) < std::stoi(prefix_id_b); + }); - std::string line; - while (std::getline(f_in, line)) { - rapidjson::Document doc; - doc.Parse(line.c_str()); - wiki_corpus_segment segment; - segment.id = std::to_string(doc["id"].GetInt64()); - segment.title = doc["title"].GetString(); - segment.text = doc["text"].GetString(); - _m_segment_wiki_corpus.push_back(segment); + for (auto& corpus_segment_path : corpus_file_paths) { + std::ifstream f_in(corpus_segment_path, std::ios::in); + if (!f_in.is_open() || f_in.bad()) { + LOG(ERROR) << fmt::format("read segment corpus file: {} failed", corpus_segment_path); + return StatusCode::RAG_LOAD_SEGMENT_CORPUS_FAILED; + } + + std::string line; + while (std::getline(f_in, line)) { + rapidjson::Document doc; + doc.Parse(line.c_str()); + wiki_corpus_segment segment; + segment.id = std::to_string(doc["id"].GetInt64()); + segment.title = doc["title"].GetString(); + segment.text = doc["text"].GetString(); + _m_segment_wiki_corpus.push_back(segment); + } } return StatusCode::OK; @@ -512,13 +554,13 @@ StatusCode WikiIndexBuilder::Impl::load_corpus_segment(const std::string &corpus /*** * * @param input_prompt - * @param referenced_corpus + * @param out_referenced_corpus * @param top_k * @param apply_chat_template * @return */ StatusCode WikiIndexBuilder::Impl::search( - const std::string &input_prompt, std::string &referenced_corpus, int top_k, bool apply_chat_template) { + const std::string &input_prompt, std::string &out_referenced_corpus, int top_k, bool apply_chat_template) { // check validation of index and corpus if (_m_segment_wiki_corpus.empty()) { LOG(ERROR) << "empty segmented corpus load them first"; @@ -549,12 +591,12 @@ StatusCode WikiIndexBuilder::Impl::search( auto title = _m_segment_wiki_corpus[ids[i]].title; auto text = _m_segment_wiki_corpus[ids[i]].text; auto fmt_str = fmt::format("Doc {} (Title: {}) {}\n", i + 1, title, text); - referenced_corpus += fmt_str; + out_referenced_corpus += fmt_str; } } else { for (auto& i : ids) { - referenced_corpus += fmt::format("title: {}\n text: {}\n", + out_referenced_corpus += fmt::format("title: {}\n text: {}\n", _m_segment_wiki_corpus[i].title, _m_segment_wiki_corpus[i].text); } } @@ -1012,33 +1054,33 @@ StatusCode WikiIndexBuilder::build_index(const std::string& source_wiki_corpus_d /*** * - * @param index_f_path + * @param index_file_dir * @return */ -StatusCode WikiIndexBuilder::load_index(const std::string &index_f_path) { - return _m_pimpl-> load_index(index_f_path); +StatusCode WikiIndexBuilder::load_index(const std::string &index_file_dir) { + return _m_pimpl-> load_index(index_file_dir); } /*** * - * @param corpus_segment_path + * @param corpus_segment_dir * @return */ -StatusCode WikiIndexBuilder::load_corpus_segment(const std::string &corpus_segment_path) { - return _m_pimpl-> load_corpus_segment(corpus_segment_path); +StatusCode WikiIndexBuilder::load_corpus_segment(const std::string &corpus_segment_dir) { + return _m_pimpl-> load_corpus_segment(corpus_segment_dir); } /*** * * @param input_prompt - * @param referenced_corpus + * @param out_referenced_corpus * @param top_k * @param apply_chat_template * @return */ StatusCode WikiIndexBuilder::search( - const std::string &input_prompt, std::string &referenced_corpus, int top_k, bool apply_chat_template) { - return _m_pimpl->search(input_prompt, referenced_corpus, top_k, apply_chat_template); + const std::string &input_prompt, std::string &out_referenced_corpus, int top_k, bool apply_chat_template) { + return _m_pimpl->search(input_prompt, out_referenced_corpus, top_k, apply_chat_template); } } diff --git a/src/models/llm/rag/wiki_index_builder.h b/src/models/llm/rag/wiki_index_builder.h index fa1f2a7..dfb1fff 100644 --- a/src/models/llm/rag/wiki_index_builder.h +++ b/src/models/llm/rag/wiki_index_builder.h @@ -69,30 +69,29 @@ class WikiIndexBuilder { /*** * - * @param index_f_path + * @param index_file_dir * @param index * @return */ - jinq::common::StatusCode load_index(const std::string& index_f_path); + jinq::common::StatusCode load_index(const std::string& index_file_dir); /*** * - * @param corpus_segment_path - * @param segments + * @param corpus_segment_dir * @return */ - jinq::common::StatusCode load_corpus_segment(const std::string& corpus_segment_path); + jinq::common::StatusCode load_corpus_segment(const std::string& corpus_segment_dir); /*** * * @param input_prompt - * @param referenced_corpus + * @param out_referenced_corpus * @param top_k * @param aapply_chat_template * @return */ jinq::common::StatusCode search( - const std::string& input_prompt, std::string& referenced_corpus, int top_k=1, bool apply_chat_template=true); + const std::string& input_prompt, std::string& out_referenced_corpus, int top_k=1, bool apply_chat_template=true); private: class Impl;