From ce338c7436bc7a045bb1d8c621fa23fbcebbea8a Mon Sep 17 00:00:00 2001 From: Eric Date: Fri, 17 Nov 2023 15:05:39 +0800 Subject: [PATCH] feat: validate token during worker registration (#803) * feat: validate token during worker registration * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * resolve comments * reslove comments * format file, update schema file * resolve comment --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- Cargo.lock | 125 ++++++++++++++++----- crates/tabby-common/src/path.rs | 2 +- crates/tabby/src/worker.rs | 7 ++ ee/tabby-webserver/Cargo.toml | 14 +++ ee/tabby-webserver/graphql/schema.graphql | 5 + ee/tabby-webserver/src/api.rs | 3 +- ee/tabby-webserver/src/db.rs | 127 ++++++++++++++++++++++ ee/tabby-webserver/src/lib.rs | 22 +++- ee/tabby-webserver/src/schema.rs | 18 ++- ee/tabby-webserver/src/server.rs | 29 ++++- 10 files changed, 318 insertions(+), 34 deletions(-) create mode 100644 ee/tabby-webserver/src/db.rs diff --git a/Cargo.lock b/Cargo.lock index dbb183b147a1..cf095c7e137b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1415,6 +1415,18 @@ version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" +[[package]] +name = "fallible-iterator" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fastdivide" version = "0.4.0" @@ -1771,6 +1783,15 @@ dependencies = [ "allocator-api2", ] +[[package]] +name = "hashlink" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" +dependencies = [ + "hashbrown 0.14.0", +] + [[package]] name = "headers" version = "0.3.8" @@ -1966,9 +1987,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.56" +version = "0.1.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0722cd7114b7de04316e7ea5456a0bbb20e4adb46fd27a3697adb812cff0f37c" +checksum = "2fad5b825842d2b38bd206f3e81d6957625fd7f0a361e345c30e01a0ae2dd613" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -2141,9 +2162,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.63" +version = "0.3.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f37a4a5928311ac501dee68b3c7613a1037d0edb30c8e5427bd832d55d1b790" +checksum = "c5f195fe497f702db0f318b07fdd68edb16955aed830df8363d837542f8f935a" dependencies = [ "wasm-bindgen", ] @@ -2262,6 +2283,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "libsqlite3-sys" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afc22eff61b133b115c6e8c74e818c628d6d5e7a502afea6f64dee076dd94326" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "libssh2-sys" version = "0.3.0" @@ -2311,9 +2343,9 @@ checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" [[package]] name = "linux-raw-sys" -version = "0.4.8" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3852614a3bd9ca9804678ba6be5e3b8ce76dfc902cae004e3e0c44051b6e88db" +checksum = "da2479e8c062e40bf0066ffa0bc823de0a9368974af99c9f6df941d2c231e03f" [[package]] name = "llama-cpp-bindings" @@ -3503,6 +3535,32 @@ dependencies = [ "serde", ] +[[package]] +name = "rusqlite" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2" +dependencies = [ + "bitflags 2.4.0", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + +[[package]] +name = "rusqlite_migration" +version = "1.1.0-alpha.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ef119690ca6bac53498f4478badf364840780248132ab5097891d0cfdf42eda" +dependencies = [ + "log", + "rusqlite", + "tokio", + "tokio-rusqlite", +] + [[package]] name = "rust-embed" version = "6.6.1" @@ -3661,7 +3719,7 @@ dependencies = [ "bitflags 2.4.0", "errno", "libc", - "linux-raw-sys 0.4.8", + "linux-raw-sys 0.4.10", "windows-sys 0.48.0", ] @@ -3773,9 +3831,9 @@ dependencies = [ [[package]] name = "security-framework" -version = "2.9.1" +version = "2.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8" +checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de" dependencies = [ "bitflags 1.3.2", "core-foundation", @@ -3786,9 +3844,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.9.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f51d0c0d83bec45f16480d0ce0058397a69e48fcdc52d1dc8855fb68acbd31a7" +checksum = "e932934257d3b408ed8f30db49d85ea163bfe74961f017f405b025af298f0c7a" dependencies = [ "core-foundation-sys", "libc", @@ -4406,6 +4464,7 @@ dependencies = [ "anyhow", "axum", "bincode", + "chrono", "futures", "hyper", "juniper", @@ -4413,14 +4472,19 @@ dependencies = [ "lazy_static", "mime_guess", "pin-project", + "rusqlite", + "rusqlite_migration", "rust-embed 8.0.0", "serde", + "tabby-common", "tarpc", "thiserror", "tokio", + "tokio-rusqlite", "tokio-tungstenite", "tracing", "unicase", + "uuid 1.4.1", ] [[package]] @@ -4824,6 +4888,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rusqlite" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7aa66395f5ff117faee90c9458232c936405f9227ad902038000b74b3bc1feac" +dependencies = [ + "crossbeam-channel", + "rusqlite", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.24.1" @@ -5670,9 +5745,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.86" +version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bba0e8cb82ba49ff4e229459ff22a191bbe9a1cb3a341610c9c33efc27ddf73" +checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -5680,9 +5755,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.86" +version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19b04bc93f9d6bdee709f6bd2118f57dd6679cf1176a1af464fca3ab0d66d8fb" +checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd" dependencies = [ "bumpalo", "log", @@ -5695,9 +5770,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.36" +version = "0.4.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d1985d03709c53167ce907ff394f5316aa22cb4e12761295c5dc57dacb6297e" +checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03" dependencies = [ "cfg-if", "js-sys", @@ -5707,9 +5782,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.86" +version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14d6b024f1a526bb0234f52840389927257beb670610081360e5a03c5df9c258" +checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -5717,9 +5792,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.86" +version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e128beba882dd1eb6200e1dc92ae6c5dbaa4311aa7bb211ca035779e5efc39f8" +checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", @@ -5730,9 +5805,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.86" +version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed9d5b4305409d1fc9482fee2d7f9bcbf24b3972bf59817ef757e23982242a93" +checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" [[package]] name = "wasm-streams" @@ -5749,9 +5824,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.63" +version = "0.3.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bdd9ef4e984da1187bf8110c5cf5b845fbc87a23602cdf912386a76fcd3a7c2" +checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/crates/tabby-common/src/path.rs b/crates/tabby-common/src/path.rs index 125eaa34e41f..55f757e6fc5e 100644 --- a/crates/tabby-common/src/path.rs +++ b/crates/tabby-common/src/path.rs @@ -18,7 +18,7 @@ pub fn set_tabby_root(path: PathBuf) { cell.replace(path); } -fn tabby_root() -> PathBuf { +pub fn tabby_root() -> PathBuf { let mut cell = TABBY_ROOT.lock().unwrap(); cell.get_mut().clone() } diff --git a/crates/tabby/src/worker.rs b/crates/tabby/src/worker.rs index 71fe68d23598..0fb390e832b3 100644 --- a/crates/tabby/src/worker.rs +++ b/crates/tabby/src/worker.rs @@ -33,6 +33,10 @@ pub struct WorkerArgs { #[clap(long, default_value_t = 8080)] port: u16, + /// Server token to register this worker to. + #[clap(long)] + token: String, + /// Model id #[clap(long, help_heading=Some("Model Options"))] model: String, @@ -99,6 +103,7 @@ async fn request_register(kind: WorkerKind, args: &WorkerArgs) { args.port, args.model.to_owned(), args.device.to_string(), + args.token.clone(), ) .await { @@ -112,6 +117,7 @@ async fn request_register_impl( port: u16, name: String, device: String, + token: String, ) -> Result<()> { let client = tabby_webserver::api::create_client(url).await; let (cpu_info, cpu_count) = read_cpu_info(); @@ -127,6 +133,7 @@ async fn request_register_impl( cpu_info, cpu_count as i32, cuda_devices, + token, ) .await??; diff --git a/ee/tabby-webserver/Cargo.toml b/ee/tabby-webserver/Cargo.toml index ed0811faa55d..4f6f37e70fc5 100644 --- a/ee/tabby-webserver/Cargo.toml +++ b/ee/tabby-webserver/Cargo.toml @@ -9,6 +9,7 @@ homepage.workspace = true anyhow.workspace = true axum = { workspace = true, features = ["ws"] } bincode = "1.3.3" +chrono = "0.4" futures.workspace = true hyper = { workspace = true, features=["client"]} juniper.workspace = true @@ -16,14 +17,27 @@ juniper-axum = { path = "../../crates/juniper-axum" } lazy_static = "1.4.0" mime_guess = "2.0.4" pin-project = "1.1.3" +rusqlite = { version = "0.29.0", features = ["bundled"] } +# `async-tokio-rusqlite` is only available from 1.1.0-alpha.2, will bump up version when it's stable +rusqlite_migration = { version = "1.1.0-alpha.2", features = ["async-tokio-rusqlite"] } rust-embed = "8.0.0" serde.workspace = true +tabby-common = { path = "../../crates/tabby-common" } tarpc = { version = "0.33.0", features = ["serde-transport"] } thiserror.workspace = true tokio.workspace = true +tokio-rusqlite = "0.4.0" tokio-tungstenite = "0.20.1" tracing.workspace = true unicase = "2.7.0" +[dependencies.uuid] +version = "1.3.3" +features = [ + "v4", # Lets you generate random UUIDs + "fast-rng", # Use a faster (but still sufficiently random) RNG + "macro-diagnostics", # Enable better diagnostics for compile-time UUIDs +] + [dev-dependencies] tokio = { workspace = true, features = ["macros"] } diff --git a/ee/tabby-webserver/graphql/schema.graphql b/ee/tabby-webserver/graphql/schema.graphql index 4a14aacd164d..31c58c84cbac 100644 --- a/ee/tabby-webserver/graphql/schema.graphql +++ b/ee/tabby-webserver/graphql/schema.graphql @@ -3,6 +3,10 @@ enum WorkerKind { CHAT } +type Mutation { + resetRegistrationToken: String! +} + type Query { workers: [Worker!]! } @@ -20,4 +24,5 @@ type Worker { schema { query: Query + mutation: Mutation } diff --git a/ee/tabby-webserver/src/api.rs b/ee/tabby-webserver/src/api.rs index 39aabfed392e..4929cb84cb53 100644 --- a/ee/tabby-webserver/src/api.rs +++ b/ee/tabby-webserver/src/api.rs @@ -25,7 +25,7 @@ pub struct Worker { #[derive(Serialize, Deserialize, Error, Debug)] pub enum HubError { - #[error("Invalid worker token")] + #[error("Invalid token")] InvalidToken(String), #[error("Feature requires enterprise license")] @@ -43,6 +43,7 @@ pub trait Hub { cpu_info: String, cpu_count: i32, cuda_devices: Vec, + token: String, ) -> Result; } diff --git a/ee/tabby-webserver/src/db.rs b/ee/tabby-webserver/src/db.rs new file mode 100644 index 000000000000..46c61ae2072e --- /dev/null +++ b/ee/tabby-webserver/src/db.rs @@ -0,0 +1,127 @@ +use std::{path::PathBuf, sync::Arc}; + +use anyhow::Result; +use lazy_static::lazy_static; +use rusqlite::params; +use rusqlite_migration::{AsyncMigrations, M}; +use tabby_common::path::tabby_root; +use tokio_rusqlite::Connection; + +lazy_static! { + static ref MIGRATIONS: AsyncMigrations = AsyncMigrations::new(vec![M::up( + r#" + CREATE TABLE IF NOT EXISTS registration_token ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token VARCHAR(255) NOT NULL, + created_at TIMESTAMP DEFAULT (DATETIME('now')), + updated_at TIMESTAMP DEFAULT (DATETIME('now')), + CONSTRAINT `idx_token` UNIQUE (`token`) + ); + "# + ),]); +} + +fn db_file() -> PathBuf { + tabby_root().join("db.sqlite3") +} + +pub struct DbConn { + conn: Arc, +} + +impl DbConn { + pub async fn new() -> Result { + let conn = Connection::open(db_file()).await?; + Self::init_db(conn).await + } + + /// Initialize database, create tables and insert first token if not exist + async fn init_db(mut conn: Connection) -> Result { + MIGRATIONS.to_latest(&mut conn).await?; + + let token = uuid::Uuid::new_v4().to_string(); + conn.call(move |c| { + c.execute( + r#"INSERT OR IGNORE INTO registration_token (id, token) VALUES (1, ?)"#, + params![token], + ) + }) + .await?; + + Ok(Self { + conn: Arc::new(conn), + }) + } + + /// Query token from database. + /// Since token is global unique for each tabby server, by right there's only one row in the table. + pub async fn read_registration_token(&self) -> Result { + let token = self + .conn + .call(|conn| { + conn.query_row( + r#"SELECT token FROM registration_token WHERE id = 1"#, + [], + |row| row.get(0), + ) + }) + .await?; + + Ok(token) + } + + /// Update token in database. + pub async fn reset_registration_token(&self) -> Result { + let token = uuid::Uuid::new_v4().to_string(); + let result = token.clone(); + let updated_at = chrono::Utc::now().timestamp() as u32; + + let res = self + .conn + .call(move |conn| { + conn.execute( + r#"UPDATE registration_token SET token = ?, updated_at = ? WHERE id = 1"#, + params![token, updated_at], + ) + }) + .await?; + if res != 1 { + return Err(anyhow::anyhow!("failed to update token")); + } + + Ok(result) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + async fn new_in_memory() -> Result { + let conn = Connection::open_in_memory().await?; + DbConn::init_db(conn).await + } + + #[tokio::test] + async fn migrations_test() { + assert!(MIGRATIONS.validate().await.is_ok()); + } + + #[tokio::test] + async fn test_token() { + let conn = new_in_memory().await.unwrap(); + let token = conn.read_registration_token().await.unwrap(); + assert_eq!(token.len(), 36); + } + + #[tokio::test] + async fn test_update_token() { + let conn = new_in_memory().await.unwrap(); + + let old_token = conn.read_registration_token().await.unwrap(); + conn.reset_registration_token().await.unwrap(); + let new_token = conn.read_registration_token().await.unwrap(); + assert_eq!(new_token.len(), 36); + assert_ne!(old_token, new_token); + } +} diff --git a/ee/tabby-webserver/src/lib.rs b/ee/tabby-webserver/src/lib.rs index ad2404258b69..8c85cd4c44a7 100644 --- a/ee/tabby-webserver/src/lib.rs +++ b/ee/tabby-webserver/src/lib.rs @@ -2,8 +2,10 @@ pub mod api; mod schema; pub use schema::create_schema; +use tracing::error; use websocket::WebSocketTransport; +mod db; mod server; mod ui; mod websocket; @@ -25,7 +27,8 @@ use server::ServerContext; use tarpc::server::{BaseChannel, Channel}; pub async fn attach_webserver(router: Router) -> Router { - let ctx = Arc::new(ServerContext::default()); + let conn = db::DbConn::new().await.unwrap(); + let ctx = Arc::new(ServerContext::new(conn)); let schema = Arc::new(create_schema()); let app = Router::new() @@ -91,7 +94,24 @@ impl Hub for Arc { cpu_info: String, cpu_count: i32, cuda_devices: Vec, + token: String, ) -> Result { + if token.is_empty() { + return Err(HubError::InvalidToken("Empty worker token".to_string())); + } + let server_token = match self.ctx.read_registration_token().await { + Ok(t) => t, + Err(err) => { + error!("fetch server token: {}", err.to_string()); + return Err(HubError::InvalidToken( + "Failed to fetch server token".to_string(), + )); + } + }; + if server_token != token { + return Err(HubError::InvalidToken("Token mismatch".to_string())); + } + let worker = Worker { name, kind, diff --git a/ee/tabby-webserver/src/schema.rs b/ee/tabby-webserver/src/schema.rs index f4364dd10ff7..21019c54d870 100644 --- a/ee/tabby-webserver/src/schema.rs +++ b/ee/tabby-webserver/src/schema.rs @@ -1,4 +1,4 @@ -use juniper::{graphql_object, EmptyMutation, EmptySubscription, RootNode}; +use juniper::{graphql_object, EmptySubscription, FieldResult, RootNode}; use crate::{api::Worker, server::ServerContext}; @@ -15,9 +15,19 @@ impl Query { } } -pub type Schema = - RootNode<'static, Query, EmptyMutation, EmptySubscription>; +#[derive(Default)] +pub struct Mutation; + +#[graphql_object(context = ServerContext)] +impl Mutation { + async fn reset_registration_token(ctx: &ServerContext) -> FieldResult { + let token = ctx.reset_registration_token().await?; + Ok(token) + } +} + +pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription>; pub fn create_schema() -> Schema { - Schema::new(Query, EmptyMutation::new(), EmptySubscription::new()) + Schema::new(Query, Mutation, EmptySubscription::new()) } diff --git a/ee/tabby-webserver/src/server.rs b/ee/tabby-webserver/src/server.rs index 3ab1c4b9fadf..eca0c38d9798 100644 --- a/ee/tabby-webserver/src/server.rs +++ b/ee/tabby-webserver/src/server.rs @@ -3,19 +3,44 @@ mod worker; use std::net::SocketAddr; +use anyhow::Result; use axum::{http::Request, middleware::Next, response::IntoResponse}; use hyper::{client::HttpConnector, Body, Client, StatusCode}; use tracing::{info, warn}; -use crate::api::{HubError, Worker, WorkerKind}; -#[derive(Default)] +use crate::{ + api::{HubError, Worker, WorkerKind}, + db::DbConn, +}; + pub struct ServerContext { client: Client, completion: worker::WorkerGroup, chat: worker::WorkerGroup, + db_conn: DbConn, } impl ServerContext { + pub fn new(db_conn: DbConn) -> Self { + Self { + client: Client::default(), + completion: worker::WorkerGroup::default(), + chat: worker::WorkerGroup::default(), + db_conn, + } + } + + /// Query current token from the database. + pub async fn read_registration_token(&self) -> Result { + self.db_conn.read_registration_token().await + } + + /// Generate new token, and update it in the database. + /// Return new token after update is done + pub async fn reset_registration_token(&self) -> Result { + self.db_conn.reset_registration_token().await + } + pub async fn register_worker(&self, worker: Worker) -> Result { let worker = match worker.kind { WorkerKind::Completion => self.completion.register(worker).await,