Skip to content

Commit

Permalink
Implement protocol versions configuration
Browse files Browse the repository at this point in the history
This provides:

- `SSL_CTX_set_min_proto_version`
- `SSL_CTX_set_max_proto_version`
- `SSL_CTX_get_min_proto_version`
- `SSL_CTX_get_max_proto_version`
- `SSL_set_min_proto_version`
- `SSL_set_max_proto_version`
- `SSL_get_min_proto_version`
- `SSL_get_max_proto_version`

(All of those functions are routed into `SSL_CTX_ctrl` & `SSL_ctrl`.)
  • Loading branch information
ctz committed Apr 12, 2024
1 parent fb44696 commit 6ddc5e0
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 7 deletions.
35 changes: 31 additions & 4 deletions rustls-libssl/src/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,22 @@ entry! {
log::warn!("unimplemented SSL_CTX_set_msg_callback_arg()");
0
}
Ok(SslCtrl::SetMinProtoVersion) => {
if larg < 0 || larg > u16::MAX.into() {
return 0;
}
inner.set_min_protocol_version(larg as u16);
C_INT_SUCCESS as c_long
}
Ok(SslCtrl::GetMinProtoVersion) => inner.get_min_protocol_version().into(),
Ok(SslCtrl::SetMaxProtoVersion) => {
log::warn!("unimplemented SSL_CTX_set_max_proto_version()");
1
if larg < 0 || larg > u16::MAX.into() {
return 0;
}
inner.set_max_protocol_version(larg as u16);
C_INT_SUCCESS as c_long
}
Ok(SslCtrl::GetMaxProtoVersion) => inner.get_max_protocol_version().into(),
Ok(SslCtrl::SetTlsExtHostname) | Ok(SslCtrl::SetTlsExtServerNameCallback) => {
// not a defined operation in the OpenSSL API
0
Expand Down Expand Up @@ -849,10 +861,22 @@ entry! {
log::warn!("unimplemented SSL_set_msg_callback_arg()");
0
}
Ok(SslCtrl::SetMinProtoVersion) => {
if larg < 0 || larg > u16::MAX.into() {
return 0;
}
inner.set_min_protocol_version(larg as u16);
C_INT_SUCCESS as c_long
}
Ok(SslCtrl::GetMinProtoVersion) => inner.get_min_protocol_version().into(),
Ok(SslCtrl::SetMaxProtoVersion) => {
log::warn!("unimplemented SSL_set_max_proto_version()");
1
if larg < 0 || larg > u16::MAX.into() {
return 0;
}
inner.set_max_protocol_version(larg as u16);
C_INT_SUCCESS as c_long
}
Ok(SslCtrl::GetMaxProtoVersion) => inner.get_max_protocol_version().into(),
Ok(SslCtrl::SetTlsExtHostname) => {
let hostname = try_str!(parg as *const c_char);
inner.set_sni_hostname(hostname) as c_long
Expand Down Expand Up @@ -1783,7 +1807,10 @@ num_enum! {
SetTlsExtServerNameArg = 54,
SetTlsExtHostname = 55,
SetChain = 88,
SetMinProtoVersion = 123,
SetMaxProtoVersion = 124,
GetMinProtoVersion = 130,
GetMaxProtoVersion = 131,
}
}

Expand Down
109 changes: 106 additions & 3 deletions rustls-libssl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ use rustls::crypto::aws_lc_rs as provider;
use rustls::pki_types::{CertificateDer, ServerName};
use rustls::server::{Accepted, Acceptor};
use rustls::{
CipherSuite, ClientConfig, ClientConnection, Connection, RootCertStore, ServerConfig,
CipherSuite, ClientConfig, ClientConnection, Connection, ProtocolVersion, RootCertStore,
ServerConfig,
};

mod bio;
Expand Down Expand Up @@ -213,6 +214,7 @@ static TLS13_CHACHA20_POLY1305_SHA256: SslCipher = SslCipher {
pub struct SslContext {
method: &'static SslMethod,
ex_data: ex_data::ExData,
versions: EnabledVersions,
raw_options: u64,
verify_mode: VerifyMode,
verify_depth: c_int,
Expand All @@ -233,6 +235,7 @@ impl SslContext {
Self {
method,
ex_data: ex_data::ExData::default(),
versions: EnabledVersions::default(),
raw_options: 0,
verify_mode: VerifyMode::default(),
verify_depth: -1,
Expand Down Expand Up @@ -275,6 +278,36 @@ impl SslContext {
self.raw_options
}

fn set_min_protocol_version(&mut self, ver: u16) {
self.versions.min = match ver {
0 => None,
_ => Some(ProtocolVersion::from(ver)),
};
}

fn get_min_protocol_version(&self) -> u16 {
self.versions
.min
.as_ref()
.map(|v| u16::from(*v))
.unwrap_or_default()
}

fn set_max_protocol_version(&mut self, ver: u16) {
self.versions.max = match ver {
0 => None,
_ => Some(ProtocolVersion::from(ver)),
};
}

fn get_max_protocol_version(&self) -> u16 {
self.versions
.max
.as_ref()
.map(|v| u16::from(*v))
.unwrap_or_default()
}

fn set_max_early_data(&mut self, max: u32) {
self.max_early_data = max;
}
Expand Down Expand Up @@ -437,6 +470,7 @@ fn encode_alpn<'a>(iter: impl Iterator<Item = &'a [u8]>) -> Vec<u8> {
struct Ssl {
ctx: Arc<Mutex<SslContext>>,
ex_data: ex_data::ExData,
versions: EnabledVersions,
raw_options: u64,
mode: ConnMode,
verify_mode: VerifyMode,
Expand Down Expand Up @@ -472,6 +506,7 @@ impl Ssl {
Ok(Self {
ctx,
ex_data: ex_data::ExData::default(),
versions: inner.versions.clone(),
raw_options: inner.raw_options,
mode: inner.method.mode(),
verify_mode: inner.verify_mode,
Expand Down Expand Up @@ -532,6 +567,36 @@ impl Ssl {
self.raw_options
}

fn set_min_protocol_version(&mut self, ver: u16) {
self.versions.min = match ver {
0 => None,
_ => Some(ProtocolVersion::from(ver)),
};
}

fn get_min_protocol_version(&self) -> u16 {
self.versions
.min
.as_ref()
.map(|v| u16::from(*v))
.unwrap_or_default()
}

fn set_max_protocol_version(&mut self, ver: u16) {
self.versions.max = match ver {
0 => None,
_ => Some(ProtocolVersion::from(ver)),
};
}

fn get_max_protocol_version(&self) -> u16 {
self.versions
.max
.as_ref()
.map(|v| u16::from(*v))
.unwrap_or_default()
}

fn set_alpn_offer(&mut self, alpn: Vec<Vec<u8>>) {
self.alpn = alpn;
}
Expand Down Expand Up @@ -693,8 +758,10 @@ impl Ssl {
&self.verify_server_name,
));

let versions = self.versions.reduce_versions(method.client_versions)?;

let wants_resolver = ClientConfig::builder_with_provider(provider)
.with_protocol_versions(method.client_versions)
.with_protocol_versions(&versions)
.map_err(error::Error::from_rustls)?
.dangerous()
.with_custom_certificate_verifier(verifier.clone());
Expand Down Expand Up @@ -791,8 +858,10 @@ impl Ssl {
.server_resolver()
.ok_or_else(|| error::Error::bad_data("missing server keys"))?;

let versions = self.versions.reduce_versions(method.server_versions)?;

let mut config = ServerConfig::builder_with_provider(provider)
.with_protocol_versions(method.server_versions)
.with_protocol_versions(&versions)
.map_err(error::Error::from_rustls)?
.with_client_cert_verifier(verifier.clone())
.with_cert_resolver(resolver);
Expand Down Expand Up @@ -1245,6 +1314,40 @@ impl From<VerifyMode> for i32 {
}
}

#[derive(Debug, Default, Clone)]
struct EnabledVersions {
min: Option<ProtocolVersion>,
max: Option<ProtocolVersion>,
}

impl EnabledVersions {
fn reduce_versions(
&self,
method_versions: &'static [&'static rustls::SupportedProtocolVersion],
) -> Result<Vec<&'static rustls::SupportedProtocolVersion>, error::Error> {
let acceptable: Vec<&'static rustls::SupportedProtocolVersion> = method_versions
.iter()
.cloned()
.filter(|v| self.satisfies(v.version))
.collect();

if acceptable.is_empty() {
Err(error::Error::bad_data(&format!(
"no versions usable: method enabled {method_versions:?}, filter {self:?}"
)))
} else {
Ok(acceptable)
}
}

fn satisfies(&self, v: ProtocolVersion) -> bool {
let min = self.min.map(u16::from).unwrap_or(0);
let max = self.max.map(u16::from).unwrap_or(0xffff);
let v = u16::from(v);
min <= v && v <= max
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
12 changes: 12 additions & 0 deletions rustls-libssl/tests/client.c
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ int main(int argc, char **argv) {
}
printf("SSL_CTX_get_verify_depth default %d\n",
SSL_CTX_get_verify_depth(ctx));
printf("SSL_CTX_get_min_proto_version default 0x%lx\n",
SSL_CTX_get_min_proto_version(ctx));
printf("SSL_CTX_get_max_proto_version default 0x%lx\n",
SSL_CTX_get_max_proto_version(ctx));
TRACE(SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION));
TRACE(SSL_CTX_set_max_proto_version(ctx, TLS1_3_VERSION));
printf("SSL_CTX_get_min_proto_version 0x%lx\n",
SSL_CTX_get_min_proto_version(ctx));
printf("SSL_CTX_get_max_proto_version 0x%lx\n",
SSL_CTX_get_max_proto_version(ctx));

X509 *client_cert = NULL;
EVP_PKEY *client_key = NULL;
Expand All @@ -110,6 +120,8 @@ int main(int argc, char **argv) {
SSL_get_certificate(ssl) == client_cert ? "same as" : "differs to");
state(ssl);
printf("SSL_get_verify_depth default %d\n", SSL_get_verify_depth(ssl));
printf("SSL_get_min_proto_version 0x%lx\n", SSL_get_min_proto_version(ssl));
printf("SSL_get_max_proto_version 0x%lx\n", SSL_get_max_proto_version(ssl));
printf("SSL_get_servername: %s (%d)\n",
SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name),
SSL_get_servername_type(ssl));
Expand Down

0 comments on commit 6ddc5e0

Please sign in to comment.