Skip to content

Commit

Permalink
refactor: add experimental-http feature (#750)
Browse files Browse the repository at this point in the history
* add experimental-http feature, update code

* refactor: add experimental-http feature
  • Loading branch information
darknight authored Nov 11, 2023
1 parent f2ea57b commit 2259237
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 39 deletions.
3 changes: 2 additions & 1 deletion crates/tabby/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ edition = "2021"

[features]
cuda = ["llama-cpp-bindings/cuda"]
experimental-http = ["dep:http-api-bindings"]

[dependencies]
tabby-common = { path = "../tabby-common" }
Expand Down Expand Up @@ -36,7 +37,7 @@ tantivy = { workspace = true }
anyhow = { workspace = true }
sysinfo = "0.29.8"
nvml-wrapper = "0.9.0"
http-api-bindings = { path = "../http-api-bindings" }
http-api-bindings = { path = "../http-api-bindings", optional = true } # included when build with `experimental-http` feature
async-stream = { workspace = true }
axum-streams = { version = "0.9.1", features = ["json"] }
minijinja = { version = "1.0.8", features = ["loader"] }
Expand Down
55 changes: 28 additions & 27 deletions crates/tabby/src/serve/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,40 @@ pub async fn create_engine(
model_id: &str,
args: &crate::serve::ServeArgs,
) -> (Box<dyn TextGeneration>, EngineInfo) {
if args.device != super::Device::ExperimentalHttp {
if fs::metadata(model_id).is_ok() {
let path = PathBuf::from(model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
let engine = create_ggml_engine(
&args.device,
model_path.display().to_string().as_str(),
args.parallelism,
);
let engine_info = EngineInfo::read(path.join("tabby.json"));
(engine, engine_info)
} else {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name);
let engine = create_ggml_engine(&args.device, &model_path, args.parallelism);
(
engine,
EngineInfo {
prompt_template: model_info.prompt_template.clone(),
chat_template: model_info.chat_template.clone(),
},
)
}
} else {
#[cfg(feature = "experimental-http")]
if args.device == crate::serve::Device::ExperimentalHttp {
let (engine, prompt_template) = http_api_bindings::create(model_id);
(
return (
engine,
EngineInfo {
prompt_template: Some(prompt_template),
chat_template: None,
},
);
}

if fs::metadata(model_id).is_ok() {
let path = PathBuf::from(model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
let engine = create_ggml_engine(
&args.device,
model_path.display().to_string().as_str(),
args.parallelism,
);
let engine_info = EngineInfo::read(path.join("tabby.json"));
(engine, engine_info)
} else {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name);
let engine = create_ggml_engine(&args.device, &model_path, args.parallelism);
(
engine,
EngineInfo {
prompt_template: model_info.prompt_template.clone(),
chat_template: model_info.chat_template.clone(),
},
)
}
}
Expand Down
30 changes: 19 additions & 11 deletions crates/tabby/src/serve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use tabby_common::{
use tabby_download::download_model;
use tokio::time::sleep;
use tower_http::{cors::CorsLayer, timeout::TimeoutLayer};
use tracing::{info, warn};
use tracing::info;
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;

Expand Down Expand Up @@ -86,6 +86,7 @@ pub enum Device {
#[strum(serialize = "metal")]
Metal,

#[cfg(feature = "experimental-http")]
#[strum(serialize = "experimental_http")]
ExperimentalHttp,
}
Expand Down Expand Up @@ -131,18 +132,14 @@ pub struct ServeArgs {
}

pub async fn main(config: &Config, args: &ServeArgs) {
if args.device != Device::ExperimentalHttp {
if fs::metadata(&args.model).is_ok() {
info!("Loading model from local path {}", &args.model);
} else {
download_model(&args.model, true).await;
if let Some(chat_model) = &args.chat_model {
download_model(chat_model, true).await;
}
}
#[cfg(feature = "experimental-http")]
if args.device == Device::ExperimentalHttp {
tracing::warn!("HTTP device is unstable and does not comply with semver expectations.");
} else {
warn!("HTTP device is unstable and does not comply with semver expectations.")
load_model(args).await;
}
#[cfg(not(feature = "experimental-http"))]
load_model(args).await;

info!("Starting server, this might takes a few minutes...");

Expand Down Expand Up @@ -172,6 +169,17 @@ pub async fn main(config: &Config, args: &ServeArgs) {
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
}

async fn load_model(args: &ServeArgs) {
if fs::metadata(&args.model).is_ok() {
info!("Loading model from local path {}", &args.model);
} else {
download_model(&args.model, true).await;
if let Some(chat_model) = &args.chat_model {
download_model(chat_model, true).await;
}
}
}

async fn api_router(args: &ServeArgs, config: &Config) -> Router {
let code = Arc::new(create_code_search());
let completion_state = {
Expand Down

0 comments on commit 2259237

Please sign in to comment.