diff --git a/rustls-libssl/src/entry.rs b/rustls-libssl/src/entry.rs index 588b36f..2349f8c 100644 --- a/rustls-libssl/src/entry.rs +++ b/rustls-libssl/src/entry.rs @@ -218,10 +218,22 @@ entry! { log::warn!("unimplemented SSL_CTX_set_msg_callback_arg()"); 0 } + Ok(SslCtrl::SetMinProtoVersion) => { + if larg < 0 || larg > u16::MAX.into() { + return 0; + } + inner.set_min_protocol_version(larg as u16); + C_INT_SUCCESS as c_long + } + Ok(SslCtrl::GetMinProtoVersion) => inner.get_min_protocol_version().into(), Ok(SslCtrl::SetMaxProtoVersion) => { - log::warn!("unimplemented SSL_CTX_set_max_proto_version()"); - 1 + if larg < 0 || larg > u16::MAX.into() { + return 0; + } + inner.set_max_protocol_version(larg as u16); + C_INT_SUCCESS as c_long } + Ok(SslCtrl::GetMaxProtoVersion) => inner.get_max_protocol_version().into(), Ok(SslCtrl::SetTlsExtHostname) | Ok(SslCtrl::SetTlsExtServerNameCallback) => { // not a defined operation in the OpenSSL API 0 @@ -849,10 +861,22 @@ entry! { log::warn!("unimplemented SSL_set_msg_callback_arg()"); 0 } + Ok(SslCtrl::SetMinProtoVersion) => { + if larg < 0 || larg > u16::MAX.into() { + return 0; + } + inner.set_min_protocol_version(larg as u16); + C_INT_SUCCESS as c_long + } + Ok(SslCtrl::GetMinProtoVersion) => inner.get_min_protocol_version().into(), Ok(SslCtrl::SetMaxProtoVersion) => { - log::warn!("unimplemented SSL_set_max_proto_version()"); - 1 + if larg < 0 || larg > u16::MAX.into() { + return 0; + } + inner.set_max_protocol_version(larg as u16); + C_INT_SUCCESS as c_long } + Ok(SslCtrl::GetMaxProtoVersion) => inner.get_max_protocol_version().into(), Ok(SslCtrl::SetTlsExtHostname) => { let hostname = try_str!(parg as *const c_char); inner.set_sni_hostname(hostname) as c_long @@ -1783,7 +1807,10 @@ num_enum! { SetTlsExtServerNameArg = 54, SetTlsExtHostname = 55, SetChain = 88, + SetMinProtoVersion = 123, SetMaxProtoVersion = 124, + GetMinProtoVersion = 130, + GetMaxProtoVersion = 131, } } diff --git a/rustls-libssl/src/lib.rs b/rustls-libssl/src/lib.rs index d31b504..c5d1d3e 100644 --- a/rustls-libssl/src/lib.rs +++ b/rustls-libssl/src/lib.rs @@ -15,7 +15,8 @@ use rustls::crypto::aws_lc_rs as provider; use rustls::pki_types::{CertificateDer, ServerName}; use rustls::server::{Accepted, Acceptor}; use rustls::{ - CipherSuite, ClientConfig, ClientConnection, Connection, RootCertStore, ServerConfig, + CipherSuite, ClientConfig, ClientConnection, Connection, ProtocolVersion, RootCertStore, + ServerConfig, }; mod bio; @@ -213,6 +214,7 @@ static TLS13_CHACHA20_POLY1305_SHA256: SslCipher = SslCipher { pub struct SslContext { method: &'static SslMethod, ex_data: ex_data::ExData, + versions: EnabledVersions, raw_options: u64, verify_mode: VerifyMode, verify_depth: c_int, @@ -233,6 +235,7 @@ impl SslContext { Self { method, ex_data: ex_data::ExData::default(), + versions: EnabledVersions::default(), raw_options: 0, verify_mode: VerifyMode::default(), verify_depth: -1, @@ -275,6 +278,36 @@ impl SslContext { self.raw_options } + fn set_min_protocol_version(&mut self, ver: u16) { + self.versions.min = match ver { + 0 => None, + _ => Some(ProtocolVersion::from(ver)), + }; + } + + fn get_min_protocol_version(&self) -> u16 { + self.versions + .min + .as_ref() + .map(|v| u16::from(*v)) + .unwrap_or_default() + } + + fn set_max_protocol_version(&mut self, ver: u16) { + self.versions.max = match ver { + 0 => None, + _ => Some(ProtocolVersion::from(ver)), + }; + } + + fn get_max_protocol_version(&self) -> u16 { + self.versions + .max + .as_ref() + .map(|v| u16::from(*v)) + .unwrap_or_default() + } + fn set_max_early_data(&mut self, max: u32) { self.max_early_data = max; } @@ -437,6 +470,7 @@ fn encode_alpn<'a>(iter: impl Iterator) -> Vec { struct Ssl { ctx: Arc>, ex_data: ex_data::ExData, + versions: EnabledVersions, raw_options: u64, mode: ConnMode, verify_mode: VerifyMode, @@ -472,6 +506,7 @@ impl Ssl { Ok(Self { ctx, ex_data: ex_data::ExData::default(), + versions: inner.versions.clone(), raw_options: inner.raw_options, mode: inner.method.mode(), verify_mode: inner.verify_mode, @@ -532,6 +567,36 @@ impl Ssl { self.raw_options } + fn set_min_protocol_version(&mut self, ver: u16) { + self.versions.min = match ver { + 0 => None, + _ => Some(ProtocolVersion::from(ver)), + }; + } + + fn get_min_protocol_version(&self) -> u16 { + self.versions + .min + .as_ref() + .map(|v| u16::from(*v)) + .unwrap_or_default() + } + + fn set_max_protocol_version(&mut self, ver: u16) { + self.versions.max = match ver { + 0 => None, + _ => Some(ProtocolVersion::from(ver)), + }; + } + + fn get_max_protocol_version(&self) -> u16 { + self.versions + .max + .as_ref() + .map(|v| u16::from(*v)) + .unwrap_or_default() + } + fn set_alpn_offer(&mut self, alpn: Vec>) { self.alpn = alpn; } @@ -693,8 +758,10 @@ impl Ssl { &self.verify_server_name, )); + let versions = self.versions.reduce_versions(method.client_versions)?; + let wants_resolver = ClientConfig::builder_with_provider(provider) - .with_protocol_versions(method.client_versions) + .with_protocol_versions(&versions) .map_err(error::Error::from_rustls)? .dangerous() .with_custom_certificate_verifier(verifier.clone()); @@ -791,8 +858,10 @@ impl Ssl { .server_resolver() .ok_or_else(|| error::Error::bad_data("missing server keys"))?; + let versions = self.versions.reduce_versions(method.server_versions)?; + let mut config = ServerConfig::builder_with_provider(provider) - .with_protocol_versions(method.server_versions) + .with_protocol_versions(&versions) .map_err(error::Error::from_rustls)? .with_client_cert_verifier(verifier.clone()) .with_cert_resolver(resolver); @@ -1245,6 +1314,40 @@ impl From for i32 { } } +#[derive(Debug, Default, Clone)] +struct EnabledVersions { + min: Option, + max: Option, +} + +impl EnabledVersions { + fn reduce_versions( + &self, + method_versions: &'static [&'static rustls::SupportedProtocolVersion], + ) -> Result, error::Error> { + let acceptable: Vec<&'static rustls::SupportedProtocolVersion> = method_versions + .iter() + .cloned() + .filter(|v| self.satisfies(v.version)) + .collect(); + + if acceptable.is_empty() { + Err(error::Error::bad_data(&format!( + "no versions usable: method enabled {method_versions:?}, filter {self:?}" + ))) + } else { + Ok(acceptable) + } + } + + fn satisfies(&self, v: ProtocolVersion) -> bool { + let min = self.min.map(u16::from).unwrap_or(0); + let max = self.max.map(u16::from).unwrap_or(0xffff); + let v = u16::from(v); + min <= v && v <= max + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/rustls-libssl/tests/client.c b/rustls-libssl/tests/client.c index 4ec8713..e6e0450 100644 --- a/rustls-libssl/tests/client.c +++ b/rustls-libssl/tests/client.c @@ -89,6 +89,16 @@ int main(int argc, char **argv) { } printf("SSL_CTX_get_verify_depth default %d\n", SSL_CTX_get_verify_depth(ctx)); + printf("SSL_CTX_get_min_proto_version default 0x%lx\n", + SSL_CTX_get_min_proto_version(ctx)); + printf("SSL_CTX_get_max_proto_version default 0x%lx\n", + SSL_CTX_get_max_proto_version(ctx)); + TRACE(SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION)); + TRACE(SSL_CTX_set_max_proto_version(ctx, TLS1_3_VERSION)); + printf("SSL_CTX_get_min_proto_version 0x%lx\n", + SSL_CTX_get_min_proto_version(ctx)); + printf("SSL_CTX_get_max_proto_version 0x%lx\n", + SSL_CTX_get_max_proto_version(ctx)); X509 *client_cert = NULL; EVP_PKEY *client_key = NULL; @@ -110,6 +120,8 @@ int main(int argc, char **argv) { SSL_get_certificate(ssl) == client_cert ? "same as" : "differs to"); state(ssl); printf("SSL_get_verify_depth default %d\n", SSL_get_verify_depth(ssl)); + printf("SSL_get_min_proto_version 0x%lx\n", SSL_get_min_proto_version(ssl)); + printf("SSL_get_max_proto_version 0x%lx\n", SSL_get_max_proto_version(ssl)); printf("SSL_get_servername: %s (%d)\n", SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name), SSL_get_servername_type(ssl));