diff --git a/crates/tabby-common/src/registry.rs b/crates/tabby-common/src/registry.rs index 81600592a150..7db5c913fc9e 100644 --- a/crates/tabby-common/src/registry.rs +++ b/crates/tabby-common/src/registry.rs @@ -62,7 +62,7 @@ impl ModelRegistry { models_dir() .join(&self.name) .join(name) - .join("ggml/q8_0.v2.gguf") + .join(GGML_MODEL_RELATIVE_PATH) } pub fn get_model_info(&self, name: &str) -> &ModelInfo { @@ -81,3 +81,5 @@ pub fn parse_model_id(model_id: &str) -> (&str, &str) { (parts[0], parts[1]) } + +pub static GGML_MODEL_RELATIVE_PATH: &str = "ggml/q8_0.v2.gguf"; diff --git a/crates/tabby-download/src/lib.rs b/crates/tabby-download/src/lib.rs index 3dbd5b42ce9f..9f7b92cfd96c 100644 --- a/crates/tabby-download/src/lib.rs +++ b/crates/tabby-download/src/lib.rs @@ -89,11 +89,6 @@ async fn download_file(url: &str, path: &Path) -> Result<()> { } pub async fn download_model(model_id: &str, prefer_local_file: bool) { - // Local file path. - if fs::metadata(model_id).is_ok() { - return; - } - let (registry, name) = parse_model_id(model_id); let registry = ModelRegistry::new(registry).await; diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/serve/engine.rs index f24eff059bc5..7ee767960d7f 100644 --- a/crates/tabby/src/serve/engine.rs +++ b/crates/tabby/src/serve/engine.rs @@ -1,22 +1,23 @@ -use std::fs; +use std::{fs, path::PathBuf}; -use tabby_common::registry::{parse_model_id, ModelRegistry}; +use serde::Deserialize; +use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH}; use tabby_inference::TextGeneration; +use crate::fatal; + pub async fn create_engine( model_id: &str, args: &crate::serve::ServeArgs, ) -> (Box, EngineInfo) { if args.device != super::Device::ExperimentalHttp { if fs::metadata(model_id).is_ok() { - let engine = create_ggml_engine(&args.device, model_id); - ( - engine, - EngineInfo { - prompt_template: args.prompt_template.clone(), - chat_template: args.chat_template.clone(), - }, - ) + let path = PathBuf::from(model_id); + let model_path = path.join(GGML_MODEL_RELATIVE_PATH); + let engine = + create_ggml_engine(&args.device, model_path.display().to_string().as_str()); + let engine_info = EngineInfo::read(path.join("tabby.json")); + (engine, engine_info) } else { let (registry, name) = parse_model_id(model_id); let registry = ModelRegistry::new(registry).await; @@ -43,11 +44,19 @@ pub async fn create_engine( } } +#[derive(Deserialize)] pub struct EngineInfo { pub prompt_template: Option, pub chat_template: Option, } +impl EngineInfo { + fn read(filepath: PathBuf) -> EngineInfo { + serdeconv::from_json_file(&filepath) + .unwrap_or_else(|_| fatal!("Invalid metadata file: {}", filepath.display())) + } +} + fn create_ggml_engine(device: &super::Device, model_path: &str) -> Box { let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default() .model_path(model_path.to_owned()) diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index b89980189c2c..c7db4e3ebb6f 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -109,18 +109,10 @@ pub struct ServeArgs { #[clap(long)] model: String, - /// Prompt template to be used when `--model` is a local file. - #[clap(long)] - prompt_template: Option, - /// Model id for `/chat/completions` API endpoints. #[clap(long)] chat_model: Option, - /// Chat prompt template to be used when `--chat-model` is a local file. - #[clap(long)] - chat_template: Option, - #[clap(long, default_value_t = 8080)] port: u16, @@ -138,9 +130,13 @@ pub async fn main(config: &Config, args: &ServeArgs) { valid_args(args); if args.device != Device::ExperimentalHttp { - download_model(&args.model, true).await; - if let Some(chat_model) = &args.chat_model { - download_model(chat_model, true).await; + if fs::metadata(&args.model).is_ok() { + info!("Loading model from local path {}", &args.model); + } else { + download_model(&args.model, true).await; + if let Some(chat_model) = &args.chat_model { + download_model(chat_model, true).await; + } } } else { warn!("HTTP device is unstable and does not comply with semver expectations.") @@ -259,16 +255,6 @@ fn valid_args(args: &ServeArgs) { if !args.device_indices.is_empty() { warn!("--device-indices is deprecated and will be removed in future release."); } - - if fs::metadata(&args.model).is_ok() && args.prompt_template.is_none() { - fatal!("When passing a local file to --model, --prompt-template is required to set.") - } - - if let Some(chat_model) = &args.chat_model { - if fs::metadata(chat_model).is_ok() && args.chat_template.is_none() { - fatal!("When passing a local file to --chat-model, --chat-template is required to set.") - } - } } fn start_heartbeat(args: &ServeArgs) {