Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement prompt lookup based speculative decoding #916

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion crates/llama-cpp-bindings/include/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,10 @@ class TextInferenceEngine {
virtual rust::Vec<StepOutput> step() = 0;
};

std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t paralellism);
std::unique_ptr<TextInferenceEngine> create_engine(
bool use_gpu,
rust::Str model_path,
uint8_t paralellism,
bool enable_prompt_lookup
);
} // namespace
180 changes: 131 additions & 49 deletions crates/llama-cpp-bindings/src/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,80 @@ TextInferenceEngine::~TextInferenceEngine() {}
namespace {
constexpr size_t N_BATCH = 512; // # per batch inference.
constexpr size_t N_CTX = 4096; // # max kv history.

constexpr int DRAFT_N_GRAM_SIZE = 3;
constexpr int DRAFT_N_PRED_TOKENS = 10;

struct Request {
Request(size_t request_id, std::vector<llama_token> input_token_ids) :
id(request_id),
tokens(input_token_ids.begin(), input_token_ids.end()) {
}
}

uint32_t id = -1;
llama_seq_id seq_id = -1;

std::vector<llama_token> tokens;
size_t i_batch = -1;
size_t n_past = 0;
size_t n_draft = 0;

int32_t multibyte_pending = 0;
std::string generated_text;


void draft_tokens(int n_draft_quota) {
if (n_draft_quota < DRAFT_N_PRED_TOKENS) {
n_draft = 0;
return;
}

auto draft = find_candidate_pred_tokens(DRAFT_N_GRAM_SIZE, DRAFT_N_PRED_TOKENS);
n_draft = draft.size();
tokens.insert(tokens.end(), draft.begin(), draft.end());
}

size_t n_past() {
return past_tokens.size();
}

void step(llama_token next_token, size_t n_dropped) {
past_tokens.insert(past_tokens.end(), tokens.begin(), tokens.end() - n_dropped);

n_draft = 0;
tokens.clear();
tokens.push_back(next_token);
}

private:
std::vector<llama_token> find_candidate_pred_tokens(size_t ngram_size, size_t n_pred_tokens) {
auto ngram = build_ngram(ngram_size);
if (ngram.size() < ngram_size) return {};

const auto end = past_tokens.end() - ngram_size - n_pred_tokens;
const auto matched = std::search(past_tokens.begin(), end, ngram.begin(), ngram.end());
if (matched == end) return {};

const auto begin = matched + ngram_size;
return std::vector<llama_token>(begin, begin + n_pred_tokens);
}

std::vector<llama_token> build_ngram(size_t ngram_size) {
GGML_ASSERT(n_draft == 0);
std::deque<llama_token> ret;
for (int i = tokens.size() - 1; i >= 0; --i) {
if (ret.size() == ngram_size) break;
ret.push_front(tokens[i]);
}

for (int i = past_tokens.size() - 1; i >= 0; --i) {
if (ret.size() == ngram_size) break;
ret.push_front(past_tokens[i]);
}

return std::vector<llama_token>(ret.begin(), ret.end());
}

std::vector<llama_token> past_tokens;
};


Expand Down Expand Up @@ -84,10 +142,11 @@ using owned = std::unique_ptr<T, std::function<void(T*)>>;

