Skip to content

Commit

Permalink
support refreshing TLS certs for quinn server
Browse files Browse the repository at this point in the history
  • Loading branch information
neevek committed Apr 6, 2024
1 parent 46d0974 commit 9c0b0af
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/bin/rstund.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async fn run(mut args: RstundArgs) -> Result<()> {
config.max_idle_timeout_ms = args.max_idle_timeout_ms;

let mut server = Server::new(config);
server.bind().await?;
server.bind()?;
server.serve().await?;
Ok(())
}
Expand Down
75 changes: 48 additions & 27 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,51 @@ impl Server {
})
}

pub async fn bind(mut self: &mut Arc<Self>) -> Result<SocketAddr> {
pub fn bind(self: &mut Arc<Self>) -> Result<SocketAddr> {
let config = &self.config;
let addr: SocketAddr = config
.addr
.parse()
.context(format!("invalid address: {}", config.addr))?;

let quinn_server_cfg = Self::load_quinn_server_config(&self.config)?;
let endpoint = quinn::Endpoint::server(quinn_server_cfg, addr).map_err(|e| {
error!(
"failed to bind tunnel server on address: {}, error: {}",
addr, e
);
e
})?;

info!(
"tunnel server is bound on address: {}, idle_timeout: {}",
endpoint.local_addr()?,
config.max_idle_timeout_ms
);

let ep = endpoint.clone();
let config = self.config.clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(3600 * 24)).await;
match Self::load_quinn_server_config(&config) {
Ok(quinn_server_cfg) => {
info!("updated quinn server config!");
ep.set_server_config(Some(quinn_server_cfg));
}
Err(e) => {
error!("failed to load quinn server config:{e}");
}
}
}
});

Arc::get_mut(self).unwrap().endpoint = Some(endpoint);

Ok(addr)
}

fn load_quinn_server_config(config: &ServerConfig) -> Result<quinn::ServerConfig> {
let (certs, key) =
Server::read_certs_and_key(config.cert_path.as_str(), config.key_path.as_str())
.context("failed to read certificate or key")?;
Expand All @@ -44,7 +87,7 @@ impl Server {
.with_single_cert(certs, key)?;

let mut transport_cfg = TransportConfig::default();
transport_cfg.stream_receive_window(quinn::VarInt::from_u32(1024 * 1024 * 1));
transport_cfg.stream_receive_window(quinn::VarInt::from_u32(1024 * 1024));
transport_cfg.receive_window(quinn::VarInt::from_u32(1024 * 1024 * 8));
transport_cfg.send_window(1024 * 1024 * 8);
transport_cfg.congestion_controller_factory(Arc::new(congestion::BbrConfig::default()));
Expand All @@ -56,31 +99,9 @@ impl Server {
}
transport_cfg.max_concurrent_bidi_streams(VarInt::from_u32(1024));

let mut cfg = quinn::ServerConfig::with_crypto(Arc::new(crypto));
cfg.transport = Arc::new(transport_cfg);

let addr: SocketAddr = config
.addr
.parse()
.context(format!("invalid address: {}", config.addr))?;

let endpoint = quinn::Endpoint::server(cfg, addr).map_err(|e| {
error!(
"failed to bind tunnel server on address: {}, error: {}",
addr, e
);
e
})?;

info!(
"tunnel server is bound on address: {}, idle_timeout: {}",
endpoint.local_addr()?,
config.max_idle_timeout_ms
);

Arc::get_mut(&mut self).unwrap().endpoint = Some(endpoint);

Ok(addr)
let mut server_cfg = quinn::ServerConfig::with_crypto(Arc::new(crypto));
server_cfg.transport = Arc::new(transport_cfg);
Ok(server_cfg)
}

pub async fn serve(self: &Arc<Self>) -> Result<()> {
Expand Down
6 changes: 6 additions & 0 deletions src/tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,9 @@ impl Tunnel {
}
}
}

impl Default for Tunnel {
fn default() -> Self {
Self::new()
}
}
2 changes: 1 addition & 1 deletion src/tunnel_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl TunnelMessage {
pub fn handle_message(msg: &TunnelMessage) -> Result<()> {
match msg {
TunnelMessage::RespSuccess => Ok(()),
TunnelMessage::RespFailure(msg) => bail!(format!("received failure, err: {}", msg)),
TunnelMessage::RespFailure(msg) => bail!(format!("received failure, err: {msg}")),
_ => bail!("unexpected message type"),
}
}
Expand Down

0 comments on commit 9c0b0af

Please sign in to comment.