From a03b4b975e1cdc65df476e8fb74a62a1a38cd83a Mon Sep 17 00:00:00 2001 From: Joseph Birr-Pixton Date: Wed, 10 Apr 2024 09:59:56 +0100 Subject: [PATCH] Implement protocol versions configuration This provides: - `SSL_CTX_set_min_proto_version` - `SSL_CTX_set_max_proto_version` - `SSL_CTX_get_min_proto_version` - `SSL_CTX_get_max_proto_version` - `SSL_set_min_proto_version` - `SSL_set_max_proto_version` - `SSL_get_min_proto_version` - `SSL_get_max_proto_version` (All of those functions are routed into `SSL_CTX_ctrl` & `SSL_ctrl`.) --- rustls-libssl/src/entry.rs | 35 +++++++++-- rustls-libssl/src/lib.rs | 109 ++++++++++++++++++++++++++++++++++- rustls-libssl/tests/client.c | 12 ++++ 3 files changed, 149 insertions(+), 7 deletions(-) diff --git a/rustls-libssl/src/entry.rs b/rustls-libssl/src/entry.rs index 40b2454..5992f28 100644 --- a/rustls-libssl/src/entry.rs +++ b/rustls-libssl/src/entry.rs @@ -218,10 +218,22 @@ entry! { log::warn!("unimplemented SSL_CTX_set_msg_callback_arg()"); 0 } + Ok(SslCtrl::SetMinProtoVersion) => { + if larg < 0 || larg > u16::MAX.into() { + return 0; + } + inner.set_min_protocol_version(larg as u16); + C_INT_SUCCESS as c_long + } + Ok(SslCtrl::GetMinProtoVersion) => inner.get_min_protocol_version().into(), Ok(SslCtrl::SetMaxProtoVersion) => { - log::warn!("unimplemented SSL_CTX_set_max_proto_version()"); - 1 + if larg < 0 || larg > u16::MAX.into() { + return 0; + } + inner.set_max_protocol_version(larg as u16); + C_INT_SUCCESS as c_long } + Ok(SslCtrl::GetMaxProtoVersion) => inner.get_max_protocol_version().into(), Ok(SslCtrl::SetTlsExtHostname) | Ok(SslCtrl::SetTlsExtServerNameCallback) => { // not a defined operation in the OpenSSL API 0 @@ -846,10 +858,22 @@ entry! { log::warn!("unimplemented SSL_set_msg_callback_arg()"); 0 } + Ok(SslCtrl::SetMinProtoVersion) => { + if larg < 0 || larg > u16::MAX.into() { + return 0; + } + inner.set_min_protocol_version(larg as u16); + C_INT_SUCCESS as c_long + } + Ok(SslCtrl::GetMinProtoVersion) => inner.get_min_protocol_version().into(), Ok(SslCtrl::SetMaxProtoVersion) => { - log::warn!("unimplemented SSL_set_max_proto_version()"); - 1 + if larg < 0 || larg > u16::MAX.into() { + return 0; + } + inner.set_max_protocol_version(larg as u16); + C_INT_SUCCESS as c_long } + Ok(SslCtrl::GetMaxProtoVersion) => inner.get_max_protocol_version().into(), Ok(SslCtrl::SetTlsExtHostname) => { let hostname = try_str!(parg as *const c_char); inner.set_sni_hostname(hostname) as c_long @@ -1795,7 +1819,10 @@ num_enum! { SetTlsExtServerNameArg = 54, SetTlsExtHostname = 55, SetChain = 88, + SetMinProtoVersion = 123, SetMaxProtoVersion = 124, + GetMinProtoVersion = 130, + GetMaxProtoVersion = 131, } } diff --git a/rustls-libssl/src/lib.rs b/rustls-libssl/src/lib.rs index 96390d1..03d2c2d 100644 --- a/rustls-libssl/src/lib.rs +++ b/rustls-libssl/src/lib.rs @@ -15,7 +15,8 @@ use rustls::crypto::aws_lc_rs as provider; use rustls::pki_types::{CertificateDer, ServerName}; use rustls::server::{Accepted, Acceptor}; use rustls::{ - CipherSuite, ClientConfig, ClientConnection, Connection, RootCertStore, ServerConfig, + CipherSuite, ClientConfig, ClientConnection, Connection, ProtocolVersion, RootCertStore, + ServerConfig, }; mod bio; @@ -213,6 +214,7 @@ static TLS13_CHACHA20_POLY1305_SHA256: SslCipher = SslCipher { pub struct SslContext { method: &'static SslMethod, ex_data: ex_data::ExData, + versions: EnabledVersions, raw_options: u64, verify_mode: VerifyMode, verify_depth: c_int, @@ -233,6 +235,7 @@ impl SslContext { Self { method, ex_data: ex_data::ExData::default(), + versions: EnabledVersions::default(), raw_options: 0, verify_mode: VerifyMode::default(), verify_depth: -1, @@ -275,6 +278,36 @@ impl SslContext { self.raw_options } + fn set_min_protocol_version(&mut self, ver: u16) { + self.versions.min = match ver { + 0 => None, + _ => Some(ProtocolVersion::from(ver)), + }; + } + + fn get_min_protocol_version(&self) -> u16 { + self.versions + .min + .as_ref() + .map(|v| u16::from(*v)) + .unwrap_or_default() + } + + fn set_max_protocol_version(&mut self, ver: u16) { + self.versions.max = match ver { + 0 => None, + _ => Some(ProtocolVersion::from(ver)), + }; + } + + fn get_max_protocol_version(&self) -> u16 { + self.versions + .max + .as_ref() + .map(|v| u16::from(*v)) + .unwrap_or_default() + } + fn set_max_early_data(&mut self, max: u32) { self.max_early_data = max; } @@ -437,6 +470,7 @@ fn encode_alpn<'a>(iter: impl Iterator) -> Vec { struct Ssl { ctx: Arc>, ex_data: ex_data::ExData, + versions: EnabledVersions, raw_options: u64, mode: ConnMode, verify_mode: VerifyMode, @@ -472,6 +506,7 @@ impl Ssl { Ok(Self { ctx, ex_data: ex_data::ExData::default(), + versions: inner.versions.clone(), raw_options: inner.raw_options, mode: inner.method.mode(), verify_mode: inner.verify_mode, @@ -532,6 +567,36 @@ impl Ssl { self.raw_options } + fn set_min_protocol_version(&mut self, ver: u16) { + self.versions.min = match ver { + 0 => None, + _ => Some(ProtocolVersion::from(ver)), + }; + } + + fn get_min_protocol_version(&self) -> u16 { + self.versions + .min + .as_ref() + .map(|v| u16::from(*v)) + .unwrap_or_default() + } + + fn set_max_protocol_version(&mut self, ver: u16) { + self.versions.max = match ver { + 0 => None, + _ => Some(ProtocolVersion::from(ver)), + }; + } + + fn get_max_protocol_version(&self) -> u16 { + self.versions + .max + .as_ref() + .map(|v| u16::from(*v)) + .unwrap_or_default() + } + fn set_alpn_offer(&mut self, alpn: Vec>) { self.alpn = alpn; } @@ -697,8 +762,10 @@ impl Ssl { &self.verify_server_name, )); + let versions = self.versions.reduce_versions(method.client_versions)?; + let wants_resolver = ClientConfig::builder_with_provider(provider) - .with_protocol_versions(method.client_versions) + .with_protocol_versions(&versions) .map_err(error::Error::from_rustls)? .dangerous() .with_custom_certificate_verifier(verifier.clone()); @@ -797,8 +864,10 @@ impl Ssl { .server_resolver() .ok_or_else(|| error::Error::bad_data("missing server keys"))?; + let versions = self.versions.reduce_versions(method.server_versions)?; + let mut config = ServerConfig::builder_with_provider(provider) - .with_protocol_versions(method.server_versions) + .with_protocol_versions(&versions) .map_err(error::Error::from_rustls)? .with_client_cert_verifier(verifier.clone()) .with_cert_resolver(resolver); @@ -1277,6 +1346,40 @@ impl From for i32 { } } +#[derive(Debug, Default, Clone)] +struct EnabledVersions { + min: Option, + max: Option, +} + +impl EnabledVersions { + fn reduce_versions( + &self, + method_versions: &'static [&'static rustls::SupportedProtocolVersion], + ) -> Result, error::Error> { + let acceptable: Vec<&'static rustls::SupportedProtocolVersion> = method_versions + .iter() + .cloned() + .filter(|v| self.satisfies(v.version)) + .collect(); + + if acceptable.is_empty() { + Err(error::Error::bad_data(&format!( + "no versions usable: method enabled {method_versions:?}, filter {self:?}" + ))) + } else { + Ok(acceptable) + } + } + + fn satisfies(&self, v: ProtocolVersion) -> bool { + let min = self.min.map(u16::from).unwrap_or(0); + let max = self.max.map(u16::from).unwrap_or(0xffff); + let v = u16::from(v); + min <= v && v <= max + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/rustls-libssl/tests/client.c b/rustls-libssl/tests/client.c index e353358..ab8416c 100644 --- a/rustls-libssl/tests/client.c +++ b/rustls-libssl/tests/client.c @@ -61,6 +61,16 @@ int main(int argc, char **argv) { } printf("SSL_CTX_get_verify_depth default %d\n", SSL_CTX_get_verify_depth(ctx)); + printf("SSL_CTX_get_min_proto_version default 0x%lx\n", + SSL_CTX_get_min_proto_version(ctx)); + printf("SSL_CTX_get_max_proto_version default 0x%lx\n", + SSL_CTX_get_max_proto_version(ctx)); + TRACE(SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION)); + TRACE(SSL_CTX_set_max_proto_version(ctx, TLS1_3_VERSION)); + printf("SSL_CTX_get_min_proto_version 0x%lx\n", + SSL_CTX_get_min_proto_version(ctx)); + printf("SSL_CTX_get_max_proto_version 0x%lx\n", + SSL_CTX_get_max_proto_version(ctx)); X509 *client_cert = NULL; EVP_PKEY *client_key = NULL; @@ -82,6 +92,8 @@ int main(int argc, char **argv) { SSL_get_certificate(ssl) == client_cert ? "same as" : "differs to"); state(ssl); printf("SSL_get_verify_depth default %d\n", SSL_get_verify_depth(ssl)); + printf("SSL_get_min_proto_version 0x%lx\n", SSL_get_min_proto_version(ssl)); + printf("SSL_get_max_proto_version 0x%lx\n", SSL_get_max_proto_version(ssl)); printf("SSL_get_servername: %s (%d)\n", SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name), SSL_get_servername_type(ssl));