Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ee): implement user authentication api #912

Merged
merged 4 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 174 additions & 27 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions ee/tabby-webserver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ homepage.workspace = true

[dependencies]
anyhow.workspace = true
argon2 = "0.5.1"
async-trait.workspace = true
axum = { workspace = true, features = ["ws"] }
bincode = "1.3.3"
chrono = "0.4"
futures.workspace = true
hyper = { workspace = true, features=["client"]}
jsonwebtoken = "9.1.0"
juniper.workspace = true
juniper-axum = { path = "../../crates/juniper-axum" }
lazy_static = "1.4.0"
Expand All @@ -33,6 +35,7 @@ tower = { version = "0.4", features = ["util"] }
tower-http = { version = "0.4.0", features = ["fs", "trace"] }
tracing.workspace = true
unicase = "2.7.0"
validator = { version = "0.16.1", features = ["derive"] }

[dependencies.uuid]
version = "1.3.3"
Expand Down
38 changes: 37 additions & 1 deletion ee/tabby-webserver/graphql/schema.graphql
Original file line number Diff line number Diff line change
@@ -1,10 +1,40 @@
type RegisterResponse {
accessToken: String!
refreshToken: String!
errors: [AuthError!]!
darknight marked this conversation as resolved.
Show resolved Hide resolved
}

type AuthError {
message: String!
code: String!
}

enum WorkerKind {
COMPLETION
CHAT
}

type Mutation {
resetRegistrationToken: String!
resetRegistrationToken(token: String): String!
register(email: String!, password1: String!, password2: String!): RegisterResponse!
tokenAuth(email: String!, password: String!): TokenAuthResponse!
verifyToken(token: String!): VerifyTokenResponse!
}

type UserInfo {
email: String!
isAdmin: Boolean!
}

type VerifyTokenResponse {
errors: [AuthError!]!
darknight marked this conversation as resolved.
Show resolved Hide resolved
claims: Claims!
}

type Claims {
exp: Float!
iat: Float!
user: UserInfo!
}

type Query {
Expand All @@ -23,6 +53,12 @@ type Worker {
cudaDevices: [String!]!
}

type TokenAuthResponse {
accessToken: String!
refreshToken: String!
errors: [AuthError!]!
darknight marked this conversation as resolved.
Show resolved Hide resolved
}

schema {
query: Query
mutation: Mutation
Expand Down
113 changes: 109 additions & 4 deletions ee/tabby-webserver/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ use std::{path::PathBuf, sync::Arc};

use anyhow::Result;
use lazy_static::lazy_static;
use rusqlite::params;
use rusqlite::{params, OptionalExtension};
use rusqlite_migration::{AsyncMigrations, M};
use tabby_common::path::tabby_root;
use tokio_rusqlite::Connection;

lazy_static! {
static ref MIGRATIONS: AsyncMigrations = AsyncMigrations::new(vec![M::up(
r#"
static ref MIGRATIONS: AsyncMigrations = AsyncMigrations::new(vec![
M::up(
r#"
CREATE TABLE IF NOT EXISTS registration_token (
id INTEGER PRIMARY KEY AUTOINCREMENT,
token VARCHAR(255) NOT NULL,
Expand All @@ -18,7 +19,32 @@ lazy_static! {
CONSTRAINT `idx_token` UNIQUE (`token`)
);
"#
),]);
),
M::up(
r#"
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email VARCHAR(150) NOT NULL COLLATE NOCASE,
password_encrypted VARCHAR(128) NOT NULL,
is_admin BOOLEAN NOT NULL DEFAULT 0,
created_at TIMESTAMP DEFAULT (DATETIME('now')),
updated_at TIMESTAMP DEFAULT (DATETIME('now')),
CONSTRAINT `idx_email` UNIQUE (`email`)
);
"#
),
]);
}

#[allow(unused)]
pub struct User {
created_at: String,
updated_at: String,

pub id: u32,
pub email: String,
pub password_encrypted: String,
pub is_admin: bool,
}

async fn db_path() -> Result<PathBuf> {
Expand All @@ -27,6 +53,7 @@ async fn db_path() -> Result<PathBuf> {
Ok(db_dir.join("db.sqlite"))
}

#[derive(Clone)]
pub struct DbConn {
conn: Arc<Connection>,
}
Expand Down Expand Up @@ -55,7 +82,10 @@ impl DbConn {
conn: Arc::new(conn),
})
}
}

