Skip to content

Commit

Permalink
feat: add worker command worker::completion and worker::chat (#778)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxiaoys authored Nov 13, 2023
1 parent 510eddc commit e521f06
Show file tree
Hide file tree
Showing 18 changed files with 489 additions and 144 deletions.
62 changes: 61 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ update-openapi-doc:
["components", "schemas", "DebugOptions"] \
])' | jq '.servers[0] |= { url: "https://playground.app.tabbyml.com", description: "Playground server" }' \
> website/static/openapi.json

update-graphql-schema:
cargo run --package tabby-webserver --example update-schema
2 changes: 2 additions & 0 deletions crates/tabby/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ async-trait.workspace = true
tabby-webserver = { path = "../../ee/tabby-webserver" }
thiserror.workspace = true
chrono = "0.4.31"
graphql_client = { version = "0.13.0", features = ["reqwest"] }
reqwest.workspace = true

[dependencies.uuid]
version = "1.3.3"
Expand Down
23 changes: 23 additions & 0 deletions crates/tabby/graphql/worker.query.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
mutation RegisterWorker(
$port: Int!
$kind: WorkerKind!
$name: String!
$device: String!
$arch: String!
$cpuInfo: String!
$cpuCount: Int!
$cudaDevices: [String!]!
) {
worker: registerWorker(
port: $port
kind: $kind
name: $name
device: $device
arch: $arch
cpuInfo: $cpuInfo
cpuCount: $cpuCount
cudaDevices: $cudaDevices
) {
addr
}
}
55 changes: 53 additions & 2 deletions crates/tabby/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
mod api;
mod download;
mod routes;
mod serve;
mod services;

mod download;
mod serve;
mod worker;

use clap::{Parser, Subcommand};
use opentelemetry::{
global,
Expand Down Expand Up @@ -36,6 +38,16 @@ pub enum Commands {

/// Run scheduler progress for cron jobs integrating external code repositories.
Scheduler(SchedulerArgs),

/// Run completion model as worker
#[clap(name = "worker::completion")]
#[command(arg_required_else_help = true)]
WorkerCompletion(worker::WorkerArgs),

/// Run chat model as worker
#[clap(name = "worker::chat")]
#[command(arg_required_else_help = true)]
WorkerChat(worker::WorkerArgs),
}

#[derive(clap::Args)]
Expand All @@ -45,6 +57,41 @@ pub struct SchedulerArgs {
now: bool,
}

#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
pub enum Device {
#[strum(serialize = "cpu")]
Cpu,

#[cfg(feature = "cuda")]
#[strum(serialize = "cuda")]
Cuda,

#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
#[strum(serialize = "metal")]
Metal,

#[cfg(feature = "experimental-http")]
#[strum(serialize = "experimental_http")]
ExperimentalHttp,
}

impl Device {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
pub fn ggml_use_gpu(&self) -> bool {
*self == Device::Metal
}

#[cfg(feature = "cuda")]
pub fn ggml_use_gpu(&self) -> bool {
*self == Device::Cuda
}

#[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))]
pub fn ggml_use_gpu(&self) -> bool {
false
}
}

#[tokio::main]
async fn main() {
let cli = Cli::parse();
Expand All @@ -58,6 +105,10 @@ async fn main() {
Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now)
.await
.unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)),
Commands::WorkerCompletion(args) => {
worker::main(worker::WorkerKind::Completion, args).await
}
Commands::WorkerChat(args) => worker::main(worker::WorkerKind::Chat, args).await,
}

opentelemetry::global::shutdown_tracer_provider();
Expand Down
93 changes: 25 additions & 68 deletions crates/tabby/src/serve/mod.rs → crates/tabby/src/serve.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::{
fs,
net::{Ipv4Addr, SocketAddr},
sync::Arc,
time::Duration,
Expand All @@ -9,7 +8,6 @@ use axum::{routing, Router, Server};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use clap::Args;
use tabby_common::{config::Config, usage};
use tabby_download::download_model;
use tabby_webserver::attach_webserver;
use tokio::time::sleep;
use tower_http::{cors::CorsLayer, timeout::TimeoutLayer};
Expand All @@ -20,7 +18,14 @@ use utoipa_swagger_ui::SwaggerUi;
use crate::{
api::{self},
fatal, routes,
services::{chat, completion, event::create_event_logger, health, model},
services::{
chat::{self, create_chat_service},
completion::{self, create_completion_service},
event::create_logger,
health,
model::download_model_if_needed,
},
Device,
};

#[derive(OpenApi)]
Expand Down Expand Up @@ -62,41 +67,6 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
)]
struct ApiDoc;

