Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement find_candidate_pred_tokens #914

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 54 additions & 19 deletions crates/llama-cpp-bindings/src/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,58 @@ constexpr size_t N_CTX = 4096; // # max kv history.
struct Request {
Request(size_t request_id, std::vector<llama_token> input_token_ids) :
id(request_id),
tokens(input_token_ids.begin(), input_token_ids.end()) {
}
pending_tokens(input_token_ids.begin(), input_token_ids.end()) {
}

uint32_t id = -1;
llama_seq_id seq_id = -1;

std::vector<llama_token> tokens;
std::vector<llama_token> pending_tokens;
size_t i_batch = -1;
size_t n_past = 0;

int32_t multibyte_pending = 0;
std::string generated_text;
};

std::vector<llama_token> tokens;

void step(llama_token id) {
++n_past;
tokens.insert(tokens.end(), pending_tokens.begin(), pending_tokens.end());

pending_tokens.clear();
pending_tokens.push_back(id);
}

std::vector<llama_token> find_candidate_pred_tokens(size_t max_ngram_size = 3, size_t n_pred_tokens = 10) {
for (size_t ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) {
if (tokens.size() < ngram_size) continue;
std::vector<llama_token> ngram(tokens.begin() + tokens.size() - ngram_size, tokens.end());

const int matched = find_ngram(ngram, n_pred_tokens);
if (matched < 0) continue;

const int offset = matched + ngram_size;
return std::vector<llama_token>(tokens.begin() + offset, tokens.begin() + offset + n_pred_tokens);
}

return std::vector<llama_token>();
}

private:
int find_ngram(const std::vector<llama_token> & ngram, size_t n_pred_tokens) {
const int max = static_cast<int>(tokens.size()) - ngram.size() - n_pred_tokens;
for (int i = 0; i < max; ++i) {
const auto mismatch = std::mismatch(tokens.begin() + i, tokens.begin() + i + ngram.size(), ngram.begin());
if (mismatch.second == ngram.end()) {
// Matched
return i;
}
}

return -1;
}
};

std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
std::vector<char> result(8, 0);
Expand All @@ -54,7 +92,7 @@ std::vector<llama_token> llama_tokenize(
const rust::Str & text,
bool add_bos,
bool special) {
// upper limit for the number of tokens
// upper limit for the number of pending_tokens
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
Expand Down Expand Up @@ -113,12 +151,12 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
}

virtual void add_request(uint32_t request_id, rust::Str text, size_t max_input_length) override {
auto tokens = llama_tokenize(llama_get_model(ctx_.get()), text, false, true);
if (tokens.size() > max_input_length) {
int start = tokens.size() - max_input_length;
tokens = std::vector<llama_token>(tokens.begin() + start, tokens.end());
auto pending_tokens = llama_tokenize(llama_get_model(ctx_.get()), text, false, true);
if (pending_tokens.size() > max_input_length) {
int start = pending_tokens.size() - max_input_length;
pending_tokens = std::vector<llama_token>(pending_tokens.begin() + start, pending_tokens.end());
}
pending_requests_.push_back(Request(request_id, tokens));
pending_requests_.push_back(Request(request_id, pending_tokens));
}

void stop_request(uint32_t request_id) override {
Expand Down Expand Up @@ -168,17 +206,17 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
// Clear the batch.
batch_.n_tokens = 0;

// Insert tokens from ongoing requests to batch.
// Insert pending_tokens from ongoing requests to batch.
for (auto& request : requests_) {
const size_t n_tokens = batch_.n_tokens;
for (size_t i = 0; i < request.tokens.size(); ++i) {
batch_.token[n_tokens + i] = request.tokens[i];
for (size_t i = 0; i < request.pending_tokens.size(); ++i) {
batch_.token[n_tokens + i] = request.pending_tokens[i];
batch_.pos[n_tokens + i] = request.n_past + i;
batch_.n_seq_id[n_tokens + i] = 1;
batch_.seq_id[n_tokens + i][0] = request.id;
batch_.logits[n_tokens + i] = false;
}
batch_.n_tokens += request.tokens.size();
batch_.n_tokens += request.pending_tokens.size();

batch_.logits[batch_.n_tokens - 1] = true;
request.i_batch = batch_.n_tokens - 1;
Expand All @@ -187,7 +225,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
rust::Vec<StepOutput> result;
result.reserve(requests_.size());

// Decode tokens in chunks
// Decode pending_tokens in chunks
for (size_t i = 0; i < static_cast<size_t>(batch_.n_tokens); i += N_BATCH) {
const int32_t n_tokens = std::min(N_BATCH, batch_.n_tokens - i);
llama_batch batch_view = {
Expand Down Expand Up @@ -216,10 +254,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
auto logits = llama_get_logits_ith(ctx, i_batch);
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));

request.n_past += request.tokens.size();

request.tokens.clear();
request.tokens.push_back(next_token);
request.step(next_token);

const auto token_str = llama_token_to_piece(ctx, next_token);
request.generated_text += token_str;
Expand Down
Loading