Skip to content

Commit

Permalink
Enable session resumption
Browse files Browse the repository at this point in the history
This implements `SSL_CTX_sess_set_cache_size`, `SSL_CTX_sess_get_cache_size`,
etc.

TODO: undo the context commit
  • Loading branch information
ctz committed Apr 11, 2024
1 parent c67bda9 commit f2e951b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 13 deletions.
15 changes: 14 additions & 1 deletion rustls-libssl/src/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,14 @@ entry! {
inner.set_servername_callback_context(parg);
C_INT_SUCCESS as c_long
}
Ok(SslCtrl::SetSessCacheSize) => {
if larg < 0 {
return 0;
}
inner.set_session_cache_size(larg as usize);
C_INT_SUCCESS as c_long
}
Ok(SslCtrl::GetSessCacheSize) => inner.get_session_cache_size() as c_long,
Err(()) => {
log::warn!("unimplemented _SSL_CTX_ctrl(..., {cmd}, {larg}, ...)");
0
Expand Down Expand Up @@ -894,7 +902,10 @@ entry! {
C_INT_SUCCESS as i64
}
// not a defined operation in the OpenSSL API
Ok(SslCtrl::SetTlsExtServerNameCallback) | Ok(SslCtrl::SetTlsExtServerNameArg) => 0,
Ok(SslCtrl::SetTlsExtServerNameCallback)
| Ok(SslCtrl::SetTlsExtServerNameArg)
| Ok(SslCtrl::SetSessCacheSize)
| Ok(SslCtrl::GetSessCacheSize) => 0,
Err(()) => {
log::warn!("unimplemented _SSL_ctrl(..., {cmd}, {larg}, ...)");
0
Expand Down Expand Up @@ -1811,6 +1822,8 @@ num_enum! {
enum SslCtrl {
Mode = 33,
SetMsgCallbackArg = 16,
SetSessCacheSize = 42,
GetSessCacheSize = 43,
SetTlsExtServerNameCallback = 53,
SetTlsExtServerNameArg = 54,
SetTlsExtHostname = 55,
Expand Down
45 changes: 33 additions & 12 deletions rustls-libssl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use openssl_sys::{
EVP_PKEY, SSL_ERROR_NONE, SSL_ERROR_SSL, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, X509,
X509_STORE, X509_V_ERR_UNSPECIFIED,
};
use rustls::client::Resumption;
use rustls::crypto::aws_lc_rs as provider;
use rustls::pki_types::{CertificateDer, ServerName};
use rustls::server::{Accepted, Acceptor};
Expand All @@ -19,6 +20,7 @@ use rustls::{
};

mod bio;
mod cache;
mod callbacks;
#[macro_use]
mod constants;
Expand Down Expand Up @@ -213,6 +215,7 @@ static TLS13_CHACHA20_POLY1305_SHA256: SslCipher = SslCipher {
pub struct SslContext {
method: &'static SslMethod,
ex_data: ex_data::ExData,
caches: cache::SessionCaches,
raw_options: u64,
verify_mode: VerifyMode,
verify_depth: c_int,
Expand All @@ -233,6 +236,7 @@ impl SslContext {
Self {
method,
ex_data: ex_data::ExData::default(),
caches: cache::SessionCaches::default(),
raw_options: 0,
verify_mode: VerifyMode::default(),
verify_depth: -1,
Expand Down Expand Up @@ -261,6 +265,11 @@ impl SslContext {
self.ex_data.get(idx)
}

fn option_is_set(&self, opt: SslOption) -> bool {
let mask = opt as u64;
self.raw_options & mask == mask
}

fn get_options(&self) -> u64 {
self.raw_options
}
Expand All @@ -275,6 +284,15 @@ impl SslContext {
self.raw_options
}

fn get_session_cache_size(&self) -> usize {
self.caches.size()
}

fn set_session_cache_size(&mut self, size: usize) {
// divergence: OpenSSL can adjust the cache size without emptying it.
self.caches = cache::SessionCaches::with_size(size);
}

fn set_max_early_data(&mut self, max: u32) {
self.max_early_data = max;
}
Expand Down Expand Up @@ -679,11 +697,7 @@ impl Ssl {
None => ServerName::try_from("0.0.0.0").unwrap(),
};

let method = self
.ctx
.lock()
.map(|ctx| ctx.method)
.map_err(|_| error::Error::cannot_lock())?;
let ctx = self.ctx.lock().map_err(|_| error::Error::cannot_lock())?;

let provider = Arc::new(provider::default_provider());
let verifier = Arc::new(verifier::ServerVerifier::new(
Expand All @@ -694,7 +708,7 @@ impl Ssl {
));

let wants_resolver = ClientConfig::builder_with_provider(provider)
.with_protocol_versions(method.client_versions)
.with_protocol_versions(ctx.method.client_versions)
.map_err(error::Error::from_rustls)?
.dangerous()
.with_custom_certificate_verifier(verifier.clone());
Expand All @@ -706,6 +720,7 @@ impl Ssl {
};

config.alpn_protocols.clone_from(&self.alpn);
config.resumption = Resumption::store(ctx.caches.get_client());

let client_conn = ClientConnection::new(Arc::new(config), sni_server_name.clone())
.map_err(error::Error::from_rustls)?;
Expand Down Expand Up @@ -770,11 +785,7 @@ impl Ssl {
}

fn init_server_conn(&mut self) -> Result<(), error::Error> {
let method = self
.ctx
.lock()
.map(|ctx| ctx.method)
.map_err(|_| error::Error::cannot_lock())?;
let ctx = self.ctx.lock().map_err(|_| error::Error::cannot_lock())?;

let provider = Arc::new(provider::default_provider());
let verifier = Arc::new(
Expand All @@ -792,13 +803,17 @@ impl Ssl {
.ok_or_else(|| error::Error::bad_data("missing server keys"))?;

let mut config = ServerConfig::builder_with_provider(provider)
.with_protocol_versions(method.server_versions)
.with_protocol_versions(ctx.method.server_versions)
.map_err(error::Error::from_rustls)?
.with_client_cert_verifier(verifier.clone())
.with_cert_resolver(resolver);

config.alpn_protocols = mem::take(&mut self.alpn);
config.max_early_data_size = self.max_early_data;
config.session_storage = ctx.caches.get_server();
if !ctx.option_is_set(SslOption::NoTicket) {
config.ticketer = ctx.caches.get_ticketer();
}

let accepted = match mem::replace(&mut self.conn, ConnState::Nothing) {
ConnState::Accepted(accepted) => accepted,
Expand Down Expand Up @@ -1245,6 +1260,12 @@ impl From<VerifyMode> for i32 {
}
}

/// Subset of SSL_OP values that we interpret.
#[repr(u64)]
enum SslOption {
NoTicket = 1 << 14,
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit f2e951b

Please sign in to comment.