From 1c7b8b28ef9773ce7d8e15f69fdd82684a814eac Mon Sep 17 00:00:00 2001 From: Evan Rittenhouse Date: Sun, 1 Sep 2024 14:00:43 -0500 Subject: [PATCH] Expose SSL_CTX_set_info_callback Model callback arguments as structs --- boring/src/ssl/callbacks.rs | 34 ++++++++++++++++-- boring/src/ssl/mod.rs | 72 +++++++++++++++++++++++++++++++++++++ boring/src/ssl/test/mod.rs | 14 ++++++++ 3 files changed, 118 insertions(+), 2 deletions(-) diff --git a/boring/src/ssl/callbacks.rs b/boring/src/ssl/callbacks.rs index 7841950c..f592d9d2 100644 --- a/boring/src/ssl/callbacks.rs +++ b/boring/src/ssl/callbacks.rs @@ -2,8 +2,9 @@ use super::{ AlpnError, ClientHello, GetSessionPendingError, PrivateKeyMethod, PrivateKeyMethodError, - SelectCertError, SniError, Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, - SslSessionRef, SslSignatureAlgorithm, SslVerifyError, SESSION_CTX_INDEX, + SelectCertError, SniError, Ssl, SslAlert, SslContext, SslContextRef, SslInfoCallbackAlert, + SslInfoCallbackMode, SslInfoCallbackValue, SslRef, SslSession, SslSessionRef, + SslSignatureAlgorithm, SslVerifyError, SESSION_CTX_INDEX, }; use crate::error::ErrorStack; use crate::ffi; @@ -521,3 +522,32 @@ where Err(err) => err.0, } } + +pub(super) unsafe extern "C" fn raw_info_callback( + ssl: *const ffi::SSL, + mode: c_int, + value: c_int, +) where + F: Fn(&SslRef, SslInfoCallbackMode, SslInfoCallbackValue) + Send + Sync + 'static, +{ + // Due to FFI signature requirements we have to pass a *const SSL into this function, but + // foreign-types requires a *mut SSL to get the Rust SslRef + let mut_ref = ssl as *mut ffi::SSL; + + // SAFETY: boring provides valid inputs. + let ssl = unsafe { SslRef::from_ptr(mut_ref) }; + let ssl_context = ssl.ssl_context(); + + let callback = ssl_context + .ex_data(SslContext::cached_ex_index::()) + .expect("BUG: info callback missing"); + + let value = match mode { + ffi::SSL_CB_READ_ALERT | ffi::SSL_CB_WRITE_ALERT => { + SslInfoCallbackValue::Alert(SslInfoCallbackAlert(value)) + } + _ => SslInfoCallbackValue::Unit, + }; + + callback(ssl, SslInfoCallbackMode(mode), value); +} diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 12ffedf5..adde108b 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -834,6 +834,66 @@ pub fn select_next_proto<'a>(server: &[u8], client: &'a [u8]) -> Option<&'a [u8] } } +/// Options controlling the behavior of the info callback. +#[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord, Hash)] +pub struct SslInfoCallbackMode(i32); + +impl SslInfoCallbackMode { + /// Signaled for each alert received, warning or fatal. + pub const READ_ALERT: Self = Self(ffi::SSL_CB_READ_ALERT); + + /// Signaled for each alert sent, warning or fatal. + pub const WRITE_ALERT: Self = Self(ffi::SSL_CB_WRITE_ALERT); + + /// Signaled when a handshake begins. + pub const HANDSHAKE_START: Self = Self(ffi::SSL_CB_HANDSHAKE_START); + + /// Signaled when a handshake completes successfully. + pub const HANDSHAKE_DONE: Self = Self(ffi::SSL_CB_HANDSHAKE_DONE); + + /// Signaled when a handshake progresses to a new state. + pub const ACCEPT_LOOP: Self = Self(ffi::SSL_CB_ACCEPT_LOOP); +} + +/// The `value` argument to an info callback. The most-significant byte is the alert level, while +/// the least significant byte is the alert itself. +#[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord, Hash)] +pub enum SslInfoCallbackValue { + /// The unit value (1). Some BoringSSL info callback modes, like ACCEPT_LOOP, always call the + /// callback with `value` set to the unit value. If the [`SslInfoCallbackValue`] is a + /// `Unit`, it can safely be disregarded. + Unit, + /// An alert. See [`SslInfoCallbackAlert`] for details on how to manipulate the alert. This + /// variant should only be present if the info callback was called with a `READ_ALERT` or + /// `WRITE_ALERT` mode. + Alert(SslInfoCallbackAlert), +} + +#[derive(Hash, Copy, Clone, PartialOrd, Ord, Eq, PartialEq, Debug)] +pub struct SslInfoCallbackAlert(c_int); + +impl SslInfoCallbackAlert { + /// The level of the SSL alert. + pub fn alert_level(&self) -> Ssl3AlertLevel { + let value = self.0 >> 8; + Ssl3AlertLevel(value) + } + + /// The value of the SSL alert. + pub fn alert(&self) -> SslAlert { + let value = self.0 & (u8::MAX as i32); + SslAlert(value) + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct Ssl3AlertLevel(c_int); + +impl Ssl3AlertLevel { + pub const WARNING: Ssl3AlertLevel = Self(ffi::SSL3_AL_WARNING); + pub const FATAL: Ssl3AlertLevel = Self(ffi::SSL3_AL_FATAL); +} + #[cfg(feature = "rpk")] extern "C" fn rpk_verify_failure_callback( _ssl: *mut ffi::SSL, @@ -1820,6 +1880,18 @@ impl SslContextBuilder { unsafe { cvt_0i(ffi::SSL_CTX_set_compliance_policy(self.as_ptr(), policy.0)).map(|_| ()) } } + /// Sets the context's info callback. + #[corresponds(SSL_CTX_set_info_callback)] + pub fn set_info_callback(&mut self, callback: F) + where + F: Fn(&SslRef, SslInfoCallbackMode, SslInfoCallbackValue) + Send + Sync + 'static, + { + unsafe { + self.replace_ex_data(SslContext::cached_ex_index::(), callback); + ffi::SSL_CTX_set_info_callback(self.as_ptr(), Some(callbacks::raw_info_callback::)); + } + } + /// Consumes the builder, returning a new `SslContext`. pub fn build(self) -> SslContext { self.ctx diff --git a/boring/src/ssl/test/mod.rs b/boring/src/ssl/test/mod.rs index 131b1127..f3b0fd29 100644 --- a/boring/src/ssl/test/mod.rs +++ b/boring/src/ssl/test/mod.rs @@ -1052,3 +1052,17 @@ fn drop_ex_data_in_ssl() { assert_eq!(ssl.replace_ex_data(index, "camembert"), Some("comté")); assert_eq!(ssl.replace_ex_data(index, "raclette"), Some("camembert")); } + +#[test] +fn test_info_callback() { + static CALLED_BACK: AtomicBool = AtomicBool::new(false); + + let server = Server::builder().build(); + let mut client = server.client(); + client.ctx().set_info_callback(move |_, _, _| { + CALLED_BACK.store(true, Ordering::Relaxed); + }); + + client.connect(); + assert!(CALLED_BACK.load(Ordering::Relaxed)); +}