Skip to content

Commit

Permalink
Search JARVIS's quotes
Browse files Browse the repository at this point in the history
  • Loading branch information
vietanhdev committed Aug 2, 2023
1 parent 859067c commit 2fbd4d3
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 34 deletions.
11 changes: 6 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,6 @@ target_link_libraries(
pthread
dl
)
add_executable(
search_doc
examples/search_doc.cpp
)
target_link_libraries(search_doc embeddb)

# Build CustomChar-core
set(TARGET customchar-core)
Expand All @@ -71,6 +66,12 @@ target_include_directories(
)
target_link_libraries(${TARGET} PUBLIC ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT} ${OpenCV_LIBS} whisper subprocess embeddb)

add_executable(
search_doc
examples/search_doc.cpp
)
target_link_libraries(search_doc PUBLIC customchar-core)

# CustomChar - cli
add_executable(
customchar-cli
Expand Down
26 changes: 26 additions & 0 deletions customchar/llm/llm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ LLM::LLM(const std::string& model_path, const std::string& path_session,
lparams_.n_ctx = 2048;
lparams_.seed = 1;
lparams_.f16_kv = true;
lparams_.embedding = true;

// Load model to ram
ctx_llama_ = llama_init_from_file(model_path_.c_str(), lparams_);
Expand Down Expand Up @@ -349,3 +350,28 @@ std::string LLM::get_answer(const std::string& user_input) {

return output_text;
}

std::vector<float> LLM::get_embedding(const std::string& text) {
std::vector<llama_token> embd(text.size());
llama_tokenize(ctx_llama_, text.c_str(), embd.data(), embd.size(), true);
llama_eval(ctx_llama_, embd.data(), embd.size(), n_past_, n_threads_);
const int n_embd = llama_n_embd(ctx_llama_);
const auto embeddings = llama_get_embeddings(ctx_llama_);
std::vector<float> result;
result.reserve(n_embd);
for (int i = 0; i < n_embd; ++i) {
result.push_back(embeddings[i]);
}

// Normalize
float norm = 0;
for (int i = 0; i < n_embd; ++i) {
norm += result[i] * result[i];
}
norm = sqrt(norm);
for (int i = 0; i < n_embd; ++i) {
result[i] /= norm;
}

return result;
}
29 changes: 28 additions & 1 deletion customchar/llm/llm.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <cassert>
#include <cstdio>
#include <fstream>
#include <iostream>
#include <regex>
#include <string>
#include <thread>
Expand Down Expand Up @@ -65,7 +66,7 @@ class LLM {
void init_prompt();

public:
/// @brief Constructor
// @brief Constructor
/// @param model_path Path to the model
/// @param path_session Path to the session
LLM(const std::string& model_path, const std::string& path_session = "",
Expand All @@ -88,6 +89,32 @@ class LLM {

/// @brief Get answer from LLM
std::string get_answer(const std::string& user_input);

/// @brief Get embedding from LLM
std::vector<float> get_embedding(const std::string& text) {
std::vector<llama_token> embd(text.size());
llama_tokenize(ctx_llama_, text.c_str(), embd.data(), embd.size(), true);
llama_eval(ctx_llama_, embd.data(), embd.size(), n_past_, n_threads_);
const int n_embd = llama_n_embd(ctx_llama_);
const auto embeddings = llama_get_embeddings(ctx_llama_);
std::vector<float> result;
result.reserve(n_embd);
for (int i = 0; i < n_embd; ++i) {
result.push_back(embeddings[i]);
}

// Normalize
float norm = 0;
for (int i = 0; i < n_embd; ++i) {
norm += result[i] * result[i];
}
norm = sqrt(norm);
for (int i = 0; i < n_embd; ++i) {
result[i] /= norm;
}

return result;
}
};

} // namespace llm
Expand Down
6 changes: 6 additions & 0 deletions data/jarvis.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Good morning. It's 7 A.M. The weather in Malibu is 72 degrees with scattered clouds. The surf conditions are fair with waist to shoulder highlines, high tide will be at 10:52 a.m. ―J.A.R.V.I.S
We are now running on emergency backup power. ―J.A.R.V.I.S
I don't want this winding up in the wrong hands. Maybe in mine, it could actually do some good.
Wake up. Daddy's home.
Welcome home, sir. Congratulations on the opening ceremonies. They were such a success, as was your senate hearing. And may I say how refreshing it is to finally see you in a video with your clothing on, sir.
Sir, we will lose power before we penetrate that shell.
100 changes: 72 additions & 28 deletions examples/search_doc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,88 @@
#include "customchar/embeddb/document.h"
#include "customchar/embeddb/embed_search.h"
#include "customchar/embeddb/types.h"
#include "customchar/llm/llm.h"

using namespace CC;
using namespace CC::embeddb;

std::vector<std::string> read_lines(const std::string& file_path) {
std::vector<std::string> lines;

// Open the file
std::ifstream file(file_path);

// Check if the file was opened successfully
if (!file.is_open()) {
std::cerr << "Error: Unable to open the file." << std::endl;
return lines;
}

// Read line by line and add non-empty lines to the vector
std::string line;
while (std::getline(file, line)) {
if (!line.empty()) {
lines.push_back(line);
}
}

// Close the file after reading
file.close();

return lines;
}

int main() {
std::string model_path = "../models/llama-2-7b-chat.ggmlv3.q4_0.bin";
llm::LLM embedding_model(model_path);
embedding_model.eval_model();

std::string connection_name = "test_collection";
std::string path = "test_collection";
int dim = 10;
int dim = 4096;
int max_size = 1000;
Collection* collection = new Collection(connection_name, path, dim, max_size);

// Test Insert
std::vector<float> embed;
for (int i = 0; i < 10; i++) {
embed.push_back(i);
}
std::string content = "test content";
std::string meta = "test meta";
u_int32_t id = collection->insert_doc(embed, content, meta, 1, 1, 1);
std::cout << "Inserted document id: " << id << std::endl;

// Test Get Doc From Ids
std::vector<u_int32_t> ids{0, 1};
std::vector<Document> docs = collection->get_docs_by_ids(ids, 2);
std::cout << docs.size() << std::endl;

// Test Search
std::vector<float> query;
for (int i = 0; i < 10; i++) {
query.push_back(i);
}
std::vector<u_int32_t> doc_ids;
std::vector<float> distances;
int top_k = 2;
float threshold = 100;
collection->search(query, top_k, threshold, doc_ids, distances);
for (int i = 0; i < doc_ids.size(); i++) {
std::cout << doc_ids[i] << " " << distances[i] << std::endl;
// Read the document from file
std::string file_path = "../data/jarvis.txt";
std::vector<std::string> lines = read_lines(file_path);

// Insert all documents
for (int i = 0; i < lines.size(); i++) {
std::string content = lines[i];
std::vector<float> embed = embedding_model.get_embedding(lines[i]);
std::string meta = "test meta";
std::cout << "Inserting document " << i << std::endl;
std::cout << "Embedding size " << embed.size() << std::endl;
std::cout << "Content size " << content.size() << std::endl;
std::cout << "Content: " << content << std::endl;
u_int32_t id = collection->insert_doc(embed, content, meta, 1, 1, 1);
}

while (true) {
// Test Search
std::string query_str;
std::cout << "Enter query: ";
std::getline(std::cin, query_str);
std::vector<float> query = embedding_model.get_embedding(query_str);

std::vector<u_int32_t> doc_ids;
std::vector<float> distances;
int top_k = 2;
float threshold = 100000000;
collection->search(query, top_k, threshold, doc_ids, distances);
for (int i = 0; i < 10; ++i) {
std::cout << query[i] << " ";
}
std::cout << std::endl;

std::cout << "Search result: " << std::endl;
for (int i = 0; i < doc_ids.size(); i++) {
std::cout << "Doc id: " << doc_ids[i] << std::endl;
std::cout << "Distance: " << distances[i] << std::endl;
Document doc = collection->get_doc(doc_ids[i]);
std::cout << "Content: " << doc.get_content() << std::endl;
}
}
return 0;
}

0 comments on commit 2fbd4d3

Please sign in to comment.