Skip to content

Commit

Permalink
Add whoami endpoint that can be hit by either user or automation token
Browse files Browse the repository at this point in the history
Also factors authorization so that all authorization types (and WS vs.
non-WS) all run through the same extractors
  • Loading branch information
jkeiser committed Jan 3, 2025
1 parent 5558101 commit 8aa67cf
Show file tree
Hide file tree
Showing 9 changed files with 410 additions and 148 deletions.
403 changes: 304 additions & 99 deletions lib/sdf-server/src/extract.rs

Large diffs are not rendered by default.

19 changes: 9 additions & 10 deletions lib/sdf-server/src/middleware/workspace_permission.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ use axum::{
};
use futures::future::BoxFuture;
use permissions::{Permission, PermissionBuilder};
use si_jwt_public_key::SiJwtClaimRole;
use tower::{Layer, Service};

use crate::{
extract::{self, Authorization},
extract::{self, EndpointAuthorization},
AppState,
};

Expand Down Expand Up @@ -66,24 +65,24 @@ where
Box::pin(async move {
let (mut parts, body) = req.into_parts();

let Authorization(claim) =
match Authorization::from_request_parts(&mut parts, &me.state).await {
Ok(claim) => claim,
Err(err) => return Ok(err.into_response()),
};
let auth = match EndpointAuthorization::from_request_parts(&mut parts, &me.state).await
{
Ok(auth) => auth,
Err(err) => return Ok(err.into_response()),
};

if let Some(client) = me.state.spicedb_client() {
let is_allowed = match PermissionBuilder::new()
.workspace_object(claim.workspace_id())
.workspace_object(auth.workspace_id)
.permission(me.permission)
.user_subject(claim.user_id())
.user_subject(auth.user.pk())
.has_permission(client)
.await
{
Ok(is_allowed) => is_allowed,
Err(e) => return Ok(extract::unauthorized_error(e).into_response()),
};
if !is_allowed || !claim.authorized_for(SiJwtClaimRole::Web) {
if !is_allowed {
return Ok(
extract::unauthorized_error("not authorized for web role").into_response()
);
Expand Down
1 change: 1 addition & 0 deletions lib/sdf-server/src/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pub fn routes(state: AppState) -> Router {
.nest("/api/module", crate::service::module::routes())
.nest("/api/variant", crate::service::variant::routes())
.nest("/api/v2", crate::service::v2::routes(state.clone()))
.nest("/api/whoami", crate::service::whoami::routes())
.layer(CompressionLayer::new())
// allows us to be permissive about cors from our owned subdomains
.layer(
Expand Down
1 change: 1 addition & 0 deletions lib/sdf-server/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub mod secret;
pub mod session;
pub mod v2;
pub mod variant;
pub mod whoami;
pub mod ws;

/// A module containing dev routes for local development only.
Expand Down
17 changes: 6 additions & 11 deletions lib/sdf-server/src/service/session/restore_authentication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use dal::{User, Workspace};
use serde::{Deserialize, Serialize};

use super::{SessionError, SessionResult};
use crate::extract::{AccessBuilder, Authorization, HandlerContext};
use crate::extract::{AccessBuilder, EndpointAuthorization, HandlerContext};

#[derive(Deserialize, Serialize, Debug)]
#[serde(rename_all = "camelCase")]
Expand All @@ -14,21 +14,16 @@ pub struct RestoreAuthenticationResponse {

pub async fn restore_authentication(
HandlerContext(builder): HandlerContext,
// NOTE: these two lines *both* go to the DB and check the token for web-level access.
// We should probably only do this once.
AccessBuilder(access_builder): AccessBuilder,
Authorization(claim): Authorization,
EndpointAuthorization {
user, workspace_id, ..
}: EndpointAuthorization,
) -> SessionResult<Json<RestoreAuthenticationResponse>> {
let ctx = builder.build_head(access_builder).await?;

let workspace = Workspace::get_by_pk(&ctx, &claim.workspace_id())
let workspace = Workspace::get_by_pk(&ctx, &workspace_id)
.await?
.ok_or(SessionError::InvalidWorkspace(claim.workspace_id()))?;

let user = User::get_by_pk(&ctx, claim.user_id())
.await?
.ok_or(SessionError::InvalidUser(claim.user_id()))?;

.ok_or(SessionError::InvalidWorkspace(workspace_id))?;
let reply = RestoreAuthenticationResponse { user, workspace };

Ok(Json(reply))
Expand Down
38 changes: 38 additions & 0 deletions lib/sdf-server/src/service/whoami.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use axum::{response::IntoResponse, routing::get, Json, Router};
use dal::{UserPk, WorkspacePk};
use serde::{Deserialize, Serialize};
use si_jwt_public_key::SiJwt;

use crate::{
extract::{AuthorizedForAutomationRole, EndpointAuthorization, ValidatedToken},
AppState,
};

pub fn routes() -> Router<AppState> {
Router::new().route("/", get(whoami))
}

#[derive(Deserialize, Serialize, Debug)]
#[serde(rename_all = "camelCase")]
struct WhoamiResponse {
pub user_id: UserPk,
pub user_email: String,
pub workspace_id: WorkspacePk,
pub token: SiJwt,
}

async fn whoami(
// Just because this is the most permissive role we have right now
_: AuthorizedForAutomationRole,
ValidatedToken(token): ValidatedToken,
EndpointAuthorization {
workspace_id, user, ..
}: EndpointAuthorization,
) -> impl IntoResponse {
Json(WhoamiResponse {
workspace_id,
user_id: user.pk(),
user_email: user.email().clone(),
token,
})
}
8 changes: 5 additions & 3 deletions lib/sdf-server/src/service/ws/crdt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use y_sync::net::BroadcastGroup;

use super::WsError;
use crate::{
extract::{Nats, WsAuthorization},
extract::{EndpointAuthorization, Nats, TokenFromQueryParam},
nats_multiplexer::NatsMultiplexerClients,
};

Expand Down Expand Up @@ -62,16 +62,18 @@ pub struct Id {
id: String,
}

#[allow(clippy::too_many_arguments)]
pub async fn crdt(
wsu: WebSocketUpgrade,
Nats(nats): Nats,
WsAuthorization(claim): WsAuthorization,
_: TokenFromQueryParam, // This tells it to pull the token from the "token" param
auth: EndpointAuthorization,
Query(Id { id }): Query<Id>,
State(shutdown_token): State<CancellationToken>,
State(broadcast_groups): State<BroadcastGroups>,
State(nats_multiplexer_clients): State<NatsMultiplexerClients>,
) -> Result<impl IntoResponse, WsError> {
let workspace_pk = claim.workspace_id();
let workspace_pk = auth.workspace_id;
let channel_name = Subject::from(format!("crdt.{workspace_pk}.{id}"));

let receiver = nats_multiplexer_clients
Expand Down
7 changes: 4 additions & 3 deletions lib/sdf-server/src/service/ws/workspace_updates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,24 @@ use tokio_util::sync::CancellationToken;

use super::WsError;
use crate::{
extract::{Nats, WsAuthorization},
extract::{EndpointAuthorization, Nats, TokenFromQueryParam},
nats_multiplexer::NatsMultiplexerClients,
};

#[allow(clippy::unused_async)]
pub async fn workspace_updates(
wsu: WebSocketUpgrade,
Nats(nats): Nats,
WsAuthorization(claim): WsAuthorization,
_: TokenFromQueryParam, // This tells it to pull the token from the "token" param
auth: EndpointAuthorization,
State(shutdown_token): State<CancellationToken>,
State(channel_multiplexer_clients): State<NatsMultiplexerClients>,
) -> Result<impl IntoResponse, WsError> {
Ok(wsu.on_upgrade(move |socket| {
run_workspace_updates_proto(
socket,
nats,
claim.workspace_id(),
auth.workspace_id,
channel_multiplexer_clients.ws,
shutdown_token,
)
Expand Down
64 changes: 42 additions & 22 deletions lib/si-jwt-public-key/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,16 @@ pub enum SiJwtClaimRole {
Automation,
}

impl SiJwtClaimRole {
pub fn is_superset_of(&self, other: Self) -> bool {
match (self, other) {
(Self::Web, Self::Web | Self::Automation) => true,
(Self::Automation, Self::Automation) => true,
(Self::Automation, Self::Web) => false,
}
}
}

#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
#[serde(untagged)]
pub enum SiJwtClaims {
Expand All @@ -94,6 +104,9 @@ pub enum SiJwtClaims {
V1(SiJwtClaimsV1),
}

/** The whole token */
pub type SiJwt = JWTClaims<SiJwtClaims>;

#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct SiJwtClaimsV2 {
Expand Down Expand Up @@ -125,12 +138,15 @@ impl SiJwtClaims {
}
}

pub fn authorized_for(&self, required_role: SiJwtClaimRole) -> bool {
let role = match self {
pub fn role(&self) -> SiJwtClaimRole {
match self {
Self::V2(SiJwtClaimsV2 { role, .. }) => *role,
Self::V1(SiJwtClaimsV1 { .. }) => SiJwtClaimRole::Web,
};
role == required_role
}
}

pub fn authorized_for(&self, required_role: SiJwtClaimRole) -> bool {
self.role().is_superset_of(required_role)
}

pub fn for_web(user_id: UserPk, workspace_id: WorkspacePk) -> Self {
Expand All @@ -149,6 +165,14 @@ impl SiJwtClaims {
let claims = validate_bearer_token(public_key, token).await?;
Ok(claims.custom)
}

pub async fn from_raw_token(
public_key: JwtPublicSigningKeyChain,
token: impl Into<String>,
) -> JwtKeyResult<SiJwtClaims> {
let claims = validate_raw_token(public_key, token).await?;
Ok(claims.custom)
}
}

#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
Expand All @@ -160,23 +184,15 @@ pub enum JwtAlgo {

pub trait JwtPublicKeyVerify: std::fmt::Debug + Send + Sync {
fn algo(&self) -> JwtAlgo;
fn verify(
&self,
token: &str,
options: Option<VerificationOptions>,
) -> JwtKeyResult<JWTClaims<SiJwtClaims>>;
fn verify(&self, token: &str, options: Option<VerificationOptions>) -> JwtKeyResult<SiJwt>;
}

impl JwtPublicKeyVerify for RS256PublicKey {
fn algo(&self) -> JwtAlgo {
JwtAlgo::RS256
}

fn verify(
&self,
token: &str,
options: Option<VerificationOptions>,
) -> JwtKeyResult<JWTClaims<SiJwtClaims>> {
fn verify(&self, token: &str, options: Option<VerificationOptions>) -> JwtKeyResult<SiJwt> {
self.verify_token(token, options)
.map_err(|err| JwtPublicSigningKeyError::Verify(format!("{err}")))
}
Expand All @@ -187,11 +203,7 @@ impl JwtPublicKeyVerify for ES256PublicKey {
JwtAlgo::ES256
}

fn verify(
&self,
token: &str,
options: Option<VerificationOptions>,
) -> JwtKeyResult<JWTClaims<SiJwtClaims>> {
fn verify(&self, token: &str, options: Option<VerificationOptions>) -> JwtKeyResult<SiJwt> {
self.verify_token(token, options)
.map_err(|err| JwtPublicSigningKeyError::Verify(format!("{err}")))
}
Expand Down Expand Up @@ -223,7 +235,7 @@ impl JwtPublicSigningKeyChain {
&self,
token: &str,
options: Option<VerificationOptions>,
) -> JwtKeyResult<JWTClaims<SiJwtClaims>> {
) -> JwtKeyResult<SiJwt> {
match self.primary.verify(token, options.clone()) {
Ok(claims) => Ok(claims),
Err(err) => match self.secondary.as_ref() {
Expand All @@ -240,17 +252,25 @@ impl JwtPublicSigningKeyChain {
}
}

#[instrument(level = "debug", skip_all)]
pub async fn validate_bearer_token(
public_key: JwtPublicSigningKeyChain,
bearer_token: impl AsRef<str>,
) -> JwtKeyResult<JWTClaims<SiJwtClaims>> {
) -> JwtKeyResult<SiJwt> {
let token = bearer_token
.as_ref()
.strip_prefix("Bearer ")
.ok_or(JwtPublicSigningKeyError::BearerToken)?
.to_string();

validate_raw_token(public_key, token).await
}

#[instrument(level = "debug", skip_all)]
pub async fn validate_raw_token(
public_key: JwtPublicSigningKeyChain,
token: impl Into<String>,
) -> JwtKeyResult<SiJwt> {
let token = token.into();
let claims =
tokio::task::spawn_blocking(move || public_key.verify_token(&token, None)).await??;

Expand Down

0 comments on commit 8aa67cf

Please sign in to comment.