From 8aa67cf456c0ee2ee98c94229c39b478b85ad109 Mon Sep 17 00:00:00 2001 From: John Keiser Date: Thu, 2 Jan 2025 21:22:54 -0800 Subject: [PATCH] Add whoami endpoint that can be hit by either user or automation token Also factors authorization so that all authorization types (and WS vs. non-WS) all run through the same extractors --- lib/sdf-server/src/extract.rs | 403 +++++++++++++----- .../src/middleware/workspace_permission.rs | 19 +- lib/sdf-server/src/routes.rs | 1 + lib/sdf-server/src/service.rs | 1 + .../service/session/restore_authentication.rs | 17 +- lib/sdf-server/src/service/whoami.rs | 38 ++ lib/sdf-server/src/service/ws/crdt.rs | 8 +- .../src/service/ws/workspace_updates.rs | 7 +- lib/si-jwt-public-key/src/lib.rs | 64 ++- 9 files changed, 410 insertions(+), 148 deletions(-) create mode 100644 lib/sdf-server/src/service/whoami.rs diff --git a/lib/sdf-server/src/extract.rs b/lib/sdf-server/src/extract.rs index fe33d616af..7b81a15d17 100644 --- a/lib/sdf-server/src/extract.rs +++ b/lib/sdf-server/src/extract.rs @@ -1,36 +1,42 @@ -use std::{collections::HashMap, fmt}; +use std::fmt; use axum::{ async_trait, extract::{FromRequestParts, Query}, http::request::Parts, + http::StatusCode, Json, }; use dal::{ context::{self, DalContextBuilder}, - User, + User, WorkspacePk, }; -use derive_more::Deref; -use hyper::StatusCode; -use si_jwt_public_key::{SiJwtClaimRole, SiJwtClaims}; +use derive_more::{Deref, Into}; +use serde::Deserialize; +use si_jwt_public_key::{validate_raw_token, SiJwt, SiJwtClaimRole}; use crate::app_state::AppState; +type ErrorResponse = (StatusCode, Json); + +/// An authorized user + workspace +#[derive(Clone, Debug, Deref, Into)] pub struct AccessBuilder(pub context::AccessBuilder); #[async_trait] impl FromRequestParts for AccessBuilder { - type Rejection = (StatusCode, Json); + type Rejection = ErrorResponse; async fn from_request_parts( parts: &mut Parts, state: &AppState, ) -> Result { - let Authorization(claim) = Authorization::from_request_parts(parts, state).await?; + // Ensure the endpoint is authorized + let auth = EndpointAuthorization::from_request_parts(parts, state).await?; Ok(Self(context::AccessBuilder::new( - dal::Tenancy::new(claim.workspace_id()), - dal::HistoryActor::from(claim.user_id()), + dal::Tenancy::new(auth.workspace_id), + dal::HistoryActor::from(auth.user.pk()), ))) } } @@ -43,7 +49,7 @@ pub struct AdminAccessBuilder; #[async_trait] impl FromRequestParts for AdminAccessBuilder { - type Rejection = (StatusCode, Json); + type Rejection = ErrorResponse; async fn from_request_parts( parts: &mut Parts, @@ -97,37 +103,12 @@ impl FromRequestParts for AdminAccessBuilder { } } -pub struct RawAccessToken(pub String); - -#[async_trait] -impl FromRequestParts for RawAccessToken { - type Rejection = (StatusCode, Json); - - async fn from_request_parts( - parts: &mut Parts, - _state: &AppState, - ) -> Result { - let raw_token_header = &parts - .headers - .get("Authorization") - .ok_or_else(|| unauthorized_error("no Authorization header"))?; - - let full_raw_token = raw_token_header.to_str().map_err(unauthorized_error)?; - - // token looks like "Bearer asdf" so we strip off the "bearer" - let raw_token = full_raw_token - .strip_prefix("Bearer ") - .ok_or_else(|| unauthorized_error("No Bearer in Authorization header"))?; - - Ok(Self(raw_token.to_owned())) - } -} - +#[derive(Clone, Debug, Deref, Into)] pub struct HandlerContext(pub DalContextBuilder); #[async_trait] impl FromRequestParts for HandlerContext { - type Rejection = (StatusCode, Json); + type Rejection = ErrorResponse; async fn from_request_parts( _parts: &mut Parts, @@ -142,12 +123,12 @@ impl FromRequestParts for HandlerContext { } } -#[derive(Deref)] +#[derive(Clone, Debug, Deref, Into)] pub struct AssetSprayer(pub asset_sprayer::AssetSprayer); #[async_trait] impl FromRequestParts for AssetSprayer { - type Rejection = (StatusCode, Json); + type Rejection = ErrorResponse; async fn from_request_parts( _parts: &mut Parts, @@ -160,11 +141,12 @@ impl FromRequestParts for AssetSprayer { } } +#[derive(Clone, Debug, Deref, Into)] pub struct PosthogClient(pub crate::app_state::PosthogClient); #[async_trait] impl FromRequestParts for PosthogClient { - type Rejection = (StatusCode, Json); + type Rejection = ErrorResponse; async fn from_request_parts( _parts: &mut Parts, @@ -174,11 +156,12 @@ impl FromRequestParts for PosthogClient { } } +#[derive(Clone, Debug, Deref, Into)] pub struct Nats(pub si_data_nats::NatsClient); #[async_trait] impl FromRequestParts for Nats { - type Rejection = (StatusCode, Json); + type Rejection = ErrorResponse; async fn from_request_parts( _parts: &mut Parts, @@ -189,105 +172,327 @@ impl FromRequestParts for Nats { } } -/** Represents a user who is authorized for the web */ +/// +/// Handles the whole endpoint authorization (checking if the user is a member of the workspace +/// as well as checking that their token has the correct role). +/// +/// Equivalent to calling both AuthorizedRole (or AuthorizedForWeb/AutomationRole) and WorkspaceMember. +/// +/// Unless you have already used the `TokenParamAccessToken` extractor to get the token from +/// query parameters, this will retrieve the token from the Authorization header. +/// +/// Unless you have already used the `AuthorizeForAutomationRole` extractor to check that the +/// token has the automation role, this will check for maximal permissions (the web role). +/// +#[derive(Clone, Debug)] +pub struct EndpointAuthorization { + pub user: User, + pub workspace_id: WorkspacePk, + pub authorized_role: SiJwtClaimRole, +} + +#[async_trait] +impl FromRequestParts for EndpointAuthorization { + type Rejection = ErrorResponse; + + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result { + let WorkspaceMember { user, workspace_id } = + WorkspaceMember::from_request_parts(parts, state).await?; + let AuthorizedRole(authorized_role) = + AuthorizedRole::from_request_parts(parts, state).await?; + Ok(Self { + user, + workspace_id, + authorized_role, + }) + } +} + +/// +/// A user who has been validated as a member of the workspace, but whose role has *not* +/// been checked for authorization. +/// #[derive(Clone, Debug)] -pub struct Authorization(pub SiJwtClaims); +struct WorkspaceMember { + pub user: User, + pub workspace_id: WorkspacePk, +} #[async_trait] -impl FromRequestParts for Authorization { - type Rejection = (StatusCode, Json); +impl FromRequestParts for WorkspaceMember { + type Rejection = ErrorResponse; async fn from_request_parts( parts: &mut Parts, state: &AppState, ) -> Result { - // If we already authorized this request for the web, don't do it again - if let Some(authorization) = parts.extensions.get::() { - return Ok(authorization.clone()); + if let Some(result) = parts.extensions.get::() { + return Ok(result.clone()); } + // Get the claims from the JWT + let token = ValidatedToken::from_request_parts(parts, state).await?.0; + let workspace_id = token.custom.workspace_id(); + + // Get a context associated with the workspace let HandlerContext(builder) = HandlerContext::from_request_parts(parts, state).await?; let mut ctx = builder.build_default().await.map_err(internal_error)?; - let jwt_public_signing_key = state.jwt_public_signing_key_chain().clone(); + ctx.update_tenancy(dal::Tenancy::new(workspace_id)); - let headers = &parts.headers; - let authorization_header_value = headers - .get("Authorization") - .ok_or_else(|| unauthorized_error("no Authorization header"))?; - let authorization = authorization_header_value - .to_str() - .map_err(internal_error)?; - let claim = SiJwtClaims::from_bearer_token(jwt_public_signing_key, authorization) + // Check if the user is a member of the workspace (and get the record if so) + let workspace_members = User::list_members_for_workspace(&ctx, workspace_id.to_string()) .await - .map_err(unauthorized_error)?; - ctx.update_tenancy(dal::Tenancy::new(claim.workspace_id())); + .map_err(internal_error)?; + let user = workspace_members + .into_iter() + .find(|m| m.pk() == token.custom.user_id()) + .ok_or_else(|| unauthorized_error("User not a member of the workspace"))?; + + // Stash and return the result + let result = Self { user, workspace_id }; + parts.extensions.insert(result.clone()); + Ok(result) + } +} - if !is_authorized_for(&ctx, &claim, SiJwtClaimRole::Web) - .await - .map_err(internal_error)? - { - return Err(unauthorized_error("not authorized for web role")); +/// +/// Confirms that this endpoint has been authorized for the desired role, but *not* that they +/// are . +/// +/// Stores the role that was authorized. +/// +/// To authorize for something other than web, use the `AuthorizeForAutomationRole` extractor. +/// +/// If it has not been authorized, this requires both that the maximal permissions (the web role). +/// +#[derive(Clone, Copy, Debug)] +struct AuthorizedRole(pub SiJwtClaimRole); + +impl AuthorizedRole { + async fn authorize_for( + parts: &mut Parts, + state: &AppState, + role: SiJwtClaimRole, + ) -> Result { + // This must not be done twice. + if parts.extensions.get::().is_some() { + return Err(internal_error( + "Must only specify explicit endpoint authorization once", + )); } - parts.extensions.insert(Self); + // Validate the token meets the role + let token = ValidatedToken::from_request_parts(parts, state).await?.0; + if !token.custom.authorized_for(role) { + return Err(unauthorized_error("Not authorized for web role")); + } + + // Stash the authorization + parts.extensions.insert(AuthorizedRole(role)); + + Ok(AuthorizedRole(role)) + } +} + +#[async_trait] +impl FromRequestParts for AuthorizedRole { + type Rejection = ErrorResponse; - Ok(Self(claim)) + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result { + if let Some(&result) = parts.extensions.get::() { + return Ok(result); + } + AuthorizedRole::authorize_for(parts, state, SiJwtClaimRole::Web).await } } -async fn is_authorized_for( - ctx: &dal::DalContext, - claim: &SiJwtClaims, - role: SiJwtClaimRole, -) -> dal::UserResult { - let workspace_members = - User::list_members_for_workspace(ctx, claim.workspace_id().to_string()).await?; +/// +/// A user who has been authorized for the given workspace for the web role. +/// +/// Does *not* validate that the user is a member of the workspace. EndpointAuthorization +/// (and WorkspaceMember) handle that. +/// +#[derive(Clone, Copy, Debug)] +pub struct AuthorizedForWebRole; - let is_member = workspace_members - .into_iter() - .any(|m| m.pk() == claim.user_id()); +#[async_trait] +impl FromRequestParts for AuthorizedForWebRole { + type Rejection = ErrorResponse; - Ok(is_member && claim.authorized_for(role)) + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result { + AuthorizedRole::authorize_for(parts, state, SiJwtClaimRole::Web).await?; + Ok(Self) + } } -pub struct WsAuthorization(pub SiJwtClaims); +/// +/// A user who has been authorized for the given workspace for the web role. +/// +#[derive(Clone, Copy, Debug)] +pub struct AuthorizedForAutomationRole; #[async_trait] -impl FromRequestParts for WsAuthorization { - type Rejection = (StatusCode, Json); +impl FromRequestParts for AuthorizedForAutomationRole { + type Rejection = ErrorResponse; async fn from_request_parts( parts: &mut Parts, state: &AppState, ) -> Result { - let HandlerContext(builder) = HandlerContext::from_request_parts(parts, state).await?; - let mut ctx = builder.build_default().await.map_err(internal_error)?; - let jwt_public_signing_key = state.jwt_public_signing_key_chain().clone(); + AuthorizedRole::authorize_for(parts, state, SiJwtClaimRole::Automation).await?; + Ok(Self) + } +} - let query: Query> = Query::from_request_parts(parts, state) - .await - .map_err(unauthorized_error)?; - let authorization = query - .get("token") - .ok_or_else(|| unauthorized_error("No token in query"))?; +/// +/// Validated JWT with unverified claims inside. +/// +/// Will retrieve this from RawAccessToken, which defaults to getting the Authorization header. +/// Use TokenParamAccessToken to get it from query parameters instead (for WS connections). +/// +/// Have not checked whether the user is a member of the workspace or has permissions. +/// +#[derive(Clone, Debug, Deref, Into)] +pub struct ValidatedToken(pub SiJwt); + +#[async_trait] +impl FromRequestParts for ValidatedToken { + type Rejection = ErrorResponse; - let claim = SiJwtClaims::from_bearer_token(jwt_public_signing_key, authorization) + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result { + if let Some(Self(claims)) = parts.extensions.get::() { + return Ok(Self(claims.clone())); + } + + let raw_token = RawAccessToken::from_request_parts(parts, state).await?.0; + + let jwt_public_signing_key = state.jwt_public_signing_key_chain().clone(); + let token = validate_raw_token(jwt_public_signing_key, raw_token) .await .map_err(unauthorized_error)?; - ctx.update_tenancy(dal::Tenancy::new(claim.workspace_id())); + parts.extensions.insert(Self(token.clone())); + Ok(Self(token)) + } +} + +/// The raw JWT token string. +/// +/// If this has not been extracted from the request, it will be extracted from the +/// Authorization header. +/// +/// Call TokenParamAccessToken to get the token from the query parameters (for WS connections) +#[derive(Clone, Debug, Deref, Into)] +pub struct RawAccessToken(pub String); - if !is_authorized_for(&ctx, &claim, SiJwtClaimRole::Web) +#[async_trait] +impl FromRequestParts for RawAccessToken { + type Rejection = ErrorResponse; + + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result { + if let Some(RawAccessToken(token)) = parts.extensions.get::() { + return Ok(Self(token.clone())); + } + + let token = TokenFromAuthorizationHeader::from_request_parts(parts, state) + .await? + .0; + Ok(Self(token)) + } +} + +/// Gets the access token from the Authorization: header and strips the "Bearer" prefix +#[derive(Clone, Debug, Deref, Into)] +pub struct TokenFromAuthorizationHeader(pub String); + +#[async_trait] +impl FromRequestParts for TokenFromAuthorizationHeader { + type Rejection = ErrorResponse; + + async fn from_request_parts( + parts: &mut Parts, + _state: &AppState, + ) -> Result { + if let Some(RawAccessToken(token)) = parts.extensions.get::() { + return Ok(Self(token.clone())); + } + + let raw_token_header = &parts + .headers + .get("Authorization") + .ok_or_else(|| unauthorized_error("no Authorization header"))?; + + let bearer_token = raw_token_header.to_str().map_err(unauthorized_error)?; + + // token looks like "Bearer asdf" so we strip off the "bearer" + let token = bearer_token + .strip_prefix("Bearer ") + .ok_or_else(|| unauthorized_error("No Bearer in Authorization header"))? + .to_owned(); + + parts.extensions.insert(RawAccessToken(token.clone())); + + Ok(Self(token)) + } +} + +/// Gets the access token from the "token" query parameter and strips the "Bearer" prefix +#[derive(Clone, Debug, Deref, Into)] +pub struct TokenFromQueryParam(pub String); + +#[derive(Deserialize)] +struct TokenParam { + token: String, +} + +#[async_trait] +impl FromRequestParts for TokenFromQueryParam { + type Rejection = ErrorResponse; + + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result { + let TokenParam { token } = Query::from_request_parts(parts, state) .await - .map_err(internal_error)? - { - return Err(unauthorized_error("not authorized for web role")); + .map_err(unauthorized_error)? + .0; + + // TODO there is a chance that somebody retrieved the token during the await, though + // the headers should come *after* the params in all cases ... may need to do something + // to force other extractors to wait (like put an awaitable rawaccesstoken instead of a + // finished one). + if parts.extensions.get::().is_some() { + return Err(internal_error("Token was already extracted!")); } - Ok(Self(claim)) + // token looks like "Bearer asdf" so we strip off the "bearer" + let token = token + .strip_prefix("Bearer ") + .ok_or_else(|| unauthorized_error("No Bearer in token query parameter"))? + .to_owned(); + + parts.extensions.insert(RawAccessToken(token.clone())); + + Ok(Self(token)) } } -fn internal_error(message: impl fmt::Display) -> (StatusCode, Json) { +fn internal_error(message: impl fmt::Display) -> ErrorResponse { let status_code = StatusCode::INTERNAL_SERVER_ERROR; ( status_code, @@ -301,7 +506,7 @@ fn internal_error(message: impl fmt::Display) -> (StatusCode, Json (StatusCode, Json) { +pub fn unauthorized_error(message: impl fmt::Display) -> ErrorResponse { let status_code = StatusCode::UNAUTHORIZED; ( status_code, @@ -315,7 +520,7 @@ pub fn unauthorized_error(message: impl fmt::Display) -> (StatusCode, Json (StatusCode, Json) { +fn not_found_error(message: &str) -> ErrorResponse { let status_code = StatusCode::NOT_FOUND; ( status_code, diff --git a/lib/sdf-server/src/middleware/workspace_permission.rs b/lib/sdf-server/src/middleware/workspace_permission.rs index 7ae4dac803..b671fae994 100644 --- a/lib/sdf-server/src/middleware/workspace_permission.rs +++ b/lib/sdf-server/src/middleware/workspace_permission.rs @@ -8,11 +8,10 @@ use axum::{ }; use futures::future::BoxFuture; use permissions::{Permission, PermissionBuilder}; -use si_jwt_public_key::SiJwtClaimRole; use tower::{Layer, Service}; use crate::{ - extract::{self, Authorization}, + extract::{self, EndpointAuthorization}, AppState, }; @@ -66,24 +65,24 @@ where Box::pin(async move { let (mut parts, body) = req.into_parts(); - let Authorization(claim) = - match Authorization::from_request_parts(&mut parts, &me.state).await { - Ok(claim) => claim, - Err(err) => return Ok(err.into_response()), - }; + let auth = match EndpointAuthorization::from_request_parts(&mut parts, &me.state).await + { + Ok(auth) => auth, + Err(err) => return Ok(err.into_response()), + }; if let Some(client) = me.state.spicedb_client() { let is_allowed = match PermissionBuilder::new() - .workspace_object(claim.workspace_id()) + .workspace_object(auth.workspace_id) .permission(me.permission) - .user_subject(claim.user_id()) + .user_subject(auth.user.pk()) .has_permission(client) .await { Ok(is_allowed) => is_allowed, Err(e) => return Ok(extract::unauthorized_error(e).into_response()), }; - if !is_allowed || !claim.authorized_for(SiJwtClaimRole::Web) { + if !is_allowed { return Ok( extract::unauthorized_error("not authorized for web role").into_response() ); diff --git a/lib/sdf-server/src/routes.rs b/lib/sdf-server/src/routes.rs index 79b12a8865..c9ef2d3363 100644 --- a/lib/sdf-server/src/routes.rs +++ b/lib/sdf-server/src/routes.rs @@ -65,6 +65,7 @@ pub fn routes(state: AppState) -> Router { .nest("/api/module", crate::service::module::routes()) .nest("/api/variant", crate::service::variant::routes()) .nest("/api/v2", crate::service::v2::routes(state.clone())) + .nest("/api/whoami", crate::service::whoami::routes()) .layer(CompressionLayer::new()) // allows us to be permissive about cors from our owned subdomains .layer( diff --git a/lib/sdf-server/src/service.rs b/lib/sdf-server/src/service.rs index 40dbbc71b4..a3025bed8e 100644 --- a/lib/sdf-server/src/service.rs +++ b/lib/sdf-server/src/service.rs @@ -23,6 +23,7 @@ pub mod secret; pub mod session; pub mod v2; pub mod variant; +pub mod whoami; pub mod ws; /// A module containing dev routes for local development only. diff --git a/lib/sdf-server/src/service/session/restore_authentication.rs b/lib/sdf-server/src/service/session/restore_authentication.rs index cd1f46804b..0fb3040c55 100644 --- a/lib/sdf-server/src/service/session/restore_authentication.rs +++ b/lib/sdf-server/src/service/session/restore_authentication.rs @@ -3,7 +3,7 @@ use dal::{User, Workspace}; use serde::{Deserialize, Serialize}; use super::{SessionError, SessionResult}; -use crate::extract::{AccessBuilder, Authorization, HandlerContext}; +use crate::extract::{AccessBuilder, EndpointAuthorization, HandlerContext}; #[derive(Deserialize, Serialize, Debug)] #[serde(rename_all = "camelCase")] @@ -14,21 +14,16 @@ pub struct RestoreAuthenticationResponse { pub async fn restore_authentication( HandlerContext(builder): HandlerContext, - // NOTE: these two lines *both* go to the DB and check the token for web-level access. - // We should probably only do this once. AccessBuilder(access_builder): AccessBuilder, - Authorization(claim): Authorization, + EndpointAuthorization { + user, workspace_id, .. + }: EndpointAuthorization, ) -> SessionResult> { let ctx = builder.build_head(access_builder).await?; - let workspace = Workspace::get_by_pk(&ctx, &claim.workspace_id()) + let workspace = Workspace::get_by_pk(&ctx, &workspace_id) .await? - .ok_or(SessionError::InvalidWorkspace(claim.workspace_id()))?; - - let user = User::get_by_pk(&ctx, claim.user_id()) - .await? - .ok_or(SessionError::InvalidUser(claim.user_id()))?; - + .ok_or(SessionError::InvalidWorkspace(workspace_id))?; let reply = RestoreAuthenticationResponse { user, workspace }; Ok(Json(reply)) diff --git a/lib/sdf-server/src/service/whoami.rs b/lib/sdf-server/src/service/whoami.rs new file mode 100644 index 0000000000..a2962ffd4b --- /dev/null +++ b/lib/sdf-server/src/service/whoami.rs @@ -0,0 +1,38 @@ +use axum::{response::IntoResponse, routing::get, Json, Router}; +use dal::{UserPk, WorkspacePk}; +use serde::{Deserialize, Serialize}; +use si_jwt_public_key::SiJwt; + +use crate::{ + extract::{AuthorizedForAutomationRole, EndpointAuthorization, ValidatedToken}, + AppState, +}; + +pub fn routes() -> Router { + Router::new().route("/", get(whoami)) +} + +#[derive(Deserialize, Serialize, Debug)] +#[serde(rename_all = "camelCase")] +struct WhoamiResponse { + pub user_id: UserPk, + pub user_email: String, + pub workspace_id: WorkspacePk, + pub token: SiJwt, +} + +async fn whoami( + // Just because this is the most permissive role we have right now + _: AuthorizedForAutomationRole, + ValidatedToken(token): ValidatedToken, + EndpointAuthorization { + workspace_id, user, .. + }: EndpointAuthorization, +) -> impl IntoResponse { + Json(WhoamiResponse { + workspace_id, + user_id: user.pk(), + user_email: user.email().clone(), + token, + }) +} diff --git a/lib/sdf-server/src/service/ws/crdt.rs b/lib/sdf-server/src/service/ws/crdt.rs index 658cf10df9..17eef67a28 100644 --- a/lib/sdf-server/src/service/ws/crdt.rs +++ b/lib/sdf-server/src/service/ws/crdt.rs @@ -21,7 +21,7 @@ use y_sync::net::BroadcastGroup; use super::WsError; use crate::{ - extract::{Nats, WsAuthorization}, + extract::{EndpointAuthorization, Nats, TokenFromQueryParam}, nats_multiplexer::NatsMultiplexerClients, }; @@ -62,16 +62,18 @@ pub struct Id { id: String, } +#[allow(clippy::too_many_arguments)] pub async fn crdt( wsu: WebSocketUpgrade, Nats(nats): Nats, - WsAuthorization(claim): WsAuthorization, + _: TokenFromQueryParam, // This tells it to pull the token from the "token" param + auth: EndpointAuthorization, Query(Id { id }): Query, State(shutdown_token): State, State(broadcast_groups): State, State(nats_multiplexer_clients): State, ) -> Result { - let workspace_pk = claim.workspace_id(); + let workspace_pk = auth.workspace_id; let channel_name = Subject::from(format!("crdt.{workspace_pk}.{id}")); let receiver = nats_multiplexer_clients diff --git a/lib/sdf-server/src/service/ws/workspace_updates.rs b/lib/sdf-server/src/service/ws/workspace_updates.rs index 52242f8f42..b42db33b6a 100644 --- a/lib/sdf-server/src/service/ws/workspace_updates.rs +++ b/lib/sdf-server/src/service/ws/workspace_updates.rs @@ -12,7 +12,7 @@ use tokio_util::sync::CancellationToken; use super::WsError; use crate::{ - extract::{Nats, WsAuthorization}, + extract::{EndpointAuthorization, Nats, TokenFromQueryParam}, nats_multiplexer::NatsMultiplexerClients, }; @@ -20,7 +20,8 @@ use crate::{ pub async fn workspace_updates( wsu: WebSocketUpgrade, Nats(nats): Nats, - WsAuthorization(claim): WsAuthorization, + _: TokenFromQueryParam, // This tells it to pull the token from the "token" param + auth: EndpointAuthorization, State(shutdown_token): State, State(channel_multiplexer_clients): State, ) -> Result { @@ -28,7 +29,7 @@ pub async fn workspace_updates( run_workspace_updates_proto( socket, nats, - claim.workspace_id(), + auth.workspace_id, channel_multiplexer_clients.ws, shutdown_token, ) diff --git a/lib/si-jwt-public-key/src/lib.rs b/lib/si-jwt-public-key/src/lib.rs index fc9275ef5c..974677ff01 100644 --- a/lib/si-jwt-public-key/src/lib.rs +++ b/lib/si-jwt-public-key/src/lib.rs @@ -86,6 +86,16 @@ pub enum SiJwtClaimRole { Automation, } +impl SiJwtClaimRole { + pub fn is_superset_of(&self, other: Self) -> bool { + match (self, other) { + (Self::Web, Self::Web | Self::Automation) => true, + (Self::Automation, Self::Automation) => true, + (Self::Automation, Self::Web) => false, + } + } +} + #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] #[serde(untagged)] pub enum SiJwtClaims { @@ -94,6 +104,9 @@ pub enum SiJwtClaims { V1(SiJwtClaimsV1), } +/** The whole token */ +pub type SiJwt = JWTClaims; + #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] #[serde(rename_all = "camelCase")] pub struct SiJwtClaimsV2 { @@ -125,12 +138,15 @@ impl SiJwtClaims { } } - pub fn authorized_for(&self, required_role: SiJwtClaimRole) -> bool { - let role = match self { + pub fn role(&self) -> SiJwtClaimRole { + match self { Self::V2(SiJwtClaimsV2 { role, .. }) => *role, Self::V1(SiJwtClaimsV1 { .. }) => SiJwtClaimRole::Web, - }; - role == required_role + } + } + + pub fn authorized_for(&self, required_role: SiJwtClaimRole) -> bool { + self.role().is_superset_of(required_role) } pub fn for_web(user_id: UserPk, workspace_id: WorkspacePk) -> Self { @@ -149,6 +165,14 @@ impl SiJwtClaims { let claims = validate_bearer_token(public_key, token).await?; Ok(claims.custom) } + + pub async fn from_raw_token( + public_key: JwtPublicSigningKeyChain, + token: impl Into, + ) -> JwtKeyResult { + let claims = validate_raw_token(public_key, token).await?; + Ok(claims.custom) + } } #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] @@ -160,11 +184,7 @@ pub enum JwtAlgo { pub trait JwtPublicKeyVerify: std::fmt::Debug + Send + Sync { fn algo(&self) -> JwtAlgo; - fn verify( - &self, - token: &str, - options: Option, - ) -> JwtKeyResult>; + fn verify(&self, token: &str, options: Option) -> JwtKeyResult; } impl JwtPublicKeyVerify for RS256PublicKey { @@ -172,11 +192,7 @@ impl JwtPublicKeyVerify for RS256PublicKey { JwtAlgo::RS256 } - fn verify( - &self, - token: &str, - options: Option, - ) -> JwtKeyResult> { + fn verify(&self, token: &str, options: Option) -> JwtKeyResult { self.verify_token(token, options) .map_err(|err| JwtPublicSigningKeyError::Verify(format!("{err}"))) } @@ -187,11 +203,7 @@ impl JwtPublicKeyVerify for ES256PublicKey { JwtAlgo::ES256 } - fn verify( - &self, - token: &str, - options: Option, - ) -> JwtKeyResult> { + fn verify(&self, token: &str, options: Option) -> JwtKeyResult { self.verify_token(token, options) .map_err(|err| JwtPublicSigningKeyError::Verify(format!("{err}"))) } @@ -223,7 +235,7 @@ impl JwtPublicSigningKeyChain { &self, token: &str, options: Option, - ) -> JwtKeyResult> { + ) -> JwtKeyResult { match self.primary.verify(token, options.clone()) { Ok(claims) => Ok(claims), Err(err) => match self.secondary.as_ref() { @@ -240,17 +252,25 @@ impl JwtPublicSigningKeyChain { } } -#[instrument(level = "debug", skip_all)] pub async fn validate_bearer_token( public_key: JwtPublicSigningKeyChain, bearer_token: impl AsRef, -) -> JwtKeyResult> { +) -> JwtKeyResult { let token = bearer_token .as_ref() .strip_prefix("Bearer ") .ok_or(JwtPublicSigningKeyError::BearerToken)? .to_string(); + validate_raw_token(public_key, token).await +} + +#[instrument(level = "debug", skip_all)] +pub async fn validate_raw_token( + public_key: JwtPublicSigningKeyChain, + token: impl Into, +) -> JwtKeyResult { + let token = token.into(); let claims = tokio::task::spawn_blocking(move || public_key.verify_token(&token, None)).await??;