/// db read/write operations for `registration_token` table
impl DbConn {
/// Query token from database.
/// Since token is global unique for each tabby server, by right there's only one row in the table.
pub async fn read_registration_token(&self) -> Result<String> {
Expand Down Expand Up @@ -96,6 +126,56 @@ impl DbConn {
}
}

/// db read/write operations for `users` table
impl DbConn {
pub async fn create_user(
&self,
email: String,
password_encrypted: String,
is_admin: bool,
) -> Result<()> {
let res = self
.conn
.call(move |c| {
c.execute(
r#"INSERT INTO users (email, password_encrypted, is_admin) VALUES (?, ?, ?)"#,
params![email, password_encrypted, is_admin],
)
})
.await?;
if res != 1 {
return Err(anyhow::anyhow!("failed to create user"));
}

Ok(())
}

pub async fn get_user_by_email(&self, email: &str) -> Result<Option<User>> {
let email = email.to_string();
let user = self
.conn
.call(move |c| {
c.query_row(
r#"SELECT id, email, password_encrypted, is_admin, created_at, updated_at FROM users WHERE email = ?"#,
params![email],
|row| {
Ok(User {
id: row.get(0)?,
email: row.get(1)?,
password_encrypted: row.get(2)?,
is_admin: row.get(3)?,
created_at: row.get(4)?,
updated_at: row.get(5)?,
})
},
).optional()
})
.await?;

Ok(user)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -127,4 +207,29 @@ mod tests {
assert_eq!(new_token.len(), 36);
assert_ne!(old_token, new_token);
}

#[tokio::test]
async fn test_create_user() {
let conn = new_in_memory().await.unwrap();

let email = "[email protected]";
let passwd = "123456";
let is_admin = true;
conn.create_user(email.to_string(), passwd.to_string(), is_admin)
.await
.unwrap();

let user = conn.get_user_by_email(email).await.unwrap().unwrap();
assert_eq!(user.id, 1);
}

#[tokio::test]
async fn test_get_user_by_email() {
let conn = new_in_memory().await.unwrap();

let email = "[email protected]";
let user = conn.get_user_by_email(email).await.unwrap();

assert!(user.is_none());
}
}
59 changes: 54 additions & 5 deletions ee/tabby-webserver/src/schema.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
use juniper::{graphql_object, EmptySubscription, FieldResult, RootNode};
pub mod auth;

use crate::{api::Worker, server::ServerContext};
use juniper::{
graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, RootNode,
};

use crate::{
api::Worker,
schema::auth::{RegisterResponse, TokenAuthResponse, VerifyTokenResponse},
server::{
auth::{validate_jwt, AuthenticationService, RegisterInput, TokenAuthInput},
ServerContext,
},
};

// To make our context usable by Juniper, we have to implement a marker trait.
impl juniper::Context for ServerContext {}
Expand All @@ -25,9 +36,47 @@ pub struct Mutation;

#[graphql_object(context = ServerContext)]
impl Mutation {
async fn reset_registration_token(ctx: &ServerContext) -> FieldResult<String> {
let token = ctx.reset_registration_token().await?;
Ok(token)
async fn reset_registration_token(
ctx: &ServerContext,
token: Option<String>,
) -> FieldResult<String> {
if let Some(Ok(claims)) = token.map(|t| validate_jwt(&t)) {
if claims.user_info().is_admin() {
let reg_token = ctx.reset_registration_token().await?;
return Ok(reg_token);
}
}
Err(FieldError::new(
"Only admin is able to reset registration token",
graphql_value!("Unauthorized"),
))
}

async fn register(
ctx: &ServerContext,
email: String,
password1: String,
password2: String,
) -> FieldResult<RegisterResponse> {
let input = RegisterInput {
email,
password1,
password2,
};
ctx.auth().register(input).await
}

async fn token_auth(
ctx: &ServerContext,
email: String,
password: String,
) -> FieldResult<TokenAuthResponse> {
let input = TokenAuthInput { email, password };
ctx.auth().token_auth(input).await
}

async fn verify_token(ctx: &ServerContext, token: String) -> FieldResult<VerifyTokenResponse> {
ctx.auth().verify_token(token).await
}
}

Expand Down
Loading
Loading