Skip to content

Commit

Permalink
Introduce ArcCastPtr; add try macros for BoxCastPtr (rustls#157)
Browse files Browse the repository at this point in the history
Building on @djc's excellent addition of BoxCastPtr in rustls#88, this
introduces ArcCastPtr, which fills the same role for Arcs. Also, this
introduces null checking for both BoxCastPtr and ArcCastPtr, and adds
try macros for both of them.
  • Loading branch information
jsha authored Oct 19, 2021
1 parent 1509d75 commit 5fddc17
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 123 deletions.
22 changes: 16 additions & 6 deletions src/cipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ use rustls_pemfile::{certs, pkcs8_private_keys, rsa_private_keys};

use crate::error::rustls_result;
use crate::rslice::rustls_slice_bytes;
use crate::{ffi_panic_boundary, try_mut_from_ptr, try_ref_from_ptr, try_slice, CastPtr};
use crate::{
ffi_panic_boundary, try_mut_from_ptr, try_ref_from_ptr, try_slice, ArcCastPtr, BoxCastPtr,
CastConstPtr, CastPtr,
};
use rustls_result::NullParameter;
use std::ops::Deref;

Expand Down Expand Up @@ -112,6 +115,8 @@ impl CastPtr for rustls_certified_key {
type RustType = CertifiedKey;
}

impl ArcCastPtr for rustls_certified_key {}

impl rustls_certified_key {
/// Build a `rustls_certified_key` from a certificate chain and a private key.
/// `cert_chain` must point to a buffer of `cert_chain_len` bytes, containing
Expand Down Expand Up @@ -199,7 +204,7 @@ impl rustls_certified_key {
} else {
new_key.ocsp = None;
}
*cloned_key_out = Arc::into_raw(Arc::new(new_key)) as *const _;
*cloned_key_out = ArcCastPtr::to_const_ptr(new_key);
return rustls_result::Ok
}
}
Expand Down Expand Up @@ -287,6 +292,8 @@ impl CastPtr for rustls_root_cert_store {
type RustType = RootCertStore;
}

impl BoxCastPtr for rustls_root_cert_store {}

