From 7bd99d14c0d346bde2a78c4eacc0ed5d53bb666d Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sat, 28 Oct 2023 23:37:05 -0700 Subject: [PATCH] feat: support continuous batching in llama.cpp backend (#659) * refactor: switch back to llama batch interface * feat: support cont batching --- crates/llama-cpp-bindings/include/engine.h | 8 +- crates/llama-cpp-bindings/src/engine.cc | 178 +++++++++++++++---- crates/llama-cpp-bindings/src/lib.rs | 195 ++++++++++++++++----- crates/tabby/src/download.rs | 2 +- crates/tabby/src/serve/engine.rs | 4 +- 5 files changed, 295 insertions(+), 92 deletions(-) diff --git a/crates/llama-cpp-bindings/include/engine.h b/crates/llama-cpp-bindings/include/engine.h index fffc0a25586f..e6a9c4dcc7b5 100644 --- a/crates/llama-cpp-bindings/include/engine.h +++ b/crates/llama-cpp-bindings/include/engine.h @@ -9,11 +9,11 @@ class TextInferenceEngine { public: virtual ~TextInferenceEngine(); - virtual void start(rust::Slice input_token_ids) = 0; - virtual uint32_t step() = 0; - virtual void end() = 0; + virtual void add_request(uint32_t request_id, rust::Slice input_token_ids) = 0; + virtual void stop_request(uint32_t request_id) = 0; + virtual rust::Vec step() = 0; - virtual uint32_t eos_token() const = 0; + virtual uint32_t eos_token_id() const = 0; }; std::unique_ptr create_engine(bool use_gpu, rust::Str model_path); diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index 3b5caaa939b1..9a93f36e62c1 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -2,6 +2,8 @@ #include #include +#include +#include #include #include @@ -10,8 +12,34 @@ namespace llama { TextInferenceEngine::~TextInferenceEngine() {} namespace { -static size_t N_BATCH = 512; // # per batch inference. -static size_t N_CTX = 4096; // # max kv history. +int get_parallelism() { + const char* parallelism = std::getenv("LLAMA_CPP_PARALLELISM"); + if (parallelism) { + return std::stoi(parallelism); + } else { + return 4; + } +} + +static size_t N_CONCURRENT_REQUESTS = get_parallelism(); + +constexpr size_t N_BATCH = 512; // # per batch inference. +constexpr size_t N_CTX = 4096; // # max kv history. + +struct Request { + Request(size_t request_id, rust::Slice input_token_ids) : + id(request_id), + tokens(input_token_ids.begin(), input_token_ids.end()) { + } + + size_t id = -1; + llama_seq_id seq_id = -1; + + std::vector tokens; + size_t i_batch = -1; + size_t n_past = 0; +}; + template using owned = std::unique_ptr>; @@ -21,61 +49,136 @@ class TextInferenceEngineImpl : public TextInferenceEngine { TextInferenceEngineImpl(owned model, owned ctx) : model_(std::move(model)), ctx_(std::move(ctx)) { + batch_ = llama_batch_init(N_CTX * N_CONCURRENT_REQUESTS, 0, 1); } - void start(rust::Slice input_token_ids) override { - auto* ctx = ctx_.get(); - llama_reset_timings(ctx); - std::vector tokens_list(input_token_ids.begin(), input_token_ids.end()); - - for (size_t i = 0; i < tokens_list.size(); i += N_BATCH) { - const size_t size = std::min(N_BATCH, tokens_list.size() - i); - eval(tokens_list.data() + i, size, /* reset = */ i == 0); - } - } - - uint32_t step() override { - const llama_token id = sample(); - eval(const_cast(&id), 1, /* reset = */ false); - return id; + ~TextInferenceEngineImpl() { + llama_batch_free(batch_); } - void end() override { - llama_print_timings(ctx_.get()); + void add_request(uint32_t request_id, rust::Slice input_token_ids) override { + pending_requests_.push_back(Request(request_id, input_token_ids)); } - uint32_t eos_token() const override { - return llama_token_eos(llama_get_model(ctx_.get())); + void stop_request(uint32_t request_id) override { + stopped_requests_.insert(request_id); } - private: - uint32_t sample() const { + rust::Vec step() override { auto* ctx = ctx_.get(); - - auto logits = llama_get_logits_ith(ctx, 0); auto n_vocab = llama_n_vocab(llama_get_model(ctx)); - // Greedy sampling (always select the highest logit). - return std::distance(logits, std::max_element(logits, logits + n_vocab)); - } + // Remove stopped requests. + if (!stopped_requests_.empty()) { + std::vector requests; + for (auto& request : requests_) { + if (stopped_requests_.count(request.id) > 0) { + // Release KV cache. + llama_kv_cache_seq_rm(ctx_.get(), request.id, -1, -1); + } else { + requests.emplace_back(request); + } + } + + requests_ = requests; + } + + // Add pending requests. + while (pending_requests_.size() > 0 && requests_.size() < N_CONCURRENT_REQUESTS) { + Request request = std::move(pending_requests_.front()); + pending_requests_.pop_front(); - void eval(llama_token* data, size_t size, bool reset) { - if (reset) { - n_past_ = 0; + // Ignore stopped pending requests. + if (stopped_requests_.count(request.id) > 0) { + continue; + } + + requests_.push_back(request); } - auto* ctx = ctx_.get(); - llama_kv_cache_tokens_rm(ctx, n_past_, -1); - if (llama_decode(ctx, llama_batch_get_one(data, size, n_past_, 0))) { - throw std::runtime_error("Failed to eval"); + // Clear stopped requests. + stopped_requests_.clear(); + + if (requests_.size() == 0) { + return {}; + } + + // Clear the batch. + batch_.n_tokens = 0; + + // Insert tokens from ongoing requests to batch. + for (auto& request : requests_) { + const size_t n_tokens = batch_.n_tokens; + for (size_t i = 0; i < request.tokens.size(); ++i) { + batch_.token[n_tokens + i] = request.tokens[i]; + batch_.pos[n_tokens + i] = request.n_past + i; + batch_.n_seq_id[n_tokens + i] = 1; + batch_.seq_id[n_tokens + i][0] = request.id; + batch_.logits[n_tokens + i] = false; + } + batch_.n_tokens += request.tokens.size(); + + batch_.logits[batch_.n_tokens - 1] = true; + request.i_batch = batch_.n_tokens - 1; } - n_past_ += size; + rust::Vec result; + result.reserve(requests_.size() * 2); + + // Decode tokens in chunks + for (size_t i = 0; i < static_cast(batch_.n_tokens); i += N_BATCH) { + const int32_t n_tokens = std::min(N_BATCH, batch_.n_tokens - i); + llama_batch batch_view = { + n_tokens, + batch_.token + i, + nullptr, + batch_.pos + i, + batch_.n_seq_id + i, + batch_.seq_id + i, + batch_.logits + i, + 0, 0, 0, // unused + }; + + const int ret = llama_decode(ctx, batch_view); + if (ret != 0) { + throw std::runtime_error("Failed to eval"); + } + + for (auto& request : requests_) { + if ((request.i_batch < i) || (request.i_batch >= (i + n_tokens))) { + continue; + } + + int32_t i_batch = request.i_batch - i; + auto logits = llama_get_logits_ith(ctx, i_batch); + auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab)); + + request.n_past += request.tokens.size(); + + request.tokens.clear(); + request.tokens.push_back(next_token); + + result.push_back(request.id); + result.push_back(next_token); + } + } + + return result; + } + + uint32_t eos_token_id() const override { + return llama_token_eos(llama_get_model(ctx_.get())); } - size_t n_past_; + private: owned model_; owned ctx_; + + llama_batch batch_; + + std::vector requests_; + std::deque pending_requests_; + std::unordered_set stopped_requests_; }; static int g_llama_cpp_log_level = 0; @@ -100,6 +203,7 @@ struct BackendInitializer { llama_backend_free(); } }; + } // namespace std::unique_ptr create_engine(bool use_gpu, rust::Str model_path) { diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index 53870fc5abd3..00e5879c6e44 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -1,12 +1,20 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use async_stream::stream; use async_trait::async_trait; +use cxx::UniquePtr; use derive_builder::Builder; use ffi::create_engine; use futures::{lock::Mutex, stream::BoxStream}; -use tabby_inference::{decoding::DecodingFactory, helpers, TextGeneration, TextGenerationOptions}; +use tabby_inference::{ + decoding::{DecodingFactory, IncrementalDecoding}, + helpers, TextGeneration, TextGenerationOptions, +}; use tokenizers::tokenizer::Tokenizer; +use tokio::{ + sync::mpsc::{channel, Sender}, + task::yield_now, +}; #[cxx::bridge(namespace = "llama")] mod ffi { @@ -17,49 +25,82 @@ mod ffi { fn create_engine(use_gpu: bool, model_path: &str) -> UniquePtr; - fn start(self: Pin<&mut TextInferenceEngine>, input_token_ids: &[u32]); - fn step(self: Pin<&mut TextInferenceEngine>) -> Result; - fn end(self: Pin<&mut TextInferenceEngine>); + fn add_request( + self: Pin<&mut TextInferenceEngine>, + request_id: u32, + input_token_ids: &[u32], + ); + fn stop_request(self: Pin<&mut TextInferenceEngine>, request_id: u32); + fn step(self: Pin<&mut TextInferenceEngine>) -> Result>; - fn eos_token(&self) -> u32; + fn eos_token_id(&self) -> u32; } } unsafe impl Send for ffi::TextInferenceEngine {} unsafe impl Sync for ffi::TextInferenceEngine {} -#[derive(Builder, Debug)] -pub struct LlamaEngineOptions { - model_path: String, - tokenizer_path: String, - use_gpu: bool, +struct InferenceRequest { + tx: Sender, + decoding: IncrementalDecoding, } -pub struct LlamaEngine { +struct AsyncTextInferenceEngine { engine: Mutex>, tokenizer: Arc, decoding_factory: DecodingFactory, + requests: Mutex>, + + next_request_id: Mutex, + eos_token_id: u32, } -impl LlamaEngine { - pub fn create(options: LlamaEngineOptions) -> Self { - let engine = create_engine(options.use_gpu, &options.model_path); - if engine.is_null() { - panic!("Unable to load model: {}", options.model_path); - } - LlamaEngine { +impl AsyncTextInferenceEngine { + fn create(engine: UniquePtr, tokenizer: Tokenizer) -> Self { + Self { + eos_token_id: engine.eos_token_id(), engine: Mutex::new(engine), - tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()), + tokenizer: Arc::new(tokenizer), decoding_factory: DecodingFactory::default(), + requests: Mutex::new(HashMap::new()), + next_request_id: Mutex::new(0), } } -} -#[async_trait] -impl TextGeneration for LlamaEngine { - async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { - let s = self.generate_stream(prompt, options).await; - helpers::stream_to_string(s).await + async fn background_job(&self) { + let mut requests = self.requests.lock().await; + if requests.len() == 0 { + return; + } + + let mut engine = self.engine.lock().await; + + let Ok(result) = engine.as_mut().unwrap().step() else { + panic!("Failed to evaluation"); + }; + + for i in (0..result.len()).step_by(2) { + let request_id = result[i]; + let token_id = result[i + 1]; + + let InferenceRequest { tx, decoding } = requests.get_mut(&request_id).unwrap(); + let mut stopped = false; + + if tx.is_closed() || token_id == self.eos_token_id { + // Cancelled by client side or hit eos. + stopped = true; + } else if let Some(new_text) = decoding.next_token(token_id) { + tx.send(new_text).await.expect("send failed"); + } else { + // Stoop words stopped + stopped = true; + } + + if stopped { + requests.remove(&request_id); + engine.as_mut().unwrap().stop_request(request_id); + } + } } async fn generate_stream( @@ -68,40 +109,98 @@ impl TextGeneration for LlamaEngine { options: TextGenerationOptions, ) -> BoxStream { let encoding = self.tokenizer.encode(prompt, true).unwrap(); - - let s = stream! { + let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length); + let decoding = self.decoding_factory.create_incremental_decoding( + self.tokenizer.clone(), + input_token_ids, + options.language, + ); + + let (tx, mut rx) = channel::(4); + { let mut engine = self.engine.lock().await; - let mut engine = engine.as_mut().unwrap(); - let eos_token = engine.eos_token(); - - let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length); - engine.as_mut().start(input_token_ids); - let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.language); - let mut n_remains = options.max_decoding_length ; - while n_remains > 0 { - let Ok(next_token_id) = engine.as_mut().step() else { - panic!("Failed to eval"); - }; - if next_token_id == eos_token { - break; - } + let engine = engine.as_mut().unwrap(); + + let mut request_id = self.next_request_id.lock().await; + self.requests + .lock() + .await + .insert(*request_id, InferenceRequest { tx, decoding }); + engine.add_request(*request_id, input_token_ids); + + // 2048 should be large enough to avoid collision. + *request_id = (*request_id + 1) % 2048; + } - if let Some(new_text) = decoding.next_token(next_token_id) { - yield new_text; - } else { + let s = stream! { + let mut length = 0; + while let Some(new_text) = rx.recv().await { + yield new_text; + length += 1; + if length >= options.max_decoding_length { break; } - - n_remains -= 1; } - engine.end(); + rx.close(); }; Box::pin(s) } } +#[derive(Builder, Debug)] +pub struct LlamaTextGenerationOptions { + model_path: String, + tokenizer_path: String, + use_gpu: bool, +} + +pub struct LlamaTextGeneration { + engine: Arc, +} + +impl LlamaTextGeneration { + pub fn create(options: LlamaTextGenerationOptions) -> Self { + let engine = create_engine(options.use_gpu, &options.model_path); + if engine.is_null() { + panic!("Unable to load model: {}", options.model_path); + } + let tokenizer = Tokenizer::from_file(&options.tokenizer_path).unwrap(); + let ret = LlamaTextGeneration { + engine: Arc::new(AsyncTextInferenceEngine::create(engine, tokenizer)), + }; + ret.start_background_job(); + ret + } + + pub fn start_background_job(&self) { + let engine = self.engine.clone(); + tokio::spawn(async move { + loop { + engine.background_job().await; + yield_now().await; + } + }); + } +} + +#[async_trait] +impl TextGeneration for LlamaTextGeneration { + async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { + let s = self.generate_stream(prompt, options).await; + helpers::stream_to_string(s).await + } + + async fn generate_stream( + &self, + prompt: &str, + options: TextGenerationOptions, + ) -> BoxStream { + self.engine.generate_stream(prompt, options).await + } +} + fn truncate_tokens(tokens: &[u32], max_length: usize) -> &[u32] { if max_length < tokens.len() { let start = tokens.len() - max_length; diff --git a/crates/tabby/src/download.rs b/crates/tabby/src/download.rs index 6ee52f031a22..bfea8b8ed94e 100644 --- a/crates/tabby/src/download.rs +++ b/crates/tabby/src/download.rs @@ -1,6 +1,6 @@ use clap::Args; use tabby_download::Downloader; -use tracing::{info, log::warn}; +use tracing::info; use crate::fatal; diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/serve/engine.rs index 4fb9dd1282c3..0b89ea5b09ef 100644 --- a/crates/tabby/src/serve/engine.rs +++ b/crates/tabby/src/serve/engine.rs @@ -39,14 +39,14 @@ pub struct EngineInfo { } fn create_ggml_engine(device: &super::Device, model_dir: &ModelDir) -> Box { - let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default() + let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default() .model_path(model_dir.ggml_q8_0_v2_file()) .tokenizer_path(model_dir.tokenizer_file()) .use_gpu(device.ggml_use_gpu()) .build() .unwrap(); - Box::new(llama_cpp_bindings::LlamaEngine::create(options)) + Box::new(llama_cpp_bindings::LlamaTextGeneration::create(options)) } fn get_model_dir(model: &str) -> ModelDir {