#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
pub enum Device {
#[strum(serialize = "cpu")]
Cpu,

#[cfg(feature = "cuda")]
#[strum(serialize = "cuda")]
Cuda,

#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
#[strum(serialize = "metal")]
Metal,

#[cfg(feature = "experimental-http")]
#[strum(serialize = "experimental_http")]
ExperimentalHttp,
}

impl Device {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
pub fn ggml_use_gpu(&self) -> bool {
*self == Device::Metal
}

#[cfg(feature = "cuda")]
pub fn ggml_use_gpu(&self) -> bool {
*self == Device::Cuda
}

#[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))]
pub fn ggml_use_gpu(&self) -> bool {
false
}
}

#[derive(Args)]
pub struct ServeArgs {
/// Model id for `/completions` API endpoint.
Expand Down Expand Up @@ -152,43 +122,30 @@ pub async fn main(config: &Config, args: &ServeArgs) {
}

async fn load_model(args: &ServeArgs) {
if fs::metadata(&args.model).is_ok() {
info!("Loading model from local path {}", &args.model);
} else {
download_model(&args.model, true).await;
if let Some(chat_model) = &args.chat_model {
download_model(chat_model, true).await;
}
download_model_if_needed(&args.model).await;
if let Some(chat_model) = &args.chat_model {
download_model_if_needed(chat_model).await
}
}

async fn api_router(args: &ServeArgs, config: &Config) -> Router {
let logger = Arc::new(create_event_logger());
let logger = Arc::new(create_logger());
let code = Arc::new(crate::services::code::create_code_search());
let completion_state = {
let (
engine,
model::PromptInfo {
prompt_template, ..
},
) = model::load_text_generation(&args.model, &args.device, args.parallelism).await;
let state = completion::CompletionService::new(
engine.clone(),
let completion = Arc::new(
create_completion_service(
code.clone(),
logger.clone(),
prompt_template,
);
Arc::new(state)
};
&args.model,
&args.device,
args.parallelism,
)
.await,
);

let chat_state = if let Some(chat_model) = &args.chat_model {
let (engine, model::PromptInfo { chat_template, .. }) =
model::load_text_generation(chat_model, &args.device, args.parallelism).await;
let Some(chat_template) = chat_template else {
panic!("Chat model requires specifying prompt template");
};
let state = chat::ChatService::new(engine, chat_template);
Some(Arc::new(state))
let chat_state = if let Some(_chat_model) = &args.chat_model {
Some(Arc::new(
create_chat_service(&args.model, &args.device, args.parallelism).await,
))
} else {
None
};
Expand Down Expand Up @@ -220,7 +177,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
Router::new()
.route(
"/v1/completions",
routing::post(routes::completions).with_state(completion_state),
routing::post(routes::completions).with_state(completion),
)
.layer(TimeoutLayer::new(Duration::from_secs(
config.server.completion_timeout,
Expand Down
16 changes: 15 additions & 1 deletion crates/tabby/src/services/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptio
use tracing::debug;
use utoipa::ToSchema;

use super::model;
use crate::{fatal, Device};

#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"messages": [
Expand Down Expand Up @@ -40,7 +43,7 @@ pub struct ChatService {
}

impl ChatService {
pub fn new(engine: Arc<dyn TextGeneration>, chat_template: String) -> Self {
fn new(engine: Arc<dyn TextGeneration>, chat_template: String) -> Self {
Self {
engine,
prompt_builder: ChatPromptBuilder::new(chat_template),
Expand Down Expand Up @@ -73,3 +76,14 @@ impl ChatService {
Box::pin(s)
}
}

pub async fn create_chat_service(model: &str, device: &Device, parallelism: u8) -> ChatService {
let (engine, model::PromptInfo { chat_template, .. }) =
model::load_text_generation(model, device, parallelism).await;

let Some(chat_template) = chat_template else {
fatal!("Chat model requires specifying prompt template");
};

ChatService::new(engine, chat_template)
}
Loading

0 comments on commit e521f06

Please sign in to comment.