diff --git a/Cargo.lock b/Cargo.lock index 86fe9d3..53a4ac5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -659,7 +659,9 @@ dependencies = [ "nix", "pledge", "rand", + "rcgen", "rerun_except", + "rustls 0.23.12", "serde", "serde_json", "sha2", @@ -753,6 +755,18 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rcgen" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54077e1872c46788540de1ea3d7f4ccb1983d12f9aa909b234468676c1a36779" +dependencies = [ + "ring", + "rustls-pki-types", + "time", + "yasna", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -838,6 +852,20 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls" +version = "0.23.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044" +dependencies = [ + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + [[package]] name = "rustls-pki-types" version = "1.7.0" @@ -846,9 +874,9 @@ checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" [[package]] name = "rustls-webpki" -version = "0.102.4" +version = "0.102.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e" +checksum = "84678086bd54edf2b415183ed7a94d0efb049f1b646a33e22a36f3794be6ae56" dependencies = [ "ring", "rustls-pki-types", @@ -969,9 +997,9 @@ checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" [[package]] name = "syn" -version = "2.0.66" +version = "2.0.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" +checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" dependencies = [ "proc-macro2", "quote", @@ -1126,7 +1154,7 @@ dependencies = [ "flate2", "log", "once_cell", - "rustls", + "rustls 0.22.4", "rustls-pki-types", "rustls-webpki", "url", @@ -1406,6 +1434,15 @@ version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + [[package]] name = "zeroize" version = "1.8.1" diff --git a/Cargo.toml b/Cargo.toml index 0e28f95..bf63fb4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,8 @@ url = "2" urlencoding = "2" wait-timeout = "0.2" whoami = "1.5" +rustls = { version = "0.23.12", features = ["ring", "std"], default-features = false } +rcgen = { version = "0.13.1", features = ["crypto", "ring"], default-features = false } [target.'cfg(target_os="openbsd")'.dependencies] pledge = "0.4" diff --git a/src/config.l b/src/config.l index fcfc7e9..5a9460a 100644 --- a/src/config.l +++ b/src/config.l @@ -20,6 +20,7 @@ client_id "CLIENT_ID" client_secret "CLIENT_SECRET" error_notify_cmd "ERROR_NOTIFY_CMD" http_listen "HTTP_LISTEN" +https_listen "HTTPS_LISTEN" login_hint "LOGIN_HINT" transient_error_if_cmd "TRANSIENT_ERROR_IF_CMD" refresh_retry "REFRESH_RETRY" diff --git a/src/config.rs b/src/config.rs index 51f40e0..1e55311 100644 --- a/src/config.rs +++ b/src/config.rs @@ -26,6 +26,7 @@ const REFRESH_RETRY_DEFAULT: Duration = Duration::from_secs(40); const AUTH_NOTIFY_INTERVAL_DEFAULT: u64 = 15 * 60; /// What is the default bind() address for the HTTP server? const HTTP_LISTEN_DEFAULT: &str = "127.0.0.1:0"; +const HTTPS_LISTEN_DEFAULT: &str = "127.0.0.1:0"; #[derive(Debug)] pub struct Config { @@ -34,6 +35,7 @@ pub struct Config { pub auth_notify_interval: Duration, pub error_notify_cmd: Option, pub http_listen: String, + pub https_listen: String, pub transient_error_if_cmd: Option, refresh_at_least: Option, refresh_before_expiry: Option, @@ -69,6 +71,7 @@ impl Config { let mut auth_notify_interval = None; let mut error_notify_cmd = None; let mut http_listen = None; + let mut https_listen = None; let mut transient_error_if_cmd = None; let mut refresh_at_least = None; let mut refresh_before_expiry = None; @@ -130,6 +133,14 @@ impl Config { http_listen, )?) } + config_ast::TopLevel::HttpsListen(span) => { + https_listen = Some(check_not_assigned_str( + &lexer, + "https_listen", + span, + https_listen, + )?) + } config_ast::TopLevel::TransientErrorIfCmd(span) => { transient_error_if_cmd = Some(check_not_assigned_str( &lexer, @@ -188,6 +199,7 @@ impl Config { .unwrap_or_else(|| Duration::from_secs(AUTH_NOTIFY_INTERVAL_DEFAULT)), error_notify_cmd, http_listen: http_listen.unwrap_or_else(|| HTTP_LISTEN_DEFAULT.to_owned()), + https_listen: https_listen.unwrap_or_else(|| HTTPS_LISTEN_DEFAULT.to_owned()), transient_error_if_cmd, refresh_at_least, refresh_before_expiry, @@ -482,10 +494,15 @@ impl Account { && self.token_uri == act_dump.token_uri } - pub fn redirect_uri(&self, http_port: u16) -> Result> { + pub fn redirect_uri(&self, http_port: u16, https_port: u16) -> Result> { let mut url = Url::parse(&self.redirect_uri)?; - url.set_port(Some(http_port)) - .map_err(|_| "Cannot set port")?; + if self.redirect_uri.starts_with("https") { + url.set_port(Some(https_port)) + .map_err(|_| "Cannot set https port")?; + } else { + url.set_port(Some(http_port)) + .map_err(|_| "Cannot set http port")?; + } Ok(url) } @@ -786,6 +803,33 @@ mod test { invalid_uri("token_uri"); } + #[test] + fn valid_https_config() { + let c = Config::from_str( + r#" + https_listen = "127.0.0.1:56789"; + account "x" { + // Mandatory fields + auth_uri = "http://a.com"; + auth_uri_fields = {"l": "m", "n": "o", "l": "p"}; + client_id = "b"; + scopes = ["c", "d"]; + token_uri = "http://f.com"; + // Optional fields + redirect_uri = "https://e.com"; + } + "#, + ) + .unwrap(); + assert_eq!(c.https_listen, "127.0.0.1:56789".to_owned()); + let act = &c.accounts["x"]; + assert_eq!(act.redirect_uri, "https://e.com"); + let uri = act.redirect_uri(0, 56789).unwrap(); + assert_eq!(uri.scheme(), "https"); + assert_eq!(uri.port(), Some(56789)); + assert_eq!(uri.host_str(), Some("e.com")); + } + #[test] fn mandatory_account_fields() { let fields = &[ diff --git a/src/config.y b/src/config.y index 9f38eea..ab44488 100644 --- a/src/config.y +++ b/src/config.y @@ -17,6 +17,7 @@ TopLevel -> Result: | "AUTH_NOTIFY_INTERVAL" "=" "TIME" ";" { Ok(TopLevel::AuthNotifyInterval(map_err($3)?)) } | "ERROR_NOTIFY_CMD" "=" "STRING" ";" { Ok(TopLevel::ErrorNotifyCmd(map_err($3)?)) } | "HTTP_LISTEN" "=" "STRING" ";" { Ok(TopLevel::HttpListen(map_err($3)?)) } + | "HTTPS_LISTEN" "=" "STRING" ";" { Ok(TopLevel::HttpsListen(map_err($3)?)) } | "TRANSIENT_ERROR_IF_CMD" "=" "STRING" ";" { Ok(TopLevel::TransientErrorIfCmd(map_err($3)?)) } | "REFRESH_AT_LEAST" "=" "TIME" ";" { Ok(TopLevel::RefreshAtLeast(map_err($3)?)) } | "REFRESH_BEFORE_EXPIRY" "=" "TIME" ";" { Ok(TopLevel::RefreshBeforeExpiry(map_err($3)?)) } diff --git a/src/config_ast.rs b/src/config_ast.rs index ddc926f..822b6a9 100644 --- a/src/config_ast.rs +++ b/src/config_ast.rs @@ -7,6 +7,7 @@ pub enum TopLevel { AuthNotifyInterval(Span), ErrorNotifyCmd(Span), HttpListen(Span), + HttpsListen(Span), TransientErrorIfCmd(Span), RefreshAtLeast(Span), RefreshBeforeExpiry(Span), diff --git a/src/server/http_server.rs b/src/server/http_server.rs index c30e137..dc67a2f 100644 --- a/src/server/http_server.rs +++ b/src/server/http_server.rs @@ -1,7 +1,7 @@ use std::{ error::Error, - io::{BufRead, BufReader, Write}, - net::{TcpListener, TcpStream}, + io::{BufRead, BufReader, Read, Write}, + net::TcpListener, sync::Arc, thread, time::{Duration, Instant}, @@ -11,6 +11,12 @@ use log::warn; use serde_json::Value; use url::Url; +use rcgen::{generate_simple_self_signed, CertifiedKey}; +use rustls::{ + pki_types::{PrivateKeyDer, PrivatePkcs8KeyDer}, + ServerConfig, +}; + use super::{ eventer::TokenEvent, expiry_instant, AccountId, AuthenticatorState, Config, TokenState, UREQ_TIMEOUT, @@ -23,14 +29,16 @@ const RETRY_POST: u8 = 10; const RETRY_DELAY: u64 = 6; /// Handle an incoming (hopefully OAuth2) HTTP request. -fn request(pstate: Arc, mut stream: TcpStream) -> Result<(), Box> { +fn request( + pstate: Arc, + mut stream: T, +) -> Result<(), Box> { // This function is split into two halves. In the first half, we process the incoming HTTP // request: if there's a problem, it (mostly) means the request is mal-formed or stale, and // there's no effect on the tokenstate. In the second half we make a request to an OAuth // server: if there's a problem, we have to reset the tokenstate and force the user to make an // entirely fresh request. - - let uri = match parse_get(&mut stream) { + let uri = match parse_get(&mut stream, pstate.http_port, pstate.https_port) { Ok(x) => x, Err(_) => { // If someone couldn't even be bothered giving us a valid URI, it's unlikely this was a @@ -66,7 +74,7 @@ fn request(pstate: Arc, mut stream: TcpStream) -> Result<(), // Now that we know which account has been matched we can check if the full URI requested // matched the redirect URI we expected for that account. let act = ct_lk.account(act_id); - let expected_uri = act.redirect_uri(pstate.http_port)?; + let expected_uri = act.redirect_uri(pstate.http_port, pstate.https_port)?; if expected_uri.scheme() != uri.scheme() || expected_uri.host_str() != uri.host_str() || expected_uri.port() != uri.port() @@ -112,7 +120,9 @@ fn request(pstate: Arc, mut stream: TcpStream) -> Result<(), }; let token_uri = act.token_uri.clone(); let client_id = act.client_id.clone(); - let redirect_uri = act.redirect_uri(pstate.http_port)?.to_string(); + let redirect_uri = act + .redirect_uri(pstate.http_port, pstate.https_port)? + .to_string(); let mut pairs = vec![ ("code", code.as_str()), ("client_id", client_id.as_str()), @@ -262,13 +272,17 @@ fn fail( /// A very literal, and rather unforgiving, implementation of RFC2616 (HTTP/1.1), returning the URL /// of GET requests: returns `Err` for anything else. -fn parse_get(stream: &mut TcpStream) -> Result> { +fn parse_get( + stream: &mut T, + http_port: u16, + https_port: u16, +) -> Result> { let mut rdr = BufReader::new(stream); let mut req_line = String::new(); rdr.read_line(&mut req_line)?; // First the request line: - // Request-Line = Method SP Request-URI SP HTTP-Version CRLF + // Request-Line = Method SP Request-URI SP HTTP-Version CRLF // where Method = "GET" and `SP` is a single space character. let req_line_sp = req_line.split(' ').collect::>(); if !matches!(req_line_sp.as_slice(), &["GET", _, _]) { @@ -321,14 +335,25 @@ fn parse_get(stream: &mut TcpStream) -> Result> { } } + // If host is Some, use addressed port to select scheme (http / https) + // This works, as no HTTPS request will arrive until here on the HTTP port and vice versa match host { - Some(h) => Url::parse(&format!("http://{h:}{path:}")) - .map_err(|e| format!("Invalid request URI: {e:}").into()), + Some(h) => { + if h.contains(&https_port.to_string()) { + Url::parse(&format!("https://{h:}{path:}")) + .map_err(|e| format!("Invalid request URI: {e:}").into()) + } else if h.contains(&http_port.to_string()) { + Url::parse(&format!("http://{h:}{path:}")) + .map_err(|e| format!("Invalid request URI: {e:}").into()) + } else { + Err("Port is not valid for the HTTP request".into()) + } + } None => Err("No host field specified in HTTP request".into()), } } -fn http_200(mut stream: TcpStream, body: &str) { +fn http_200(mut stream: T, body: &str) { stream .write_all( format!("HTTP/1.1 200 OK\r\n\r\n

{body}

").as_bytes(), @@ -336,15 +361,16 @@ fn http_200(mut stream: TcpStream, body: &str) { .ok(); } -fn http_404(mut stream: TcpStream) { +fn http_404(mut stream: T) { stream.write_all(b"HTTP/1.1 404\r\n\r\n").ok(); } -fn http_400(mut stream: TcpStream) { +fn http_400(mut stream: T) { stream.write_all(b"HTTP/1.1 400\r\n\r\n").ok(); } pub fn http_server_setup(conf: &Config) -> Result<(u16, TcpListener), Box> { + // Bind TCP port for HTTP let listener = TcpListener::bind(&conf.http_listen)?; Ok((listener.local_addr()?.port(), listener)) } @@ -365,3 +391,57 @@ pub fn http_server( }); Ok(()) } + +pub fn https_server_setup( + conf: &Config, +) -> Result<(u16, TcpListener, CertifiedKey), Box> { + // Set a process wide default crypto provider. + let _ = rustls::crypto::ring::default_provider().install_default(); + + // Generate self-signed certificate + let cert = generate_simple_self_signed(vec![String::from("localhost"), String::from("127.0.0.1")])?; + + // Bind TCP port for HTTPS + let listener = TcpListener::bind(&conf.https_listen)?; + Ok((listener.local_addr()?.port(), listener, cert)) +} + +pub fn https_server( + pstate: Arc, + listener: TcpListener, + cert: CertifiedKey, +) -> Result<(), Box> { + // Build TLS configuration. + let mut server_config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert( + vec![cert.cert.into()], + PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der())), + ) + .map_err(|e| e.to_string())?; + + // Negotiate application layer protocols: Only HTTP/1.1 and HTTP/1.0 are allowed + server_config.alpn_protocols = vec![b"http/1.1".to_vec(), b"http/1.0".to_vec()]; + + thread::spawn(move || { + for mut stream in listener.incoming().flatten() { + // generate a new TLS connection + let conn = rustls::ServerConnection::new(Arc::new(server_config.clone())); + if let Err(e) = conn { + warn!("{e:}"); + continue; + } + let mut conn = conn.unwrap(); + + let pstate = Arc::clone(&pstate); + thread::spawn(move || { + // convert TCP stream into TLS stream + let stream = rustls::Stream::new(&mut conn, &mut stream); + if let Err(e) = request(pstate, stream) { + warn!("{e:}"); + } + }); + } + }); + Ok(()) +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 742b33f..c0669d4 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -308,6 +308,9 @@ pub fn server(conf_path: PathBuf, conf: Config, cache_path: &Path) -> Result<(), pledge("stdio rpath wpath inet fattr unix dns proc exec", None).unwrap(); let (http_port, http_state) = http_server::http_server_setup(&conf)?; + let (https_port, https_state, certificate) = http_server::https_server_setup(&conf)?; + // TODO: Store certificate into trusted folder (OS dependent..)? + let eventer = Arc::new(Eventer::new()?); let notifier = Arc::new(Notifier::new()?); let refresher = Refresher::new(); @@ -316,12 +319,14 @@ pub fn server(conf_path: PathBuf, conf: Config, cache_path: &Path) -> Result<(), conf_path, conf, http_port, + https_port, Arc::clone(&eventer), Arc::clone(¬ifier), Arc::clone(&refresher), )); http_server::http_server(Arc::clone(&pstate), http_state)?; + http_server::https_server(Arc::clone(&pstate), https_state, certificate)?; eventer.eventer(Arc::clone(&pstate))?; refresher.refresher(Arc::clone(&pstate))?; notifier.notifier(Arc::clone(&pstate))?; diff --git a/src/server/request_token.rs b/src/server/request_token.rs index 445d9dd..0ebfa04 100644 --- a/src/server/request_token.rs +++ b/src/server/request_token.rs @@ -32,7 +32,7 @@ pub fn request_token( let code_challenge = URL_SAFE_NO_PAD.encode(hasher.finalize()); let scopes_join = act.scopes.join(" "); - let redirect_uri = act.redirect_uri(pstate.http_port)?.to_string(); + let redirect_uri = act.redirect_uri(pstate.http_port, pstate.https_port)?.to_string(); let mut params = vec![ ("access_type", "offline"), ("code_challenge", &code_challenge), diff --git a/src/server/state.rs b/src/server/state.rs index e2e62b8..1a87d74 100644 --- a/src/server/state.rs +++ b/src/server/state.rs @@ -47,6 +47,8 @@ pub struct AuthenticatorState { locked_state: Mutex, /// port of the HTTP server required by OAuth. pub http_port: u16, + /// port of the HTTPS server required by OAuth. + pub https_port: u16, pub eventer: Arc, pub notifier: Arc, pub refresher: Arc, @@ -57,6 +59,7 @@ impl AuthenticatorState { conf_path: PathBuf, conf: Config, http_port: u16, + https_port: u16, eventer: Arc, notifier: Arc, refresher: Arc, @@ -65,6 +68,7 @@ impl AuthenticatorState { conf_path, locked_state: Mutex::new(LockedState::new(conf)), http_port, + https_port, eventer, notifier, refresher, @@ -586,7 +590,7 @@ mod test { let eventer = Arc::new(Eventer::new().unwrap()); let notifier = Arc::new(Notifier::new().unwrap()); let pstate = - AuthenticatorState::new(PathBuf::new(), conf, 0, eventer, notifier, Refresher::new()); + AuthenticatorState::new(PathBuf::new(), conf, 0, 0, eventer, notifier, Refresher::new()); let mut old_x_id; { let ct_lk = pstate.ct_lock(); @@ -712,7 +716,7 @@ mod test { let eventer = Arc::new(Eventer::new().unwrap()); let notifier = Arc::new(Notifier::new().unwrap()); let pstate = - AuthenticatorState::new(PathBuf::new(), conf, 0, eventer, notifier, Refresher::new()); + AuthenticatorState::new(PathBuf::new(), conf, 0, 0, eventer, notifier, Refresher::new()); let old_x_id; { let ct_lk = pstate.ct_lock(); @@ -786,7 +790,7 @@ mod test { let eventer = Arc::new(Eventer::new().unwrap()); let notifier = Arc::new(Notifier::new().unwrap()); let pstate = - AuthenticatorState::new(PathBuf::new(), conf, 0, eventer, notifier, Refresher::new()); + AuthenticatorState::new(PathBuf::new(), conf, 0, 0, eventer, notifier, Refresher::new()); let old_x_id; {