diff --git a/src/app.rs b/src/app.rs index 9e6712c..a8ee92e 100644 --- a/src/app.rs +++ b/src/app.rs @@ -53,8 +53,10 @@ impl App { .split_for_parts(); #[cfg(feature = "openapi")] - router.merge(SwaggerUi::new("/swagger-ui") - .url("/api-docs/openapi.json", openapi.clone())) + let router = router.merge(SwaggerUi::new("/swagger-ui") + .url("/api-docs/openapi.json", openapi.clone())); + + router } } diff --git a/src/handlers.rs b/src/handlers.rs index 76425e4..7c896c6 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -32,7 +32,8 @@ use crate::claims::{ClientAssertion, JWTBearerAssertion}; ), responses( (status = OK, description = "Success", body = TokenResponse, content_type = "application/json"), - (status = BAD_REQUEST, description = "Bad request", body = ErrorResponse, content_type = "application/json") + (status = BAD_REQUEST, description = "Bad request", body = ErrorResponse, content_type = "application/json"), + (status = INTERNAL_SERVER_ERROR, description = "Server error", body = ErrorResponse, content_type = "application/json"), ) )] pub async fn token( @@ -101,6 +102,7 @@ pub async fn introspect( Ok((StatusCode::OK, Json(claims))) } +// TODO: rename to something more descriptive #[derive(serde::Deserialize)] struct Claims { iss: String, @@ -209,15 +211,24 @@ pub enum ApiError { impl IntoResponse for ApiError { fn into_response(self) -> Response { match &self { + // network error while talking to upstream ApiError::UpstreamRequest(err) => ( err.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), self.to_string(), ).into_response(), + // upstream responded with a non-json error? ApiError::JSON(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response(), - ApiError::Upstream{status_code, error} => { + // upstream successful responded with an oauth error + ApiError::Upstream { status_code, error } => { (status_code.clone(), Json(error.clone())).into_response() + + // TODO: map status code to the correct error code + //400, 500 -> verbatim + //* -> 500 } + // failed to validate token for introspection ApiError::Validate(_) => (StatusCode::BAD_REQUEST, self.to_string()).into_response(), + // failed to sign JWT assertion ApiError::Sign => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response(), } } diff --git a/src/identity_provider.rs b/src/identity_provider.rs index fe53f23..ab11ccf 100644 --- a/src/identity_provider.rs +++ b/src/identity_provider.rs @@ -34,17 +34,62 @@ pub enum TokenType { /// RFC6749 token response from section 5.2. #[derive(Serialize, Deserialize, ToSchema, Debug, Clone)] pub struct ErrorResponse { - pub error: String, + pub error: OAuthErrorCode, #[serde(rename = "error_description")] pub description: String, } impl Display for ErrorResponse { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}: {}", self.error, self.description) + write!(f, "{}: {}", serde_json::to_string(&self.error).unwrap(), self.description) } } +impl From for ErrorResponse { + fn from(err: ApiError) -> Self { + match err { + ApiError::Sign => ErrorResponse { + error: OAuthErrorCode::ServerError, + description: "Failed to sign assertion".to_string(), + }, + ApiError::UpstreamRequest(err) => ErrorResponse { + error: OAuthErrorCode::ServerError, + description: format!("Upstream request failed: {}", err), + }, + ApiError::JSON(err) => ErrorResponse { + error: OAuthErrorCode::ServerError, + description: format!("Failed to parse JSON: {}", err), + }, + ApiError::Upstream { status_code: _status_code, error } => ErrorResponse { + error: error.error, + description: error.description, + }, + ApiError::Validate(_) => ErrorResponse { + error: OAuthErrorCode::ServerError, + description: "Failed to validate token".to_string(), + } + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, ToSchema)] +pub enum OAuthErrorCode { + #[serde(rename = "invalid_request")] + InvalidRequest, + #[serde(rename = "invalid_client")] + InvalidClient, + #[serde(rename = "invalid_grant")] + InvalidGrant, + #[serde(rename = "unauthorized_client")] + UnauthorizedClient, + #[serde(rename = "unsupported_grant_type")] + UnsupportedGrantType, + #[serde(rename = "invalid_scope")] + InvalidScope, + #[serde(rename = "server_error")] + ServerError, +} + /// Identity provider for use with token fetch, exchange and validation. #[derive(Deserialize, Serialize, ToSchema, Clone, Debug)] pub enum IdentityProvider { @@ -260,7 +305,7 @@ where let err: ErrorResponse = response.json().await.map_err(ApiError::JSON)?; let err = ApiError::Upstream { status_code: status, - error: err + error: err, }; error!("get_token_with_config: {}", err); return Err(err);