Skip to content

Commit

Permalink
feat: validate token during worker registration (#803)
Browse files Browse the repository at this point in the history
* feat: validate token during worker registration

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* resolve comments

* reslove comments

* format file, update schema file

* resolve comment

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
darknight and autofix-ci[bot] authored Nov 17, 2023
1 parent 97f4989 commit ce338c7
Show file tree
Hide file tree
Showing 10 changed files with 318 additions and 34 deletions.
125 changes: 100 additions & 25 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/tabby-common/src/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub fn set_tabby_root(path: PathBuf) {
cell.replace(path);
}

fn tabby_root() -> PathBuf {
pub fn tabby_root() -> PathBuf {
let mut cell = TABBY_ROOT.lock().unwrap();
cell.get_mut().clone()
}
Expand Down
7 changes: 7 additions & 0 deletions crates/tabby/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ pub struct WorkerArgs {
#[clap(long, default_value_t = 8080)]
port: u16,

/// Server token to register this worker to.
#[clap(long)]
token: String,

/// Model id
#[clap(long, help_heading=Some("Model Options"))]
model: String,
Expand Down Expand Up @@ -99,6 +103,7 @@ async fn request_register(kind: WorkerKind, args: &WorkerArgs) {
args.port,
args.model.to_owned(),
args.device.to_string(),
args.token.clone(),
)
.await
{
Expand All @@ -112,6 +117,7 @@ async fn request_register_impl(
port: u16,
name: String,
device: String,
token: String,
) -> Result<()> {
let client = tabby_webserver::api::create_client(url).await;
let (cpu_info, cpu_count) = read_cpu_info();
Expand All @@ -127,6 +133,7 @@ async fn request_register_impl(
cpu_info,
cpu_count as i32,
cuda_devices,
token,
)
.await??;

Expand Down
14 changes: 14 additions & 0 deletions ee/tabby-webserver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,35 @@ homepage.workspace = true
anyhow.workspace = true
axum = { workspace = true, features = ["ws"] }
bincode = "1.3.3"
chrono = "0.4"
futures.workspace = true
hyper = { workspace = true, features=["client"]}
juniper.workspace = true
juniper-axum = { path = "../../crates/juniper-axum" }
lazy_static = "1.4.0"
mime_guess = "2.0.4"
pin-project = "1.1.3"
rusqlite = { version = "0.29.0", features = ["bundled"] }
# `async-tokio-rusqlite` is only available from 1.1.0-alpha.2, will bump up version when it's stable
rusqlite_migration = { version = "1.1.0-alpha.2", features = ["async-tokio-rusqlite"] }
rust-embed = "8.0.0"
serde.workspace = true
tabby-common = { path = "../../crates/tabby-common" }
tarpc = { version = "0.33.0", features = ["serde-transport"] }
thiserror.workspace = true
tokio.workspace = true
tokio-rusqlite = "0.4.0"
tokio-tungstenite = "0.20.1"
tracing.workspace = true
unicase = "2.7.0"

[dependencies.uuid]
version = "1.3.3"
features = [
"v4", # Lets you generate random UUIDs
"fast-rng", # Use a faster (but still sufficiently random) RNG
"macro-diagnostics", # Enable better diagnostics for compile-time UUIDs
]

[dev-dependencies]
tokio = { workspace = true, features = ["macros"] }
5 changes: 5 additions & 0 deletions ee/tabby-webserver/graphql/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ enum WorkerKind {
CHAT
}

type Mutation {
resetRegistrationToken: String!
}

type Query {
workers: [Worker!]!
}
Expand All @@ -20,4 +24,5 @@ type Worker {

schema {
query: Query
mutation: Mutation
}
3 changes: 2 additions & 1 deletion ee/tabby-webserver/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub struct Worker {

#[derive(Serialize, Deserialize, Error, Debug)]
pub enum HubError {
#[error("Invalid worker token")]
#[error("Invalid token")]
InvalidToken(String),

#[error("Feature requires enterprise license")]
Expand All @@ -43,6 +43,7 @@ pub trait Hub {
cpu_info: String,
cpu_count: i32,
cuda_devices: Vec<String>,
token: String,
) -> Result<Worker, HubError>;
}

Expand Down
Loading

0 comments on commit ce338c7

Please sign in to comment.