Skip to content

Commit

Permalink
update llama3 chat server
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed Dec 23, 2024
1 parent dd4f95c commit 46088e8
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 57 deletions.
142 changes: 92 additions & 50 deletions src/models/llm/rag/wiki_index_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,28 +121,28 @@ class WikiIndexBuilder::Impl {

/***
*
* @param index_f_path
* @param index_file_dir
* @return
*/
StatusCode load_index(const std::string& index_f_path);
StatusCode load_index(const std::string& index_file_dir);

/***
*
* @param corpus_segment_path
* @param corpus_segment_dir
* @return
*/
StatusCode load_corpus_segment(const std::string& corpus_segment_path);
StatusCode load_corpus_segment(const std::string& corpus_segment_dir);

/***
*
* @param input_prompt
* @param referenced_corpus
* @param out_referenced_corpus
* @param top_k
* @param apply_chat_template
* @return
*/
StatusCode search(
const std::string& input_prompt, std::string& referenced_corpus, int top_k=1, bool apply_chat_template=true);
const std::string& input_prompt, std::string& out_referenced_corpus, int top_k=1, bool apply_chat_template=true);

/***
*
Expand Down Expand Up @@ -455,55 +455,97 @@ StatusCode WikiIndexBuilder::Impl::build_index(const std::string &source_wiki_co

/***
*
* @param index_f_path
* @param index_file_dir
* @return
*/
StatusCode WikiIndexBuilder::Impl::load_index(const std::string &index_f_path) {
if (!FilePathUtil::is_file_exist(index_f_path)) {
LOG(ERROR) << fmt::format("index file: {} not exist", index_f_path);
StatusCode WikiIndexBuilder::Impl::load_index(const std::string &index_file_dir) {
// gather all *.index file
std::vector<std::string> index_file_paths;
cv::glob(fmt::format("{}/*.index", index_file_dir), index_file_paths, true);
if (index_file_paths.empty()) {
LOG(ERROR) << fmt::format("no *.index index file found in {}", index_file_dir);
return StatusCode::RAG_LOAD_INDEX_FAILED;
}

auto* index = dynamic_cast<faiss::IndexFlatL2*>(faiss::read_index(index_f_path.c_str(), 0));
if (index == nullptr) {
LOG(ERROR) << fmt::format("load index file: {} failed", index_f_path);
return StatusCode::RAG_LOAD_INDEX_FAILED;
// sort index file
std::sort(index_file_paths.begin(), index_file_paths.end(), [](const std::string& path_a, const std::string& path_b) -> bool {
auto name_a = FilePathUtil::get_file_name(path_a);
auto prefix_id_a = name_a.substr(0, name_a.find(".index"));
prefix_id_a = prefix_id_a.substr(prefix_id_a.rfind('_') + 1);
auto name_b = FilePathUtil::get_file_name(path_b);
auto prefix_id_b = name_b.substr(0, name_b.find(".index"));
prefix_id_b = prefix_id_b.substr(prefix_id_b.rfind('_') + 1);
return std::stoi(prefix_id_a) < std::stoi(prefix_id_b);
});

// load and merge index file
std::vector<float> merged_vectors;
int vec_dims = 0;
for (auto& f_path : index_file_paths) {
auto* index = dynamic_cast<faiss::IndexFlatL2*>(faiss::read_index(f_path.c_str(), 0));
if (index == nullptr) {
LOG(ERROR) << fmt::format("read index file failed: {}", f_path);
return StatusCode::RAG_LOAD_INDEX_FAILED;
}
const float* raw_data = index->get_xb();
vec_dims = index->d;
for (size_t i = 0; i < index->ntotal * index->d; ++i) {
merged_vectors.push_back(raw_data[i]);
}
}
if (_m_index == nullptr) {
_m_index = std::unique_ptr<faiss::IndexFlatL2>(index);
} else {
_m_index.reset(index);

if (_m_index != nullptr) {
_m_index.reset();
}
_m_index = std::make_unique<faiss::IndexFlatL2>(vec_dims);
auto merged_index_ntotal = static_cast<long>(merged_vectors.size() / vec_dims);
_m_index->add(merged_index_ntotal, merged_vectors.data());

return StatusCode::OK;
}

/***
*
* @param corpus_segment_path
* @param corpus_segment_dir
* @return
*/
StatusCode WikiIndexBuilder::Impl::load_corpus_segment(const std::string &corpus_segment_path) {
if (!FilePathUtil::is_file_exist(corpus_segment_path)) {
LOG(ERROR) << fmt::format("segment corpus file: {} not exist", corpus_segment_path);
return StatusCode::RAG_LOAD_SEGMENT_CORPUS_FAILED;
StatusCode WikiIndexBuilder::Impl::load_corpus_segment(const std::string &corpus_segment_dir) {
// gather all *.jsonl file
std::vector<std::string> corpus_file_paths;
cv::glob(fmt::format("{}/*.jsonl", corpus_segment_dir), corpus_file_paths, true);
if (corpus_file_paths.empty()) {
LOG(ERROR) << fmt::format("no *.jsonl corpus file found in {}", corpus_segment_dir);
return StatusCode::RAG_LOAD_INDEX_FAILED;
}

std::ifstream f_in(corpus_segment_path, std::ios::in);
if (!f_in.is_open() || f_in.bad()) {
LOG(ERROR) << fmt::format("read segment corpus file: {} failed", corpus_segment_path);
return StatusCode::RAG_LOAD_SEGMENT_CORPUS_FAILED;
}
// sort index file
std::sort(corpus_file_paths.begin(), corpus_file_paths.end(), [](const std::string& path_a, const std::string& path_b) -> bool {
auto name_a = FilePathUtil::get_file_name(path_a);
auto prefix_id_a = name_a.substr(0, name_a.find(".jsonl"));
prefix_id_a = prefix_id_a.substr(prefix_id_a.rfind('_') + 1);
auto name_b = FilePathUtil::get_file_name(path_b);
auto prefix_id_b = name_b.substr(0, name_b.find(".jsonl"));
prefix_id_b = prefix_id_b.substr(prefix_id_b.rfind('_') + 1);
return std::stoi(prefix_id_a) < std::stoi(prefix_id_b);
});

std::string line;
while (std::getline(f_in, line)) {
rapidjson::Document doc;
doc.Parse(line.c_str());
wiki_corpus_segment segment;
segment.id = std::to_string(doc["id"].GetInt64());
segment.title = doc["title"].GetString();
segment.text = doc["text"].GetString();
_m_segment_wiki_corpus.push_back(segment);
for (auto& corpus_segment_path : corpus_file_paths) {
std::ifstream f_in(corpus_segment_path, std::ios::in);
if (!f_in.is_open() || f_in.bad()) {
LOG(ERROR) << fmt::format("read segment corpus file: {} failed", corpus_segment_path);
return StatusCode::RAG_LOAD_SEGMENT_CORPUS_FAILED;
}

std::string line;
while (std::getline(f_in, line)) {
rapidjson::Document doc;
doc.Parse(line.c_str());
wiki_corpus_segment segment;
segment.id = std::to_string(doc["id"].GetInt64());
segment.title = doc["title"].GetString();
segment.text = doc["text"].GetString();
_m_segment_wiki_corpus.push_back(segment);
}
}

return StatusCode::OK;
Expand All @@ -512,13 +554,13 @@ StatusCode WikiIndexBuilder::Impl::load_corpus_segment(const std::string &corpus
/***
*
* @param input_prompt
* @param referenced_corpus
* @param out_referenced_corpus
* @param top_k
* @param apply_chat_template
* @return
*/
StatusCode WikiIndexBuilder::Impl::search(
const std::string &input_prompt, std::string &referenced_corpus, int top_k, bool apply_chat_template) {
const std::string &input_prompt, std::string &out_referenced_corpus, int top_k, bool apply_chat_template) {
// check validation of index and corpus
if (_m_segment_wiki_corpus.empty()) {
LOG(ERROR) << "empty segmented corpus load them first";
Expand Down Expand Up @@ -549,12 +591,12 @@ StatusCode WikiIndexBuilder::Impl::search(
auto title = _m_segment_wiki_corpus[ids[i]].title;
auto text = _m_segment_wiki_corpus[ids[i]].text;
auto fmt_str = fmt::format("Doc {} (Title: {}) {}\n", i + 1, title, text);
referenced_corpus += fmt_str;
out_referenced_corpus += fmt_str;
}

} else {
for (auto& i : ids) {
referenced_corpus += fmt::format("title: {}\n text: {}\n",
out_referenced_corpus += fmt::format("title: {}\n text: {}\n",
_m_segment_wiki_corpus[i].title, _m_segment_wiki_corpus[i].text);
}
}
Expand Down Expand Up @@ -1012,33 +1054,33 @@ StatusCode WikiIndexBuilder::build_index(const std::string& source_wiki_corpus_d

/***
*
* @param index_f_path
* @param index_file_dir
* @return
*/
StatusCode WikiIndexBuilder::load_index(const std::string &index_f_path) {
return _m_pimpl-> load_index(index_f_path);
StatusCode WikiIndexBuilder::load_index(const std::string &index_file_dir) {
return _m_pimpl-> load_index(index_file_dir);
}

/***
*
* @param corpus_segment_path
* @param corpus_segment_dir
* @return
*/
StatusCode WikiIndexBuilder::load_corpus_segment(const std::string &corpus_segment_path) {
return _m_pimpl-> load_corpus_segment(corpus_segment_path);
StatusCode WikiIndexBuilder::load_corpus_segment(const std::string &corpus_segment_dir) {
return _m_pimpl-> load_corpus_segment(corpus_segment_dir);
}

/***
*
* @param input_prompt
* @param referenced_corpus
* @param out_referenced_corpus
* @param top_k
* @param apply_chat_template
* @return
*/
StatusCode WikiIndexBuilder::search(
const std::string &input_prompt, std::string &referenced_corpus, int top_k, bool apply_chat_template) {
return _m_pimpl->search(input_prompt, referenced_corpus, top_k, apply_chat_template);
const std::string &input_prompt, std::string &out_referenced_corpus, int top_k, bool apply_chat_template) {
return _m_pimpl->search(input_prompt, out_referenced_corpus, top_k, apply_chat_template);
}

}
Expand Down
13 changes: 6 additions & 7 deletions src/models/llm/rag/wiki_index_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,29 @@ class WikiIndexBuilder {

/***
*
* @param index_f_path
* @param index_file_dir
* @param index
* @return
*/
jinq::common::StatusCode load_index(const std::string& index_f_path);
jinq::common::StatusCode load_index(const std::string& index_file_dir);

/***
*
* @param corpus_segment_path
* @param segments
* @param corpus_segment_dir
* @return
*/
jinq::common::StatusCode load_corpus_segment(const std::string& corpus_segment_path);
jinq::common::StatusCode load_corpus_segment(const std::string& corpus_segment_dir);

/***
*
* @param input_prompt
* @param referenced_corpus
* @param out_referenced_corpus
* @param top_k
* @param aapply_chat_template
* @return
*/
jinq::common::StatusCode search(
const std::string& input_prompt, std::string& referenced_corpus, int top_k=1, bool apply_chat_template=true);
const std::string& input_prompt, std::string& out_referenced_corpus, int top_k=1, bool apply_chat_template=true);

private:
class Impl;
Expand Down

0 comments on commit 46088e8

Please sign in to comment.