diff --git a/crates/llama-cpp-bindings/include/engine.h b/crates/llama-cpp-bindings/include/engine.h index ae75c50298fa..fa8743cf1532 100644 --- a/crates/llama-cpp-bindings/include/engine.h +++ b/crates/llama-cpp-bindings/include/engine.h @@ -15,5 +15,5 @@ class TextInferenceEngine { virtual rust::Vec step() = 0; }; -std::unique_ptr create_engine(bool use_gpu, rust::Str model_path); +std::unique_ptr create_engine(bool use_gpu, rust::Str model_path, uint8_t paralellism); } // namespace diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index fabec66cbeda..3c261aeb340b 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -14,17 +14,6 @@ namespace llama { TextInferenceEngine::~TextInferenceEngine() {} namespace { -int get_parallelism() { - const char* parallelism = std::getenv("LLAMA_CPP_PARALLELISM"); - if (parallelism) { - return std::stoi(parallelism); - } else { - return 4; - } -} - -static size_t N_CONCURRENT_REQUESTS = get_parallelism(); - constexpr size_t N_BATCH = 512; // # per batch inference. constexpr size_t N_CTX = 4096; // # max kv history. @@ -95,10 +84,11 @@ using owned = std::unique_ptr>; class TextInferenceEngineImpl : public TextInferenceEngine { public: - TextInferenceEngineImpl(owned model, owned ctx) : + TextInferenceEngineImpl(owned model, owned ctx, uint8_t parallelism) : model_(std::move(model)), - ctx_(std::move(ctx)) { - batch_ = llama_batch_init(N_CTX * N_CONCURRENT_REQUESTS, 0, 1); + ctx_(std::move(ctx)), + parallelism_(parallelism) { + batch_ = llama_batch_init(N_CTX * parallelism, 0, 1); // warm up { batch_.n_tokens = 16; @@ -155,7 +145,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { } // Add pending requests. - while (pending_requests_.size() > 0 && requests_.size() < N_CONCURRENT_REQUESTS) { + while (pending_requests_.size() > 0 && requests_.size() < parallelism_) { Request request = std::move(pending_requests_.front()); pending_requests_.pop_front(); @@ -283,6 +273,8 @@ class TextInferenceEngineImpl : public TextInferenceEngine { std::vector requests_; std::deque pending_requests_; std::unordered_set stopped_requests_; + + uint32_t parallelism_; }; static int g_llama_cpp_log_level = 0; @@ -310,7 +302,7 @@ struct BackendInitializer { } // namespace -std::unique_ptr create_engine(bool use_gpu, rust::Str model_path) { +std::unique_ptr create_engine(bool use_gpu, rust::Str model_path, uint8_t parallelism) { static BackendInitializer initializer; llama_model_params model_params = llama_model_default_params(); @@ -322,13 +314,14 @@ std::unique_ptr create_engine(bool use_gpu, rust::Str model } llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = N_CTX * N_CONCURRENT_REQUESTS; + ctx_params.n_ctx = N_CTX * parallelism; ctx_params.n_batch = N_BATCH; llama_context* ctx = llama_new_context_with_model(model, ctx_params); return std::make_unique( owned(model, llama_free_model), - owned(ctx, llama_free) + owned(ctx, llama_free), + parallelism ); } diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index d21805c58b79..cd2938cf58c9 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -23,7 +23,11 @@ mod ffi { type TextInferenceEngine; - fn create_engine(use_gpu: bool, model_path: &str) -> UniquePtr; + fn create_engine( + use_gpu: bool, + model_path: &str, + parallelism: u8, + ) -> UniquePtr; fn add_request( self: Pin<&mut TextInferenceEngine>, @@ -43,6 +47,7 @@ unsafe impl Sync for ffi::TextInferenceEngine {} pub struct LlamaTextGenerationOptions { model_path: String, use_gpu: bool, + parallelism: u8, } pub struct LlamaTextGeneration { @@ -52,7 +57,7 @@ pub struct LlamaTextGeneration { impl LlamaTextGeneration { pub fn new(options: LlamaTextGenerationOptions) -> Self { - let engine = create_engine(options.use_gpu, &options.model_path); + let engine = create_engine(options.use_gpu, &options.model_path, options.parallelism); if engine.is_null() { fatal!("Unable to load model: {}", options.model_path); } diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/serve/engine.rs index db9f7eee13f3..9dc7a0a300df 100644 --- a/crates/tabby/src/serve/engine.rs +++ b/crates/tabby/src/serve/engine.rs @@ -14,8 +14,11 @@ pub async fn create_engine( if fs::metadata(model_id).is_ok() { let path = PathBuf::from(model_id); let model_path = path.join(GGML_MODEL_RELATIVE_PATH); - let engine = - create_ggml_engine(&args.device, model_path.display().to_string().as_str()); + let engine = create_ggml_engine( + &args.device, + model_path.display().to_string().as_str(), + args.parallelism, + ); let engine_info = EngineInfo::read(path.join("tabby.json")); (engine, engine_info) } else { @@ -23,7 +26,7 @@ pub async fn create_engine( let registry = ModelRegistry::new(registry).await; let model_path = registry.get_model_path(name).display().to_string(); let model_info = registry.get_model_info(name); - let engine = create_ggml_engine(&args.device, &model_path); + let engine = create_ggml_engine(&args.device, &model_path, args.parallelism); ( engine, EngineInfo { @@ -57,10 +60,15 @@ impl EngineInfo { } } -fn create_ggml_engine(device: &super::Device, model_path: &str) -> Box { +fn create_ggml_engine( + device: &super::Device, + model_path: &str, + parallelism: u8, +) -> Box { let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default() .model_path(model_path.to_owned()) .use_gpu(device.ggml_use_gpu()) + .parallelism(parallelism) .build() .unwrap(); diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 9d2ff0fb45b5..d3c7942ba64e 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -121,15 +121,13 @@ pub struct ServeArgs { #[clap(long, default_value_t=Device::Cpu)] device: Device, - /// DEPRECATED: Do not use. - #[deprecated(since = "0.5.0")] - #[clap(long, hide(true))] - device_indices: Vec, + /// Parallelism for model serving - increasing this number will have a significant impact on the + /// memory requirement e.g., GPU vRAM. + #[clap(long, default_value_t = 1)] + parallelism: u8, } pub async fn main(config: &Config, args: &ServeArgs) { - valid_args(args); - if args.device != Device::ExperimentalHttp { if fs::metadata(&args.model).is_ok() { info!("Loading model from local path {}", &args.model); @@ -252,12 +250,6 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router { .layer(opentelemetry_tracing_layer()) } -fn valid_args(args: &ServeArgs) { - if !args.device_indices.is_empty() { - warn!("--device-indices is deprecated and will be removed in future release."); - } -} - fn start_heartbeat(args: &ServeArgs) { let state = HealthState::new(args); tokio::spawn(async move {