impl rustls_root_cert_store {
/// Create a rustls_root_cert_store. Caller owns the memory and must
/// eventually call rustls_root_cert_store_free. The store starts out empty.
Expand All @@ -296,8 +303,7 @@ impl rustls_root_cert_store {
pub extern "C" fn rustls_root_cert_store_new() -> *mut rustls_root_cert_store {
ffi_panic_boundary! {
let store = rustls::RootCertStore::empty();
let s = Box::new(store);
Box::into_raw(s) as *mut _
BoxCastPtr::to_mut_ptr(store)
}
}

Expand Down Expand Up @@ -362,10 +368,12 @@ pub struct rustls_client_cert_verifier {
_private: [u8; 0],
}

impl CastPtr for rustls_client_cert_verifier {
impl CastConstPtr for rustls_client_cert_verifier {
type RustType = AllowAnyAuthenticatedClient;
}

impl ArcCastPtr for rustls_client_cert_verifier {}

impl rustls_client_cert_verifier {
/// Create a new client certificate verifier for the root store. The verifier
/// can be used in several rustls_server_config instances. Must be freed by
Expand Down Expand Up @@ -412,10 +420,12 @@ pub struct rustls_client_cert_verifier_optional {
_private: [u8; 0],
}

impl CastPtr for rustls_client_cert_verifier_optional {
impl CastConstPtr for rustls_client_cert_verifier_optional {
type RustType = AllowAnyAnonymousOrAuthenticatedClient;
}

impl ArcCastPtr for rustls_client_cert_verifier_optional {}

impl rustls_client_cert_verifier_optional {
/// Create a new rustls_client_cert_verifier_optional for the root store. The
/// verifier can be used in several rustls_server_config instances. Must be
Expand Down
89 changes: 40 additions & 49 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use crate::error::{self, result_to_error, rustls_result};
use crate::rslice::NulByte;
use crate::rslice::{rustls_slice_bytes, rustls_slice_slice_bytes, rustls_str};
use crate::{
arc_with_incref_from_raw, ffi_panic_boundary, try_mut_from_ptr, try_ref_from_ptr, try_slice,
userdata_get, BoxCastPtr, CastPtr,
ffi_panic_boundary, try_arc_from_ptr, try_box_from_ptr, try_mut_from_ptr, try_ref_from_ptr,
try_slice, userdata_get, ArcCastPtr, BoxCastPtr, CastConstPtr, CastPtr,
};

/// A client config being constructed. A builder can be modified by,
Expand Down Expand Up @@ -64,10 +64,12 @@ pub struct rustls_client_config {
_private: [u8; 0],
}

impl CastPtr for rustls_client_config {
impl CastConstPtr for rustls_client_config {
type RustType = ClientConfig;
}

impl ArcCastPtr for rustls_client_config {}

impl rustls_client_config_builder {
/// Create a rustls_client_config_builder. Caller owns the memory and must
/// eventually call rustls_client_config_builder_build, then free the
Expand Down Expand Up @@ -284,7 +286,7 @@ impl rustls_client_config_builder {
None => return rustls_result::InvalidParameter,
};

let new = *BoxCastPtr::to_box(wants_verifier);
let new = *try_box_from_ptr!(wants_verifier);
let verifier: Verifier = Verifier{callback: callback};
// TODO: no client authentication support for now
let config = new.with_custom_certificate_verifier(Arc::new(verifier)).with_no_client_auth();
Expand All @@ -307,7 +309,7 @@ impl rustls_client_config_builder {
) -> rustls_result {
ffi_panic_boundary! {
let root_store: &RootCertStore = try_ref_from_ptr!(roots);
let prev = *BoxCastPtr::to_box(wants_verifier);
let prev = *try_box_from_ptr!(wants_verifier);
let config = prev.with_root_certificates(root_store.clone()).with_no_client_auth();
BoxCastPtr::set_mut_ptr(builder, config);
rustls_result::Ok
Expand All @@ -323,12 +325,13 @@ impl rustls_client_config_builder {
builder: *mut *mut rustls_client_config_builder,
) -> rustls_result {
ffi_panic_boundary! {
let filename: &CStr = unsafe {
if filename.is_null() {
return rustls_result::NullParameter;
}
CStr::from_ptr(filename)
};
let prev = *try_box_from_ptr!(wants_verifier);
let filename: &CStr = unsafe {
if filename.is_null() {
return rustls_result::NullParameter;
}
CStr::from_ptr(filename)
};

let filename: &[u8] = filename.to_bytes();
let filename: &str = match std::str::from_utf8(filename) {
Expand All @@ -353,7 +356,6 @@ impl rustls_client_config_builder {
return rustls_result::CertificateParseError;
}

let prev = *BoxCastPtr::to_box(wants_verifier);
// TODO: no client authentication support for now
let config = prev.with_root_certificates(roots).with_no_client_auth();
BoxCastPtr::set_mut_ptr(builder, config);
Expand Down Expand Up @@ -443,13 +445,7 @@ impl rustls_client_config_builder {
let keys_ptrs: &[*const rustls_certified_key] = try_slice!(certified_keys, certified_keys_len);
let mut keys: Vec<Arc<CertifiedKey>> = Vec::new();
for &key_ptr in keys_ptrs {
let key_ptr: &CertifiedKey = try_ref_from_ptr!(key_ptr);
let certified_key: Arc<CertifiedKey> = unsafe {
match (key_ptr as *const CertifiedKey).as_ref() {
Some(c) => arc_with_incref_from_raw(c),
None => return NullParameter,
}
};
let certified_key: Arc<CertifiedKey> = try_arc_from_ptr!(key_ptr);
keys.push(certified_key);
}
config.client_auth_cert_resolver = Arc::new(ResolvesClientCertFromChoices { keys });
Expand Down Expand Up @@ -490,8 +486,8 @@ impl rustls_client_config_builder {
builder: *mut rustls_client_config_builder,
) -> *const rustls_client_config {
ffi_panic_boundary! {
let b = BoxCastPtr::to_box(builder);
Arc::into_raw(Arc::new(*b)) as *const _
let b = try_box_from_ptr!(builder);
ArcCastPtr::to_const_ptr(*b)
}
}

Expand Down Expand Up @@ -539,34 +535,29 @@ impl rustls_client_config {
conn_out: *mut *mut rustls_connection,
) -> rustls_result {
ffi_panic_boundary! {
let hostname: &CStr = unsafe {
if hostname.is_null() {
return NullParameter;
}
CStr::from_ptr(hostname)
};
let config: Arc<ClientConfig> = unsafe {
match (config as *const ClientConfig).as_ref() {
Some(c) => arc_with_incref_from_raw(c),
None => return NullParameter,
}
};
let hostname: &str = match hostname.to_str() {
Ok(s) => s,
Err(std::str::Utf8Error { .. }) => return rustls_result::InvalidDnsNameError,
};
let server_name: rustls::ServerName = match hostname.try_into() {
Ok(sn) => sn,
Err(_) => return rustls_result::InvalidDnsNameError,
};
let client = ClientConnection::new(config, server_name).unwrap();

// We've succeeded. Put the client on the heap, and transfer ownership
// to the caller. After this point, we must return CRUSTLS_OK so the
// caller knows it is responsible for this memory.
let c = Connection::from_client(client);
BoxCastPtr::set_mut_ptr(conn_out, c);
rustls_result::Ok
let hostname: &CStr = unsafe {
if hostname.is_null() {
return NullParameter;
}
CStr::from_ptr(hostname)
};
let config: Arc<ClientConfig> = try_arc_from_ptr!(config);
let hostname: &str = match hostname.to_str() {
Ok(s) => s,
Err(std::str::Utf8Error { .. }) => return rustls_result::InvalidDnsNameError,
};
let server_name: rustls::ServerName = match hostname.try_into() {
Ok(sn) => sn,
Err(_) => return rustls_result::InvalidDnsNameError,
};
let client = ClientConnection::new(config, server_name).unwrap();

// We've succeeded. Put the client on the heap, and transfer ownership
// to the caller. After this point, we must return CRUSTLS_OK so the
// caller knows it is responsible for this memory.
let c = Connection::from_client(client);
BoxCastPtr::set_mut_ptr(conn_out, c);
rustls_result::Ok
}
}
}
119 changes: 93 additions & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,19 +287,40 @@ const RUSTLS_CRATE_VERSION: &str = "0.20.0";
pub(crate) trait CastPtr {
type RustType;

fn cast_mut_ptr(ptr: *mut Self) -> *mut Self::RustType {
ptr as *mut _
}
}

/// CastConstPtr represents a subset of CastPtr, for when we can only treat
/// something as a const (for instance when dealing with Arc).
pub(crate) trait CastConstPtr {
type RustType;

fn cast_const_ptr(ptr: *const Self) -> *const Self::RustType {
ptr as *const _
}
}

fn cast_mut_ptr(ptr: *mut Self) -> *mut Self::RustType {
ptr as *mut _
}
/// Anything that qualifies for CastPtr also automatically qualifies for
/// CastConstPtr. Splitting out CastPtr vs CastConstPtr allows us to ensure
/// that Arcs are never cast to a mutable pointer.
impl<T, R> CastConstPtr for T
where
T: CastPtr<RustType = R>,
{
type RustType = R;
}

// An implementation of BoxCastPtr means that when we give C code a pointer to the relevant type,
// it is actually a Box.
pub(crate) trait BoxCastPtr: CastPtr + Sized {
fn to_box(ptr: *mut Self) -> Box<Self::RustType> {
fn to_box(ptr: *mut Self) -> Option<Box<Self::RustType>> {
if ptr.is_null() {
return None;
}
let rs_typed = Self::cast_mut_ptr(ptr);
unsafe { Box::from_raw(rs_typed) }
unsafe { Some(Box::from_raw(rs_typed)) }
}

fn to_mut_ptr(src: Self::RustType) -> *mut Self {
Expand All @@ -313,6 +334,38 @@ pub(crate) trait BoxCastPtr: CastPtr + Sized {
}
}

// An implementation of ArcCastPtr means that when we give C code a pointer to the relevant type,
// it is actually a Arc.
pub(crate) trait ArcCastPtr: CastConstPtr + Sized {
/// Sometimes we create an Arc, then call `into_raw` and return the resulting raw pointer
/// to C. C can then call rustls_server_session_new multiple times using that
/// same raw pointer. On each call, we need to reconstruct the Arc. But once we reconstruct the Arc,
/// its reference count will be decremented on drop. We need to reference count to stay at 1,
/// because the C code is holding a copy. This function turns the raw pointer back into an Arc,
/// clones it to increment the reference count (which will make it 2 in this particular case), and
/// mem::forgets the clone. The mem::forget prevents the reference count from being decremented when
/// we exit this function, so it will stay at 2 as long as we are in Rust code. Once the caller
/// drops its Arc, the reference count will go back down to 1, indicating the C code's copy.
///
/// Unsafety:
///
/// v must be a non-null pointer that resulted from previously calling `Arc::into_raw`.
fn to_arc(ptr: *const Self) -> Option<Arc<Self::RustType>> {
if ptr.is_null() {
return None;
}
let rs_typed = Self::cast_const_ptr(ptr);
let r = unsafe { Arc::from_raw(rs_typed) };
let val = Arc::clone(&r);
mem::forget(r);
Some(val)
}

fn to_const_ptr(src: Self::RustType) -> *const Self {
Arc::into_raw(Arc::new(src)) as *const _
}
}

#[doc(hidden)]
#[macro_export]
macro_rules! try_slice {
Expand Down Expand Up @@ -343,7 +396,7 @@ macro_rules! try_mut_slice {
/// against T (the cast-to type) rather than across F (the from type).
pub(crate) fn try_from<'a, F, T>(from: *const F) -> Option<&'a T>
where
F: CastPtr<RustType = T>,
F: CastConstPtr<RustType = T>,
{
unsafe { F::cast_const_ptr(from).as_ref() }
}
Expand All @@ -356,6 +409,20 @@ where
unsafe { F::cast_mut_ptr(from).as_mut() }
}

pub(crate) fn try_box_from<'a, F, T>(from: *mut F) -> Option<Box<T>>
where
F: BoxCastPtr<RustType = T>,
{
F::to_box(from)
}

pub(crate) fn try_arc_from<'a, F, T>(from: *const F) -> Option<Arc<T>>
where
F: ArcCastPtr<RustType = T>,
{
F::to_arc(from)
}

impl CastPtr for size_t {
type RustType = size_t;
}
Expand Down Expand Up @@ -392,6 +459,26 @@ macro_rules! try_mut_from_ptr {
}

#[doc(hidden)]
#[macro_export]
macro_rules! try_box_from_ptr {
( $var:ident ) => {
match crate::try_box_from($var) {
Some(c) => c,
None => return crate::panic::NullParameterOrDefault::value(),
}
};
}

#[macro_export]
macro_rules! try_arc_from_ptr {
( $var:ident ) => {
match crate::try_arc_from($var) {
Some(c) => c,
None => return crate::panic::NullParameterOrDefault::value(),
}
};
}

#[macro_export]
macro_rules! try_callback {
( $var:ident ) => {
Expand Down Expand Up @@ -425,23 +512,3 @@ pub extern "C" fn rustls_version(buf: *mut c_char, len: size_t) -> size_t {
len
}
}

/// Sometimes we create an Arc, then call `into_raw` and return the resulting raw pointer
/// to C. C can then call back into rustls multiple times using that same raw pointer.
/// On each call, we need to reconstruct the Arc. But once we reconstruct the Arc,
/// its reference count will be decremented on drop. We need to reference count to stay at 1,
/// because the C code is holding a copy. This function turns the raw pointer back into an Arc,
/// clones it to increment the reference count (which will make it 2 in this particular case), and
/// mem::forgets the clone. The mem::forget prevents the reference count from being decremented when
/// we exit this function, so it will stay at 2 as long as we are in Rust code. Once the caller
/// drops its Arc, the reference count will go back down to 1, indicating the C code's copy.
///
/// Unsafety:
///
/// v must be a non-null pointer that resulted from previously calling `Arc::into_raw`.
unsafe fn arc_with_incref_from_raw<T>(v: *const T) -> Arc<T> {
let r = Arc::from_raw(v);
let val = Arc::clone(&r);
mem::forget(r);
val
}
Loading

0 comments on commit 5fddc17

Please sign in to comment.