diff --git a/Cargo.lock b/Cargo.lock index d0e897693382..269c161e07c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1151,6 +1151,12 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "htmlescape" version = "0.3.1" @@ -2646,6 +2652,19 @@ dependencies = [ "digest", ] +[[package]] +name = "sha256" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7895c8ae88588ccead14ff438b939b0c569cd619116f14b4d13fdff7b8333386" +dependencies = [ + "async-trait", + "bytes", + "hex", + "sha2", + "tokio", +] + [[package]] name = "sharded-slab" version = "0.1.4" @@ -2855,6 +2874,7 @@ dependencies = [ name = "tabby-common" version = "0.5.0-dev" dependencies = [ + "anyhow", "chrono", "filenamify", "lazy_static", @@ -2880,6 +2900,7 @@ dependencies = [ "serde", "serde_json", "serdeconv", + "sha256", "tabby-common", "tokio-retry", "tracing", diff --git a/MODEL_SPEC.md b/MODEL_SPEC.md index d214d61e2365..e1f5daf3282f 100644 --- a/MODEL_SPEC.md +++ b/MODEL_SPEC.md @@ -1,5 +1,7 @@ # Tabby Model Specification (Unstable) +> [!WARNING] **Since v0.5.0** This document is intended exclusively for local models. For remote models, we rely on the `tabby-registry` repository within each organization or user. You can refer to https://github.com/TabbyML/registry-tabby/blob/main/models.json for an example. + Tabby organizes the model within a directory. This document provides an explanation of the necessary contents for supporting model serving. An example model directory can be found at https://huggingface.co/TabbyML/StarCoder-1B The minimal Tabby model directory should include the following contents: diff --git a/crates/tabby-common/Cargo.toml b/crates/tabby-common/Cargo.toml index 936ce39b9124..02ac08b50c23 100644 --- a/crates/tabby-common/Cargo.toml +++ b/crates/tabby-common/Cargo.toml @@ -14,6 +14,7 @@ reqwest = { workspace = true, features = [ "json" ] } tokio = { workspace = true, features = ["rt", "macros"] } uuid = { version = "1.4.1", features = ["v4"] } tantivy.workspace = true +anyhow.workspace = true [features] testutils = [] diff --git a/crates/tabby-common/src/lib.rs b/crates/tabby-common/src/lib.rs index 3bbbd5b8b658..2458fefd86ff 100644 --- a/crates/tabby-common/src/lib.rs +++ b/crates/tabby-common/src/lib.rs @@ -3,6 +3,7 @@ pub mod events; pub mod index; pub mod languages; pub mod path; +pub mod registry; pub mod usage; use std::{ diff --git a/crates/tabby-common/src/path.rs b/crates/tabby-common/src/path.rs index 4a143921951a..125eaa34e41f 100644 --- a/crates/tabby-common/src/path.rs +++ b/crates/tabby-common/src/path.rs @@ -51,38 +51,4 @@ pub fn events_dir() -> PathBuf { tabby_root().join("events") } -pub struct ModelDir(PathBuf); - -impl ModelDir { - pub fn new(model: &str) -> Self { - Self(models_dir().join(model)) - } - - pub fn from(path: &str) -> Self { - Self(PathBuf::from(path)) - } - - pub fn path(&self) -> &PathBuf { - &self.0 - } - - pub fn path_string(&self, name: &str) -> String { - self.0.join(name).display().to_string() - } - - pub fn cache_info_file(&self) -> String { - self.path_string(".cache_info.json") - } - - pub fn metadata_file(&self) -> String { - self.path_string("tabby.json") - } - - pub fn ggml_q8_0_file(&self) -> String { - self.path_string("ggml/q8_0.gguf") - } - - pub fn ggml_q8_0_v2_file(&self) -> String { - self.path_string("ggml/q8_0.v2.gguf") - } -} +mod registry {} diff --git a/crates/tabby-common/src/registry.rs b/crates/tabby-common/src/registry.rs new file mode 100644 index 000000000000..7db5c913fc9e --- /dev/null +++ b/crates/tabby-common/src/registry.rs @@ -0,0 +1,85 @@ +use std::{fs, path::PathBuf}; + +use anyhow::Result; +use serde::{Deserialize, Serialize}; + +use crate::path::models_dir; + +#[derive(Serialize, Deserialize)] +pub struct ModelInfo { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_template: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub chat_template: Option, + pub urls: Vec, + pub sha256: String, +} + +fn models_json_file(registry: &str) -> PathBuf { + models_dir().join(registry).join("models.json") +} + +async fn load_remote_registry(registry: &str) -> Result> { + let value = reqwest::get(format!( + "https://raw.githubusercontent.com/{}/registry-tabby/main/models.json", + registry + )) + .await? + .json() + .await?; + fs::create_dir_all(models_dir().join(registry))?; + serdeconv::to_json_file(&value, models_json_file(registry))?; + Ok(value) +} + +fn load_local_registry(registry: &str) -> Result> { + Ok(serdeconv::from_json_file(models_json_file(registry))?) +} + +#[derive(Default)] +pub struct ModelRegistry { + pub name: String, + pub models: Vec, +} + +impl ModelRegistry { + pub async fn new(registry: &str) -> Self { + Self { + name: registry.to_owned(), + models: load_remote_registry(registry).await.unwrap_or_else(|err| { + load_local_registry(registry).unwrap_or_else(|_| { + panic!( + "Failed to fetch model organization <{}>: {:?}", + registry, err + ) + }) + }), + } + } + + pub fn get_model_path(&self, name: &str) -> PathBuf { + models_dir() + .join(&self.name) + .join(name) + .join(GGML_MODEL_RELATIVE_PATH) + } + + pub fn get_model_info(&self, name: &str) -> &ModelInfo { + self.models + .iter() + .find(|x| x.name == name) + .unwrap_or_else(|| panic!("Invalid model_id <{}/{}>", self.name, name)) + } +} + +pub fn parse_model_id(model_id: &str) -> (&str, &str) { + let parts: Vec<_> = model_id.split('/').collect(); + if parts.len() != 2 { + panic!("Invalid model id {}", model_id); + } + + (parts[0], parts[1]) +} + +pub static GGML_MODEL_RELATIVE_PATH: &str = "ggml/q8_0.v2.gguf"; diff --git a/crates/tabby-download/Cargo.toml b/crates/tabby-download/Cargo.toml index 5fe8c9d5f8d8..aaead140c25e 100644 --- a/crates/tabby-download/Cargo.toml +++ b/crates/tabby-download/Cargo.toml @@ -17,3 +17,4 @@ urlencoding = "2.1.3" serde_json = { workspace = true } cached = { version = "0.46.0", features = ["async", "proc_macro"] } async-trait = { workspace = true } +sha256 = "1.4.0" diff --git a/crates/tabby-download/src/cache_info.rs b/crates/tabby-download/src/cache_info.rs deleted file mode 100644 index c82ffef6fa57..000000000000 --- a/crates/tabby-download/src/cache_info.rs +++ /dev/null @@ -1,46 +0,0 @@ -use std::{collections::HashMap, fs, path::Path}; - -use anyhow::Result; -use serde::{Deserialize, Serialize}; -use tabby_common::path::ModelDir; - -#[derive(Serialize, Deserialize)] -pub struct CacheInfo { - etags: HashMap, -} - -impl CacheInfo { - pub async fn from(model_id: &str) -> CacheInfo { - if let Some(cache_info) = Self::from_local(model_id) { - cache_info - } else { - CacheInfo { - etags: HashMap::new(), - } - } - } - - fn from_local(model_id: &str) -> Option { - let cache_info_file = ModelDir::new(model_id).cache_info_file(); - if fs::metadata(&cache_info_file).is_ok() { - serdeconv::from_json_file(cache_info_file).ok() - } else { - None - } - } - - pub fn local_cache_key(&self, path: &str) -> Option<&str> { - self.etags.get(path).map(|x| x.as_str()) - } - - pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) { - self.etags.insert(path.to_string(), cache_key.to_string()); - } - - pub fn save(&self, model_id: &str) -> Result<()> { - let cache_info_file = ModelDir::new(model_id).cache_info_file(); - let cache_info_file_path = Path::new(&cache_info_file); - serdeconv::to_json_file(self, cache_info_file_path)?; - Ok(()) - } -} diff --git a/crates/tabby-download/src/lib.rs b/crates/tabby-download/src/lib.rs index 99ebc3b8d6c0..9f7b92cfd96c 100644 --- a/crates/tabby-download/src/lib.rs +++ b/crates/tabby-download/src/lib.rs @@ -1,126 +1,65 @@ -mod cache_info; -mod registry; - use std::{cmp, fs, io::Write, path::Path}; use anyhow::{anyhow, Result}; -use cache_info::CacheInfo; use futures_util::StreamExt; use indicatif::{ProgressBar, ProgressStyle}; -use registry::{create_registry, Registry}; -use tabby_common::path::ModelDir; +use tabby_common::registry::{parse_model_id, ModelRegistry}; use tokio_retry::{ strategy::{jitter, ExponentialBackoff}, Retry, }; +use tracing::{info, warn}; -pub struct Downloader { - model_id: String, +async fn download_model_impl( + registry: &ModelRegistry, + name: &str, prefer_local_file: bool, - registry: Box, -} - -impl Downloader { - pub fn new(model_id: &str, prefer_local_file: bool) -> Self { - Self { - model_id: model_id.to_owned(), - prefer_local_file, - registry: create_registry(), - } - } - - pub async fn download_ggml_files(&self) -> Result<()> { - let files = vec![("tabby.json", true), ("ggml/q8_0.v2.gguf", true)]; - self.download_files(&files).await - } - - async fn download_files(&self, files: &[(&str, bool)]) -> Result<()> { - // Local path, no need for downloading. - if fs::metadata(&self.model_id).is_ok() { +) -> Result<()> { + let model_info = registry.get_model_info(name); + let model_path = registry.get_model_path(name); + if model_path.exists() { + if !prefer_local_file { + info!("Checking model integrity.."); + let checksum = sha256::try_digest(&model_path).unwrap(); + if checksum == model_info.sha256 { + return Ok(()); + } + + warn!( + "Checksum doesn't match for <{}/{}>, re-downloading...", + registry.name, name + ); + fs::remove_file(&model_path)?; + } else { return Ok(()); } - - let mut cache_info = CacheInfo::from(&self.model_id).await; - for (path, required) in files { - download_model_file( - self.registry.as_ref(), - &mut cache_info, - &self.model_id, - path, - self.prefer_local_file, - *required, - ) - .await?; - } - Ok(()) } -} - -async fn download_model_file( - registry: &dyn Registry, - cache_info: &mut CacheInfo, - model_id: &str, - path: &str, - prefer_local_file: bool, - required: bool, -) -> Result<()> { - // Create url. - let url = registry.build_url(model_id, path); - - // Create destination path. - let filepath = ModelDir::new(model_id).path_string(path); - // Get cache key. - let local_cache_key = cache_info.local_cache_key(path); - - // Check local file ready. - let local_cache_key = local_cache_key - // local cache key is only valid if == 404 or local file exists. - // FIXME(meng): use sha256 to validate file is ready. - .filter(|&local_cache_key| local_cache_key == "404" || fs::metadata(&filepath).is_ok()); + let registry = std::env::var("TABBY_DOWNLOAD_HOST").unwrap_or("huggingface.co".to_owned()); + let Some(model_url) = model_info.urls.iter().find(|x| x.contains(®istry)) else { + return Err(anyhow!( + "Invalid mirror <{}> for model urls: {:?}", + registry, + model_info.urls + )); + }; let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2); - let download_job = Retry::spawn(strategy, || { - download_file(registry, &url, &filepath, local_cache_key, !required) - }); - if let Ok(etag) = download_job.await { - cache_info.set_local_cache_key(path, &etag).await; - } else if prefer_local_file && local_cache_key.is_some() { - // Do nothing. - } else { - return Err(anyhow!("Failed to fetch url {}", url)); - } - - cache_info.save(model_id)?; + let download_job = Retry::spawn(strategy, || download_file(model_url, model_path.as_path())); + download_job.await?; Ok(()) } -async fn download_file( - registry: &dyn Registry, - url: &str, - path: &str, - local_cache_key: Option<&str>, - is_optional: bool, -) -> Result { - fs::create_dir_all(Path::new(path).parent().unwrap())?; +async fn download_file(url: &str, path: &Path) -> Result<()> { + fs::create_dir_all(path.parent().unwrap())?; // Reqwest setup let res = reqwest::get(url).await?; - if is_optional && res.status() == 404 { - // Cache 404 for optional file. - return Ok("404".to_owned()); - } - if !res.status().is_success() { return Err(anyhow!(format!("Invalid url: {}", url))); } - let remote_cache_key = registry.build_cache_key(url).await?; - if local_cache_key == Some(remote_cache_key.as_str()) { - return Ok(remote_cache_key); - } - let total_size = res .content_length() .ok_or(anyhow!("No content length in headers"))?; @@ -130,7 +69,7 @@ async fn download_file( pb.set_style(ProgressStyle::default_bar() .template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")? .progress_chars("#>-")); - pb.set_message(format!("Downloading {}", path)); + pb.set_message(format!("Downloading {}", path.display())); // download chunks let mut file = fs::File::create(path)?; @@ -145,6 +84,17 @@ async fn download_file( pb.set_position(new); } - pb.finish_with_message(format!("Downloaded {}", path)); - Ok(remote_cache_key) + pb.finish_with_message(format!("Downloaded {}", path.display())); + Ok(()) +} + +pub async fn download_model(model_id: &str, prefer_local_file: bool) { + let (registry, name) = parse_model_id(model_id); + + let registry = ModelRegistry::new(registry).await; + + let handler = |err| panic!("Failed to fetch model '{}' due to '{}'", model_id, err); + download_model_impl(®istry, name, prefer_local_file) + .await + .unwrap_or_else(handler) } diff --git a/crates/tabby/src/download.rs b/crates/tabby/src/download.rs index bfea8b8ed94e..dcc36c1d116e 100644 --- a/crates/tabby/src/download.rs +++ b/crates/tabby/src/download.rs @@ -1,9 +1,7 @@ use clap::Args; -use tabby_download::Downloader; +use tabby_download::download_model; use tracing::info; -use crate::fatal; - #[derive(Args)] pub struct DownloadArgs { /// model id to fetch. @@ -16,12 +14,6 @@ pub struct DownloadArgs { } pub async fn main(args: &DownloadArgs) { - let downloader = Downloader::new(&args.model, args.prefer_local_file); - - downloader - .download_ggml_files() - .await - .unwrap_or_else(|err| fatal!("Failed to fetch model '{}' due to '{}'", args.model, err)); - + download_model(&args.model, args.prefer_local_file).await; info!("model '{}' is ready", args.model); } diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/serve/engine.rs index eff29f0f926b..7ee767960d7f 100644 --- a/crates/tabby/src/serve/engine.rs +++ b/crates/tabby/src/serve/engine.rs @@ -1,28 +1,39 @@ -use std::path::Path; +use std::{fs, path::PathBuf}; use serde::Deserialize; -use tabby_common::path::ModelDir; +use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH}; use tabby_inference::TextGeneration; use crate::fatal; -pub fn create_engine( - model: &str, +pub async fn create_engine( + model_id: &str, args: &crate::serve::ServeArgs, ) -> (Box, EngineInfo) { if args.device != super::Device::ExperimentalHttp { - let model_dir = get_model_dir(model); - let metadata = read_metadata(&model_dir); - let engine = create_ggml_engine(&args.device, &model_dir); - ( - engine, - EngineInfo { - prompt_template: metadata.prompt_template, - chat_template: metadata.chat_template, - }, - ) + if fs::metadata(model_id).is_ok() { + let path = PathBuf::from(model_id); + let model_path = path.join(GGML_MODEL_RELATIVE_PATH); + let engine = + create_ggml_engine(&args.device, model_path.display().to_string().as_str()); + let engine_info = EngineInfo::read(path.join("tabby.json")); + (engine, engine_info) + } else { + let (registry, name) = parse_model_id(model_id); + let registry = ModelRegistry::new(registry).await; + let model_path = registry.get_model_path(name).display().to_string(); + let model_info = registry.get_model_info(name); + let engine = create_ggml_engine(&args.device, &model_path); + ( + engine, + EngineInfo { + prompt_template: model_info.prompt_template.clone(), + chat_template: model_info.chat_template.clone(), + }, + ) + } } else { - let (engine, prompt_template) = http_api_bindings::create(model); + let (engine, prompt_template) = http_api_bindings::create(model_id); ( engine, EngineInfo { @@ -33,38 +44,25 @@ pub fn create_engine( } } +#[derive(Deserialize)] pub struct EngineInfo { pub prompt_template: Option, pub chat_template: Option, } -fn create_ggml_engine(device: &super::Device, model_dir: &ModelDir) -> Box { +impl EngineInfo { + fn read(filepath: PathBuf) -> EngineInfo { + serdeconv::from_json_file(&filepath) + .unwrap_or_else(|_| fatal!("Invalid metadata file: {}", filepath.display())) + } +} + +fn create_ggml_engine(device: &super::Device, model_path: &str) -> Box { let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default() - .model_path(model_dir.ggml_q8_0_v2_file()) + .model_path(model_path.to_owned()) .use_gpu(device.ggml_use_gpu()) .build() .unwrap(); Box::new(llama_cpp_bindings::LlamaTextGeneration::create(options)) } - -fn get_model_dir(model: &str) -> ModelDir { - if Path::new(model).exists() { - ModelDir::from(model) - } else { - ModelDir::new(model) - } -} - -#[derive(Deserialize)] -struct Metadata { - #[allow(dead_code)] - auto_model: String, - prompt_template: Option, - chat_template: Option, -} - -fn read_metadata(model_dir: &ModelDir) -> Metadata { - serdeconv::from_json_file(model_dir.metadata_file()) - .unwrap_or_else(|_| fatal!("Invalid metadata file: {}", model_dir.metadata_file())) -} diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index c19270a970f3..c7db4e3ebb6f 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -7,6 +7,7 @@ mod search; mod ui; use std::{ + fs, net::{Ipv4Addr, SocketAddr}, sync::Arc, time::Duration, @@ -16,7 +17,7 @@ use axum::{routing, Router, Server}; use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use clap::Args; use tabby_common::{config::Config, usage}; -use tabby_download::Downloader; +use tabby_download::download_model; use tokio::time::sleep; use tower_http::{cors::CorsLayer, timeout::TimeoutLayer}; use tracing::{info, warn}; @@ -129,9 +130,13 @@ pub async fn main(config: &Config, args: &ServeArgs) { valid_args(args); if args.device != Device::ExperimentalHttp { - download_model(&args.model).await; - if let Some(chat_model) = &args.chat_model { - download_model(chat_model).await; + if fs::metadata(&args.model).is_ok() { + info!("Loading model from local path {}", &args.model); + } else { + download_model(&args.model, true).await; + if let Some(chat_model) = &args.chat_model { + download_model(chat_model, true).await; + } } } else { warn!("HTTP device is unstable and does not comply with semver expectations.") @@ -144,7 +149,7 @@ pub async fn main(config: &Config, args: &ServeArgs) { let app = Router::new() .route("/", routing::get(ui::handler)) - .merge(api_router(args, config)) + .merge(api_router(args, config).await) .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc)) .fallback(ui::handler); @@ -165,7 +170,7 @@ pub async fn main(config: &Config, args: &ServeArgs) { .unwrap_or_else(|err| fatal!("Error happens during serving: {}", err)) } -fn api_router(args: &ServeArgs, config: &Config) -> Router { +async fn api_router(args: &ServeArgs, config: &Config) -> Router { let index_server = Arc::new(IndexServer::new()); let completion_state = { let ( @@ -173,7 +178,7 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router { EngineInfo { prompt_template, .. }, - ) = create_engine(&args.model, args); + ) = create_engine(&args.model, args).await; let engine = Arc::new(engine); let state = completions::CompletionState::new( engine.clone(), @@ -184,7 +189,7 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router { }; let chat_state = if let Some(chat_model) = &args.chat_model { - let (engine, EngineInfo { chat_template, .. }) = create_engine(chat_model, args); + let (engine, EngineInfo { chat_template, .. }) = create_engine(chat_model, args).await; let Some(chat_template) = chat_template else { panic!("Chat model requires specifying prompt template"); }; @@ -262,13 +267,6 @@ fn start_heartbeat(args: &ServeArgs) { }); } -async fn download_model(model: &str) { - let downloader = Downloader::new(model, /* prefer_local_file= */ true); - let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", model, err,); - let download_result = downloader.download_ggml_files().await; - download_result.unwrap_or_else(handler); -} - trait OpenApiOverride { fn override_doc(&mut self, args: &ServeArgs); }