Skip to content

Commit

Permalink
Status handling improvements (#3)
Browse files Browse the repository at this point in the history
Co-authored-by: Luke Mann <[email protected]>
Co-authored-by: Miguel Piedrafita <[email protected]>
  • Loading branch information
3 people authored Nov 9, 2023
1 parent d3b7e92 commit 3217997
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 51 deletions.
42 changes: 41 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,41 @@
# A bridge between the World ID SDK and the World App
# Wallet Bridge

> **Warning** This project is still in early alpha.
An end-to-end encrypted bridge between the World ID SDK and World App. This bridge is used to pass zero-knowledge proofs for World ID verifications.

More details in the [docs](https://docs.worldcoin.org/further-reading/protocol-internals).

## Flow

```mermaid
sequenceDiagram
IDKit ->> Bridge: POST /request
Bridge ->> IDKit: <id>
IDKit ->> Bridge: Poll for updates GET /response/:id
WorldApp ->> Bridge: GET /request/:id
Bridge ->> WorldApp: <request>
WorldApp ->> Bridge: PUT /response/:id
IDKit ->> Bridge: Poll for updates GET /response/:id
Bridge ->> IDKit: <response>
```

```mermaid
flowchart
A[IDKit posts request /request] --> B[Request is stored in the bridge with status = initialized]
B --> C[IDKit starts polling /response/:id]
C --> D[User scans QR code with requestId & decryption key]
D --> E[App fetches request at /request/:id]
E --> F[Bridge updates status = retrieved]
F -- Status updated = retrieved --> C
F --> G[App generates proof and PUTs to /response/:id]
G --> H[Bridge stores response. One-time retrieval]
H -- Response provided --> C
```

## Endpoints

- `POST /request`: Called by IDKit. Initializes a proof verification request.
- `GET /request/:id`: Called by World App. Used to fetch the proof verification request. One time use.
- `PUT /response/:id`: Called by World App. Used to send the proof back to the application.
- `GET /response/:id`: Called by IDKit. Continuous pulling to fetch the status of the request and the response if available. Response can only be retrieved once.
3 changes: 1 addition & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ use std::env;

mod routes;
mod server;

const EXPIRE_AFTER_SECONDS: usize = 60;
mod utils;

#[tokio::main]
async fn main() {
Expand Down
92 changes: 53 additions & 39 deletions src/routes/request.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,43 @@
use axum::{
extract::Path,
http::{Method, StatusCode},
routing::head,
routing::{head, post},
Extension, Json, Router,
};
use redis::{aio::ConnectionManager, AsyncCommands};
use tower_http::cors::{AllowHeaders, Any, CorsLayer};
use uuid::Uuid;

use crate::EXPIRE_AFTER_SECONDS;
use crate::utils::{
handle_redis_error, RequestPayload, RequestStatus, EXPIRE_AFTER_SECONDS, REQ_STATUS_PREFIX,
};

const REQ_PREFIX: &str = "req:";

#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct Request {
iv: String,
payload: String,
#[derive(Debug, serde::Serialize)]
struct CustomResponse {
request_id: Uuid,
}

pub fn handler() -> Router {
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_headers(AllowHeaders::any())
.allow_methods([Method::PUT, Method::HEAD]);
.allow_methods([Method::POST, Method::HEAD]);

Router::new().route(
"/request/:request_id",
head(has_request)
.get(get_request)
.put(insert_request)
.layer(cors),
)
// You must chain the routes to the same Router instance
Router::new()
.route("/request", post(insert_request))
.route("/request/:request_id", head(has_request).get(get_request))
.layer(cors) // Apply the CORS layer to all routes
}

async fn has_request(
Path(request_id): Path<Uuid>,
Extension(mut redis): Extension<ConnectionManager>,
) -> StatusCode {
let Ok(exists) = redis
.exists::<_, bool>(format!("{REQ_PREFIX}{request_id}"))
.exists::<_, bool>(format!("{REQ_STATUS_PREFIX}{request_id}"))
.await
else {
return StatusCode::INTERNAL_SERVER_ERROR;
Expand All @@ -54,42 +53,57 @@ async fn has_request(
async fn get_request(
Path(request_id): Path<Uuid>,
Extension(mut redis): Extension<ConnectionManager>,
) -> Result<Json<Request>, StatusCode> {
) -> Result<Json<RequestPayload>, StatusCode> {
let value = redis
.get_del::<_, Option<Vec<u8>>>(format!("{REQ_PREFIX}{request_id}"))
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
.map_err(handle_redis_error)?;

if value.is_none() {
return Err(StatusCode::NOT_FOUND);
}

//ANCHOR - Update the status of the request
redis
.set_ex::<_, _, ()>(
format!("{REQ_STATUS_PREFIX}{request_id}"),
RequestStatus::Retrieved.to_string(),
EXPIRE_AFTER_SECONDS,
)
.await
.map_err(handle_redis_error)?;

value.map_or_else(
|| Err(StatusCode::NOT_FOUND),
|value| {
serde_json::from_slice(&value).map_or(Err(StatusCode::INTERNAL_SERVER_ERROR), |value| {
Ok(Json(value))
})
},
)
serde_json::from_slice(&value.unwrap())
.map_or(Err(StatusCode::INTERNAL_SERVER_ERROR), |value| {
Ok(Json(value))
})
}

async fn insert_request(
Path(request_id): Path<Uuid>,
Extension(mut redis): Extension<ConnectionManager>,
Json(request): Json<Request>,
) -> Result<StatusCode, StatusCode> {
if !redis
.set_nx::<_, _, bool>(
format!("{REQ_PREFIX}{request_id}"),
serde_json::to_vec(&request).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
Json(request): Json<RequestPayload>,
) -> Result<Json<CustomResponse>, StatusCode> {
let request_id = Uuid::new_v4();

//ANCHOR - Set request status
redis
.set_ex::<_, _, ()>(
format!("{REQ_STATUS_PREFIX}{request_id}"),
RequestStatus::Initialized.to_string(),
EXPIRE_AFTER_SECONDS,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
{
return Ok(StatusCode::CONFLICT);
}
.map_err(handle_redis_error)?;

//ANCHOR - Store payload
redis
.expire::<_, ()>(format!("{REQ_PREFIX}{request_id}"), EXPIRE_AFTER_SECONDS)
.set_ex::<_, _, ()>(
format!("{REQ_PREFIX}{request_id}"),
serde_json::to_vec(&request).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
EXPIRE_AFTER_SECONDS,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
.map_err(handle_redis_error)?;

Ok(StatusCode::CREATED)
Ok(Json(CustomResponse { request_id }))
}
70 changes: 61 additions & 9 deletions src/routes/response.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
use std::str::FromStr;

use axum::{
body::Bytes,
extract::Path,
http::{Method, StatusCode},
routing::get,
Extension, Router,
Extension, Json, Router,
};
use redis::{aio::ConnectionManager, AsyncCommands};
use std::str;
use tower_http::cors::{AllowHeaders, Any, CorsLayer};
use uuid::Uuid;

use crate::EXPIRE_AFTER_SECONDS;
use crate::utils::{
handle_redis_error, RequestPayload, RequestStatus, EXPIRE_AFTER_SECONDS, REQ_STATUS_PREFIX,
};

const RES_PREFIX: &str = "res:";

#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct Response {
status: RequestStatus,
response: Option<RequestPayload>,
}

pub fn handler() -> Router {
let cors = CorsLayer::new()
.allow_origin(Any)
Expand All @@ -28,24 +38,59 @@ pub fn handler() -> Router {
async fn get_response(
Path(request_id): Path<Uuid>,
Extension(mut redis): Extension<ConnectionManager>,
) -> Result<Vec<u8>, StatusCode> {
) -> Result<Json<Response>, StatusCode> {
//ANCHOR - Return the response if available
let value = redis
.get_del::<_, Option<Vec<u8>>>(format!("{RES_PREFIX}{request_id}"))
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
.map_err(handle_redis_error)?;

if let Some(value) = value {
return serde_json::from_slice(&value).map_or(
Err(StatusCode::INTERNAL_SERVER_ERROR),
|value| {
Ok(Json(Response {
response: value,
status: RequestStatus::Completed,
}))
},
);
}

value.ok_or(StatusCode::NOT_FOUND)
//ANCHOR - Return the current status for the request
let Some(status) = redis
.get::<_, Option<String>>(format!("{REQ_STATUS_PREFIX}{request_id}"))
.await
.map_err(handle_redis_error)?
else {
//ANCHOR - Request ID does not exist
return Err(StatusCode::NOT_FOUND);
};

let status: RequestStatus = RequestStatus::from_str(&status).map_err(|e| {
tracing::error!("Failed to parse status: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;

Ok(Json(Response {
status,
response: None,
}))
}

async fn insert_response(
Path(request_id): Path<Uuid>,
Extension(mut redis): Extension<ConnectionManager>,
body: Bytes,
Json(request): Json<RequestPayload>,
) -> Result<StatusCode, StatusCode> {
//ANCHOR - Store the response
if !redis
.set_nx::<_, _, bool>(format!("{RES_PREFIX}{request_id}"), body.to_vec())
.set_nx::<_, _, bool>(
format!("{RES_PREFIX}{request_id}"),
serde_json::to_vec(&request).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.map_err(handle_redis_error)?
{
return Ok(StatusCode::CONFLICT);
}
Expand All @@ -55,5 +100,12 @@ async fn insert_response(
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

//ANCHOR - Delete status
//NOTE - We can delete the status now as the presence of a response implies the request is complete
redis
.del::<_, Option<Vec<u8>>>(format!("{REQ_STATUS_PREFIX}{request_id}"))
.await
.map_err(handle_redis_error)?;

Ok(StatusCode::CREATED)
}
50 changes: 50 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use std::{fmt::Display, str::FromStr};

use axum::http::StatusCode;
use redis::RedisError;

pub const EXPIRE_AFTER_SECONDS: usize = 180;
pub const REQ_STATUS_PREFIX: &str = "req:status:";

#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RequestStatus {
Initialized,
Retrieved,
Completed,
}

impl Display for RequestStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Retrieved => write!(f, "retrieved"),
Self::Completed => write!(f, "completed"),
Self::Initialized => write!(f, "initialized"),
}
}
}

impl FromStr for RequestStatus {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"initialized" => Ok(Self::Initialized),
"retrieved" => Ok(Self::Retrieved),
"completed" => Ok(Self::Completed),
_ => Err(format!("Invalid status: {s}")),
}
}
}

#[derive(Debug, serde::Deserialize, serde::Serialize)]
pub struct RequestPayload {
iv: String,
payload: String,
}

#[allow(clippy::needless_pass_by_value)]
pub fn handle_redis_error(e: RedisError) -> StatusCode {
tracing::error!("Redis error: {e}");
StatusCode::INTERNAL_SERVER_ERROR
}

0 comments on commit 3217997

Please sign in to comment.