diff --git a/Cargo.toml b/Cargo.toml index b9acfdb..400c4e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,19 +11,20 @@ license = "AGPL-3.0-or-later" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -axum = "0.6" +axum = "0.7.4" features = "0.10.0" -futures = "0.3.23" -hyper = "0.14.20" +futures = "0.3" +hyper = "1" tokio = { version = "1.20.1", features = ["rt", "rt-multi-thread", "macros"] } tower = "0.4.13" -tower-http = { version = "0.4", features = ["trace"] } -trace = "0.1.6" +tower-http = { version = "0.5", features = ["trace"] } +trace = "0.1" privacypass = { git = "https://github.com/raphaelrobert/privacypass" } -async-trait = "0.1.57" -base64 = "0.21.0" -actix-web = "4.2.1" +async-trait = "0.1" +base64 = "0.21" +actix-web = "4.4" futures-util = "0.3.24" +http = "1" [dev-dependencies] -reqwest = "0.11.11" +reqwest = "0.11" diff --git a/src/actix_middleware.rs b/src/actix_middleware.rs index 4af0b4d..bd894e5 100644 --- a/src/actix_middleware.rs +++ b/src/actix_middleware.rs @@ -6,7 +6,6 @@ use actix_web::{ dev::{forward_ready, Response, Service, ServiceRequest, ServiceResponse, Transform}, error::{ErrorInternalServerError, ErrorUnauthorized}, http::{ - self, header::{self, HeaderValue}, StatusCode, }, @@ -31,7 +30,10 @@ use privacypass::{ Deserialize, NonceStore, }; -use crate::state::{PrivacyPassProvider, PrivacyPassState}; +use crate::{ + state::{PrivacyPassProvider, PrivacyPassState}, + utils::{header_name_to_http02, header_value_to_http02, header_value_to_http10, uri_to_http10}, +}; // There are two steps in middleware processing. // 1. Middleware initialization, middleware factory gets called with @@ -103,8 +105,9 @@ where let authorization = req.headers_mut().remove(header::AUTHORIZATION).next(); // Deserialize the token from the authorization header. - let token_option = - authorization.and_then(|header_value| parse_authorization_header(&header_value).ok()); + let token_option = authorization.and_then(|header_value| { + parse_authorization_header(&header_value_to_http10(header_value)).ok() + }); // If the token is present, then authenticate the token. if let Some(token) = token_option { @@ -123,11 +126,17 @@ where // If the token is not present, then issue a challenge. let public_key = state.public_key(); let token_key = serialize_public_key(*public_key); - let build_res = - build_www_authenticate_header(&challenge(req.request().uri()), &token_key, None); + let build_res = build_www_authenticate_header( + &challenge(&uri_to_http10(req.request().uri())), + &token_key, + None, + ); if let Ok((header_name, header_value)) = build_res { let response = HttpResponse::build(StatusCode::OK) - .append_header((header_name, header_value)) + .append_header(( + header_name_to_http02(header_name), + header_value_to_http02(header_value), + )) .finish(); let sr: ServiceResponse<B> = req.into_response(response); diff --git a/src/lib.rs b/src/lib.rs index 19ab134..c6b92bd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,3 +6,5 @@ pub mod actix_middleware; pub mod axum_middleware; pub mod memory_stores; pub mod state; + +pub mod utils; diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..ff7f8de --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: 2024 Phoenix R&D GmbH <hello@phnx.im> +// +// SPDX-License-Identifier: AGPL-3.0-or-later +//! This module contains utility functions for converting between actix/reqwest +//! and http types. This is necessary because actix/reqwest use a different http +//! library to hyper. + +use std::str::FromStr; + +pub fn header_name_to_http10( + header_name: actix_web::http::header::HeaderName, +) -> http::header::HeaderName { + http::header::HeaderName::from_str(header_name.as_str()).unwrap() +} + +pub fn header_value_to_http10( + header_value: actix_web::http::header::HeaderValue, +) -> http::header::HeaderValue { + http::header::HeaderValue::from_str(header_value.to_str().unwrap()).unwrap() +} + +pub fn header_name_to_http02( + header_name: http::header::HeaderName, +) -> actix_web::http::header::HeaderName { + actix_web::http::header::HeaderName::from_str(header_name.as_str()).unwrap() +} + +pub fn header_value_to_http02( + header_value: http::header::HeaderValue, +) -> actix_web::http::header::HeaderValue { + actix_web::http::header::HeaderValue::from_str(header_value.to_str().unwrap()).unwrap() +} + +pub fn uri_to_http10(uri: &actix_web::http::Uri) -> http::Uri { + http::Uri::from_str(uri.to_string().as_str()).unwrap() +} diff --git a/tests/actix.rs b/tests/actix.rs index 8633f78..6302964 100644 --- a/tests/actix.rs +++ b/tests/actix.rs @@ -15,6 +15,7 @@ use privacypass_middleware::{ actix_middleware::*, memory_stores::{MemoryKeyStore, MemoryNonceStore}, state::PrivacyPassState, + utils::{header_name_to_http02, header_value_to_http02, header_value_to_http10}, }; use std::{sync::Arc, thread}; @@ -79,7 +80,7 @@ async fn full_cycle_actix() { // Extract token challenge from header let header_name = http::header::WWW_AUTHENTICATE; - let header_value = res.headers().get(header_name).unwrap().clone(); + let header_value = header_value_to_http10(res.headers().get(header_name).unwrap().clone()); assert_eq!(res.bytes().await.unwrap().len(), 0); @@ -128,6 +129,9 @@ async fn full_cycle_actix() { // Redeem a token let (header_name, header_value) = build_authorization_header(&tokens[0]).unwrap(); + let header_name = header_name_to_http02(header_name); + let header_value = header_value_to_http02(header_value); + let res = http_client .get("http://localhost:3001/origin") .header(header_name.clone(), header_value.clone()) diff --git a/tests/axum.rs b/tests/axum.rs index fe3d666..9c42ab7 100644 --- a/tests/axum.rs +++ b/tests/axum.rs @@ -2,20 +2,26 @@ // // SPDX-License-Identifier: AGPL-3.0-or-later -use axum::http::{self, HeaderValue}; use axum::{ routing::{get, post}, Extension, Router, }; -use hyper::StatusCode; use privacypass::{ auth::{authenticate::parse_www_authenticate_header, authorize::build_authorization_header}, batched_tokens::{server::*, TokenResponse}, Serialize, }; -use privacypass_middleware::memory_stores::{MemoryKeyStore, MemoryNonceStore}; -use privacypass_middleware::{axum_middleware::*, state::PrivacyPassState}; -use std::{net::SocketAddr, sync::Arc}; +use privacypass_middleware::{ + axum_middleware::*, + memory_stores::{MemoryKeyStore, MemoryNonceStore}, + state::PrivacyPassState, + utils::{header_name_to_http02, header_value_to_http02, header_value_to_http10}, +}; +use reqwest::{ + header::{HeaderValue, CONTENT_TYPE}, + StatusCode, +}; +use std::sync::Arc; use tower_http::trace::TraceLayer; /// Sample server using axum. The server exposes two endpoints: @@ -38,11 +44,10 @@ async fn run_server() { .route_layer(Extension(privacy_pass_state.clone())) .layer(TraceLayer::new_for_http()); - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - axum::Server::bind(&addr) - .serve(app.into_make_service()) + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); + axum::serve(listener, app).await.unwrap(); } /// Origin endpoint @@ -73,8 +78,8 @@ async fn full_cycle_axum() { assert_eq!(res.status(), StatusCode::OK); // Extract token challenge from header - let header_name = http::header::WWW_AUTHENTICATE; - let header_value = res.headers().get(header_name).unwrap().clone(); + let header_name = reqwest::header::WWW_AUTHENTICATE; + let header_value = header_value_to_http10(res.headers().get(header_name).unwrap().clone()); assert_eq!(res.bytes().await.unwrap().len(), 0); @@ -98,7 +103,7 @@ async fn full_cycle_axum() { let res = http_client .post("http://localhost:3000/issuer") .header( - http::header::CONTENT_TYPE, + CONTENT_TYPE, HeaderValue::from_static("message/token-request"), ) .body(token_request.tls_serialize_detached().unwrap()) @@ -108,7 +113,7 @@ async fn full_cycle_axum() { assert_eq!(res.status(), StatusCode::OK); assert_eq!( - res.headers().get(http::header::CONTENT_TYPE).unwrap(), + res.headers().get(reqwest::header::CONTENT_TYPE).unwrap(), HeaderValue::from_static("message/token-response") ); @@ -123,6 +128,9 @@ async fn full_cycle_axum() { // Redeem a token let (header_name, header_value) = build_authorization_header(&tokens[0]).unwrap(); + let header_name = header_name_to_http02(header_name); + let header_value = header_value_to_http02(header_value); + let res = http_client .get("http://localhost:3000/origin") .header(header_name.clone(), header_value.clone())