diff --git a/rustls-libssl/MATRIX.md b/rustls-libssl/MATRIX.md index 6d9dbf8..353c573 100644 --- a/rustls-libssl/MATRIX.md +++ b/rustls-libssl/MATRIX.md @@ -67,7 +67,7 @@ | `SSL_CTX_add_custom_ext` | | | | | `SSL_CTX_add_server_custom_ext` | | | | | `SSL_CTX_add_session` | | | | -| `SSL_CTX_callback_ctrl` | | :white_check_mark: | | +| `SSL_CTX_callback_ctrl` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_check_private_key` | :white_check_mark: | | :white_check_mark: | | `SSL_CTX_clear_options` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_config` | | | | @@ -88,14 +88,14 @@ | `SSL_CTX_get0_security_ex_data` | | | | | `SSL_CTX_get_cert_store` | :white_check_mark: | :white_check_mark: | :white_check_mark: | | `SSL_CTX_get_ciphers` | | | | -| `SSL_CTX_get_client_CA_list` | | :white_check_mark: | | +| `SSL_CTX_get_client_CA_list` | | :white_check_mark: | :exclamation: [^stub] | | `SSL_CTX_get_client_cert_cb` | | | | | `SSL_CTX_get_default_passwd_cb` | | | | | `SSL_CTX_get_default_passwd_cb_userdata` | | | | -| `SSL_CTX_get_ex_data` | | :white_check_mark: | :exclamation: [^stub] | +| `SSL_CTX_get_ex_data` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_get_info_callback` | | | | | `SSL_CTX_get_keylog_callback` | | | | -| `SSL_CTX_get_max_early_data` | | :white_check_mark: | | +| `SSL_CTX_get_max_early_data` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_get_num_tickets` | | | | | `SSL_CTX_get_options` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_get_quiet_shutdown` | | | | @@ -104,14 +104,14 @@ | `SSL_CTX_get_security_callback` | | | | | `SSL_CTX_get_security_level` | | | | | `SSL_CTX_get_ssl_method` | | | | -| `SSL_CTX_get_timeout` | | :white_check_mark: | | -| `SSL_CTX_get_verify_callback` | | :white_check_mark: | | -| `SSL_CTX_get_verify_depth` | | :white_check_mark: | | -| `SSL_CTX_get_verify_mode` | | :white_check_mark: | | +| `SSL_CTX_get_timeout` | | :white_check_mark: | :exclamation: [^stub] | +| `SSL_CTX_get_verify_callback` | | :white_check_mark: | :white_check_mark: | +| `SSL_CTX_get_verify_depth` | | :white_check_mark: | :white_check_mark: | +| `SSL_CTX_get_verify_mode` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_has_client_custom_ext` | | | | | `SSL_CTX_load_verify_dir` | :white_check_mark: | | :white_check_mark: | | `SSL_CTX_load_verify_file` | :white_check_mark: | | :white_check_mark: | -| `SSL_CTX_load_verify_locations` | | :white_check_mark: | | +| `SSL_CTX_load_verify_locations` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_load_verify_store` | | | | | `SSL_CTX_new` | :white_check_mark: | :white_check_mark: | :white_check_mark: | | `SSL_CTX_new_ex` | | | | @@ -131,16 +131,16 @@ | `SSL_CTX_set1_param` | | | | | `SSL_CTX_set_allow_early_data_cb` | | | | | `SSL_CTX_set_alpn_protos` | :white_check_mark: | :white_check_mark: | :white_check_mark: | -| `SSL_CTX_set_alpn_select_cb` | | :white_check_mark: | | +| `SSL_CTX_set_alpn_select_cb` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_set_async_callback` | | | | | `SSL_CTX_set_async_callback_arg` | | | | | `SSL_CTX_set_block_padding` | | | | -| `SSL_CTX_set_cert_cb` | | :white_check_mark: | | +| `SSL_CTX_set_cert_cb` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_set_cert_store` | | | | | `SSL_CTX_set_cert_verify_callback` | | | | -| `SSL_CTX_set_cipher_list` | :white_check_mark: | :white_check_mark: | :exclamation: [^stub] | +| `SSL_CTX_set_cipher_list` | :white_check_mark: | :white_check_mark: | :white_check_mark: | | `SSL_CTX_set_ciphersuites` | :white_check_mark: | | :exclamation: [^stub] | -| `SSL_CTX_set_client_CA_list` | | :white_check_mark: | | +| `SSL_CTX_set_client_CA_list` | | :white_check_mark: | :exclamation: [^stub] | | `SSL_CTX_set_client_cert_cb` | | | | | `SSL_CTX_set_client_cert_engine` [^engine] | | | | | `SSL_CTX_set_client_hello_cb` | | | | @@ -156,14 +156,14 @@ | `SSL_CTX_set_default_verify_file` | | | :white_check_mark: | | `SSL_CTX_set_default_verify_paths` | | | :white_check_mark: | | `SSL_CTX_set_default_verify_store` | | | :exclamation: [^stub] | -| `SSL_CTX_set_ex_data` | | :white_check_mark: | :exclamation: [^stub] | +| `SSL_CTX_set_ex_data` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_set_generate_session_id` | | | | -| `SSL_CTX_set_info_callback` | | :white_check_mark: | | +| `SSL_CTX_set_info_callback` | | :white_check_mark: | :exclamation: [^stub] | | `SSL_CTX_set_keylog_callback` | :white_check_mark: | | :exclamation: [^stub] | -| `SSL_CTX_set_max_early_data` | | :white_check_mark: | | +| `SSL_CTX_set_max_early_data` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_set_msg_callback` | :white_check_mark: | | :exclamation: [^stub] | | `SSL_CTX_set_next_proto_select_cb` [^nextprotoneg] | :white_check_mark: | | :exclamation: [^stub] | -| `SSL_CTX_set_next_protos_advertised_cb` [^nextprotoneg] | | :white_check_mark: | | +| `SSL_CTX_set_next_protos_advertised_cb` [^nextprotoneg] | | :white_check_mark: | :exclamation: [^stub] | | `SSL_CTX_set_not_resumable_session_callback` | | | | | `SSL_CTX_set_num_tickets` | | | | | `SSL_CTX_set_options` | :white_check_mark: | :white_check_mark: | :white_check_mark: | @@ -179,7 +179,7 @@ | `SSL_CTX_set_recv_max_early_data` | | | | | `SSL_CTX_set_security_callback` | | | | | `SSL_CTX_set_security_level` | | | | -| `SSL_CTX_set_session_id_context` | | :white_check_mark: | :exclamation: [^stub] | +| `SSL_CTX_set_session_id_context` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_set_session_ticket_cb` | | | | | `SSL_CTX_set_srp_cb_arg` [^deprecatedin_3_0] [^srp] | | | | | `SSL_CTX_set_srp_client_pwd_callback` [^deprecatedin_3_0] [^srp] | | | | @@ -191,14 +191,14 @@ | `SSL_CTX_set_ssl_version` [^deprecatedin_3_0] | | | | | `SSL_CTX_set_stateless_cookie_generate_cb` | | | | | `SSL_CTX_set_stateless_cookie_verify_cb` | | | | -| `SSL_CTX_set_timeout` | | :white_check_mark: | | +| `SSL_CTX_set_timeout` | | :white_check_mark: | :exclamation: [^stub] | | `SSL_CTX_set_tlsext_max_fragment_length` | | | | | `SSL_CTX_set_tlsext_ticket_key_evp_cb` | | | | | `SSL_CTX_set_tlsext_use_srtp` [^srtp] | | | | | `SSL_CTX_set_tmp_dh_callback` [^deprecatedin_3_0] [^dh] | | | | | `SSL_CTX_set_trust` | | | | | `SSL_CTX_set_verify` | :white_check_mark: | :white_check_mark: | :white_check_mark: | -| `SSL_CTX_set_verify_depth` | | :white_check_mark: | | +| `SSL_CTX_set_verify_depth` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_up_ref` | | | :white_check_mark: | | `SSL_CTX_use_PrivateKey` | :white_check_mark: | :white_check_mark: | :white_check_mark: | | `SSL_CTX_use_PrivateKey_ASN1` | | | | @@ -272,7 +272,7 @@ | `SSL_callback_ctrl` | | | | | `SSL_certs_clear` | | | | | `SSL_check_chain` | | | | -| `SSL_check_private_key` | | | | +| `SSL_check_private_key` | | | :white_check_mark: | | `SSL_clear` | | | | | `SSL_clear_options` | | :white_check_mark: | :white_check_mark: | | `SSL_client_hello_get0_ciphers` | | | | @@ -293,7 +293,7 @@ | `SSL_dane_enable` | | | | | `SSL_dane_set_flags` | | | | | `SSL_dane_tlsa_add` | | | | -| `SSL_do_handshake` | | :white_check_mark: | | +| `SSL_do_handshake` | | :white_check_mark: | :white_check_mark: | | `SSL_dup` | | | | | `SSL_dup_CA_list` | | | | | `SSL_enable_ct` [^ct] | | | | @@ -307,7 +307,7 @@ | `SSL_get0_dane` | | | | | `SSL_get0_dane_authority` | | | | | `SSL_get0_dane_tlsa` | | | | -| `SSL_get0_next_proto_negotiated` [^nextprotoneg] | | :white_check_mark: | | +| `SSL_get0_next_proto_negotiated` [^nextprotoneg] | | :white_check_mark: | :exclamation: [^stub] | | `SSL_get0_param` | | | | | `SSL_get0_peer_CA_list` | | | | | `SSL_get0_peer_certificate` | | | :white_check_mark: | @@ -329,15 +329,15 @@ | `SSL_get_client_ciphers` | | | | | `SSL_get_client_random` | | | | | `SSL_get_current_cipher` | :white_check_mark: | :white_check_mark: | :white_check_mark: | -| `SSL_get_current_compression` | | | | +| `SSL_get_current_compression` | | | :white_check_mark: | | `SSL_get_current_expansion` | | | | | `SSL_get_default_passwd_cb` | | | | | `SSL_get_default_passwd_cb_userdata` | | | | | `SSL_get_default_timeout` | | | | | `SSL_get_early_data_status` | | | | | `SSL_get_error` | :white_check_mark: | :white_check_mark: | :white_check_mark: | -| `SSL_get_ex_data` | :white_check_mark: | :white_check_mark: | :exclamation: [^stub] | -| `SSL_get_ex_data_X509_STORE_CTX_idx` | | :white_check_mark: | | +| `SSL_get_ex_data` | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| `SSL_get_ex_data_X509_STORE_CTX_idx` | | :white_check_mark: | :exclamation: [^stub] | | `SSL_get_fd` | | | | | `SSL_get_finished` | | | | | `SSL_get_info_callback` | | | | @@ -362,8 +362,8 @@ | `SSL_get_security_level` | | | | | `SSL_get_selected_srtp_profile` [^srtp] | | | | | `SSL_get_server_random` | | | | -| `SSL_get_servername` | | :white_check_mark: | | -| `SSL_get_servername_type` | | | | +| `SSL_get_servername` | | :white_check_mark: | :white_check_mark: | +| `SSL_get_servername_type` | | | :white_check_mark: | | `SSL_get_session` | | :white_check_mark: | :exclamation: [^stub] | | `SSL_get_shared_ciphers` | | | | | `SSL_get_shared_sigalgs` | | | | @@ -378,7 +378,7 @@ | `SSL_get_ssl_method` | | | | | `SSL_get_state` | | | :white_check_mark: | | `SSL_get_verify_callback` | | | | -| `SSL_get_verify_depth` | | | | +| `SSL_get_verify_depth` | | | :white_check_mark: | | `SSL_get_verify_mode` | | | | | `SSL_get_verify_result` | :white_check_mark: | :white_check_mark: | :white_check_mark: | | `SSL_get_version` | :white_check_mark: | :white_check_mark: | :white_check_mark: | @@ -393,7 +393,7 @@ | `SSL_is_init_finished` | | :white_check_mark: | :white_check_mark: | | `SSL_is_server` | | | :white_check_mark: | | `SSL_key_update` | | | | -| `SSL_load_client_CA_file` | | :white_check_mark: | | +| `SSL_load_client_CA_file` | | :white_check_mark: | :exclamation: [^stub] | | `SSL_load_client_CA_file_ex` | | | | | `SSL_new` | :white_check_mark: | :white_check_mark: | :white_check_mark: | | `SSL_new_session_ticket` | | | | @@ -401,16 +401,16 @@ | `SSL_peek_ex` | | | | | `SSL_pending` | :white_check_mark: | | :white_check_mark: | | `SSL_read` | :white_check_mark: | :white_check_mark: | :white_check_mark: | -| `SSL_read_early_data` | | :white_check_mark: | | +| `SSL_read_early_data` | | :white_check_mark: | :exclamation: [^stub] | | `SSL_read_ex` | | | | | `SSL_renegotiate` | | | | | `SSL_renegotiate_abbreviated` | | | | | `SSL_renegotiate_pending` | | | | | `SSL_rstate_string` | | | | | `SSL_rstate_string_long` | | | | -| `SSL_select_next_proto` | | :white_check_mark: | | +| `SSL_select_next_proto` | | :white_check_mark: | :white_check_mark: | | `SSL_sendfile` | | | | -| `SSL_session_reused` | | :white_check_mark: | :exclamation: [^stub] | +| `SSL_session_reused` | | :white_check_mark: | :white_check_mark: | | `SSL_set0_CA_list` | | | | | `SSL_set0_rbio` | | | :white_check_mark: | | `SSL_set0_security_ex_data` | | | | @@ -436,7 +436,7 @@ | `SSL_set_default_passwd_cb` | | | | | `SSL_set_default_passwd_cb_userdata` | | | | | `SSL_set_default_read_buffer_len` | | | | -| `SSL_set_ex_data` | :white_check_mark: | :white_check_mark: | :exclamation: [^stub] | +| `SSL_set_ex_data` | :white_check_mark: | :white_check_mark: | :white_check_mark: | | `SSL_set_fd` [^sock] | :white_check_mark: | :white_check_mark: | :white_check_mark: | | `SSL_set_generate_session_id` | | | | | `SSL_set_hostflags` | | | | @@ -452,7 +452,7 @@ | `SSL_set_psk_server_callback` [^psk] | | | | | `SSL_set_psk_use_session_callback` | | | | | `SSL_set_purpose` | | | | -| `SSL_set_quiet_shutdown` | | :white_check_mark: | | +| `SSL_set_quiet_shutdown` | | :white_check_mark: | :white_check_mark: | | `SSL_set_read_ahead` | | | | | `SSL_set_record_padding_callback` | | | | | `SSL_set_record_padding_callback_arg` | | | | @@ -473,8 +473,8 @@ | `SSL_set_tlsext_use_srtp` [^srtp] | | | | | `SSL_set_tmp_dh_callback` [^deprecatedin_3_0] [^dh] | | | | | `SSL_set_trust` | | | | -| `SSL_set_verify` | | :white_check_mark: | | -| `SSL_set_verify_depth` | | :white_check_mark: | | +| `SSL_set_verify` | | :white_check_mark: | :white_check_mark: | +| `SSL_set_verify_depth` | | :white_check_mark: | :white_check_mark: | | `SSL_set_verify_result` | | | | | `SSL_set_wfd` [^sock] | | | | | `SSL_shutdown` | :white_check_mark: | :white_check_mark: | :white_check_mark: | @@ -487,7 +487,7 @@ | `SSL_up_ref` | | | :white_check_mark: | | `SSL_use_PrivateKey` | | :white_check_mark: | :white_check_mark: | | `SSL_use_PrivateKey_ASN1` | | | | -| `SSL_use_PrivateKey_file` | | | | +| `SSL_use_PrivateKey_file` | | | :white_check_mark: | | `SSL_use_RSAPrivateKey` [^deprecatedin_3_0] | | | | | `SSL_use_RSAPrivateKey_ASN1` [^deprecatedin_3_0] | | | | | `SSL_use_RSAPrivateKey_file` [^deprecatedin_3_0] | | | | @@ -502,7 +502,7 @@ | `SSL_waiting_for_async` | | | | | `SSL_want` | | | :white_check_mark: | | `SSL_write` | :white_check_mark: | :white_check_mark: | :white_check_mark: | -| `SSL_write_early_data` | | :white_check_mark: | | +| `SSL_write_early_data` | | :white_check_mark: | :exclamation: [^stub] | | `SSL_write_ex` | | | | | `SSLv3_client_method` [^deprecatedin_1_1_0] [^ssl3_method] | | | | | `SSLv3_method` [^deprecatedin_1_1_0] [^ssl3_method] | | | | diff --git a/rustls-libssl/build.rs b/rustls-libssl/build.rs index 905d4e5..34fb6d2 100644 --- a/rustls-libssl/build.rs +++ b/rustls-libssl/build.rs @@ -49,6 +49,7 @@ const ENTRYPOINTS: &[&str] = &[ "SSL_accept", "SSL_alert_desc_string", "SSL_alert_desc_string_long", + "SSL_check_private_key", "SSL_CIPHER_description", "SSL_CIPHER_find", "SSL_CIPHER_get_bits", @@ -61,6 +62,7 @@ const ENTRYPOINTS: &[&str] = &[ "SSL_connect", "SSL_ctrl", "SSL_CTX_add_client_CA", + "SSL_CTX_callback_ctrl", "SSL_CTX_check_private_key", "SSL_CTX_clear_options", "SSL_CTX_ctrl", @@ -68,18 +70,28 @@ const ENTRYPOINTS: &[&str] = &[ "SSL_CTX_get0_certificate", "SSL_CTX_get0_privatekey", "SSL_CTX_get_cert_store", + "SSL_CTX_get_client_CA_list", "SSL_CTX_get_ex_data", + "SSL_CTX_get_max_early_data", "SSL_CTX_get_options", + "SSL_CTX_get_timeout", + "SSL_CTX_get_verify_callback", + "SSL_CTX_get_verify_depth", + "SSL_CTX_get_verify_mode", "SSL_CTX_load_verify_dir", "SSL_CTX_load_verify_file", + "SSL_CTX_load_verify_locations", "SSL_CTX_new", "SSL_CTX_remove_session", "SSL_CTX_sess_set_get_cb", "SSL_CTX_sess_set_new_cb", "SSL_CTX_sess_set_remove_cb", "SSL_CTX_set_alpn_protos", + "SSL_CTX_set_alpn_select_cb", + "SSL_CTX_set_cert_cb", "SSL_CTX_set_cipher_list", "SSL_CTX_set_ciphersuites", + "SSL_CTX_set_client_CA_list", "SSL_CTX_set_default_passwd_cb", "SSL_CTX_set_default_passwd_cb_userdata", "SSL_CTX_set_default_verify_dir", @@ -87,38 +99,50 @@ const ENTRYPOINTS: &[&str] = &[ "SSL_CTX_set_default_verify_paths", "SSL_CTX_set_default_verify_store", "SSL_CTX_set_ex_data", + "SSL_CTX_set_info_callback", "SSL_CTX_set_keylog_callback", + "SSL_CTX_set_max_early_data", "SSL_CTX_set_msg_callback", "SSL_CTX_set_next_proto_select_cb", + "SSL_CTX_set_next_protos_advertised_cb", "SSL_CTX_set_options", "SSL_CTX_set_post_handshake_auth", "SSL_CTX_set_session_id_context", "SSL_CTX_set_srp_password", "SSL_CTX_set_srp_username", + "SSL_CTX_set_timeout", "SSL_CTX_set_verify", + "SSL_CTX_set_verify_depth", "SSL_CTX_up_ref", "SSL_CTX_use_certificate", "SSL_CTX_use_certificate_chain_file", "SSL_CTX_use_certificate_file", "SSL_CTX_use_PrivateKey", "SSL_CTX_use_PrivateKey_file", + "SSL_do_handshake", "SSL_free", "SSL_get0_alpn_selected", + "SSL_get0_next_proto_negotiated", "SSL_get0_peer_certificate", "SSL_get0_verified_chain", "SSL_get1_peer_certificate", "SSL_get1_session", "SSL_get_certificate", "SSL_get_current_cipher", + "SSL_get_current_compression", "SSL_get_error", "SSL_get_ex_data", + "SSL_get_ex_data_X509_STORE_CTX_idx", "SSL_get_options", "SSL_get_peer_cert_chain", "SSL_get_privatekey", "SSL_get_rbio", + "SSL_get_servername", + "SSL_get_servername_type", "SSL_get_session", "SSL_get_shutdown", "SSL_get_state", + "SSL_get_verify_depth", "SSL_get_verify_result", "SSL_get_version", "SSL_get_wbio", @@ -127,9 +151,12 @@ const ENTRYPOINTS: &[&str] = &[ "SSL_in_init", "SSL_is_init_finished", "SSL_is_server", + "SSL_load_client_CA_file", "SSL_new", "SSL_pending", "SSL_read", + "SSL_read_early_data", + "SSL_select_next_proto", "SSL_SESSION_free", "SSL_SESSION_get_id", "SSL_session_reused", @@ -145,16 +172,21 @@ const ENTRYPOINTS: &[&str] = &[ "SSL_set_fd", "SSL_set_options", "SSL_set_post_handshake_auth", + "SSL_set_quiet_shutdown", "SSL_set_session", "SSL_set_session_id_context", "SSL_set_shutdown", "SSL_set_SSL_CTX", + "SSL_set_verify", + "SSL_set_verify_depth", "SSL_shutdown", "SSL_up_ref", "SSL_use_certificate", "SSL_use_PrivateKey", + "SSL_use_PrivateKey_file", "SSL_want", "SSL_write", + "SSL_write_early_data", "TLS_client_method", "TLS_method", "TLS_server_method", diff --git a/rustls-libssl/src/callbacks.rs b/rustls-libssl/src/callbacks.rs new file mode 100644 index 0000000..05f0b62 --- /dev/null +++ b/rustls-libssl/src/callbacks.rs @@ -0,0 +1,181 @@ +use core::cell::RefCell; +use core::ffi::{c_int, c_uchar, c_void}; +use core::{ptr, slice}; + +use openssl_sys::{SSL_TLSEXT_ERR_NOACK, SSL_TLSEXT_ERR_OK}; +use rustls::AlertDescription; + +use crate::entry::{ + SSL_CTX_alpn_select_cb_func, SSL_CTX_cert_cb_func, SSL_CTX_servername_callback_func, SSL, +}; +use crate::error::Error; + +/// Smuggling SSL* pointers from the outer entrypoint into the +/// callback call site. +pub struct SslCallbackContext; + +impl SslCallbackContext { + /// Register the original SSL* pointer for use in later callbacks. + /// + /// The returned object de-registers itself when dropped. + pub fn new(ssl: *mut SSL) -> Self { + SSL_CALLBACK_CONTEXT.set(Some(ssl)); + Self + } + + /// Get the original SSL* pointer, or else `NULL` + /// + /// This has thread-local semantics: it uses the most recent + /// object of this type created on this thread. + pub fn ssl_ptr() -> *mut SSL { + SSL_CALLBACK_CONTEXT.with_borrow(|holder| { + holder + .as_ref() + .map(|inner| *inner) + .unwrap_or_else(ptr::null_mut) + }) + } +} + +impl Drop for SslCallbackContext { + fn drop(&mut self) { + SSL_CALLBACK_CONTEXT.set(None); + } +} + +thread_local! { + static SSL_CALLBACK_CONTEXT: RefCell> = const { RefCell::new(None) }; +} + +/// Configuration needed to call [`invoke_alpn_callback`] later +#[derive(Debug, Clone)] +pub struct AlpnCallbackConfig { + pub cb: SSL_CTX_alpn_select_cb_func, + pub context: *mut c_void, +} + +impl AlpnCallbackConfig { + /// Call a `SSL_CTX_alpn_select_cb_func` callback + /// + /// Returns the selected ALPN, or None, or an error. + pub fn invoke(&self, offer: &[u8]) -> Result>, Error> { + let callback = match self.cb { + Some(callback) => callback, + None => { + return Ok(None); + } + }; + + let ssl = SslCallbackContext::ssl_ptr(); + + let mut output_ptr: *const c_uchar = ptr::null(); + let mut output_len = 0u8; + let result = unsafe { + callback( + ssl, + &mut output_ptr as *mut *const c_uchar, + &mut output_len as *mut u8, + offer.as_ptr(), + offer.len() as u32, + self.context, + ) + }; + + if result == SSL_TLSEXT_ERR_OK && !output_ptr.is_null() { + let chosen = unsafe { slice::from_raw_parts(output_ptr, output_len as usize) }; + Ok(Some(chosen.to_vec())) + } else if result == SSL_TLSEXT_ERR_NOACK { + Ok(None) + } else { + Err(Error::bad_data("alpn not chosen")) + } + } +} + +impl Default for AlpnCallbackConfig { + fn default() -> Self { + Self { + cb: None, + context: ptr::null_mut(), + } + } +} + +/// Configuration needed to call [`invoke_cert_callback`] later +#[derive(Debug, Clone)] +pub struct CertCallbackConfig { + pub cb: SSL_CTX_cert_cb_func, + pub context: *mut c_void, +} + +impl CertCallbackConfig { + pub fn invoke(&self) -> Result<(), Error> { + let callback = match self.cb { + Some(callback) => callback, + None => { + return Ok(()); + } + }; + let ssl = SslCallbackContext::ssl_ptr(); + + let result = unsafe { callback(ssl, self.context) }; + + match result { + 1 => Ok(()), + _ => Err(Error::not_supported("SSL_CTX_cert_cb_func returned != 1")), + } + } +} + +impl Default for CertCallbackConfig { + fn default() -> Self { + Self { + cb: None, + context: ptr::null_mut(), + } + } +} + +/// Configuration needed to call [`invoke_servername_callback`] later +#[derive(Debug, Clone)] +pub struct ServerNameCallbackConfig { + pub cb: SSL_CTX_servername_callback_func, + pub context: *mut c_void, +} + +impl ServerNameCallbackConfig { + pub fn invoke(&self) -> Result<(), Error> { + let callback = match self.cb { + Some(callback) => callback, + None => { + return Ok(()); + } + }; + + let ssl = SslCallbackContext::ssl_ptr(); + + let unrecognised_name = u8::from(AlertDescription::UnrecognisedName) as c_int; + let mut alert = unrecognised_name; + let result = unsafe { callback(ssl, &mut alert as *mut c_int, self.context) }; + + if alert != unrecognised_name { + log::trace!("NYI: customised alert during servername callback"); + } + + match result { + SSL_TLSEXT_ERR_OK => Ok(()), + _ => Err(Error::not_supported( + "SSL_CTX_servername_callback_func return error", + )), + } + } +} + +impl Default for ServerNameCallbackConfig { + fn default() -> Self { + Self { + cb: None, + context: ptr::null_mut(), + } + } +} diff --git a/rustls-libssl/src/entry.rs b/rustls-libssl/src/entry.rs index 2f89493..c1105d7 100644 --- a/rustls-libssl/src/entry.rs +++ b/rustls-libssl/src/entry.rs @@ -8,15 +8,21 @@ use std::io::{self, Read}; use std::os::raw::{c_char, c_int, c_long, c_uchar, c_uint, c_void}; use std::{fs, path::PathBuf}; -use openssl_sys::{stack_st_X509, OPENSSL_malloc, EVP_PKEY, X509, X509_STORE, X509_STORE_CTX}; +use openssl_sys::{ + stack_st_X509, stack_st_X509_NAME, OPENSSL_malloc, TLSEXT_NAMETYPE_host_name, EVP_PKEY, + OPENSSL_NPN_NEGOTIATED, OPENSSL_NPN_NO_OVERLAP, X509, X509_STORE, X509_STORE_CTX, +}; use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer}; use crate::bio::{Bio, BIO, BIO_METHOD}; +use crate::callbacks::SslCallbackContext; use crate::error::{ffi_panic_boundary, Error, MysteriouslyOppositeReturnValue}; use crate::evp_pkey::EvpPkey; +use crate::ex_data::ExData; use crate::ffi::{ - free_arc, str_from_cstring, to_arc_mut_ptr, try_clone_arc, try_from, try_mut_slice_int, - try_ref_from_ptr, try_slice, try_slice_int, try_str, Castable, OwnershipArc, OwnershipRef, + clone_arc, free_arc, str_from_cstring, to_arc_mut_ptr, try_clone_arc, try_from, + try_mut_slice_int, try_ref_from_ptr, try_slice, try_slice_int, try_str, Castable, OwnershipArc, + OwnershipRef, }; use crate::not_thread_safe::NotThreadSafe; use crate::x509::{load_certs, OwnedX509, OwnedX509Stack}; @@ -102,12 +108,23 @@ impl Castable for SSL_METHOD { type RustType = SSL_METHOD; } -type SSL_CTX = crate::SslContext; +pub type SSL_CTX = crate::SslContext; entry! { pub fn _SSL_CTX_new(meth: *const SSL_METHOD) -> *mut SSL_CTX { let method = try_ref_from_ptr!(meth); - to_arc_mut_ptr(NotThreadSafe::new(crate::SslContext::new(method))) + let out = to_arc_mut_ptr(NotThreadSafe::new(crate::SslContext::new(method))); + let ex_data = match ExData::new_ssl_ctx(out) { + None => { + _SSL_CTX_free(out); + return ptr::null_mut(); + } + Some(ex_data) => ex_data, + }; + + // safety: we just made this object, the pointer must be valid + clone_arc(out).unwrap().get_mut().install_ex_data(ex_data); + out } } @@ -125,6 +142,21 @@ entry! { } } +entry! { + pub fn _SSL_CTX_set_ex_data(ctx: *mut SSL_CTX, idx: c_int, data: *mut c_void) -> c_int { + match try_clone_arc!(ctx).get_mut().set_ex_data(idx, data) { + Err(e) => e.raise().into(), + Ok(()) => C_INT_SUCCESS, + } + } +} + +entry! { + pub fn _SSL_CTX_get_ex_data(ctx: *const SSL_CTX, idx: c_int) -> *mut c_void { + try_clone_arc!(ctx).get().get_ex_data(idx) + } +} + entry! { pub fn _SSL_CTX_get_options(ctx: *const SSL_CTX) -> u64 { try_clone_arc!(ctx).get().get_options() @@ -156,11 +188,23 @@ entry! { log::warn!("unimplemented SSL_CTX_set_msg_callback_arg()"); 0 } + Ok(SslCtrl::SetMinProtoVersion) => { + if larg < 0 || larg > u16::MAX.into() { + return 0; + } + ctx.get_mut().set_min_protocol_version(larg as u16); + C_INT_SUCCESS as c_long + } + Ok(SslCtrl::GetMinProtoVersion) => ctx.get().get_min_protocol_version().into(), Ok(SslCtrl::SetMaxProtoVersion) => { - log::warn!("unimplemented SSL_CTX_set_max_proto_version()"); - 1 + if larg < 0 || larg > u16::MAX.into() { + return 0; + } + ctx.get_mut().set_max_protocol_version(larg as u16); + C_INT_SUCCESS as c_long } - Ok(SslCtrl::SetTlsExtHostname) => { + Ok(SslCtrl::GetMaxProtoVersion) => ctx.get().get_max_protocol_version().into(), + Ok(SslCtrl::SetTlsExtHostname) | Ok(SslCtrl::SetTlsExtServerNameCallback) => { // not a defined operation in the OpenSSL API 0 } @@ -180,7 +224,10 @@ entry! { ctx.get_mut().stage_certificate_chain(chain); C_INT_SUCCESS as i64 } - + Ok(SslCtrl::SetTlsExtServerNameArg) => { + ctx.get_mut().set_servername_callback_context(parg); + C_INT_SUCCESS as c_long + } Err(()) => { log::warn!("unimplemented _SSL_CTX_ctrl(..., {cmd}, {larg}, ...)"); 0 @@ -204,6 +251,30 @@ entry! { } } +entry! { + pub fn _SSL_CTX_get_verify_callback(ctx: *const SSL_CTX) -> SSL_verify_cb { + try_clone_arc!(ctx).get().get_verify_callback() + } +} + +entry! { + pub fn _SSL_CTX_get_verify_mode(ctx: *const SSL_CTX) -> c_int { + try_clone_arc!(ctx).get().get_verify_mode().into() + } +} + +entry! { + pub fn _SSL_CTX_set_verify_depth(ctx: *mut SSL_CTX, depth: c_int) { + try_clone_arc!(ctx).get_mut().set_verify_depth(depth) + } +} + +entry! { + pub fn _SSL_CTX_get_verify_depth(ctx: *mut SSL_CTX) -> c_int { + try_clone_arc!(ctx).get().get_verify_depth() + } +} + pub type SSL_verify_cb = Option c_int>; @@ -274,6 +345,28 @@ entry! { } } +entry! { + pub fn _SSL_CTX_load_verify_locations( + ctx: *mut SSL_CTX, + ca_file: *const c_char, + ca_path: *const c_char, + ) -> c_int { + if ca_path.is_null() && ca_path.is_null() { + return 0; + } + + if !ca_file.is_null() && _SSL_CTX_load_verify_file(ctx, ca_file) == 0 { + return 0; + } + + if !ca_path.is_null() && _SSL_CTX_load_verify_dir(ctx, ca_path) == 0 { + return 0; + } + + C_INT_SUCCESS + } +} + entry! { pub fn _SSL_CTX_set_alpn_protos( ctx: *mut SSL_CTX, @@ -351,6 +444,47 @@ entry! { } } +fn use_private_key_file(file_name: &str, file_type: c_int) -> Result { + let der_data = match file_type { + FILETYPE_PEM => { + let mut file_reader = match fs::File::open(file_name) { + Ok(content) => io::BufReader::new(content), + Err(err) => return Err(Error::from_io(err)), + }; + + match rustls_pemfile::private_key(&mut file_reader) { + Ok(Some(key)) => key, + Ok(None) => { + log::trace!("No keys found in {file_name:?}"); + return Err(Error::bad_data("pem file")); + } + Err(err) => { + log::trace!("Failed to read {file_name:?}: {err:?}"); + return Err(Error::from_io(err)); + } + } + } + FILETYPE_DER => { + let mut data = vec![]; + match fs::File::open(file_name).and_then(|mut f| f.read_to_end(&mut data)) { + Ok(_) => PrivatePkcs8KeyDer::from(data).into(), + Err(err) => { + log::trace!("Failed to read {file_name:?}: {err:?}"); + return Err(Error::from_io(err)); + } + } + } + _ => { + return Err(Error::not_supported("file_type not in (PEM, DER)")); + } + }; + + match EvpPkey::new_from_der_bytes(der_data) { + None => Err(Error::not_supported("invalid key format")), + Some(key) => Ok(key), + } +} + entry! { pub fn _SSL_CTX_use_PrivateKey_file( ctx: *mut SSL_CTX, @@ -360,47 +494,13 @@ entry! { let ctx = try_clone_arc!(ctx); let file_name = try_str!(file_name); - let der_data = match file_type { - FILETYPE_PEM => { - let mut file_reader = match fs::File::open(file_name) { - Ok(content) => io::BufReader::new(content), - Err(err) => return Error::from_io(err).raise().into(), - }; - - match rustls_pemfile::private_key(&mut file_reader) { - Ok(Some(key)) => key, - Ok(None) => { - log::trace!("No keys found in {file_name:?}"); - return Error::bad_data("pem file").raise().into(); - } - Err(err) => { - log::trace!("Failed to read {file_name:?}: {err:?}"); - return Error::from_io(err).raise().into(); - } - } - } - FILETYPE_DER => { - let mut data = vec![]; - match fs::File::open(file_name).and_then(|mut f| f.read_to_end(&mut data)) { - Ok(_) => PrivatePkcs8KeyDer::from(data).into(), - Err(err) => { - log::trace!("Failed to read {file_name:?}: {err:?}"); - return Error::from_io(err).raise().into(); - } - } - } - _ => { - return Error::not_supported("file_type not in (PEM, DER)") - .raise() - .into(); + let key = match use_private_key_file(file_name, file_type) { + Ok(key) => key, + Err(err) => { + return err.raise().into(); } }; - let key = match EvpPkey::new_from_der_bytes(der_data) { - None => return Error::not_supported("invalid key format").raise().into(), - Some(key) => key, - }; - match ctx.get_mut().commit_private_key(key) { Err(e) => e.raise().into(), Ok(()) => C_INT_SUCCESS, @@ -444,12 +544,106 @@ entry! { } } +pub type SSL_CTX_alpn_select_cb_func = Option< + unsafe extern "C" fn( + ssl: *mut SSL, + out: *mut *const c_uchar, + outlen: *mut c_uchar, + in_: *const c_uchar, + inlen: c_uint, + arg: *mut c_void, + ) -> c_int, +>; + +entry! { + pub fn _SSL_CTX_set_alpn_select_cb( + ctx: *mut SSL_CTX, + cb: SSL_CTX_alpn_select_cb_func, + arg: *mut c_void, + ) { + let ctx = try_clone_arc!(ctx); + ctx.get_mut().set_alpn_select_cb(cb, arg); + } +} + +pub type SSL_CTX_cert_cb_func = + Option c_int>; + +entry! { + pub fn _SSL_CTX_set_cert_cb(ctx: *mut SSL_CTX, cb: SSL_CTX_cert_cb_func, arg: *mut c_void) { + let ctx = try_clone_arc!(ctx); + ctx.get_mut().set_cert_cb(cb, arg); + } +} + +// nb. calls into SSL_CTX_callback_ctrl cast away the real function pointer type, +// and then cast back to the real type based on `cmd`. +pub type SSL_CTX_any_func = Option; + +pub type SSL_CTX_servername_callback_func = + Option c_int>; + +entry! { + pub fn _SSL_CTX_callback_ctrl(ctx: *mut SSL_CTX, cmd: c_int, fp: SSL_CTX_any_func) -> c_long { + let ctx = try_clone_arc!(ctx); + + match SslCtrl::try_from(cmd) { + Ok(SslCtrl::SetTlsExtServerNameCallback) => { + // safety: same layout + let fp = unsafe { + mem::transmute::(fp) + }; + ctx.get_mut().set_servername_callback(fp); + C_INT_SUCCESS as c_long + } + _ => 0, + } + } +} + +entry! { + pub fn _SSL_CTX_get_max_early_data(ctx: *const SSL_CTX) -> u32 { + try_clone_arc!(ctx).get().get_max_early_data() + } +} + +entry! { + pub fn _SSL_CTX_set_max_early_data(ctx: *mut SSL_CTX, max_early_data: u32) -> c_int { + try_clone_arc!(ctx) + .get_mut() + .set_max_early_data(max_early_data); + C_INT_SUCCESS + } +} + +entry! { + pub fn _SSL_CTX_set_cipher_list(_ctx: *mut SSL_CTX, s: *const c_char) -> c_int { + match try_str!(s) { + "HIGH:!aNULL:!MD5" => C_INT_SUCCESS, + _ => Error::not_supported("SSL_CTX_set_cipher_list") + .raise() + .into(), + } + } +} + +entry! { + pub fn _SSL_CTX_set_session_id_context( + _ctx: *mut SSL_CTX, + _sid_ctx: *const c_uchar, + _sid_ctx_len: c_uint, + ) -> c_int { + log::warn!("SSL_CTX_set_session_id_context not yet implemented"); + C_INT_SUCCESS + } +} + impl Castable for SSL_CTX { type Ownership = OwnershipArc; type RustType = NotThreadSafe; } -type SSL = crate::Ssl; +pub type SSL = crate::Ssl; entry! { pub fn _SSL_new(ctx: *mut SSL_CTX) -> *mut SSL { @@ -460,7 +654,18 @@ entry! { None => return ptr::null_mut(), }; - to_arc_mut_ptr(NotThreadSafe::new(ssl)) + let out = to_arc_mut_ptr(NotThreadSafe::new(ssl)); + let ex_data = match ExData::new_ssl(out) { + None => { + _SSL_free(out); + return ptr::null_mut(); + } + Some(ex_data) => ex_data, + }; + + // safety: we just made this object, the pointer must be valid. + clone_arc(out).unwrap().get_mut().install_ex_data(ex_data); + out } } @@ -478,6 +683,21 @@ entry! { } } +entry! { + pub fn _SSL_set_ex_data(ssl: *mut SSL, idx: c_int, data: *mut c_void) -> c_int { + match try_clone_arc!(ssl).get_mut().set_ex_data(idx, data) { + Err(e) => e.raise().into(), + Ok(()) => C_INT_SUCCESS, + } + } +} + +entry! { + pub fn _SSL_get_ex_data(ssl: *const SSL, idx: c_int) -> *mut c_void { + try_clone_arc!(ssl).get().get_ex_data(idx) + } +} + entry! { pub fn _SSL_ctrl(ssl: *mut SSL, cmd: c_int, larg: c_long, parg: *mut c_void) -> c_long { let ssl = try_clone_arc!(ssl); @@ -491,10 +711,22 @@ entry! { log::warn!("unimplemented SSL_set_msg_callback_arg()"); 0 } + Ok(SslCtrl::SetMinProtoVersion) => { + if larg < 0 || larg > u16::MAX.into() { + return 0; + } + ssl.get_mut().set_min_protocol_version(larg as u16); + C_INT_SUCCESS as c_long + } + Ok(SslCtrl::GetMinProtoVersion) => ssl.get().get_min_protocol_version().into(), Ok(SslCtrl::SetMaxProtoVersion) => { - log::warn!("unimplemented SSL_set_max_proto_version()"); - 1 + if larg < 0 || larg > u16::MAX.into() { + return 0; + } + ssl.get_mut().set_max_protocol_version(larg as u16); + C_INT_SUCCESS as c_long } + Ok(SslCtrl::GetMaxProtoVersion) => ssl.get().get_max_protocol_version().into(), Ok(SslCtrl::SetTlsExtHostname) => { let hostname = try_str!(parg as *const c_char); ssl.get_mut().set_sni_hostname(hostname) as c_long @@ -515,6 +747,8 @@ entry! { ssl.get_mut().stage_certificate_chain(chain); C_INT_SUCCESS as i64 } + // not a defined operation in the OpenSSL API + Ok(SslCtrl::SetTlsExtServerNameCallback) | Ok(SslCtrl::SetTlsExtServerNameArg) => 0, Err(()) => { log::warn!("unimplemented _SSL_ctrl(..., {cmd}, {larg}, ...)"); 0 @@ -650,6 +884,7 @@ entry! { entry! { pub fn _SSL_accept(ssl: *mut SSL) -> c_int { + let _callbacks = SslCallbackContext::new(ssl); match try_clone_arc!(ssl).get_mut().accept() { Err(e) => e.raise().into(), Ok(()) => C_INT_SUCCESS, @@ -657,6 +892,18 @@ entry! { } } +entry! { + pub fn _SSL_do_handshake(ssl: *mut SSL) -> c_int { + let _callbacks = SslCallbackContext::new(ssl); + let ssl = try_clone_arc!(ssl); + + match ssl.get_mut().handshake() { + Err(e) => e.raise().into(), + Ok(()) => C_INT_SUCCESS, + } + } +} + entry! { pub fn _SSL_write(ssl: *mut SSL, buf: *const c_void, num: c_int) -> c_int { const ERROR: c_int = -1; @@ -739,6 +986,12 @@ entry! { } } +entry! { + pub fn _SSL_set_quiet_shutdown(ssl: *mut SSL, mode: c_int) { + try_clone_arc!(ssl).get_mut().set_quiet_shutdown(mode != 0) + } +} + entry! { pub fn _SSL_pending(ssl: *const SSL) -> c_int { try_clone_arc!(ssl).get_mut().get_pending_plaintext() as c_int @@ -919,6 +1172,95 @@ entry! { } } +entry! { + pub fn _SSL_use_PrivateKey_file( + ssl: *mut SSL, + file_name: *const c_char, + file_type: c_int, + ) -> c_int { + let ssl = try_clone_arc!(ssl); + let file_name = try_str!(file_name); + + let key = match use_private_key_file(file_name, file_type) { + Ok(key) => key, + Err(err) => { + return err.raise().into(); + } + }; + + match ssl.get_mut().commit_private_key(key) { + Err(e) => e.raise().into(), + Ok(()) => C_INT_SUCCESS, + } + } +} + +entry! { + pub fn _SSL_check_private_key(_ssl: *const SSL) -> c_int { + log::trace!("not implemented: _SSL_check_private_key, returning success"); + C_INT_SUCCESS + } +} + +entry! { + pub fn _SSL_get_servername(ssl: *const SSL, ty: c_int) -> *const c_char { + if ty != TLSEXT_NAMETYPE_host_name { + return ptr::null(); + } + + try_clone_arc!(ssl).get_mut().server_name_pointer() + } +} + +entry! { + pub fn _SSL_get_servername_type(ssl: *const SSL) -> c_int { + if _SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name).is_null() { + -1 + } else { + TLSEXT_NAMETYPE_host_name + } + } +} + +entry! { + pub fn _SSL_set_verify(ssl: *mut SSL, mode: c_int, callback: SSL_verify_cb) { + let ssl = try_clone_arc!(ssl); + + if callback.is_some() { + // supporting verify callbacks would mean we need to fully use + // the openssl certificate verifier, because X509_STORE and + // X509_STORE_CTX are both in libcrypto. + return Error::not_supported("verify callback").raise().into(); + } + + ssl.get_mut().set_verify(crate::VerifyMode::from(mode)); + } +} + +entry! { + pub fn _SSL_set_verify_depth(ssl: *mut SSL, depth: c_int) { + try_clone_arc!(ssl).get_mut().set_verify_depth(depth) + } +} + +entry! { + pub fn _SSL_get_verify_depth(ssl: *mut SSL) -> c_int { + try_clone_arc!(ssl).get().get_verify_depth() + } +} + +entry! { + pub fn _SSL_get_current_compression(_ssl: *const SSL) -> *const c_void { + ptr::null() + } +} + +entry! { + pub fn _SSL_session_reused(ssl: *const SSL) -> c_int { + try_clone_arc!(ssl).get().was_session_reused() as c_int + } +} + impl Castable for SSL { type Ownership = OwnershipArc; type RustType = NotThreadSafe; @@ -1024,6 +1366,71 @@ impl Castable for SSL_CIPHER { type RustType = SSL_CIPHER; } +entry! { + pub fn _SSL_select_next_proto( + out: *mut *mut c_uchar, + out_len: *mut c_uchar, + server: *const c_uchar, + server_len: c_uint, + client: *const c_uchar, + client_len: c_uint, + ) -> c_int { + let server = try_slice!(server, server_len); + let client = try_slice!(client, client_len); + + if out.is_null() || out_len.is_null() { + return 0; + } + + // ensure `out` and `out_len` are written even on error. + unsafe { + ptr::write(out, ptr::null_mut()); + ptr::write(out_len, 0); + } + + // ensure `client` is fully validated irrespective of `server` value + for offer in crate::iter_alpn(client) { + if offer.is_none() { + return OPENSSL_NPN_NO_OVERLAP; + } + } + + for supported in crate::iter_alpn(server) { + match supported { + None => { + return OPENSSL_NPN_NO_OVERLAP; + } + + Some(supported) + if crate::iter_alpn(client).any(|offer| offer == Some(supported)) => + { + unsafe { + // safety: + // 1) the openssl API is const-incorrect, we must slice the const from `supported` + // 2) supported.len() must fit inside c_uchar; it was decoded from that + ptr::write(out, supported.as_ptr() as *mut c_uchar); + ptr::write(out_len, supported.len() as c_uchar); + return OPENSSL_NPN_NEGOTIATED; + } + } + + Some(_) => { + continue; + } + } + } + + // fallback: "If no match is found, the first item in client, client_len is returned" + if let Some(Some(fallback)) = crate::iter_alpn(client).next() { + unsafe { + ptr::write(out, fallback.as_ptr() as *mut c_uchar); + ptr::write(out_len, fallback.len() as c_uchar); + } + } + OPENSSL_NPN_NO_OVERLAP + } +} + /// Normal OpenSSL return value convention success indicator. /// /// Compare [`crate::ffi::MysteriouslyOppositeReturnValue`]. @@ -1068,9 +1475,14 @@ num_enum! { enum SslCtrl { Mode = 33, SetMsgCallbackArg = 16, + SetTlsExtServerNameCallback = 53, + SetTlsExtServerNameArg = 54, SetTlsExtHostname = 55, SetChain = 88, + SetMinProtoVersion = 123, SetMaxProtoVersion = 124, + GetMinProtoVersion = 130, + GetMaxProtoVersion = 131, } } @@ -1099,29 +1511,13 @@ macro_rules! entry_stub { // some extent: entry_stub! { - pub fn _SSL_CTX_set_ex_data(_ssl: *mut SSL_CTX, _idx: c_int, _data: *mut c_void) -> c_int; -} - -entry_stub! { - pub fn _SSL_CTX_get_ex_data(_ssl: *const SSL_CTX, _idx: c_int) -> *mut c_void; -} - -entry_stub! { - pub fn _SSL_set_ex_data(_ssl: *mut SSL, _idx: c_int, _data: *mut c_void) -> c_int; -} - -entry_stub! { - pub fn _SSL_get_ex_data(_ssl: *const SSL, _idx: c_int) -> *mut c_void; + pub fn _SSL_get_ex_data_X509_STORE_CTX_idx() -> c_int; } entry_stub! { pub fn _SSL_set_session(_ssl: *mut SSL, _session: *mut SSL_SESSION) -> c_int; } -entry_stub! { - pub fn _SSL_session_reused(_ssl: *const SSL) -> c_int; -} - entry_stub! { pub fn _SSL_get1_session(_ssl: *mut SSL) -> *mut SSL_SESSION; } @@ -1158,15 +1554,6 @@ pub type SSL_CTX_sess_remove_cb = Option; entry_stub! { - pub fn _SSL_CTX_set_session_id_context( - _ctx: *mut SSL_CTX, - _sid_ctx: *const c_uchar, - _sid_ctx_len: c_uint, - ) -> c_int; -} - -entry_stub! { - pub fn _SSL_set_session_id_context( _ssl: *mut SSL, _sid_ctx: *const c_uchar, @@ -1212,10 +1599,6 @@ entry_stub! { pub fn _i2d_SSL_SESSION(_in: *const SSL_SESSION, _pp: *mut *mut c_uchar) -> c_int; } -entry_stub! { - pub fn _SSL_CTX_set_cipher_list(_ctx: *mut SSL_CTX, _s: *const c_char) -> c_int; -} - entry_stub! { pub fn _SSL_CTX_set_ciphersuites(_ctx: *mut SSL_CTX, _s: *const c_char) -> c_int; } @@ -1239,6 +1622,44 @@ entry_stub! { pub fn _SSL_SESSION_free(_sess: *mut SSL_SESSION); } +entry_stub! { + pub fn _SSL_write_early_data( + _ssl: *mut SSL, + _buf: *const c_void, + _num: usize, + _written: *mut usize, + ) -> c_int; +} + +entry_stub! { + pub fn _SSL_read_early_data( + _ssl: *mut SSL, + _buf: *mut c_void, + _num: usize, + _readbytes: *mut usize, + ) -> c_int; +} + +entry_stub! { + pub fn _SSL_CTX_get_timeout(_ctx: *const SSL_CTX) -> c_long; +} + +entry_stub! { + pub fn _SSL_CTX_set_timeout(_ctx: *mut SSL_CTX, _t: c_long) -> c_long; +} + +entry_stub! { + pub fn _SSL_CTX_get_client_CA_list(_ctx: *const SSL_CTX) -> *mut stack_st_X509_NAME; +} + +entry_stub! { + pub fn _SSL_CTX_set_client_CA_list(_ctx: *mut SSL_CTX, _name_list: *mut stack_st_X509_NAME); +} + +entry_stub! { + pub fn _SSL_load_client_CA_file(_file: *const c_char) -> *mut stack_st_X509_NAME; +} + // no individual message logging entry_stub! { @@ -1257,6 +1678,15 @@ pub type SSL_CTX_msg_cb_func = Option< ), >; +// no state machine observation + +entry_stub! { + pub fn _SSL_CTX_set_info_callback( + _ctx: *mut SSL_CTX, + _cb: Option, + ); +} + // no NPN (obsolete precursor to ALPN) entry_stub! { @@ -1278,6 +1708,31 @@ pub type SSL_CTX_npn_select_cb_func = Option< ) -> c_int, >; +entry_stub! { + pub fn _SSL_get0_next_proto_negotiated( + _ssl: *const SSL, + _data: *mut *const c_uchar, + _len: *mut c_uint, + ); +} + +entry_stub! { + pub fn _SSL_CTX_set_next_protos_advertised_cb( + _ctx: *mut SSL_CTX, + _cb: SSL_CTX_npn_advertised_cb_func, + _arg: *mut c_void, + ); +} + +pub type SSL_CTX_npn_advertised_cb_func = Option< + unsafe extern "C" fn( + ssl: *mut SSL, + out: *mut *const c_uchar, + outlen: *mut c_uint, + arg: *mut c_void, + ) -> c_int, +>; + // no password-protected key loading entry_stub! { @@ -1435,4 +1890,124 @@ mod tests { _SSL_free(ssl); _SSL_CTX_free(ctx); } + + #[test] + fn test_SSL_select_next_proto_match() { + let mut output = ptr::null_mut(); + let mut output_len = 0u8; + let client = b"\x05hello\x05world"; + let server = b"\x05uhoh!\x05world"; + assert_eq!( + _SSL_select_next_proto( + &mut output as *mut *mut u8, + &mut output_len as *mut u8, + server.as_ptr(), + server.len() as c_uint, + client.as_ptr(), + client.len() as c_uint + ), + 1i32 + ); + assert_eq!(b"world", &server[7..]); + assert_eq!(output as *const u8, server[7..].as_ptr()); + assert_eq!(output_len, 5); + } + + #[test] + fn test_SSL_select_next_proto_no_overlap() { + let mut output = ptr::null_mut(); + let mut output_len = 0u8; + let client = b"\x05hello\x05world"; + let server = b"\x05uhoh!\x05what!"; + assert_eq!( + _SSL_select_next_proto( + &mut output as *mut *mut u8, + &mut output_len as *mut u8, + server.as_ptr(), + server.len() as c_uint, + client.as_ptr(), + client.len() as c_uint + ), + 2i32 + ); + assert_eq!(b"hello", &client[1..6]); + assert_eq!(output as *const u8, client[1..].as_ptr()); + assert_eq!(output_len, 5); + } + + #[test] + fn test_SSL_select_next_proto_illegal_client() { + let mut output = ptr::null_mut(); + let mut output_len = 0u8; + let client = b"\x09hello"; + let server = b"\x05uhoh!\x05world"; + assert_eq!( + _SSL_select_next_proto( + &mut output as *mut *mut u8, + &mut output_len as *mut u8, + server.as_ptr(), + server.len() as c_uint, + client.as_ptr(), + client.len() as c_uint + ), + 2i32 + ); + assert_eq!(output as *const u8, ptr::null_mut()); + } + + #[test] + fn test_SSL_select_next_proto_null() { + let mut output = ptr::null_mut(); + let mut output_len = 0u8; + let client = b"\x05hello\x05world"; + let server = b"\x05uhoh!\x05world"; + + assert_eq!( + _SSL_select_next_proto( + ptr::null_mut(), + &mut output_len as *mut u8, + server.as_ptr(), + server.len() as c_uint, + client.as_ptr(), + client.len() as c_uint + ), + 0 + ); + + assert_eq!( + _SSL_select_next_proto( + &mut output as *mut *mut u8, + ptr::null_mut(), + server.as_ptr(), + server.len() as c_uint, + client.as_ptr(), + client.len() as c_uint + ), + 0 + ); + + assert_eq!( + _SSL_select_next_proto( + &mut output as *mut *mut u8, + &mut output_len as *mut u8, + ptr::null(), + server.len() as c_uint, + client.as_ptr(), + client.len() as c_uint + ), + 0 + ); + + assert_eq!( + _SSL_select_next_proto( + &mut output as *mut *mut u8, + &mut output_len as *mut u8, + server.as_ptr(), + server.len() as c_uint, + ptr::null(), + client.len() as c_uint + ), + 0 + ); + } } diff --git a/rustls-libssl/src/error.rs b/rustls-libssl/src/error.rs index d6590d1..7e5e0b8 100644 --- a/rustls-libssl/src/error.rs +++ b/rustls-libssl/src/error.rs @@ -215,6 +215,12 @@ impl From for () { } } +impl From for crate::entry::SSL_verify_cb { + fn from(_: Error) -> crate::entry::SSL_verify_cb { + None + } +} + #[macro_export] macro_rules! ffi_panic_boundary { ( $($tt:tt)* ) => { diff --git a/rustls-libssl/src/ex_data.rs b/rustls-libssl/src/ex_data.rs new file mode 100644 index 0000000..a7cfea9 --- /dev/null +++ b/rustls-libssl/src/ex_data.rs @@ -0,0 +1,110 @@ +use core::ffi::{c_int, c_void}; +use core::ptr; + +use crate::entry::{SSL, SSL_CTX}; +use crate::error::Error; + +/// Safe(ish), owning wrapper around an OpenSSL `CRYPTO_EX_DATA`. +/// +/// `ty` and `owner` allow us to drop this object with no extra context. +/// +/// Because this refers to the object that contains it, a two-step +/// construction is needed. +pub struct ExData { + ex_data: CRYPTO_EX_DATA, + ty: c_int, + owner: *mut c_void, +} + +impl ExData { + /// Makes a new CRYPTO_EX_DATA for an SSL object. + pub fn new_ssl(ssl: *mut SSL) -> Option { + let mut ex_data = CRYPTO_EX_DATA::default(); + let owner = ssl as *mut c_void; + let ty = CRYPTO_EX_INDEX_SSL; + let rc = unsafe { CRYPTO_new_ex_data(ty, owner, &mut ex_data) }; + if rc == 1 { + Some(Self { ex_data, ty, owner }) + } else { + None + } + } + + /// Makes a new CRYPTO_EX_DATA for an SSL_CTX object. + pub fn new_ssl_ctx(ctx: *mut SSL_CTX) -> Option { + let mut ex_data = CRYPTO_EX_DATA::default(); + let owner = ctx as *mut c_void; + let ty = CRYPTO_EX_INDEX_SSL_CTX; + let rc = unsafe { CRYPTO_new_ex_data(ty, owner, &mut ex_data) }; + if rc == 1 { + Some(Self { ex_data, ty, owner }) + } else { + None + } + } + + pub fn set(&mut self, idx: c_int, data: *mut c_void) -> Result<(), Error> { + let rc = unsafe { CRYPTO_set_ex_data(&mut self.ex_data, idx, data) }; + if rc == 1 { + Ok(()) + } else { + Err(Error::bad_data("CRYPTO_set_ex_data")) + } + } + + pub fn get(&self, idx: c_int) -> *mut c_void { + unsafe { CRYPTO_get_ex_data(&self.ex_data, idx) } + } +} + +impl Drop for ExData { + fn drop(&mut self) { + if !self.owner.is_null() { + unsafe { + CRYPTO_free_ex_data(self.ty, self.owner, &mut self.ex_data); + }; + self.owner = ptr::null_mut(); + } + } +} + +impl Default for ExData { + fn default() -> Self { + Self { + ex_data: CRYPTO_EX_DATA::default(), + ty: -1, + owner: ptr::null_mut(), + } + } +} + +/// This has the same layout prefix as `struct crypto_ex_data_st` aka +/// `CRYPTO_EX_DATA` -- just two pointers. We don't need to know +/// the types of these; the API lets us treat them opaquely. +/// +/// This is _not_ owning. +#[repr(C)] +struct CRYPTO_EX_DATA { + ctx: *mut c_void, + sk: *mut c_void, +} + +impl Default for CRYPTO_EX_DATA { + fn default() -> Self { + Self { + ctx: ptr::null_mut(), + sk: ptr::null_mut(), + } + } +} + +// See `crypto.h` +const CRYPTO_EX_INDEX_SSL: c_int = 0; +const CRYPTO_EX_INDEX_SSL_CTX: c_int = 1; + +extern "C" { + fn CRYPTO_new_ex_data(class_index: c_int, obj: *mut c_void, ed: *mut CRYPTO_EX_DATA) -> c_int; + fn CRYPTO_set_ex_data(ed: *mut CRYPTO_EX_DATA, index: c_int, data: *mut c_void) -> c_int; + fn CRYPTO_get_ex_data(ed: *const CRYPTO_EX_DATA, index: c_int) -> *mut c_void; + fn CRYPTO_free_ex_data(class_index: c_int, obj: *mut c_void, ed: *mut CRYPTO_EX_DATA); +} diff --git a/rustls-libssl/src/lib.rs b/rustls-libssl/src/lib.rs index 03c35e2..eb62060 100644 --- a/rustls-libssl/src/lib.rs +++ b/rustls-libssl/src/lib.rs @@ -1,5 +1,6 @@ -use core::ffi::{c_int, c_uint, CStr}; +use core::ffi::{c_char, c_int, c_uint, c_void, CStr}; use core::{mem, ptr}; +use std::ffi::CString; use std::fs; use std::io::{ErrorKind, Read, Write}; use std::path::PathBuf; @@ -14,12 +15,14 @@ use rustls::crypto::aws_lc_rs as provider; use rustls::pki_types::{CertificateDer, ServerName}; use rustls::server::{Accepted, Acceptor}; use rustls::{ - CipherSuite, ClientConfig, ClientConnection, Connection, RootCertStore, ServerConfig, + CipherSuite, ClientConfig, ClientConnection, Connection, HandshakeKind, ProtocolVersion, + RootCertStore, ServerConfig, SupportedProtocolVersion, }; use not_thread_safe::NotThreadSafe; mod bio; +mod callbacks; #[macro_use] mod constants; #[allow( @@ -33,6 +36,7 @@ mod constants; mod entry; mod error; mod evp_pkey; +mod ex_data; #[macro_use] #[allow(unused_macros, dead_code, unused_imports)] mod ffi; @@ -49,8 +53,8 @@ mod x509; /// # Lifetime /// Functions that return SSL_METHOD, like `TLS_method()`, give static-lifetime pointers. pub struct SslMethod { - client_versions: &'static [&'static rustls::SupportedProtocolVersion], - server_versions: &'static [&'static rustls::SupportedProtocolVersion], + client_versions: &'static [&'static SupportedProtocolVersion], + server_versions: &'static [&'static SupportedProtocolVersion], } impl SslMethod { @@ -212,31 +216,57 @@ static TLS13_CHACHA20_POLY1305_SHA256: SslCipher = SslCipher { pub struct SslContext { method: &'static SslMethod, + ex_data: ex_data::ExData, + versions: EnabledVersions, raw_options: u64, verify_mode: VerifyMode, + verify_depth: c_int, verify_roots: RootCertStore, verify_x509_store: x509::OwnedX509Store, alpn: Vec>, default_cert_file: Option, default_cert_dir: Option, + alpn_callback: callbacks::AlpnCallbackConfig, + cert_callback: callbacks::CertCallbackConfig, + servername_callback: callbacks::ServerNameCallbackConfig, auth_keys: sign::CertifiedKeySet, + max_early_data: u32, } impl SslContext { fn new(method: &'static SslMethod) -> Self { Self { method, + ex_data: ex_data::ExData::default(), + versions: EnabledVersions::default(), raw_options: 0, verify_mode: VerifyMode::default(), + verify_depth: -1, verify_roots: RootCertStore::empty(), verify_x509_store: x509::OwnedX509Store::new(), alpn: vec![], default_cert_file: None, default_cert_dir: None, + alpn_callback: callbacks::AlpnCallbackConfig::default(), + cert_callback: callbacks::CertCallbackConfig::default(), + servername_callback: callbacks::ServerNameCallbackConfig::default(), auth_keys: sign::CertifiedKeySet::default(), + max_early_data: 0, } } + fn install_ex_data(&mut self, ex_data: ex_data::ExData) { + self.ex_data = ex_data; + } + + fn set_ex_data(&mut self, idx: c_int, data: *mut c_void) -> Result<(), error::Error> { + self.ex_data.set(idx, data) + } + + fn get_ex_data(&self, idx: c_int) -> *mut c_void { + self.ex_data.get(idx) + } + fn get_options(&self) -> u64 { self.raw_options } @@ -251,6 +281,44 @@ impl SslContext { self.raw_options } + fn set_min_protocol_version(&mut self, ver: u16) { + self.versions.min = match ver { + 0 => None, + _ => Some(ProtocolVersion::from(ver)), + }; + } + + fn get_min_protocol_version(&self) -> u16 { + self.versions + .min + .as_ref() + .map(|v| u16::from(*v)) + .unwrap_or_default() + } + + fn set_max_protocol_version(&mut self, ver: u16) { + self.versions.max = match ver { + 0 => None, + _ => Some(ProtocolVersion::from(ver)), + }; + } + + fn get_max_protocol_version(&self) -> u16 { + self.versions + .max + .as_ref() + .map(|v| u16::from(*v)) + .unwrap_or_default() + } + + fn set_max_early_data(&mut self, max: u32) { + self.max_early_data = max; + } + + fn get_max_early_data(&self) -> u32 { + self.max_early_data + } + fn set_verify(&mut self, mode: VerifyMode) { self.verify_mode = mode; } @@ -274,6 +342,23 @@ impl SslContext { self.default_cert_file = cert_file; } + fn get_verify_mode(&self) -> VerifyMode { + self.verify_mode + } + + fn get_verify_callback(&self) -> entry::SSL_verify_cb { + // TODO: `SSL_CTX_set_verify` currently rejects non-NULL callback + None + } + + fn set_verify_depth(&mut self, depth: c_int) { + self.verify_depth = depth; + } + + fn get_verify_depth(&self) -> c_int { + self.verify_depth + } + fn add_trusted_certs( &mut self, certs: Vec>, @@ -294,6 +379,14 @@ impl SslContext { self.alpn = alpn; } + fn set_alpn_select_cb(&mut self, cb: entry::SSL_CTX_alpn_select_cb_func, context: *mut c_void) { + self.alpn_callback = callbacks::AlpnCallbackConfig { cb, context }; + } + + fn set_cert_cb(&mut self, cb: entry::SSL_CTX_cert_cb_func, context: *mut c_void) { + self.cert_callback = callbacks::CertCallbackConfig { cb, context }; + } + fn stage_certificate_end_entity(&mut self, end: CertificateDer<'static>) { self.auth_keys.stage_certificate_end_entity(end) } @@ -313,6 +406,14 @@ impl SslContext { fn get_privatekey(&self) -> *mut EVP_PKEY { self.auth_keys.borrow_current_key() } + + fn set_servername_callback(&mut self, cb: entry::SSL_CTX_servername_callback_func) { + self.servername_callback.cb = cb; + } + + fn set_servername_callback_context(&mut self, context: *mut c_void) { + self.servername_callback.context = context; + } } /// Parse the ALPN wire format (which is used in the openssl API) @@ -320,37 +421,66 @@ impl SslContext { /// /// For an empty `slice`, returns `Some(vec![])`. /// For a slice with invalid contents, returns `None`. -pub fn parse_alpn(mut slice: &[u8]) -> Option>> { +pub fn parse_alpn(slice: &[u8]) -> Option>> { let mut out = vec![]; - while !slice.is_empty() { - let len = *slice.first()? as usize; - if len == 0 { - return None; - } - let body = slice.get(1..1 + len)?; - out.push(body.to_vec()); - slice = &slice[1 + len..]; + for item in iter_alpn(slice) { + out.push(item?.to_vec()); } Some(out) } +pub fn iter_alpn(mut slice: &[u8]) -> impl Iterator> { + std::iter::from_fn(move || { + // None => end iteration + // Some(None) => error + // Some(_) => an item + + let len = match slice.first() { + None => { + return None; + } + Some(len) => *len as usize, + }; + + if len == 0 { + return Some(None); + } + + match slice.get(1..1 + len) { + None => Some(None), + Some(body) => { + slice = &slice[1 + len..]; + Some(Some(body)) + } + } + }) +} + struct Ssl { ctx: Arc>, + ex_data: ex_data::ExData, + versions: EnabledVersions, raw_options: u64, mode: ConnMode, verify_mode: VerifyMode, + verify_depth: c_int, verify_roots: RootCertStore, verify_server_name: Option>, alpn: Vec>, + alpn_callback: callbacks::AlpnCallbackConfig, + cert_callback: callbacks::CertCallbackConfig, + servername_callback: callbacks::ServerNameCallbackConfig, sni_server_name: Option>, + server_name: Option, bio: Option, conn: ConnState, peer_cert: Option, peer_cert_chain: Option, shutdown_flags: ShutdownFlags, auth_keys: sign::CertifiedKeySet, + max_early_data: u32, } #[allow(clippy::large_enum_variant)] @@ -366,22 +496,42 @@ impl Ssl { fn new(ctx: Arc>, inner: &SslContext) -> Result { Ok(Self { ctx, + ex_data: ex_data::ExData::default(), + versions: inner.versions.clone(), raw_options: inner.raw_options, mode: inner.method.mode(), verify_mode: inner.verify_mode, + verify_depth: inner.verify_depth, verify_roots: Self::load_verify_certs(inner)?, verify_server_name: None, alpn: inner.alpn.clone(), + alpn_callback: inner.alpn_callback.clone(), + cert_callback: inner.cert_callback.clone(), + servername_callback: inner.servername_callback.clone(), sni_server_name: None, + server_name: None, bio: None, conn: ConnState::Nothing, peer_cert: None, peer_cert_chain: None, shutdown_flags: ShutdownFlags::default(), auth_keys: inner.auth_keys.clone(), + max_early_data: inner.max_early_data, }) } + fn install_ex_data(&mut self, ex_data: ex_data::ExData) { + self.ex_data = ex_data; + } + + fn set_ex_data(&mut self, idx: c_int, data: *mut c_void) -> Result<(), error::Error> { + self.ex_data.set(idx, data) + } + + fn get_ex_data(&self, idx: c_int) -> *mut c_void { + self.ex_data.get(idx) + } + fn set_ctx(&mut self, ctx: Arc>) { // there are no docs for `SSL_set_SSL_CTX`. it seems the only // meaningful reason to use this is key/certificate switching @@ -404,6 +554,36 @@ impl Ssl { self.raw_options } + fn set_min_protocol_version(&mut self, ver: u16) { + self.versions.min = match ver { + 0 => None, + _ => Some(ProtocolVersion::from(ver)), + }; + } + + fn get_min_protocol_version(&self) -> u16 { + self.versions + .min + .as_ref() + .map(|v| u16::from(*v)) + .unwrap_or_default() + } + + fn set_max_protocol_version(&mut self, ver: u16) { + self.versions.max = match ver { + 0 => None, + _ => Some(ProtocolVersion::from(ver)), + }; + } + + fn get_max_protocol_version(&self) -> u16 { + self.versions + .max + .as_ref() + .map(|v| u16::from(*v)) + .unwrap_or_default() + } + fn set_alpn_offer(&mut self, alpn: Vec>) { self.alpn = alpn; } @@ -453,6 +633,18 @@ impl Ssl { } } + fn set_verify(&mut self, mode: VerifyMode) { + self.verify_mode = mode; + } + + fn set_verify_depth(&mut self, depth: c_int) { + self.verify_depth = depth; + } + + fn get_verify_depth(&self) -> c_int { + self.verify_depth + } + fn set_sni_hostname(&mut self, hostname: &str) -> bool { match ServerName::try_from(hostname).ok() { Some(server_name) => { @@ -463,6 +655,33 @@ impl Ssl { } } + fn server_name_pointer(&mut self) -> *const c_char { + // This does double duty (see `SSL_get_servername`): + // + // for clients, it is just `sni_server_name` + // (filled in here, lazily) + // + // for servers, it is the client's offered SNI name + // (filled in below in `invoke_accepted_callbacks`) + // + // the remaining annoyance is that the returned pointer has to NUL-terminated. + + match self.mode { + ConnMode::Server => self.server_name.as_ref().map(|cstr| cstr.as_ptr()), + ConnMode::Client | ConnMode::Unknown => match &self.server_name { + Some(existing) => Some(existing.as_ptr()), + None => { + self.server_name = self + .sni_server_name + .as_ref() + .and_then(|name| CString::new(name.to_str().as_bytes()).ok()); + self.server_name.as_ref().map(|cstr| cstr.as_ptr()) + } + }, + } + .unwrap_or_else(ptr::null) + } + fn set_bio(&mut self, bio: bio::Bio) { self.bio = Some(bio); } @@ -489,6 +708,14 @@ impl Ssl { .unwrap_or_else(ptr::null_mut) } + fn handshake(&mut self) -> Result<(), error::Error> { + match self.mode { + ConnMode::Client => self.connect(), + ConnMode::Server => self.accept(), + ConnMode::Unknown => Err(error::Error::bad_data("connection mode required")), + } + } + fn connect(&mut self) -> Result<(), error::Error> { if let ConnMode::Unknown = self.mode { self.set_client_mode(); @@ -518,8 +745,10 @@ impl Ssl { &self.verify_server_name, )); + let versions = self.versions.reduce_versions(method.client_versions)?; + let wants_resolver = ClientConfig::builder_with_provider(provider) - .with_protocol_versions(method.client_versions) + .with_protocol_versions(&versions) .map_err(error::Error::from_rustls)? .dangerous() .with_custom_certificate_verifier(verifier.clone()); @@ -548,8 +777,39 @@ impl Ssl { self.conn = ConnState::Accepting(Acceptor::default()); } - self.try_io()?; + self.try_io() + } + + fn invoke_accepted_callbacks(&mut self) -> Result<(), error::Error> { + // called on transition from `Accepting` -> `Accepted` + let accepted = match &self.conn { + ConnState::Accepted(accepted) => accepted, + _ => unreachable!(), + }; + + self.server_name = accepted + .client_hello() + .server_name() + .and_then(|sni| CString::new(sni.as_bytes()).ok()); + + self.servername_callback.invoke()?; + + if let Some(alpn_iter) = accepted.client_hello().alpn() { + let offer = encode_alpn(alpn_iter); + let choice = self.alpn_callback.invoke(&offer)?; + + if let Some(choice) = choice { + self.alpn = vec![choice]; + } + } + + self.cert_callback.invoke()?; + + self.complete_accept() + } + + fn complete_accept(&mut self) -> Result<(), error::Error> { if let ConnState::Accepted(_) = self.conn { self.init_server_conn()?; } @@ -575,12 +835,18 @@ impl Ssl { .server_resolver() .ok_or_else(|| error::Error::bad_data("missing server keys"))?; - let config = ServerConfig::builder_with_provider(provider) - .with_protocol_versions(method.server_versions) + let versions = self.versions.reduce_versions(method.server_versions)?; + + let mut config = ServerConfig::builder_with_provider(provider) + .with_protocol_versions(&versions) .map_err(error::Error::from_rustls)? .with_client_cert_verifier(verifier.clone()) .with_cert_resolver(resolver); + config.alpn_protocols = mem::take(&mut self.alpn); + config.max_early_data_size = self.max_early_data; + config.send_tls13_tickets = 2; // match OpenSSL default: see `man SSL_CTX_set_num_tickets` + let accepted = match mem::replace(&mut self.conn, ConnState::Nothing) { ConnState::Accepted(accepted) => accepted, _ => unreachable!(), @@ -697,7 +963,7 @@ impl Ssl { Ok(None) => Ok(()), Ok(Some(accepted)) => { self.conn = ConnState::Accepted(accepted); - Ok(()) + self.invoke_accepted_callbacks() } Err((error, mut alert)) => { alert.write_all(bio).map_err(error::Error::from_io)?; @@ -710,6 +976,12 @@ impl Ssl { } fn try_shutdown(&mut self) -> Result { + if self.shutdown_flags.quiet() { + self.shutdown_flags.set_sent(); + self.shutdown_flags.set_received(); + return Ok(ShutdownResult::Received); + } + if !self.shutdown_flags.is_sent() { if let Some(conn) = self.conn_mut() { conn.send_close_notify(); @@ -727,13 +999,17 @@ impl Ssl { } fn get_shutdown(&self) -> i32 { - self.shutdown_flags.0 + self.shutdown_flags.get() } fn set_shutdown(&mut self, flags: i32) { self.shutdown_flags.set(flags); } + fn set_quiet_shutdown(&mut self, enabled: bool) { + self.shutdown_flags.set_quiet(enabled); + } + fn get_pending_plaintext(&mut self) -> usize { self.conn_mut() .as_mut() @@ -892,6 +1168,25 @@ impl Ssl { None => HandshakeState::Before, } } + + fn was_session_reused(&self) -> bool { + match self.conn() { + Some(conn) => conn.handshake_kind() == Some(HandshakeKind::Resumed), + None => false, + } + } +} + +/// Encode rustls's internal representation in the wire format. +fn encode_alpn<'a>(iter: impl Iterator) -> Vec { + let mut out = vec![]; + + for item in iter { + out.push(item.len() as u8); + out.extend_from_slice(item); + } + + out } /// This is a reduced-fidelity version of `OSSL_HANDSHAKE_STATE`. @@ -961,6 +1256,9 @@ struct ShutdownFlags(i32); impl ShutdownFlags { const SENT: i32 = 1; const RECEIVED: i32 = 2; + const PUBLIC: i32 = Self::SENT | Self::RECEIVED; + + const PRIV_QUIET: i32 = 4; fn is_sent(&self) -> bool { self.0 & ShutdownFlags::SENT == ShutdownFlags::SENT @@ -979,7 +1277,23 @@ impl ShutdownFlags { } fn set(&mut self, flags: i32) { - self.0 |= flags & (ShutdownFlags::SENT | ShutdownFlags::RECEIVED); + self.0 |= flags & ShutdownFlags::PUBLIC; + } + + fn get(&self) -> i32 { + self.0 & ShutdownFlags::PUBLIC + } + + fn set_quiet(&mut self, enabled: bool) { + if enabled { + self.0 |= ShutdownFlags::PRIV_QUIET; + } else { + self.0 &= !ShutdownFlags::PRIV_QUIET; + } + } + + fn quiet(&self) -> bool { + self.0 & ShutdownFlags::PRIV_QUIET == ShutdownFlags::PRIV_QUIET } } @@ -1016,6 +1330,46 @@ impl From for VerifyMode { } } +impl From for i32 { + fn from(v: VerifyMode) -> Self { + v.0 + } +} + +#[derive(Debug, Default, Clone)] +struct EnabledVersions { + min: Option, + max: Option, +} + +impl EnabledVersions { + fn reduce_versions( + &self, + method_versions: &'static [&'static SupportedProtocolVersion], + ) -> Result, error::Error> { + let acceptable: Vec<&'static SupportedProtocolVersion> = method_versions + .iter() + .cloned() + .filter(|v| self.satisfies(v.version)) + .collect(); + + if acceptable.is_empty() { + Err(error::Error::bad_data(&format!( + "no versions usable: method enabled {method_versions:?}, filter {self:?}" + ))) + } else { + Ok(acceptable) + } + } + + fn satisfies(&self, v: ProtocolVersion) -> bool { + let min = self.min.map(u16::from).unwrap_or(0); + let max = self.max.map(u16::from).unwrap_or(0xffff); + let v = u16::from(v); + min <= v && v <= max + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/rustls-libssl/src/miri.rs b/rustls-libssl/src/miri.rs index 9429525..b6dcece 100644 --- a/rustls-libssl/src/miri.rs +++ b/rustls-libssl/src/miri.rs @@ -1,5 +1,6 @@ +use core::ptr; /// Shims for functions we call, written in rust so they are visible to miri. -use std::ffi::{c_char, c_int, CStr}; +use std::ffi::{c_char, c_int, c_void, CStr}; pub struct X509_STORE(()); @@ -27,3 +28,24 @@ pub extern "C" fn ERR_set_error(lib: c_int, reason: c_int, message: *const c_cha CStr::from_ptr(message) }); } + +#[no_mangle] +pub extern "C" fn CRYPTO_new_ex_data( + ty: c_int, + owner: *mut c_void, + out: *mut [*mut c_void; 2], +) -> c_int { + eprintln!("CRYPTO_new_ex_data({ty}, {owner:?});"); + let marker = [owner, owner]; + unsafe { + ptr::write(out, marker); + }; + 1 +} + +#[no_mangle] +pub extern "C" fn CRYPTO_free_ex_data(ty: c_int, owner: *mut c_void, ed: *mut [*mut c_void; 2]) { + let marker: [*mut c_void; 2] = unsafe { ptr::read(ed) }; + assert!(marker[0] == owner); + assert!(marker[1] == owner); +} diff --git a/rustls-libssl/tests/client.c b/rustls-libssl/tests/client.c index 4d57886..ab8416c 100644 --- a/rustls-libssl/tests/client.c +++ b/rustls-libssl/tests/client.c @@ -54,9 +54,23 @@ int main(int argc, char **argv) { } else { SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, NULL); dump_openssl_error_stack(); + assert(SSL_CTX_get_verify_mode(ctx) == SSL_VERIFY_PEER); + assert(SSL_CTX_get_verify_callback(ctx) == NULL); TRACE(SSL_CTX_load_verify_file(ctx, cacert)); dump_openssl_error_stack(); } + printf("SSL_CTX_get_verify_depth default %d\n", + SSL_CTX_get_verify_depth(ctx)); + printf("SSL_CTX_get_min_proto_version default 0x%lx\n", + SSL_CTX_get_min_proto_version(ctx)); + printf("SSL_CTX_get_max_proto_version default 0x%lx\n", + SSL_CTX_get_max_proto_version(ctx)); + TRACE(SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION)); + TRACE(SSL_CTX_set_max_proto_version(ctx, TLS1_3_VERSION)); + printf("SSL_CTX_get_min_proto_version 0x%lx\n", + SSL_CTX_get_min_proto_version(ctx)); + printf("SSL_CTX_get_max_proto_version 0x%lx\n", + SSL_CTX_get_max_proto_version(ctx)); X509 *client_cert = NULL; EVP_PKEY *client_key = NULL; @@ -77,6 +91,17 @@ int main(int argc, char **argv) { printf("SSL_new: SSL_get_certificate %s SSL_CTX_get0_certificate\n", SSL_get_certificate(ssl) == client_cert ? "same as" : "differs to"); state(ssl); + printf("SSL_get_verify_depth default %d\n", SSL_get_verify_depth(ssl)); + printf("SSL_get_min_proto_version 0x%lx\n", SSL_get_min_proto_version(ssl)); + printf("SSL_get_max_proto_version 0x%lx\n", SSL_get_max_proto_version(ssl)); + printf("SSL_get_servername: %s (%d)\n", + SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name), + SSL_get_servername_type(ssl)); + TRACE(SSL_set_tlsext_host_name(ssl, "localhost")); + dump_openssl_error_stack(); + printf("SSL_get_servername: %s (%d)\n", + SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name), + SSL_get_servername_type(ssl)); TRACE(SSL_set1_host(ssl, host)); dump_openssl_error_stack(); TRACE(SSL_set_fd(ssl, sock)); diff --git a/rustls-libssl/tests/nginx.conf b/rustls-libssl/tests/nginx.conf new file mode 100644 index 0000000..4acc9ed --- /dev/null +++ b/rustls-libssl/tests/nginx.conf @@ -0,0 +1,47 @@ +daemon off; +master_process off; +pid nginx.pid; + +events { +} + +http { + ssl_protocols TLSv1.2 TLSv1.3; + access_log access.log; + + server { + listen 8443 ssl; + server_name localhost; + ssl_certificate ../../../test-ca/rsa/server.cert; + ssl_certificate_key ../../../test-ca/rsa/server.key; + + location = / { + return 200 "hello world\n"; + } + + location /ssl-agreed { + return 200 "protocol:$ssl_protocol,cipher:$ssl_cipher\n"; + } + + location /ssl-server-name { + return 200 "server-name:$ssl_server_name\n"; + } + + location /ssl-was-reused { + return 200 "reused:$ssl_session_reused\n"; + } + + # not currently implemented: + location /ssl-offer { + return 200 "ciphers:$ssl_ciphers,curves:$ssl_curves\n"; + } + + location /ssl-early-data { + return 200 "early-data:$ssl_early_data\n"; + } + + location /ssl-client-auth { + return 200 "s-dn:$ssl_client_s_dn\ni-dn:$ssl_client_i_dn\nserial:$ssl_client_serial\nfp:$ssl_client_fingerprint\nverify:$ssl_client_verify\nv-start:$ssl_client_v_start\nv-end:$ssl_client_v_end\nv-remain:$ssl_client_v_remain\ncert:\n$ssl_client_cert\n"; + } + } +} diff --git a/rustls-libssl/tests/runner.rs b/rustls-libssl/tests/runner.rs index ff3792a..03e6075 100644 --- a/rustls-libssl/tests/runner.rs +++ b/rustls-libssl/tests/runner.rs @@ -1,6 +1,6 @@ use std::io::Read; use std::process::{Child, Command, Output, Stdio}; -use std::{net, thread, time}; +use std::{fs, net, thread, time}; /* Note: * @@ -327,6 +327,71 @@ fn server() { assert_eq!(openssl_output, rustls_output); } +const NGINX_LOG_LEVEL: &str = "info"; + +#[test] +#[ignore] +fn nginx() { + fs::create_dir_all("target/nginx-tmp/basic/html").unwrap(); + fs::write( + "target/nginx-tmp/basic/server.conf", + include_str!("nginx.conf"), + ) + .unwrap(); + + let big_file = vec![b'a'; 5 * 1024 * 1024]; + fs::write("target/nginx-tmp/basic/html/large.html", &big_file).unwrap(); + + let nginx_server = KillOnDrop(Some( + Command::new("tests/maybe-valgrind.sh") + .args([ + "nginx", + "-g", + &format!("error_log stderr {NGINX_LOG_LEVEL};"), + "-p", + "./target/nginx-tmp/basic", + "-c", + "server.conf", + ]) + .spawn() + .unwrap(), + )); + wait_for_port(8443); + + // basic single request + assert_eq!( + Command::new("curl") + .env("LD_LIBRARY_PATH", "") + .args(["--cacert", "test-ca/rsa/ca.cert", "https://localhost:8443/"]) + .stdout(Stdio::piped()) + .output() + .map(print_output) + .unwrap() + .stdout, + b"hello world\n" + ); + + // big download (throttled by curl to ensure non-blocking writes work) + assert_eq!( + Command::new("curl") + .env("LD_LIBRARY_PATH", "") + .args([ + "--cacert", + "test-ca/rsa/ca.cert", + "--limit-rate", + "1M", + "https://localhost:8443/large.html" + ]) + .stdout(Stdio::piped()) + .output() + .unwrap() + .stdout, + big_file + ); + + drop(nginx_server); +} + struct KillOnDrop(Option); impl KillOnDrop { diff --git a/rustls-libssl/tests/server.c b/rustls-libssl/tests/server.c index 3fbf152..d329960 100644 --- a/rustls-libssl/tests/server.c +++ b/rustls-libssl/tests/server.c @@ -19,6 +19,56 @@ #include "helpers.h" +static int ssl_ctx_ex_data_idx_message; +static int ssl_ex_data_idx_message; + +static int alpn_cookie = 12345; + +static int alpn_callback(SSL *ssl, const uint8_t **out, uint8_t *outlen, + const uint8_t *in, unsigned int inlen, void *arg) { + printf("in alpn_callback:\n"); + assert(ssl != NULL); + assert(arg == &alpn_cookie); + printf(" ssl_ex_data_idx_message: %s\n", + (const char *)SSL_get_ex_data(ssl, ssl_ex_data_idx_message)); + hexdump(" in", in, (int)inlen); + if (SSL_select_next_proto((uint8_t **)out, outlen, + (const uint8_t *)"\x08http/1.1", 9, in, + inlen) == OPENSSL_NPN_NEGOTIATED) { + hexdump(" select", *out, (int)*outlen); + return SSL_TLSEXT_ERR_OK; + } else { + printf(" alpn failed\n"); + return SSL_TLSEXT_ERR_ALERT_FATAL; + } +} + +static int cert_cookie = 12345; + +static int cert_callback(SSL *ssl, void *arg) { + printf("in cert_callback\n"); + assert(ssl != NULL); + assert(arg == &cert_cookie); + printf(" ssl_ex_data_idx_message: %s\n", + (const char *)SSL_get_ex_data(ssl, ssl_ex_data_idx_message)); + return 1; +} + +static int sni_cookie = 12345; + +static int sni_callback(SSL *ssl, int *al, void *arg) { + printf("in sni_callback\n"); + assert(ssl != NULL); + assert(arg == &sni_cookie); + assert(*al == SSL_AD_UNRECOGNIZED_NAME); + printf(" SSL_get_servername: %s (%d)\n", + SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name), + SSL_get_servername_type(ssl)); + printf(" ssl_ex_data_idx_message: %s\n", + (const char *)SSL_get_ex_data(ssl, ssl_ex_data_idx_message)); + return SSL_TLSEXT_ERR_OK; +} + int main(int argc, char **argv) { if (argc != 5) { printf("%s |unauth\n\n", @@ -55,6 +105,24 @@ int main(int argc, char **argv) { printf("client auth disabled\n"); } + ssl_ctx_ex_data_idx_message = + SSL_CTX_get_ex_new_index(0, NULL, NULL, NULL, NULL); + TRACE(SSL_CTX_set_ex_data(ctx, ssl_ctx_ex_data_idx_message, + "hello from SSL_CTX!")); + printf("ssl_ctx_ex_data_idx_message: %s\n", + (const char *)SSL_CTX_get_ex_data(ctx, ssl_ctx_ex_data_idx_message)); + + SSL_CTX_set_alpn_select_cb(ctx, alpn_callback, &alpn_cookie); + dump_openssl_error_stack(); + + SSL_CTX_set_cert_cb(ctx, cert_callback, &cert_cookie); + dump_openssl_error_stack(); + + SSL_CTX_set_tlsext_servername_callback(ctx, sni_callback); + dump_openssl_error_stack(); + SSL_CTX_set_tlsext_servername_arg(ctx, &sni_cookie); + dump_openssl_error_stack(); + X509 *server_cert = NULL; EVP_PKEY *server_key = NULL; TRACE(SSL_CTX_use_certificate_chain_file(ctx, certfile)); @@ -62,8 +130,8 @@ int main(int argc, char **argv) { server_key = SSL_CTX_get0_privatekey(ctx); server_cert = SSL_CTX_get0_certificate(ctx); - TRACE(SSL_CTX_set_alpn_protos(ctx, (const uint8_t *)"\x08http/1.1", 9)); - dump_openssl_error_stack(); + printf("SSL_CTX_get_max_early_data default %lu\n", + (unsigned long)SSL_CTX_get_max_early_data(ctx)); SSL *ssl = SSL_new(ctx); dump_openssl_error_stack(); @@ -71,6 +139,12 @@ int main(int argc, char **argv) { SSL_get_privatekey(ssl) == server_key ? "same as" : "differs to"); printf("SSL_new: SSL_get_certificate %s SSL_CTX_get0_certificate\n", SSL_get_certificate(ssl) == server_cert ? "same as" : "differs to"); + + ssl_ex_data_idx_message = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL); + TRACE(SSL_set_ex_data(ssl, ssl_ex_data_idx_message, "hello from SSL!")); + printf("ssl_ex_data_idx_message: %s\n", + (const char *)SSL_get_ex_data(ssl, ssl_ex_data_idx_message)); + state(ssl); TRACE(SSL_set_fd(ssl, sock)); dump_openssl_error_stack(); @@ -92,6 +166,9 @@ int main(int argc, char **argv) { printf("version: %s\n", SSL_get_version(ssl)); printf("verify-result: %ld\n", SSL_get_verify_result(ssl)); printf("cipher: %s\n", SSL_CIPHER_standard_name(SSL_get_current_cipher(ssl))); + printf("SSL_get_servername: %s (%d)\n", + SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name), + SSL_get_servername_type(ssl)); show_peer_certificate("client", ssl); @@ -125,6 +202,10 @@ int main(int argc, char **argv) { TRACE(SSL_shutdown(ssl)); dump_openssl_error_stack(); + printf("ssl_ex_data_idx_message: %s\n", + (const char *)SSL_get_ex_data(ssl, ssl_ex_data_idx_message)); + printf("ssl_ctx_ex_data_idx_message: %s\n", + (const char *)SSL_CTX_get_ex_data(ctx, ssl_ctx_ex_data_idx_message)); close(sock); close(listener); SSL_free(ssl);