diff --git a/Cargo.toml b/Cargo.toml index 999dee1..6d817f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,20 +24,20 @@ clap = { version = "4.4", features = ["derive"] } csv = "1.3" email_address = "0.2" fs4 = "0.7" -hyper = "1.0" +hyper = "1.1" hyper-util = "0.1" lettre = { version = "0.11", default-features = false, features = ["builder", "smtp-transport", "rustls-tls"] } -oauth2 = { version = "4.4", default-features = false, features = ["rustls-tls"] } +oauth2 = { version = "4.4", default-features = false } rand = "0.8" reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "json"] } roxmltree = "0.19" rusqlite = { version = "0.30", features = ["bundled"], optional = true } -rustls = "0.21" -rustls-pemfile = "1.0" +rustls = "0.22" +rustls-pemfile = "2.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -tokio = { version = "1.35", features = ["full"] } -tokio-rustls = "0.24" +tokio = { version = "1.36", features = ["fs", "sync", "time", "macros", "rt-multi-thread"]} +tokio-rustls = "0.25" tower = { version = "0.4", features = ["util", "timeout"] } tower-http = { version = "0.5", features = ["fs", "trace", "compression-deflate"] } tower-service = "0.3" diff --git a/src/error.rs b/src/error.rs index ba0126c..e92acd0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,5 @@ +use std::fmt; + use axum::response::IntoResponse; use axum::Json; use hyper::StatusCode; @@ -48,6 +50,14 @@ pub enum Error { UnsupportedProjectVersion, } +impl std::error::Error for Error {} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + #[cfg(feature = "sqlite")] #[allow(deprecated)] impl From for Error { diff --git a/src/server/auth.rs b/src/server/auth.rs index fa13488..4ad89a4 100644 --- a/src/server/auth.rs +++ b/src/server/auth.rs @@ -15,10 +15,9 @@ use base64::engine::general_purpose::STANDARD as BASE64; use base64::Engine; use hyper::{HeaderMap, StatusCode}; use oauth2::basic::BasicClient; -use oauth2::reqwest::async_http_client; use oauth2::{ - AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, Scope, - TokenResponse, TokenUrl, + AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, HttpRequest, HttpResponse, + RedirectUrl, Scope, TokenResponse, TokenUrl, }; use rand::RngCore; use serde::{Deserialize, Serialize}; @@ -300,6 +299,28 @@ fn unix_secs() -> u64 { .as_secs() } +/// Asynchronous HTTP client. +async fn async_http_client(request: HttpRequest) -> Result { + let client = reqwest::Client::builder() + // Following redirects opens the client up to SSRF vulnerabilities. + .redirect(reqwest::redirect::Policy::none()) + .build()?; + + let request = client + .request(request.method, request.url) + .headers(request.headers) + .body(request.body) + .build()?; + + let response = client.execute(request).await?; + + Ok(HttpResponse { + status_code: response.status(), + headers: response.headers().to_owned(), + body: response.bytes().await?.to_vec(), + }) +} + #[cfg(test)] mod test { use super::Session; diff --git a/src/server/mod.rs b/src/server/mod.rs index 8fb79a3..fcb4896 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -16,7 +16,6 @@ use axum::Router; use hyper::body::Incoming; use hyper::Request; use hyper_util::rt::{TokioExecutor, TokioIo}; -use rustls::{Certificate, PrivateKey}; use tokio::net::TcpListener; use tokio_rustls::rustls::ServerConfig; use tokio_rustls::TlsAcceptor; @@ -128,19 +127,13 @@ async fn serve(host: SocketAddr, tls: ServerConfig, app: Router) -> io::Result<( fn load_tls_config(cert: &std::path::Path, key: &std::path::Path) -> io::Result { let certs = rustls_pemfile::certs(&mut io::BufReader::new(File::open(cert)?)) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))? - .into_iter() - .map(Certificate) - .collect(); + .collect::>()?; let key = rustls_pemfile::pkcs8_private_keys(&mut io::BufReader::new(File::open(key)?)) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))? - .into_iter() .next() - .ok_or(io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))?; + .ok_or(io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))??; rustls::ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() - .with_single_cert(certs, PrivateKey(key)) + .with_single_cert(certs, key.into()) .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err)) }