diff --git a/rustls-libssl/MATRIX.md b/rustls-libssl/MATRIX.md index ce1816f..f1d8603 100644 --- a/rustls-libssl/MATRIX.md +++ b/rustls-libssl/MATRIX.md @@ -67,7 +67,7 @@ | `SSL_CTX_add_custom_ext` | | | | | `SSL_CTX_add_server_custom_ext` | | | | | `SSL_CTX_add_session` | | | | -| `SSL_CTX_callback_ctrl` | | :white_check_mark: | | +| `SSL_CTX_callback_ctrl` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_check_private_key` | :white_check_mark: | | :white_check_mark: | | `SSL_CTX_clear_options` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_config` | | | | diff --git a/rustls-libssl/build.rs b/rustls-libssl/build.rs index 03b19b8..2ed851f 100644 --- a/rustls-libssl/build.rs +++ b/rustls-libssl/build.rs @@ -61,6 +61,7 @@ const ENTRYPOINTS: &[&str] = &[ "SSL_connect", "SSL_ctrl", "SSL_CTX_add_client_CA", + "SSL_CTX_callback_ctrl", "SSL_CTX_check_private_key", "SSL_CTX_clear_options", "SSL_CTX_ctrl", diff --git a/rustls-libssl/src/callbacks.rs b/rustls-libssl/src/callbacks.rs index c7e15db..9be92d3 100644 --- a/rustls-libssl/src/callbacks.rs +++ b/rustls-libssl/src/callbacks.rs @@ -1,12 +1,13 @@ -use core::ffi::{c_uchar, c_void}; +use core::ffi::{c_int, c_uchar, c_void}; use core::ptr; use std::collections::VecDeque; use openssl_sys::SSL_TLSEXT_ERR_OK; +use rustls::AlertDescription; use crate::entry::{ - SSL_CTX_alpn_select_cb_func, SSL_CTX_cert_cb_func, _internal_SSL_complete_accept, - _internal_SSL_set_alpn_choice, SSL, + SSL_CTX_alpn_select_cb_func, SSL_CTX_cert_cb_func, SSL_CTX_servername_callback_func, + _internal_SSL_complete_accept, _internal_SSL_set_alpn_choice, SSL, }; use crate::error::Error; @@ -174,3 +175,50 @@ impl PendingCallback for CompleteAcceptPendingCallback { _internal_SSL_complete_accept(self.ssl) } } + +/// Configuration needed to create a `ServerNamePendingCallback` later +#[derive(Debug, Clone)] +pub struct ServerNameCallbackConfig { + pub cb: SSL_CTX_servername_callback_func, + pub context: *mut c_void, +} + +impl Default for ServerNameCallbackConfig { + fn default() -> Self { + Self { + cb: None, + context: ptr::null_mut(), + } + } +} + +pub struct ServerNamePendingCallback { + pub config: ServerNameCallbackConfig, + pub ssl: *mut SSL, +} + +impl PendingCallback for ServerNamePendingCallback { + fn call(self: Box) -> Result<(), Error> { + let callback = match self.config.cb { + Some(callback) => callback, + None => { + return Ok(()); + } + }; + + let unrecognised_name = u8::from(AlertDescription::UnrecognisedName) as c_int; + let mut alert = unrecognised_name; + let result = unsafe { callback(self.ssl, &mut alert as *mut c_int, self.config.context) }; + + if alert != unrecognised_name { + log::trace!("NYI: customised alert during servername callback"); + } + + match result { + 1 => Ok(()), + _ => Err(Error::not_supported( + "SSL_CTX_servername_callback_func returned != 1", + )), + } + } +} diff --git a/rustls-libssl/src/entry.rs b/rustls-libssl/src/entry.rs index 38a549a..f1862f3 100644 --- a/rustls-libssl/src/entry.rs +++ b/rustls-libssl/src/entry.rs @@ -161,34 +161,40 @@ entry! { } entry! { - pub fn _SSL_CTX_ctrl( - _ctx: *mut SSL_CTX, - cmd: c_int, - larg: c_long, - _parg: *mut c_void, - ) -> c_long { - match SslCtrl::try_from(cmd) { - Ok(SslCtrl::Mode) => { - log::warn!("unimplemented SSL_CTX_set_mode()"); - 0 - } - Ok(SslCtrl::SetMsgCallbackArg) => { - log::warn!("unimplemented SSL_CTX_set_msg_callback_arg()"); - 0 - } - Ok(SslCtrl::SetMaxProtoVersion) => { - log::warn!("unimplemented SSL_CTX_set_max_proto_version()"); - 1 - } - Ok(SslCtrl::SetTlsExtHostname) => { - // not a defined operation in the OpenSSL API - 0 - } - Err(()) => { - log::warn!("unimplemented _SSL_CTX_ctrl(..., {cmd}, {larg}, ...)"); - 0 + pub fn _SSL_CTX_ctrl(ctx: *mut SSL_CTX, cmd: c_int, larg: c_long, parg: *mut c_void) -> c_long { + let ctx = try_clone_arc!(ctx); + + let result = if let Ok(mut inner) = ctx.lock() { + match SslCtrl::try_from(cmd) { + Ok(SslCtrl::Mode) => { + log::warn!("unimplemented SSL_CTX_set_mode()"); + 0 + } + Ok(SslCtrl::SetMsgCallbackArg) => { + log::warn!("unimplemented SSL_CTX_set_msg_callback_arg()"); + 0 + } + Ok(SslCtrl::SetMaxProtoVersion) => { + log::warn!("unimplemented SSL_CTX_set_max_proto_version()"); + 1 + } + Ok(SslCtrl::SetTlsExtHostname) | Ok(SslCtrl::SetTlsExtServerNameCallback) => { + // not a defined operation in the OpenSSL API + 0 + } + Ok(SslCtrl::SetTlsExtServerNameArg) => { + inner.set_servername_callback_context(parg); + C_INT_SUCCESS as c_long + } + Err(()) => { + log::warn!("unimplemented _SSL_CTX_ctrl(..., {cmd}, {larg}, ...)"); + 0 + } } - } + } else { + 0 + }; + result } } @@ -519,6 +525,36 @@ entry! { } } +// nb. calls into SSL_CTX_callback_ctrl cast away the real function pointer type, +// and then cast back to the real type based on `cmd`. +pub type SSL_CTX_any_func = Option; + +pub type SSL_CTX_servername_callback_func = + Option c_int>; + +entry! { + pub fn _SSL_CTX_callback_ctrl(ctx: *mut SSL_CTX, cmd: c_int, fp: SSL_CTX_any_func) -> c_long { + let ctx = try_clone_arc!(ctx); + + let result = if let Ok(mut inner) = ctx.lock() { + match SslCtrl::try_from(cmd) { + Ok(SslCtrl::SetTlsExtServerNameCallback) => { + // safety: same layout + let fp = unsafe { + mem::transmute::(fp) + }; + inner.set_servername_callback(fp); + C_INT_SUCCESS as c_long + } + _ => 0, + } + } else { + 0 + }; + result + } +} + impl Castable for SSL_CTX { type Ownership = OwnershipArc; type RustType = Mutex; @@ -575,6 +611,8 @@ entry! { .map(|mut ssl| ssl.set_sni_hostname(hostname)) .unwrap_or_default() as c_long } + // not a defined operation in the OpenSSL API + Ok(SslCtrl::SetTlsExtServerNameCallback) | Ok(SslCtrl::SetTlsExtServerNameArg) => 0, Err(()) => { log::warn!("unimplemented _SSL_ctrl(..., {cmd}, {larg}, ...)"); 0 @@ -1354,6 +1392,8 @@ num_enum! { enum SslCtrl { Mode = 33, SetMsgCallbackArg = 16, + SetTlsExtServerNameCallback = 53, + SetTlsExtServerNameArg = 54, SetTlsExtHostname = 55, SetMaxProtoVersion = 124, } diff --git a/rustls-libssl/src/lib.rs b/rustls-libssl/src/lib.rs index c6f6bbc..1ad9585 100644 --- a/rustls-libssl/src/lib.rs +++ b/rustls-libssl/src/lib.rs @@ -215,6 +215,7 @@ pub struct SslContext { alpn: Vec>, alpn_callback: callbacks::AlpnCallbackConfig, cert_callback: callbacks::CertCallbackConfig, + servername_callback: callbacks::ServerNameCallbackConfig, auth_keys: sign::CertifiedKeySet, } @@ -229,6 +230,7 @@ impl SslContext { alpn: vec![], alpn_callback: callbacks::AlpnCallbackConfig::default(), cert_callback: callbacks::CertCallbackConfig::default(), + servername_callback: callbacks::ServerNameCallbackConfig::default(), auth_keys: sign::CertifiedKeySet::default(), } } @@ -294,6 +296,14 @@ impl SslContext { fn get_privatekey(&self) -> *mut EVP_PKEY { self.auth_keys.borrow_current_key() } + + fn set_servername_callback(&mut self, cb: entry::SSL_CTX_servername_callback_func) { + self.servername_callback.cb = cb; + } + + fn set_servername_callback_context(&mut self, context: *mut c_void) { + self.servername_callback.context = context; + } } /// Parse the ALPN wire format (which is used in the openssl API) @@ -360,6 +370,7 @@ struct Ssl { alpn: Vec>, alpn_callback: callbacks::AlpnCallbackConfig, cert_callback: callbacks::CertCallbackConfig, + servername_callback: callbacks::ServerNameCallbackConfig, sni_server_name: Option>, server_name: Option, bio: Option, @@ -391,6 +402,7 @@ impl Ssl { alpn: inner.alpn.clone(), alpn_callback: inner.alpn_callback.clone(), cert_callback: inner.cert_callback.clone(), + servername_callback: inner.servername_callback.clone(), sni_server_name: None, server_name: None, bio: None, @@ -608,6 +620,11 @@ impl Ssl { .server_name() .and_then(|sni| CString::new(sni.as_bytes()).ok()); + callbacks.add(Box::new(callbacks::ServerNamePendingCallback { + config: self.servername_callback.clone(), + ssl: callbacks.ssl_ptr(), + })); + if let Some(alpn_iter) = accepted.client_hello().alpn() { let offer = encode_alpn(alpn_iter); callbacks.add(Box::new(callbacks::AlpnPendingCallback { diff --git a/rustls-libssl/tests/server.c b/rustls-libssl/tests/server.c index 0345367..60d3192 100644 --- a/rustls-libssl/tests/server.c +++ b/rustls-libssl/tests/server.c @@ -75,6 +75,19 @@ static int cert_callback(SSL *ssl, void *arg) { return 1; } +static int sni_cookie = 12345; + +static int sni_callback(SSL *ssl, int *al, void *arg) { + printf("in sni_callback\n"); + assert(ssl != NULL); + assert(arg == &sni_cookie); + assert(*al == SSL_AD_UNRECOGNIZED_NAME); + printf(" SSL_get_servername: %s (%d)\n", + SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name), + SSL_get_servername_type(ssl)); + return 1; +} + int main(int argc, char **argv) { if (argc != 5) { printf("%s |unauth\n\n", @@ -117,6 +130,11 @@ int main(int argc, char **argv) { SSL_CTX_set_cert_cb(ctx, cert_callback, &cert_cookie); dump_openssl_error_stack(); + SSL_CTX_set_tlsext_servername_callback(ctx, sni_callback); + dump_openssl_error_stack(); + SSL_CTX_set_tlsext_servername_arg(ctx, &sni_cookie); + dump_openssl_error_stack(); + X509 *client_cert = NULL; EVP_PKEY *client_key = NULL; TRACE(SSL_CTX_use_certificate_chain_file(ctx, certfile));