diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index 207fd3d6204b..d21805c58b79 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -1,20 +1,14 @@ +mod llama; mod utils; -use std::{collections::HashMap, sync::Arc}; - use async_stream::stream; use async_trait::async_trait; -use cxx::UniquePtr; use derive_builder::Builder; use ffi::create_engine; -use futures::{lock::Mutex, stream::BoxStream}; +use futures::stream::BoxStream; +use llama::LlamaService; use tabby_inference::{ - decoding::{StopCondition, StopConditionFactory}, - helpers, TextGeneration, TextGenerationOptions, -}; -use tokio::{ - sync::mpsc::{channel, Sender}, - task::yield_now, + decoding::StopConditionFactory, helpers, TextGeneration, TextGenerationOptions, }; #[cxx::bridge(namespace = "llama")] @@ -45,66 +39,36 @@ mod ffi { unsafe impl Send for ffi::TextInferenceEngine {} unsafe impl Sync for ffi::TextInferenceEngine {} -struct InferenceRequest { - tx: Sender, - stop_condition: StopCondition, +#[derive(Builder, Debug)] +pub struct LlamaTextGenerationOptions { + model_path: String, + use_gpu: bool, } -struct AsyncTextInferenceEngine { - engine: Mutex>, +pub struct LlamaTextGeneration { + service: LlamaService, stop_condition_factory: StopConditionFactory, - requests: Mutex>, - - next_request_id: Mutex, } -impl AsyncTextInferenceEngine { - fn create(engine: UniquePtr) -> Self { +impl LlamaTextGeneration { + pub fn new(options: LlamaTextGenerationOptions) -> Self { + let engine = create_engine(options.use_gpu, &options.model_path); + if engine.is_null() { + fatal!("Unable to load model: {}", options.model_path); + } + Self { - engine: Mutex::new(engine), + service: LlamaService::new(engine), stop_condition_factory: StopConditionFactory::default(), - requests: Mutex::new(HashMap::new()), - next_request_id: Mutex::new(0), } } +} - async fn background_job(&self) { - let mut requests = self.requests.lock().await; - if requests.len() == 0 { - return; - } - - let mut engine = self.engine.lock().await; - - let result = match engine.as_mut().unwrap().step() { - Ok(result) => result, - Err(err) => { - fatal!("Failed to step: {}", err) - } - }; - - for ffi::StepOutput { request_id, text } in result { - let mut stopped = false; - let InferenceRequest { tx, stop_condition } = requests.get_mut(&request_id).unwrap(); - - if tx.is_closed() || text.is_empty() { - // Cancelled by client side or hit eos. - stopped = true; - } else if !stop_condition.should_stop(&text) { - match tx.send(text).await { - Ok(_) => (), - Err(_) => stopped = true, - } - } else { - // Stoop words stopped - stopped = true; - } - - if stopped { - requests.remove(&request_id); - engine.as_mut().unwrap().stop_request(request_id); - } - } +#[async_trait] +impl TextGeneration for LlamaTextGeneration { + async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { + let s = self.generate_stream(prompt, options).await; + helpers::stream_to_string(s).await } async fn generate_stream( @@ -114,23 +78,10 @@ impl AsyncTextInferenceEngine { ) -> BoxStream { let stop_condition = self.stop_condition_factory.create(prompt, options.language); - let (tx, mut rx) = channel::(4); - { - let mut engine = self.engine.lock().await; - - let mut request_id = self.next_request_id.lock().await; - self.requests - .lock() - .await - .insert(*request_id, InferenceRequest { tx, stop_condition }); - engine - .as_mut() - .unwrap() - .add_request(*request_id, prompt, options.max_input_length); - - // 2048 should be large enough to avoid collision. - *request_id = (*request_id + 1) % 2048; - } + let mut rx = self + .service + .add_request(prompt, options.max_input_length, stop_condition) + .await; let s = stream! { let mut length = 0; @@ -148,53 +99,3 @@ impl AsyncTextInferenceEngine { Box::pin(s) } } - -#[derive(Builder, Debug)] -pub struct LlamaTextGenerationOptions { - model_path: String, - use_gpu: bool, -} - -pub struct LlamaTextGeneration { - engine: Arc, -} - -impl LlamaTextGeneration { - pub fn create(options: LlamaTextGenerationOptions) -> Self { - let engine = create_engine(options.use_gpu, &options.model_path); - if engine.is_null() { - fatal!("Unable to load model: {}", options.model_path); - } - let ret = LlamaTextGeneration { - engine: Arc::new(AsyncTextInferenceEngine::create(engine)), - }; - ret.start_background_job(); - ret - } - - pub fn start_background_job(&self) { - let engine = self.engine.clone(); - tokio::spawn(async move { - loop { - engine.background_job().await; - yield_now().await; - } - }); - } -} - -#[async_trait] -impl TextGeneration for LlamaTextGeneration { - async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { - let s = self.generate_stream(prompt, options).await; - helpers::stream_to_string(s).await - } - - async fn generate_stream( - &self, - prompt: &str, - options: TextGenerationOptions, - ) -> BoxStream { - self.engine.generate_stream(prompt, options).await - } -} diff --git a/crates/llama-cpp-bindings/src/llama.rs b/crates/llama-cpp-bindings/src/llama.rs new file mode 100644 index 000000000000..975d4b11da28 --- /dev/null +++ b/crates/llama-cpp-bindings/src/llama.rs @@ -0,0 +1,155 @@ +use std::{collections::HashMap, thread::JoinHandle}; + +use cxx::UniquePtr; +use tabby_inference::decoding::StopCondition; +use tokio::sync::mpsc::{channel, Receiver, Sender}; + +use crate::ffi; + +struct LlamaInitRequest { + prompt: String, + max_input_length: usize, + + tx: Sender, + stop_condition: StopCondition, +} + +struct LlamaRunningRequest { + tx: Sender, + stop_condition: StopCondition, +} + +struct LlamaServiceImpl { + next_request_id: u32, + engine: cxx::UniquePtr, + rx: Receiver, + requests: HashMap, +} + +impl LlamaServiceImpl { + fn new(engine: UniquePtr, rx: Receiver) -> Self { + Self { + next_request_id: 0, + engine, + rx, + requests: HashMap::new(), + } + } + + fn alloc_request_id(&mut self) -> u32 { + let ret = self.next_request_id; + self.next_request_id += 1; + ret + } + + async fn next_request(&mut self) -> Option { + if self.requests.is_empty() { + self.rx.recv().await + } else { + self.rx.try_recv().ok() + } + } + + async fn background_job(&mut self) { + while let Some(LlamaInitRequest { + prompt, + tx, + max_input_length, + stop_condition, + }) = self.next_request().await + { + let request_id = self.alloc_request_id(); + self.requests + .insert(request_id, LlamaRunningRequest { tx, stop_condition }); + self.engine + .as_mut() + .unwrap() + .add_request(request_id, &prompt, max_input_length); + } + + let result = match self.engine.as_mut().unwrap().step() { + Ok(result) => result, + Err(err) => { + crate::fatal!("Failed to step: {}", err) + } + }; + + for ffi::StepOutput { request_id, text } in result { + let mut stopped = false; + let LlamaRunningRequest { tx, stop_condition } = + self.requests.get_mut(&request_id).unwrap(); + + if tx.is_closed() || text.is_empty() { + // Cancelled by client side or hit eos. + stopped = true; + } else if !stop_condition.should_stop(&text) { + match tx.send(text).await { + Ok(_) => (), + Err(_) => stopped = true, + } + } else { + // Stoop words stopped + stopped = true; + } + + if stopped { + self.requests.remove(&request_id); + self.engine.as_mut().unwrap().stop_request(request_id); + } + } + } +} + +fn start_llama_service_impl( + engine: UniquePtr, + rx: Receiver, +) -> JoinHandle<()> { + let mut service = LlamaServiceImpl::new(engine, rx); + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + std::thread::spawn(move || { + let local = tokio::task::LocalSet::new(); + local.spawn_local(async move { + loop { + service.background_job().await; + } + }); + + rt.block_on(local); + }) +} + +pub struct LlamaService { + tx: Sender, +} + +impl LlamaService { + pub fn new(engine: UniquePtr) -> Self { + let (tx, rx) = channel(20); + start_llama_service_impl(engine, rx); + Self { tx } + } + + pub async fn add_request( + &self, + prompt: &str, + max_input_length: usize, + stop_condition: StopCondition, + ) -> Receiver { + let (tx, rx) = channel(8); + self.tx + .send(LlamaInitRequest { + prompt: prompt.to_owned(), + tx, + max_input_length, + stop_condition, + }) + .await + .expect("Failed to add request"); + + rx + } +} diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/serve/engine.rs index 7ee767960d7f..db9f7eee13f3 100644 --- a/crates/tabby/src/serve/engine.rs +++ b/crates/tabby/src/serve/engine.rs @@ -64,5 +64,5 @@ fn create_ggml_engine(device: &super::Device, model_path: &str) -> Box