Skip to content

Commit

Permalink
Allow returning GetSessionPendingError from get session callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
nox committed Oct 24, 2023
1 parent 2d04d9a commit ea0a1b1
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 15 deletions.
22 changes: 13 additions & 9 deletions boring/src/ssl/callbacks.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#![forbid(unsafe_op_in_unsafe_fn)]

use super::{
AlpnError, ClientHello, PrivateKeyMethod, PrivateKeyMethodError, SelectCertError, SniError,
Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef,
SslSignatureAlgorithm, SESSION_CTX_INDEX,
AlpnError, ClientHello, GetSessionPendingError, PrivateKeyMethod, PrivateKeyMethodError,
SelectCertError, SniError, Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession,
SslSessionRef, SslSignatureAlgorithm, SESSION_CTX_INDEX,
};
use crate::error::ErrorStack;
use crate::ffi;
Expand All @@ -13,7 +13,6 @@ use foreign_types::ForeignTypeRef;
use libc::c_char;
use libc::{c_int, c_uchar, c_uint, c_void};
use std::ffi::CStr;
use std::mem;
use std::ptr;
use std::slice;
use std::str;
Expand Down Expand Up @@ -323,7 +322,10 @@ pub(super) unsafe extern "C" fn raw_get_session<F>(
copy: *mut c_int,
) -> *mut ffi::SSL_SESSION
where
F: Fn(&mut SslRef, &[u8]) -> Option<SslSession> + 'static + Sync + Send,
F: Fn(&mut SslRef, &[u8]) -> Result<Option<SslSession>, GetSessionPendingError>
+ 'static
+ Sync
+ Send,
{
// SAFETY: boring provides valid inputs.
let ssl = unsafe { SslRef::from_ptr_mut(ssl) };
Expand All @@ -342,13 +344,15 @@ where
let callback = unsafe { &*(callback as *const F) };

match callback(ssl, data) {
Some(session) => {
let p = session.as_ptr();
mem::forget(session);
Ok(Some(session)) => {
let p = session.into_ptr();

*copy = 0;

p
}
None => ptr::null_mut(),
Ok(None) => ptr::null_mut(),
Err(GetSessionPendingError) => unsafe { ffi::SSL_magic_pending_session_ptr() },
}
}

Expand Down
14 changes: 12 additions & 2 deletions boring/src/ssl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1599,12 +1599,15 @@ impl SslContextBuilder {
///
/// # Safety
///
/// The returned `SslSession` must not be associated with a different `SslContext`.
/// The returned [`SslSession`] must not be associated with a different [`SslContext`].
///
/// [`SSL_CTX_sess_set_get_cb`]: https://www.openssl.org/docs/manmaster/man3/SSL_CTX_sess_set_new_cb.html
pub unsafe fn set_get_session_callback<F>(&mut self, callback: F)
where
F: Fn(&mut SslRef, &[u8]) -> Option<SslSession> + 'static + Sync + Send,
F: Fn(&mut SslRef, &[u8]) -> Result<Option<SslSession>, GetSessionPendingError>
+ 'static
+ Sync
+ Send,
{
self.set_ex_data(SslContext::cached_ex_index::<F>(), callback);
ffi::SSL_CTX_sess_set_get_cb(self.as_ptr(), Some(callbacks::raw_get_session::<F>));
Expand Down Expand Up @@ -1978,6 +1981,13 @@ impl SslContextRef {
}
}

/// Error returned by the callback to get a session when operation
/// could not complete and should be retried later.
///
/// See [`SslContextBuilder::set_get_session_callback`].
#[derive(Debug)]
pub struct GetSessionPendingError;

#[cfg(not(any(feature = "fips", feature = "fips-link-precompiled")))]
type ProtosLen = usize;
#[cfg(any(feature = "fips", feature = "fips-link-precompiled"))]
Expand Down
57 changes: 53 additions & 4 deletions boring/src/ssl/test/session.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::io::Write;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::OnceLock;

use crate::ssl::test::server::Server;
use crate::ssl::{
Ssl, SslContext, SslContextBuilder, SslMethod, SslOptions, SslSession, SslSessionCacheMode,
SslVersion,
ErrorCode, GetSessionPendingError, HandshakeError, Ssl, SslContext, SslContextBuilder,
SslMethod, SslOptions, SslSession, SslSessionCacheMode, SslVersion,
};

#[test]
Expand Down Expand Up @@ -53,7 +54,7 @@ fn new_get_session_callback() {
unsafe {
server.ctx().set_get_session_callback(|_, id| {
let Some(der) = SERVER_SESSION_DER.get() else {
return None;
return Ok(None);
};

let session = SslSession::from_der(der).unwrap();
Expand All @@ -62,7 +63,7 @@ fn new_get_session_callback() {

assert_eq!(id, session.id());

Some(session)
Ok(Some(session))
});
}
server.ctx().set_session_id_context(b"foo").unwrap();
Expand Down Expand Up @@ -100,6 +101,54 @@ fn new_get_session_callback() {
assert!(FOUND_SESSION.load(Ordering::SeqCst));
}

#[test]
fn new_get_session_callback_pending() {
static CALLED_SERVER_CALLBACK: AtomicBool = AtomicBool::new(false);

let mut server = Server::builder();

server
.ctx()
.set_max_proto_version(Some(SslVersion::TLS1_2))
.unwrap();
server.ctx().set_options(SslOptions::NO_TICKET);
server
.ctx()
.set_session_cache_mode(SslSessionCacheMode::SERVER | SslSessionCacheMode::NO_INTERNAL);
unsafe {
server.ctx().set_get_session_callback(|_, _| {
if !CALLED_SERVER_CALLBACK.swap(true, Ordering::SeqCst) {
return Err(GetSessionPendingError);
}

Ok(None)
});
}
server.ctx().set_session_id_context(b"foo").unwrap();
server.err_cb(|error| {
let HandshakeError::WouldBlock(mid_handshake) = error else {
panic!("should be WouldBlock");
};

assert!(mid_handshake.error().would_block());
assert_eq!(mid_handshake.error().code(), ErrorCode::PENDING_SESSION);

let mut socket = mid_handshake.handshake().unwrap();

socket.write_all(&[0]).unwrap();
});

let server = server.build();

let mut client = server.client();

client
.ctx()
.set_session_cache_mode(SslSessionCacheMode::CLIENT);

client.connect();
}

#[test]
fn new_session_callback_swapped_ctx() {
static CALLED_BACK: AtomicBool = AtomicBool::new(false);
Expand Down

0 comments on commit ea0a1b1

Please sign in to comment.