From 9c0b0af4f77ccefcd742f303ebb15516088f9e8b Mon Sep 17 00:00:00 2001 From: neevek Date: Sat, 6 Apr 2024 12:13:48 +0800 Subject: [PATCH] support refreshing TLS certs for quinn server --- src/bin/rstund.rs | 2 +- src/server.rs | 75 +++++++++++++++++++++++++++---------------- src/tunnel.rs | 6 ++++ src/tunnel_message.rs | 2 +- 4 files changed, 56 insertions(+), 29 deletions(-) diff --git a/src/bin/rstund.rs b/src/bin/rstund.rs index d2df8d7..73db53b 100644 --- a/src/bin/rstund.rs +++ b/src/bin/rstund.rs @@ -72,7 +72,7 @@ async fn run(mut args: RstundArgs) -> Result<()> { config.max_idle_timeout_ms = args.max_idle_timeout_ms; let mut server = Server::new(config); - server.bind().await?; + server.bind()?; server.serve().await?; Ok(()) } diff --git a/src/server.rs b/src/server.rs index 1275b30..1e35735 100644 --- a/src/server.rs +++ b/src/server.rs @@ -30,8 +30,51 @@ impl Server { }) } - pub async fn bind(mut self: &mut Arc) -> Result { + pub fn bind(self: &mut Arc) -> Result { let config = &self.config; + let addr: SocketAddr = config + .addr + .parse() + .context(format!("invalid address: {}", config.addr))?; + + let quinn_server_cfg = Self::load_quinn_server_config(&self.config)?; + let endpoint = quinn::Endpoint::server(quinn_server_cfg, addr).map_err(|e| { + error!( + "failed to bind tunnel server on address: {}, error: {}", + addr, e + ); + e + })?; + + info!( + "tunnel server is bound on address: {}, idle_timeout: {}", + endpoint.local_addr()?, + config.max_idle_timeout_ms + ); + + let ep = endpoint.clone(); + let config = self.config.clone(); + tokio::spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(3600 * 24)).await; + match Self::load_quinn_server_config(&config) { + Ok(quinn_server_cfg) => { + info!("updated quinn server config!"); + ep.set_server_config(Some(quinn_server_cfg)); + } + Err(e) => { + error!("failed to load quinn server config:{e}"); + } + } + } + }); + + Arc::get_mut(self).unwrap().endpoint = Some(endpoint); + + Ok(addr) + } + + fn load_quinn_server_config(config: &ServerConfig) -> Result { let (certs, key) = Server::read_certs_and_key(config.cert_path.as_str(), config.key_path.as_str()) .context("failed to read certificate or key")?; @@ -44,7 +87,7 @@ impl Server { .with_single_cert(certs, key)?; let mut transport_cfg = TransportConfig::default(); - transport_cfg.stream_receive_window(quinn::VarInt::from_u32(1024 * 1024 * 1)); + transport_cfg.stream_receive_window(quinn::VarInt::from_u32(1024 * 1024)); transport_cfg.receive_window(quinn::VarInt::from_u32(1024 * 1024 * 8)); transport_cfg.send_window(1024 * 1024 * 8); transport_cfg.congestion_controller_factory(Arc::new(congestion::BbrConfig::default())); @@ -56,31 +99,9 @@ impl Server { } transport_cfg.max_concurrent_bidi_streams(VarInt::from_u32(1024)); - let mut cfg = quinn::ServerConfig::with_crypto(Arc::new(crypto)); - cfg.transport = Arc::new(transport_cfg); - - let addr: SocketAddr = config - .addr - .parse() - .context(format!("invalid address: {}", config.addr))?; - - let endpoint = quinn::Endpoint::server(cfg, addr).map_err(|e| { - error!( - "failed to bind tunnel server on address: {}, error: {}", - addr, e - ); - e - })?; - - info!( - "tunnel server is bound on address: {}, idle_timeout: {}", - endpoint.local_addr()?, - config.max_idle_timeout_ms - ); - - Arc::get_mut(&mut self).unwrap().endpoint = Some(endpoint); - - Ok(addr) + let mut server_cfg = quinn::ServerConfig::with_crypto(Arc::new(crypto)); + server_cfg.transport = Arc::new(transport_cfg); + Ok(server_cfg) } pub async fn serve(self: &Arc) -> Result<()> { diff --git a/src/tunnel.rs b/src/tunnel.rs index c9d5c78..267a8d3 100644 --- a/src/tunnel.rs +++ b/src/tunnel.rs @@ -75,3 +75,9 @@ impl Tunnel { } } } + +impl Default for Tunnel { + fn default() -> Self { + Self::new() + } +} diff --git a/src/tunnel_message.rs b/src/tunnel_message.rs index 62946af..5384365 100644 --- a/src/tunnel_message.rs +++ b/src/tunnel_message.rs @@ -55,7 +55,7 @@ impl TunnelMessage { pub fn handle_message(msg: &TunnelMessage) -> Result<()> { match msg { TunnelMessage::RespSuccess => Ok(()), - TunnelMessage::RespFailure(msg) => bail!(format!("received failure, err: {}", msg)), + TunnelMessage::RespFailure(msg) => bail!(format!("received failure, err: {msg}")), _ => bail!("unexpected message type"), } }