Skip to content

Commit

Permalink
update local model support
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxiaoys committed Nov 2, 2023
1 parent d6ab5da commit bf07bed
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 37 deletions.
4 changes: 3 additions & 1 deletion crates/tabby-common/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl ModelRegistry {
models_dir()
.join(&self.name)
.join(name)
.join("ggml/q8_0.v2.gguf")
.join(GGML_MODEL_RELATIVE_PATH)
}

pub fn get_model_info(&self, name: &str) -> &ModelInfo {
Expand All @@ -81,3 +81,5 @@ pub fn parse_model_id(model_id: &str) -> (&str, &str) {

(parts[0], parts[1])
}

pub static GGML_MODEL_RELATIVE_PATH: &str = "ggml/q8_0.v2.gguf";
5 changes: 0 additions & 5 deletions crates/tabby-download/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,6 @@ async fn download_file(url: &str, path: &Path) -> Result<()> {
}

pub async fn download_model(model_id: &str, prefer_local_file: bool) {
// Local file path.
if fs::metadata(model_id).is_ok() {
return;
}

let (registry, name) = parse_model_id(model_id);

let registry = ModelRegistry::new(registry).await;
Expand Down
29 changes: 19 additions & 10 deletions crates/tabby/src/serve/engine.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
use std::fs;
use std::{fs, path::PathBuf};

use tabby_common::registry::{parse_model_id, ModelRegistry};
use serde::Deserialize;
use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};
use tabby_inference::TextGeneration;

use crate::fatal;

pub async fn create_engine(
model_id: &str,
args: &crate::serve::ServeArgs,
) -> (Box<dyn TextGeneration>, EngineInfo) {
if args.device != super::Device::ExperimentalHttp {
if fs::metadata(model_id).is_ok() {
let engine = create_ggml_engine(&args.device, model_id);
(
engine,
EngineInfo {
prompt_template: args.prompt_template.clone(),
chat_template: args.chat_template.clone(),
},
)
let path = PathBuf::from(model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
let engine =
create_ggml_engine(&args.device, model_path.display().to_string().as_str());
let engine_info = EngineInfo::read(path.join("tabby.json"));
(engine, engine_info)
} else {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
Expand All @@ -43,11 +44,19 @@ pub async fn create_engine(
}
}

#[derive(Deserialize)]
pub struct EngineInfo {
pub prompt_template: Option<String>,
pub chat_template: Option<String>,
}

impl EngineInfo {
fn read(filepath: PathBuf) -> EngineInfo {
serdeconv::from_json_file(&filepath)
.unwrap_or_else(|_| fatal!("Invalid metadata file: {}", filepath.display()))
}
}

fn create_ggml_engine(device: &super::Device, model_path: &str) -> Box<dyn TextGeneration> {
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
.model_path(model_path.to_owned())
Expand Down
28 changes: 7 additions & 21 deletions crates/tabby/src/serve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,10 @@ pub struct ServeArgs {
#[clap(long)]
model: String,

/// Prompt template to be used when `--model` is a local file.
#[clap(long)]
prompt_template: Option<String>,

/// Model id for `/chat/completions` API endpoints.
#[clap(long)]
chat_model: Option<String>,

/// Chat prompt template to be used when `--chat-model` is a local file.
#[clap(long)]
chat_template: Option<String>,

#[clap(long, default_value_t = 8080)]
port: u16,

Expand All @@ -138,9 +130,13 @@ pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args);

if args.device != Device::ExperimentalHttp {
download_model(&args.model, true).await;
if let Some(chat_model) = &args.chat_model {
download_model(chat_model, true).await;
if fs::metadata(&args.model).is_ok() {
info!("Loading model from local path {}", &args.model);
} else {
download_model(&args.model, true).await;
if let Some(chat_model) = &args.chat_model {
download_model(chat_model, true).await;
}
}
} else {
warn!("HTTP device is unstable and does not comply with semver expectations.")
Expand Down Expand Up @@ -259,16 +255,6 @@ fn valid_args(args: &ServeArgs) {
if !args.device_indices.is_empty() {
warn!("--device-indices is deprecated and will be removed in future release.");
}

if fs::metadata(&args.model).is_ok() && args.prompt_template.is_none() {
fatal!("When passing a local file to --model, --prompt-template is required to set.")
}

if let Some(chat_model) = &args.chat_model {
if fs::metadata(chat_model).is_ok() && args.chat_template.is_none() {
fatal!("When passing a local file to --chat-model, --chat-template is required to set.")
}
}
}

fn start_heartbeat(args: &ServeArgs) {
Expand Down

0 comments on commit bf07bed

Please sign in to comment.