From 3217997790f041de633fac3faaacee8053a76ef6 Mon Sep 17 00:00:00 2001 From: pdtfh <149602456+pdtfh@users.noreply.github.com> Date: Thu, 9 Nov 2023 11:44:09 -0800 Subject: [PATCH] Status handling improvements (#3) Co-authored-by: Luke Mann Co-authored-by: Miguel Piedrafita --- README.md | 42 ++++++++++++++++++- src/main.rs | 3 +- src/routes/request.rs | 92 ++++++++++++++++++++++++------------------ src/routes/response.rs | 70 +++++++++++++++++++++++++++----- src/utils.rs | 50 +++++++++++++++++++++++ 5 files changed, 206 insertions(+), 51 deletions(-) create mode 100644 src/utils.rs diff --git a/README.md b/README.md index 7d0a79b..92cbbfd 100644 --- a/README.md +++ b/README.md @@ -1 +1,41 @@ -# A bridge between the World ID SDK and the World App +# Wallet Bridge + +> **Warning** This project is still in early alpha. + +An end-to-end encrypted bridge between the World ID SDK and World App. This bridge is used to pass zero-knowledge proofs for World ID verifications. + +More details in the [docs](https://docs.worldcoin.org/further-reading/protocol-internals). + +## Flow + +```mermaid +sequenceDiagram +IDKit ->> Bridge: POST /request +Bridge ->> IDKit: +IDKit ->> Bridge: Poll for updates GET /response/:id +WorldApp ->> Bridge: GET /request/:id +Bridge ->> WorldApp: +WorldApp ->> Bridge: PUT /response/:id +IDKit ->> Bridge: Poll for updates GET /response/:id +Bridge ->> IDKit: +``` + +```mermaid +flowchart +A[IDKit posts request /request] --> B[Request is stored in the bridge with status = initialized] +B --> C[IDKit starts polling /response/:id] +C --> D[User scans QR code with requestId & decryption key] +D --> E[App fetches request at /request/:id] +E --> F[Bridge updates status = retrieved] +F -- Status updated = retrieved --> C +F --> G[App generates proof and PUTs to /response/:id] +G --> H[Bridge stores response. One-time retrieval] +H -- Response provided --> C +``` + +## Endpoints + +- `POST /request`: Called by IDKit. Initializes a proof verification request. +- `GET /request/:id`: Called by World App. Used to fetch the proof verification request. One time use. +- `PUT /response/:id`: Called by World App. Used to send the proof back to the application. +- `GET /response/:id`: Called by IDKit. Continuous pulling to fetch the status of the request and the response if available. Response can only be retrieved once. diff --git a/src/main.rs b/src/main.rs index 4abef64..591ba05 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,8 +6,7 @@ use std::env; mod routes; mod server; - -const EXPIRE_AFTER_SECONDS: usize = 60; +mod utils; #[tokio::main] async fn main() { diff --git a/src/routes/request.rs b/src/routes/request.rs index 0f2d12e..3119e0b 100644 --- a/src/routes/request.rs +++ b/src/routes/request.rs @@ -1,36 +1,35 @@ use axum::{ extract::Path, http::{Method, StatusCode}, - routing::head, + routing::{head, post}, Extension, Json, Router, }; use redis::{aio::ConnectionManager, AsyncCommands}; use tower_http::cors::{AllowHeaders, Any, CorsLayer}; use uuid::Uuid; -use crate::EXPIRE_AFTER_SECONDS; +use crate::utils::{ + handle_redis_error, RequestPayload, RequestStatus, EXPIRE_AFTER_SECONDS, REQ_STATUS_PREFIX, +}; const REQ_PREFIX: &str = "req:"; -#[derive(Debug, serde::Deserialize, serde::Serialize)] -struct Request { - iv: String, - payload: String, +#[derive(Debug, serde::Serialize)] +struct CustomResponse { + request_id: Uuid, } pub fn handler() -> Router { let cors = CorsLayer::new() .allow_origin(Any) .allow_headers(AllowHeaders::any()) - .allow_methods([Method::PUT, Method::HEAD]); + .allow_methods([Method::POST, Method::HEAD]); - Router::new().route( - "/request/:request_id", - head(has_request) - .get(get_request) - .put(insert_request) - .layer(cors), - ) + // You must chain the routes to the same Router instance + Router::new() + .route("/request", post(insert_request)) + .route("/request/:request_id", head(has_request).get(get_request)) + .layer(cors) // Apply the CORS layer to all routes } async fn has_request( @@ -38,7 +37,7 @@ async fn has_request( Extension(mut redis): Extension, ) -> StatusCode { let Ok(exists) = redis - .exists::<_, bool>(format!("{REQ_PREFIX}{request_id}")) + .exists::<_, bool>(format!("{REQ_STATUS_PREFIX}{request_id}")) .await else { return StatusCode::INTERNAL_SERVER_ERROR; @@ -54,42 +53,57 @@ async fn has_request( async fn get_request( Path(request_id): Path, Extension(mut redis): Extension, -) -> Result, StatusCode> { +) -> Result, StatusCode> { let value = redis .get_del::<_, Option>>(format!("{REQ_PREFIX}{request_id}")) .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + .map_err(handle_redis_error)?; + + if value.is_none() { + return Err(StatusCode::NOT_FOUND); + } + + //ANCHOR - Update the status of the request + redis + .set_ex::<_, _, ()>( + format!("{REQ_STATUS_PREFIX}{request_id}"), + RequestStatus::Retrieved.to_string(), + EXPIRE_AFTER_SECONDS, + ) + .await + .map_err(handle_redis_error)?; - value.map_or_else( - || Err(StatusCode::NOT_FOUND), - |value| { - serde_json::from_slice(&value).map_or(Err(StatusCode::INTERNAL_SERVER_ERROR), |value| { - Ok(Json(value)) - }) - }, - ) + serde_json::from_slice(&value.unwrap()) + .map_or(Err(StatusCode::INTERNAL_SERVER_ERROR), |value| { + Ok(Json(value)) + }) } async fn insert_request( - Path(request_id): Path, Extension(mut redis): Extension, - Json(request): Json, -) -> Result { - if !redis - .set_nx::<_, _, bool>( - format!("{REQ_PREFIX}{request_id}"), - serde_json::to_vec(&request).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?, + Json(request): Json, +) -> Result, StatusCode> { + let request_id = Uuid::new_v4(); + + //ANCHOR - Set request status + redis + .set_ex::<_, _, ()>( + format!("{REQ_STATUS_PREFIX}{request_id}"), + RequestStatus::Initialized.to_string(), + EXPIRE_AFTER_SECONDS, ) .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? - { - return Ok(StatusCode::CONFLICT); - } + .map_err(handle_redis_error)?; + //ANCHOR - Store payload redis - .expire::<_, ()>(format!("{REQ_PREFIX}{request_id}"), EXPIRE_AFTER_SECONDS) + .set_ex::<_, _, ()>( + format!("{REQ_PREFIX}{request_id}"), + serde_json::to_vec(&request).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?, + EXPIRE_AFTER_SECONDS, + ) .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + .map_err(handle_redis_error)?; - Ok(StatusCode::CREATED) + Ok(Json(CustomResponse { request_id })) } diff --git a/src/routes/response.rs b/src/routes/response.rs index f612351..f009f1a 100644 --- a/src/routes/response.rs +++ b/src/routes/response.rs @@ -1,18 +1,28 @@ +use std::str::FromStr; + use axum::{ - body::Bytes, extract::Path, http::{Method, StatusCode}, routing::get, - Extension, Router, + Extension, Json, Router, }; use redis::{aio::ConnectionManager, AsyncCommands}; +use std::str; use tower_http::cors::{AllowHeaders, Any, CorsLayer}; use uuid::Uuid; -use crate::EXPIRE_AFTER_SECONDS; +use crate::utils::{ + handle_redis_error, RequestPayload, RequestStatus, EXPIRE_AFTER_SECONDS, REQ_STATUS_PREFIX, +}; const RES_PREFIX: &str = "res:"; +#[derive(Debug, serde::Deserialize, serde::Serialize)] +struct Response { + status: RequestStatus, + response: Option, +} + pub fn handler() -> Router { let cors = CorsLayer::new() .allow_origin(Any) @@ -28,24 +38,59 @@ pub fn handler() -> Router { async fn get_response( Path(request_id): Path, Extension(mut redis): Extension, -) -> Result, StatusCode> { +) -> Result, StatusCode> { + //ANCHOR - Return the response if available let value = redis .get_del::<_, Option>>(format!("{RES_PREFIX}{request_id}")) .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + .map_err(handle_redis_error)?; + + if let Some(value) = value { + return serde_json::from_slice(&value).map_or( + Err(StatusCode::INTERNAL_SERVER_ERROR), + |value| { + Ok(Json(Response { + response: value, + status: RequestStatus::Completed, + })) + }, + ); + } - value.ok_or(StatusCode::NOT_FOUND) + //ANCHOR - Return the current status for the request + let Some(status) = redis + .get::<_, Option>(format!("{REQ_STATUS_PREFIX}{request_id}")) + .await + .map_err(handle_redis_error)? + else { + //ANCHOR - Request ID does not exist + return Err(StatusCode::NOT_FOUND); + }; + + let status: RequestStatus = RequestStatus::from_str(&status).map_err(|e| { + tracing::error!("Failed to parse status: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + Ok(Json(Response { + status, + response: None, + })) } async fn insert_response( Path(request_id): Path, Extension(mut redis): Extension, - body: Bytes, + Json(request): Json, ) -> Result { + //ANCHOR - Store the response if !redis - .set_nx::<_, _, bool>(format!("{RES_PREFIX}{request_id}"), body.to_vec()) + .set_nx::<_, _, bool>( + format!("{RES_PREFIX}{request_id}"), + serde_json::to_vec(&request).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?, + ) .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .map_err(handle_redis_error)? { return Ok(StatusCode::CONFLICT); } @@ -55,5 +100,12 @@ async fn insert_response( .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + //ANCHOR - Delete status + //NOTE - We can delete the status now as the presence of a response implies the request is complete + redis + .del::<_, Option>>(format!("{REQ_STATUS_PREFIX}{request_id}")) + .await + .map_err(handle_redis_error)?; + Ok(StatusCode::CREATED) } diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..e9b7247 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,50 @@ +use std::{fmt::Display, str::FromStr}; + +use axum::http::StatusCode; +use redis::RedisError; + +pub const EXPIRE_AFTER_SECONDS: usize = 180; +pub const REQ_STATUS_PREFIX: &str = "req:status:"; + +#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum RequestStatus { + Initialized, + Retrieved, + Completed, +} + +impl Display for RequestStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Retrieved => write!(f, "retrieved"), + Self::Completed => write!(f, "completed"), + Self::Initialized => write!(f, "initialized"), + } + } +} + +impl FromStr for RequestStatus { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "initialized" => Ok(Self::Initialized), + "retrieved" => Ok(Self::Retrieved), + "completed" => Ok(Self::Completed), + _ => Err(format!("Invalid status: {s}")), + } + } +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub struct RequestPayload { + iv: String, + payload: String, +} + +#[allow(clippy::needless_pass_by_value)] +pub fn handle_redis_error(e: RedisError) -> StatusCode { + tracing::error!("Redis error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR +}