diff --git a/Cargo.lock b/Cargo.lock index f11d31521ad1..c9154a3b17bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -987,6 +987,17 @@ dependencies = [ "nom 4.1.1", ] +[[package]] +name = "cron" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ff76b51e4c068c52bfd2866e1567bee7c567ae8f24ada09fd4307019e25eab7" +dependencies = [ + "chrono", + "nom 7.1.3", + "once_cell", +] + [[package]] name = "crossbeam" version = "0.8.2" @@ -2258,7 +2269,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51f368c9c76dde2282714ae32dc274b79c27527a0c06c816f6dda048904d0d7c" dependencies = [ "chrono", - "cron", + "cron 0.6.1", "uuid 0.8.2", ] @@ -2907,6 +2918,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-derive" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "876a53fff98e03a936a674b29568b0e605f06b29372c2489ff4de23f1949743d" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "num-integer" version = "0.1.45" @@ -4824,6 +4846,7 @@ dependencies = [ "tarpc", "thiserror", "tokio", + "tokio-cron-scheduler", "tokio-rusqlite", "tokio-tungstenite", "tower", @@ -5193,6 +5216,21 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "tokio-cron-scheduler" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de2c1fd54a857b29c6cd1846f31903d0ae8e28175615c14a277aed45c58d8e27" +dependencies = [ + "chrono", + "cron 0.12.0", + "num-derive", + "num-traits", + "tokio", + "tracing", + "uuid 1.4.1", +] + [[package]] name = "tokio-io-timeout" version = "1.2.0" diff --git a/ee/tabby-webserver/Cargo.toml b/ee/tabby-webserver/Cargo.toml index 288e915e4146..b3da0bd6689d 100644 --- a/ee/tabby-webserver/Cargo.toml +++ b/ee/tabby-webserver/Cargo.toml @@ -29,6 +29,7 @@ tabby-common = { path = "../../crates/tabby-common" } tarpc = { version = "0.33.0", features = ["serde-transport"] } thiserror.workspace = true tokio = { workspace = true, features = ["fs"] } +tokio-cron-scheduler = "0.9.4" tokio-rusqlite = "0.4.0" tokio-tungstenite = "0.20.1" tower = { version = "0.4", features = ["util"] } diff --git a/ee/tabby-webserver/graphql/schema.graphql b/ee/tabby-webserver/graphql/schema.graphql index cdd700cf7b9f..366f52708bde 100644 --- a/ee/tabby-webserver/graphql/schema.graphql +++ b/ee/tabby-webserver/graphql/schema.graphql @@ -13,6 +13,7 @@ type Mutation { register(email: String!, password1: String!, password2: String!, invitationCode: String): RegisterResponse! tokenAuth(email: String!, password: String!): TokenAuthResponse! verifyToken(token: String!): VerifyTokenResponse! + refreshToken(refreshToken: String!): RefreshTokenResponse! createInvitation(email: String!): Int! deleteInvitation(id: Int!): Int! } @@ -63,6 +64,12 @@ type TokenAuthResponse { refreshToken: String! } +type RefreshTokenResponse { + accessToken: String! + refreshToken: String! + refreshExpiresAt: Float! +} + schema { query: Query mutation: Mutation diff --git a/ee/tabby-webserver/src/schema/auth.rs b/ee/tabby-webserver/src/schema/auth.rs index 82a1f23c8776..7331ca3f67ae 100644 --- a/ee/tabby-webserver/src/schema/auth.rs +++ b/ee/tabby-webserver/src/schema/auth.rs @@ -7,6 +7,7 @@ use juniper::{FieldError, GraphQLObject, IntoFieldError, ScalarValue}; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use thiserror::Error; +use uuid::Uuid; use validator::ValidationErrors; use super::from_validation_errors; @@ -19,6 +20,7 @@ lazy_static! { jwt_token_secret().as_bytes() ); static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes + static ref JWT_REFRESH_PERIOD: i64 = 7 * 24 * 60 * 60; // 7 days } pub fn generate_jwt(claims: Claims) -> jwt::errors::Result { @@ -37,10 +39,15 @@ fn jwt_token_secret() -> String { std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string()) } +pub fn generate_refresh_token(utc_ts: i64) -> (String, i64) { + let token = Uuid::new_v4().to_string().replace('-', ""); + (token, utc_ts + *JWT_REFRESH_PERIOD) +} + #[derive(Debug, GraphQLObject)] pub struct RegisterResponse { access_token: String, - refresh_token: String, + pub refresh_token: String, } impl RegisterResponse { @@ -82,7 +89,7 @@ impl IntoFieldError for RegisterError { #[derive(Debug, GraphQLObject)] pub struct TokenAuthResponse { access_token: String, - refresh_token: String, + pub refresh_token: String, } impl TokenAuthResponse { @@ -127,11 +134,45 @@ impl IntoFieldError for TokenAuthError { } } -#[derive(Debug, Default, GraphQLObject)] +#[derive(Error, Debug)] +pub enum RefreshTokenError { + #[error("Invalid refresh token")] + InvalidRefreshToken, + + #[error("Expired refresh token")] + ExpiredRefreshToken, + + #[error("User not found")] + UserNotFound, + + #[error(transparent)] + Other(#[from] anyhow::Error), + + #[error("Unknown error")] + Unknown, +} + +impl IntoFieldError for RefreshTokenError { + fn into_field_error(self) -> FieldError { + self.into() + } +} + +#[derive(Debug, GraphQLObject)] pub struct RefreshTokenResponse { - access_token: String, - refresh_token: String, - refresh_expires_in: i32, + pub access_token: String, + pub refresh_token: String, + pub refresh_expires_at: f64, +} + +impl RefreshTokenResponse { + pub fn new(access_token: String, refresh_token: String, refresh_expires_at: f64) -> Self { + Self { + access_token, + refresh_token, + refresh_expires_at, + } + } } #[derive(Debug, GraphQLObject)] @@ -215,7 +256,10 @@ pub trait AuthenticationService: Send + Sync { password: String, ) -> std::result::Result; - async fn refresh_token(&self, refresh_token: String) -> Result; + async fn refresh_token( + &self, + refresh_token: String, + ) -> std::result::Result; async fn verify_token(&self, access_token: String) -> Result; async fn is_admin_initialized(&self) -> Result; @@ -245,4 +289,11 @@ mod tests { &UserInfo::new("test".to_string(), false) ); } + + #[test] + fn test_generate_refresh_token() { + let (token, exp) = generate_refresh_token(100); + assert_eq!(token.len(), 32); + assert_eq!(exp, 100 + *JWT_REFRESH_PERIOD); + } } diff --git a/ee/tabby-webserver/src/schema/mod.rs b/ee/tabby-webserver/src/schema/mod.rs index 444713e3522e..04b486e39031 100644 --- a/ee/tabby-webserver/src/schema/mod.rs +++ b/ee/tabby-webserver/src/schema/mod.rs @@ -17,7 +17,10 @@ use self::{ worker::WorkerService, }; use crate::schema::{ - auth::{RegisterResponse, TokenAuthResponse, UserInfo, VerifyTokenResponse}, + auth::{ + RefreshTokenError, RefreshTokenResponse, RegisterResponse, TokenAuthResponse, UserInfo, + VerifyTokenResponse, + }, worker::Worker, }; @@ -142,6 +145,13 @@ impl Mutation { Ok(ctx.locator.auth().verify_token(token).await?) } + async fn refresh_token( + ctx: &Context, + refresh_token: String, + ) -> Result { + ctx.locator.auth().refresh_token(refresh_token).await + } + async fn create_invitation(ctx: &Context, email: String) -> Result { if let Some(claims) = &ctx.claims { if claims.user_info().is_admin() { diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index bd14f3d9f809..833fe50e7901 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -9,9 +9,9 @@ use validator::Validate; use super::db::DbConn; use crate::schema::auth::{ - generate_jwt, validate_jwt, AuthenticationService, Claims, Invitation, RefreshTokenResponse, - RegisterError, RegisterResponse, TokenAuthError, TokenAuthResponse, UserInfo, - VerifyTokenResponse, + generate_jwt, generate_refresh_token, validate_jwt, AuthenticationService, Claims, Invitation, + RefreshTokenError, RefreshTokenResponse, RegisterError, RegisterResponse, TokenAuthError, + TokenAuthResponse, UserInfo, VerifyTokenResponse, }; /// Input parameters for register mutation @@ -146,6 +146,10 @@ impl AuthenticationService for DbConn { .await?; let user = self.get_user(id).await?.unwrap(); + let (refresh_token, expires_at) = generate_refresh_token(chrono::Utc::now().timestamp()); + self.create_refresh_token(id, &refresh_token, expires_at) + .await?; + let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new( user.email.clone(), user.is_admin, @@ -153,7 +157,7 @@ impl AuthenticationService for DbConn { return Err(RegisterError::Unknown); }; - let resp = RegisterResponse::new(access_token, "".to_string()); + let resp = RegisterResponse::new(access_token, refresh_token); Ok(resp) } @@ -173,6 +177,10 @@ impl AuthenticationService for DbConn { return Err(TokenAuthError::InvalidPassword); } + let (refresh_token, expires_at) = generate_refresh_token(chrono::Utc::now().timestamp()); + self.create_refresh_token(user.id, &refresh_token, expires_at) + .await?; + let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new( user.email.clone(), user.is_admin, @@ -180,12 +188,39 @@ impl AuthenticationService for DbConn { return Err(TokenAuthError::Unknown); }; - let resp = TokenAuthResponse::new(access_token, "".to_string()); + let resp = TokenAuthResponse::new(access_token, refresh_token); Ok(resp) } - async fn refresh_token(&self, _refresh_token: String) -> Result { - Ok(RefreshTokenResponse::default()) + async fn refresh_token( + &self, + token: String, + ) -> std::result::Result { + let Some(refresh_token) = self.get_refresh_token(&token).await? else { + return Err(RefreshTokenError::InvalidRefreshToken); + }; + if refresh_token.is_expired() { + return Err(RefreshTokenError::ExpiredRefreshToken); + } + let Some(user) = self.get_user(refresh_token.user_id).await? else { + return Err(RefreshTokenError::UserNotFound); + }; + + let (new_token, _) = generate_refresh_token(chrono::Utc::now().timestamp()); + self.replace_refresh_token(&token, &new_token).await?; + + // refresh token update is done, generate new access token based on user info + let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new( + user.email.clone(), + user.is_admin, + ))) else { + return Err(RefreshTokenError::Unknown); + }; + + let resp = + RefreshTokenResponse::new(access_token, new_token, refresh_token.expires_at as f64); + + Ok(resp) } async fn verify_token(&self, access_token: String) -> Result { @@ -256,7 +291,7 @@ mod tests { static ADMIN_EMAIL: &str = "test@example.com"; static ADMIN_PASSWORD: &str = "123456789"; - async fn create_admin_user(conn: &DbConn) -> i32 { + async fn register_admin_user(conn: &DbConn) -> RegisterResponse { conn.register( ADMIN_EMAIL.to_owned(), ADMIN_PASSWORD.to_owned(), @@ -264,8 +299,7 @@ mod tests { None, ) .await - .unwrap(); - 1 + .unwrap() } #[tokio::test] @@ -277,7 +311,7 @@ mod tests { Err(TokenAuthError::UserNotFound) ); - create_admin_user(&conn).await; + register_admin_user(&conn).await; assert_matches!( conn.token_auth(ADMIN_EMAIL.to_owned(), "12345678".to_owned()) @@ -285,10 +319,16 @@ mod tests { Err(TokenAuthError::InvalidPassword) ); - assert!(conn + let resp1 = conn .token_auth(ADMIN_EMAIL.to_owned(), ADMIN_PASSWORD.to_owned()) .await - .is_ok()); + .unwrap(); + let resp2 = conn + .token_auth(ADMIN_EMAIL.to_owned(), ADMIN_PASSWORD.to_owned()) + .await + .unwrap(); + // each auth should generate a new refresh token + assert_ne!(resp1.refresh_token, resp2.refresh_token); } #[tokio::test] @@ -296,7 +336,7 @@ mod tests { let conn = DbConn::new_in_memory().await.unwrap(); assert!(!conn.is_admin_initialized().await.unwrap()); - create_admin_user(&conn).await; + register_admin_user(&conn).await; let email = "user@user.com"; let password = "12345678"; @@ -351,4 +391,23 @@ mod tests { Err(RegisterError::DuplicateEmail) ); } + + #[tokio::test] + async fn test_refresh_token() { + let conn = DbConn::new_in_memory().await.unwrap(); + let reg = register_admin_user(&conn).await; + + let resp1 = conn.refresh_token(reg.refresh_token.clone()).await.unwrap(); + // new access token should be valid + assert!(validate_jwt(&resp1.access_token).is_ok()); + // refresh token should be renewed + assert_ne!(reg.refresh_token, resp1.refresh_token); + + let resp2 = conn + .refresh_token(resp1.refresh_token.clone()) + .await + .unwrap(); + // expire time should be no change + assert_eq!(resp1.refresh_expires_at, resp2.refresh_expires_at); + } } diff --git a/ee/tabby-webserver/src/service/cron.rs b/ee/tabby-webserver/src/service/cron.rs new file mode 100644 index 000000000000..0ba79f769451 --- /dev/null +++ b/ee/tabby-webserver/src/service/cron.rs @@ -0,0 +1,62 @@ +use std::time::Duration; + +use anyhow::Result; +use tokio_cron_scheduler::{Job, JobScheduler}; +use tracing::{error, warn}; + +use crate::service::db::DbConn; + +async fn new_job_scheduler(jobs: Vec) -> Result { + let scheduler = JobScheduler::new().await?; + for job in jobs { + scheduler.add(job).await?; + } + scheduler.start().await?; + Ok(scheduler) +} + +async fn new_refresh_token_job(db_conn: DbConn) -> Result { + // job is run every 2 hours + let job = Job::new_async("0 0 1/2 * * * *", move |_, _| { + let utc_ts = chrono::Utc::now().timestamp(); + let db_conn = db_conn.clone(); + Box::pin(async move { + let res = db_conn.delete_expired_token(utc_ts).await; + if let Err(e) = res { + error!("failed to delete expired token: {}", e); + } + }) + })?; + + Ok(job) +} + +pub fn run_offline_job(db_conn: DbConn) { + tokio::spawn(async move { + let Ok(job) = new_refresh_token_job(db_conn.clone()).await else { + error!("failed to create db job"); + return; + }; + + let Ok(mut scheduler) = new_job_scheduler(vec![job]).await else { + error!("failed to start job scheduler"); + return; + }; + + loop { + match scheduler.time_till_next_job().await { + Ok(Some(duration)) => { + tokio::time::sleep(duration).await; + } + Ok(None) => { + warn!("no job available, exit scheduler"); + return; + } + Err(e) => { + error!("failed to get job sleep time: {}, re-try in 1 second", e); + tokio::time::sleep(Duration::from_secs(1)).await; + } + } + } + }); +} diff --git a/ee/tabby-webserver/src/service/db.rs b/ee/tabby-webserver/src/service/db.rs index b456e8376960..7941576ef444 100644 --- a/ee/tabby-webserver/src/service/db.rs +++ b/ee/tabby-webserver/src/service/db.rs @@ -8,7 +8,7 @@ use tabby_common::path::tabby_root; use tokio_rusqlite::Connection; use uuid::Uuid; -use crate::schema::auth::Invitation; +use crate::{schema::auth::Invitation, service::cron::run_offline_job}; lazy_static! { static ref MIGRATIONS: AsyncMigrations = AsyncMigrations::new(vec![ @@ -51,6 +51,19 @@ lazy_static! { "# ) .down("DROP TABLE invitations"), + M::up( + r#" + CREATE TABLE refresh_tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + token VARCHAR(255) NOT NULL COLLATE NOCASE, + expires_at INTEGER NOT NULL, + created_at TIMESTAMP DEFAULT (DATETIME('now')), + CONSTRAINT `idx_token` UNIQUE (`token`) + ); + "# + ) + .down("DROP TABLE refresh_tokens"), ]); } @@ -59,7 +72,7 @@ pub struct User { created_at: String, updated_at: String, - pub id: u32, + pub id: i32, pub email: String, pub password_encrypted: String, pub is_admin: bool, @@ -121,9 +134,12 @@ impl DbConn { }) .await?; - Ok(Self { + let res = Self { conn: Arc::new(conn), - }) + }; + run_offline_job(res.clone()); + + Ok(res) } } @@ -309,6 +325,114 @@ impl DbConn { } } +#[allow(unused)] +pub struct RefreshToken { + id: u32, + created_at: String, + + pub user_id: i32, + pub token: String, + pub expires_at: i64, +} + +impl RefreshToken { + fn select(clause: &str) -> String { + r#"SELECT id, user_id, token, expires_at, created_at FROM refresh_tokens WHERE "#.to_owned() + + clause + } + + fn from_row(row: &Row<'_>) -> std::result::Result { + Ok(RefreshToken { + id: row.get(0)?, + user_id: row.get(1)?, + token: row.get(2)?, + expires_at: row.get(3)?, + created_at: row.get(4)?, + }) + } + + pub fn is_expired(&self) -> bool { + let now = chrono::Utc::now().timestamp(); + self.expires_at < now + } +} + +/// db read/write operations for `refresh_tokens` table +impl DbConn { + pub async fn create_refresh_token( + &self, + user_id: i32, + token: &str, + expires_at: i64, + ) -> Result<()> { + let token = token.to_string(); + let res = self + .conn + .call(move |c| { + c.execute( + r#"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (?, ?, ?)"#, + params![user_id, token, expires_at], + ) + }) + .await?; + if res != 1 { + return Err(anyhow::anyhow!("failed to create refresh token")); + } + + Ok(()) + } + + pub async fn replace_refresh_token(&self, old: &str, new: &str) -> Result<()> { + let old = old.to_string(); + let new = new.to_string(); + let res = self + .conn + .call(move |c| { + c.execute( + r#"UPDATE refresh_tokens SET token = ? WHERE token = ?"#, + params![new, old], + ) + }) + .await?; + if res != 1 { + return Err(anyhow::anyhow!("failed to replace refresh token")); + } + + Ok(()) + } + + pub async fn delete_expired_token(&self, utc_ts: i64) -> Result { + let res = self + .conn + .call(move |c| { + c.execute( + r#"DELETE FROM refresh_tokens WHERE expires_at < ?"#, + params![utc_ts], + ) + }) + .await?; + + Ok(res as i32) + } + + pub async fn get_refresh_token(&self, token: &str) -> Result> { + let token = token.to_string(); + let token = self + .conn + .call(move |c| { + c.query_row( + RefreshToken::select("token = ?").as_str(), + params![token], + RefreshToken::from_row, + ) + .optional() + }) + .await?; + + Ok(token) + } +} + #[cfg(test)] mod tests { @@ -398,4 +522,33 @@ mod tests { let invitations = conn.list_invitations().await.unwrap(); assert!(invitations.is_empty()); } + + #[tokio::test] + async fn test_create_refresh_token() { + let conn = DbConn::new_in_memory().await.unwrap(); + + conn.create_refresh_token(1, "test", 100).await.unwrap(); + + let token = conn.get_refresh_token("test").await.unwrap().unwrap(); + + assert_eq!(token.user_id, 1); + assert_eq!(token.token, "test"); + assert_eq!(token.expires_at, 100); + } + + #[tokio::test] + async fn test_replace_refresh_token() { + let conn = DbConn::new_in_memory().await.unwrap(); + + conn.create_refresh_token(1, "test", 100).await.unwrap(); + conn.replace_refresh_token("test", "test2").await.unwrap(); + + let token = conn.get_refresh_token("test").await.unwrap(); + assert!(token.is_none()); + + let token = conn.get_refresh_token("test2").await.unwrap().unwrap(); + assert_eq!(token.user_id, 1); + assert_eq!(token.token, "test2"); + assert_eq!(token.expires_at, 100); + } } diff --git a/ee/tabby-webserver/src/service/mod.rs b/ee/tabby-webserver/src/service/mod.rs index bd47418ed4c4..c25795d1ac0c 100644 --- a/ee/tabby-webserver/src/service/mod.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -1,4 +1,5 @@ mod auth; +mod cron; mod db; mod proxy; mod worker;