class TextInferenceEngineImpl : public TextInferenceEngine {
public:
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx, uint8_t parallelism) :
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx, uint8_t parallelism, bool enable_prompt_lookup) :
model_(std::move(model)),
ctx_(std::move(ctx)),
parallelism_(parallelism) {
parallelism_(parallelism),
enable_prompt_lookup_(enable_prompt_lookup) {
batch_ = llama_batch_init(N_CTX * parallelism, 0, 1);
// warm up
{
Expand Down Expand Up @@ -171,17 +230,26 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
// Insert tokens from ongoing requests to batch.
for (auto& request : requests_) {
const size_t n_tokens = batch_.n_tokens;

// Ensure the draft logits always fall into the same batch.
if (enable_prompt_lookup_) {
const int n_draft_quota = N_BATCH - (n_tokens + request.tokens.size()) % N_BATCH;
request.draft_tokens(n_draft_quota);
}

for (size_t i = 0; i < request.tokens.size(); ++i) {
batch_.token[n_tokens + i] = request.tokens[i];
batch_.pos[n_tokens + i] = request.n_past + i;
batch_.pos[n_tokens + i] = request.n_past() + i;
batch_.n_seq_id[n_tokens + i] = 1;
batch_.seq_id[n_tokens + i][0] = request.id;
batch_.logits[n_tokens + i] = false;
}
batch_.n_tokens += request.tokens.size();

batch_.logits[batch_.n_tokens - 1] = true;
request.i_batch = batch_.n_tokens - 1;
for (int k = batch_.n_tokens - request.n_draft - 1; k <= batch_.n_tokens - 1; ++ k) {
batch_.logits[k] = true;
}
request.i_batch = batch_.n_tokens;
}

rust::Vec<StepOutput> result;
Expand All @@ -208,55 +276,62 @@ class TextInferenceEngineImpl : public TextInferenceEngine {

const auto eos_id = llama_token_eos(llama_get_model(ctx));
for (auto& request : requests_) {
if ((request.i_batch < i) || (request.i_batch >= (i + n_tokens))) {
int32_t i_batch = request.i_batch - i - 1;
if ((i_batch < 0) || (i_batch >= n_tokens)) {
continue;
}

int32_t i_batch = request.i_batch - i;
auto logits = llama_get_logits_ith(ctx, i_batch);
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));

request.n_past += request.tokens.size();

request.tokens.clear();
request.tokens.push_back(next_token);

const auto token_str = llama_token_to_piece(ctx, next_token);
request.generated_text += token_str;
for (int k = -request.n_draft; k < 1; ++k) {
auto logits = llama_get_logits_ith(ctx, i_batch + k);
llama_token next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));

const auto token_str = llama_token_to_piece(ctx, next_token);
request.generated_text += token_str;

// FIXME: Hack for codellama to simplify tabby's implementation.
const bool is_eos = next_token == eos_id || token_str == " <EOT>";

if (request.multibyte_pending > 0) {
request.multibyte_pending -= token_str.size();
} else if (token_str.size() == 1) {
const char c = token_str[0];
// 2-byte characters: 110xxxxx 10xxxxxx
if ((c & 0xE0) == 0xC0) {
request.multibyte_pending = 1;
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
}
else if ((c & 0xF0) == 0xE0) {
request.multibyte_pending = 2;
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
} else if ((c & 0xF8) == 0xF0) {
request.multibyte_pending = 3;
}
else {
request.multibyte_pending = 0;
}
}

// FIXME: Hack for codellama to simplify tabby's implementation.
const bool is_eos = next_token == eos_id || token_str == " <EOT>";
if (request.multibyte_pending == 0) {
rust::String generated_text;
try {
generated_text = is_eos ? "" : request.generated_text;
} catch (const std::invalid_argument& e) {
fprintf(stderr, "%s:%d [%s] - ignoring non utf-8/utf-16 output\n", __FILE__, __LINE__, __func__);
}

if (request.multibyte_pending > 0) {
request.multibyte_pending -= token_str.size();
} else if (token_str.size() == 1) {
const char c = token_str[0];
// 2-byte characters: 110xxxxx 10xxxxxx
if ((c & 0xE0) == 0xC0) {
request.multibyte_pending = 1;
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
result.push_back({request.id, generated_text});
request.generated_text.clear();
}
else if ((c & 0xF0) == 0xE0) {
request.multibyte_pending = 2;
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
} else if ((c & 0xF8) == 0xF0) {
request.multibyte_pending = 3;
}
else {
request.multibyte_pending = 0;
}
}

if (request.multibyte_pending == 0) {
rust::String generated_text;
try {
generated_text = is_eos ? "" : request.generated_text;
} catch (const std::invalid_argument& e) {
fprintf(stderr, "%s:%d [%s] - ignoring non utf-8/utf-16 output\n", __FILE__, __LINE__, __func__);
if (is_eos) {
break;
}

result.push_back({request.id, generated_text});
request.generated_text.clear();
if ((k == 0) || ((k < 0 && next_token != request.tokens[request.tokens.size() + k]))) {
request.step(next_token, -k);
llama_kv_cache_seq_rm(ctx_.get(), request.id, request.n_past(), -1);
break;
}
}
}
}
Expand All @@ -275,6 +350,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
std::unordered_set<uint32_t> stopped_requests_;

uint32_t parallelism_;
bool enable_prompt_lookup_;
};

static int g_llama_cpp_log_level = 0;
Expand Down Expand Up @@ -302,7 +378,12 @@ struct BackendInitializer {

} // namespace

std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t parallelism) {
std::unique_ptr<TextInferenceEngine> create_engine(
bool use_gpu,
rust::Str model_path,
uint8_t parallelism,
bool enable_prompt_lookup
) {
static BackendInitializer initializer;

llama_model_params model_params = llama_model_default_params();
Expand All @@ -325,7 +406,8 @@ std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model
return std::make_unique<TextInferenceEngineImpl>(
owned<llama_model>(model, llama_free_model),
owned<llama_context>(ctx, llama_free),
parallelism
parallelism,
enable_prompt_lookup
);
}

Expand Down
4 changes: 3 additions & 1 deletion crates/llama-cpp-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ mod ffi {
use_gpu: bool,
model_path: &str,
parallelism: u8,
enable_prompt_lookup: bool,
) -> UniquePtr<TextInferenceEngine>;

fn add_request(
Expand All @@ -48,6 +49,7 @@ pub struct LlamaTextGenerationOptions {
model_path: String,
use_gpu: bool,
parallelism: u8,
enable_prompt_lookup: bool,
}

pub struct LlamaTextGeneration {
Expand All @@ -57,7 +59,7 @@ pub struct LlamaTextGeneration {

impl LlamaTextGeneration {
pub fn new(options: LlamaTextGenerationOptions) -> Self {
let engine = create_engine(options.use_gpu, &options.model_path, options.parallelism);
let engine = create_engine(options.use_gpu, &options.model_path, options.parallelism, options.enable_prompt_lookup);
if engine.is_null() {
fatal!("Unable to load model: {}", options.model_path);
}
Expand Down
2 changes: 1 addition & 1 deletion crates/tabby/src/services/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl ChatService {

pub async fn create_chat_service(model: &str, device: &Device, parallelism: u8) -> ChatService {
let (engine, model::PromptInfo { chat_template, .. }) =
model::load_text_generation(model, device, parallelism).await;
model::load_text_generation(model, device, parallelism, true).await;

let Some(chat_template) = chat_template else {
fatal!("Chat model requires specifying prompt template");
Expand Down
2 changes: 1 addition & 1 deletion crates/tabby/src/services/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ pub async fn create_completion_service(
model::PromptInfo {
prompt_template, ..
},
) = model::load_text_generation(model, device, parallelism).await;
) = model::load_text_generation(model, device, parallelism, false).await;

CompletionService::new(engine.clone(), code, logger, prompt_template)
}
19 changes: 16 additions & 3 deletions crates/tabby/src/services/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub async fn load_text_generation(
model_id: &str,
device: &Device,
parallelism: u8,
enable_prompt_lookup: bool,
) -> (Arc<dyn TextGeneration>, PromptInfo) {
#[cfg(feature = "experimental-http")]
if device == &Device::ExperimentalHttp {
Expand All @@ -28,19 +29,25 @@ pub async fn load_text_generation(
if fs::metadata(model_id).is_ok() {
let path = PathBuf::from(model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
let engine_info = PromptInfo::read(path.join("tabby.json"));
let engine = create_ggml_engine(
device,
model_path.display().to_string().as_str(),
parallelism,
enable_prompt_lookup,
);
let engine_info = PromptInfo::read(path.join("tabby.json"));
(Arc::new(engine), engine_info)
} else {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name);
let engine = create_ggml_engine(device, &model_path, parallelism);
let engine = create_ggml_engine(
device,
&model_path,
parallelism,
enable_prompt_lookup,
);
(
Arc::new(engine),
PromptInfo {
Expand All @@ -64,11 +71,17 @@ impl PromptInfo {
}
}

fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> impl TextGeneration {
fn create_ggml_engine(
device: &Device,
model_path: &str,
parallelism: u8,
enable_prompt_lookup: bool,
) -> impl TextGeneration {
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
.model_path(model_path.to_owned())
.use_gpu(device.ggml_use_gpu())
.parallelism(parallelism)
.enable_prompt_lookup(enable_prompt_lookup)
.build()
.unwrap();

Expand Down