Skip to content

Commit

Permalink
Implement SSL_accept and associated server support
Browse files Browse the repository at this point in the history
  • Loading branch information
ctz committed Apr 10, 2024
1 parent 65083f5 commit 0367fe2
Show file tree
Hide file tree
Showing 6 changed files with 317 additions and 63 deletions.
2 changes: 1 addition & 1 deletion rustls-libssl/MATRIX.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@
| `SSL_SESSION_up_ref` | | :white_check_mark: | :exclamation: [^stub] |
| `SSL_SRP_CTX_free` [^deprecatedin_3_0] [^srp] | | | |
| `SSL_SRP_CTX_init` [^deprecatedin_3_0] [^srp] | | | |
| `SSL_accept` | | | |
| `SSL_accept` | | | :white_check_mark: |
| `SSL_add1_host` | | | |
| `SSL_add1_to_CA_list` | | | |
| `SSL_add_client_CA` | | | |
Expand Down
1 change: 1 addition & 0 deletions rustls-libssl/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ const ENTRYPOINTS: &[&str] = &[
"d2i_SSL_SESSION",
"i2d_SSL_SESSION",
"OPENSSL_init_ssl",
"SSL_accept",
"SSL_alert_desc_string",
"SSL_alert_desc_string_long",
"SSL_CIPHER_description",
Expand Down
16 changes: 16 additions & 0 deletions rustls-libssl/src/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,22 @@ entry! {
}
}

entry! {
pub fn _SSL_accept(ssl: *mut SSL) -> c_int {
let ssl = try_clone_arc!(ssl);

match ssl
.lock()
.map_err(|_| Error::cannot_lock())
.and_then(|mut ssl| ssl.accept())
.map_err(|err| err.raise())
{
Err(e) => e.into(),
Ok(()) => C_INT_SUCCESS,
}
}
}

entry! {
pub fn _SSL_write(ssl: *mut SSL, buf: *const c_void, num: c_int) -> c_int {
const ERROR: c_int = -1;
Expand Down
183 changes: 147 additions & 36 deletions rustls-libssl/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use core::ffi::{c_int, c_uint, CStr};
use core::ptr;
use core::{mem, ptr};
use std::fs;
use std::io::{ErrorKind, Read, Write};
use std::path::PathBuf;
Expand All @@ -12,7 +12,10 @@ use openssl_sys::{
};
use rustls::crypto::aws_lc_rs as provider;
use rustls::pki_types::{CertificateDer, ServerName};
use rustls::{CipherSuite, ClientConfig, ClientConnection, Connection, RootCertStore};
use rustls::server::{Accepted, Acceptor};
use rustls::{
CipherSuite, ClientConfig, ClientConnection, Connection, RootCertStore, ServerConfig,
};

mod bio;
#[macro_use]
Expand Down Expand Up @@ -340,14 +343,22 @@ struct Ssl {
alpn: Vec<Vec<u8>>,
sni_server_name: Option<ServerName<'static>>,
bio: Option<bio::Bio>,
conn: Option<Connection>,
verifier: Option<Arc<verifier::ServerVerifier>>,
conn: ConnState,
peer_cert: Option<x509::OwnedX509>,
peer_cert_chain: Option<x509::OwnedX509Stack>,
shutdown_flags: ShutdownFlags,
auth_keys: sign::CertifiedKeySet,
}

#[allow(clippy::large_enum_variant)]
enum ConnState {
Nothing,
Client(Connection, Arc<verifier::ServerVerifier>),
Accepting(Acceptor),
Accepted(Accepted),
Server(Connection, Arc<verifier::ClientVerifier>),
}

impl Ssl {
fn new(ctx: Arc<Mutex<SslContext>>, inner: &SslContext) -> Result<Self, error::Error> {
Ok(Self {
Expand All @@ -360,8 +371,7 @@ impl Ssl {
alpn: inner.alpn.clone(),
sni_server_name: None,
bio: None,
conn: None,
verifier: None,
conn: ConnState::Nothing,
peer_cert: None,
peer_cert_chain: None,
shutdown_flags: ShutdownFlags::default(),
Expand Down Expand Up @@ -482,7 +492,7 @@ impl Ssl {

fn connect(&mut self) -> Result<(), error::Error> {
self.set_client_mode();
if self.conn.is_none() {
if matches!(self.conn, ConnState::Nothing) {
self.init_client_conn()?;
}
self.try_io()
Expand All @@ -508,15 +518,14 @@ impl Ssl {
self.verify_mode,
&self.verify_server_name,
));
self.verifier = Some(verifier.clone());

let wants_resolver = ClientConfig::builder_with_provider(provider)
.with_protocol_versions(method.client_versions)
.map_err(error::Error::from_rustls)?
.dangerous()
.with_custom_certificate_verifier(verifier);
.with_custom_certificate_verifier(verifier.clone());

let mut config = if let Some(resolver) = self.auth_keys.resolver() {
let mut config = if let Some(resolver) = self.auth_keys.client_resolver() {
wants_resolver.with_client_cert_resolver(resolver)
} else {
wants_resolver.with_no_client_auth()
Expand All @@ -527,23 +536,99 @@ impl Ssl {
let client_conn = ClientConnection::new(Arc::new(config), sni_server_name.clone())
.map_err(error::Error::from_rustls)?;

self.conn = Some(client_conn.into());
self.conn = ConnState::Client(client_conn.into(), verifier);
Ok(())
}

fn accept(&mut self) -> Result<(), error::Error> {
self.set_server_mode();

if matches!(self.conn, ConnState::Nothing) {
self.conn = ConnState::Accepting(Acceptor::default());
}

self.try_io()?;

if let ConnState::Accepted(_) = self.conn {
self.init_server_conn()?;
}

self.try_io()
}

fn init_server_conn(&mut self) -> Result<(), error::Error> {
let method = self
.ctx
.lock()
.map(|ctx| ctx.method)
.map_err(|_| error::Error::cannot_lock())?;

let provider = Arc::new(provider::default_provider());
let verifier = Arc::new(
verifier::ClientVerifier::new(
self.verify_roots.clone().into(),
provider.clone(),
self.verify_mode,
)
.map_err(error::Error::from_rustls)?,
);

let resolver = self
.auth_keys
.server_resolver()
.ok_or_else(|| error::Error::bad_data("missing server keys"))?;

let config = ServerConfig::builder_with_provider(provider)
.with_protocol_versions(method.server_versions)
.map_err(error::Error::from_rustls)?
.with_client_cert_verifier(verifier.clone())
.with_cert_resolver(resolver);

let accepted = match mem::replace(&mut self.conn, ConnState::Nothing) {
ConnState::Accepted(accepted) => accepted,
_ => unreachable!(),
};

// TODO: send alert
let server_conn = accepted
.into_connection(Arc::new(config))
.map_err(|(err, _alert)| error::Error::from_rustls(err))?;

self.conn = ConnState::Server(server_conn.into(), verifier);
Ok(())
}

fn conn(&self) -> Option<&Connection> {
match &self.conn {
ConnState::Client(conn, _) | ConnState::Server(conn, _) => Some(conn),
_ => None,
}
}

fn conn_mut(&mut self) -> Option<&mut Connection> {
match &mut self.conn {
ConnState::Client(conn, _) | ConnState::Server(conn, _) => Some(conn),
_ => None,
}
}

fn want(&self) -> Want {
match &self.conn {
Some(conn) => Want {
ConnState::Client(conn, _) | ConnState::Server(conn, _) => Want {
read: conn.wants_read(),
write: conn.wants_write(),
},
None => Want::default(),
ConnState::Accepting(_) => Want {
read: true,
write: false,
},
_ => Want::default(),
}
}

fn write(&mut self, slice: &[u8]) -> Result<usize, error::Error> {
let written = match &mut self.conn {
Some(ref mut conn) => conn.writer().write(slice).map_err(error::Error::from_io)?,
let written = match self.conn_mut() {
Some(conn) => conn.writer().write(slice).map_err(error::Error::from_io)?,
None => 0,
};
self.try_io()?;
Expand All @@ -554,8 +639,8 @@ impl Ssl {
let (late_err, read_count) = loop {
let late_err = self.try_io();

match &mut self.conn {
Some(ref mut conn) => match conn.reader().read(slice) {
match self.conn_mut() {
Some(conn) => match conn.reader().read(slice) {
Ok(read) => break (late_err, read),
Err(err) if err.kind() == ErrorKind::WouldBlock && late_err.is_ok() => {
// no data available, go around again.
Expand Down Expand Up @@ -586,7 +671,7 @@ impl Ssl {
};

match &mut self.conn {
Some(ref mut conn) => {
ConnState::Client(conn, _) | ConnState::Server(conn, _) => {
match conn.complete_io(bio) {
Ok(_) => {}
Err(e) => {
Expand All @@ -606,15 +691,34 @@ impl Ssl {
}
Ok(())
}
None => Ok(()),
ConnState::Accepting(acceptor) => {
match acceptor.read_tls(bio) {
Ok(_) => {}
Err(e) => {
return Err(error::Error::from_io(e));
}
};

match acceptor.accept() {
Ok(None) => Ok(()),
Ok(Some(accepted)) => {
self.conn = ConnState::Accepted(accepted);
Ok(())
}
Err((error, mut alert)) => {
alert.write_all(bio).map_err(error::Error::from_io)?;
Err(error::Error::from_rustls(error))
}
}
}
_ => Ok(()),
}
}

fn try_shutdown(&mut self) -> Result<ShutdownResult, error::Error> {
if !self.shutdown_flags.is_sent() {
match &mut self.conn {
Some(ref mut conn) => conn.send_close_notify(),
None => (),
if let Some(ref mut conn) = self.conn_mut() {
conn.send_close_notify();
};

self.shutdown_flags.set_sent();
Expand All @@ -637,7 +741,7 @@ impl Ssl {
}

fn get_pending_plaintext(&mut self) -> usize {
self.conn
self.conn_mut()
.as_mut()
.and_then(|conn| {
let io_state = conn.process_new_packets().ok()?;
Expand All @@ -647,11 +751,11 @@ impl Ssl {
}

fn get_agreed_alpn(&mut self) -> Option<&[u8]> {
self.conn.as_ref().and_then(|conn| conn.alpn_protocol())
self.conn().and_then(|conn| conn.alpn_protocol())
}

fn init_peer_cert(&mut self) {
let conn = match &self.conn {
let conn = match self.conn() {
Some(conn) => conn,
None => return,
};
Expand All @@ -662,6 +766,8 @@ impl Ssl {
};

let mut stack = x509::OwnedX509Stack::empty();
let mut peer_cert = None;

for (i, cert) in certs.iter().enumerate() {
let converted = match x509::OwnedX509::parse_der(cert.as_ref()) {
Some(converted) => converted,
Expand All @@ -676,12 +782,13 @@ impl Ssl {
// certificate must be obtained separately"
stack.push(&converted);
}
self.peer_cert = Some(converted);
peer_cert = Some(converted);
} else {
stack.push(&converted);
}
}

self.peer_cert = peer_cert;
self.peer_cert_chain = Some(stack);
}

Expand All @@ -700,22 +807,21 @@ impl Ssl {
}

fn get_negotiated_cipher_suite_id(&self) -> Option<CipherSuite> {
self.conn
.as_ref()
self.conn()
.and_then(|conn| conn.negotiated_cipher_suite())
.map(|suite| suite.suite())
}

fn get_last_verification_result(&self) -> i64 {
if let Some(verifier) = &self.verifier {
verifier.last_result()
} else {
X509_V_ERR_UNSPECIFIED as i64
match &self.conn {
ConnState::Client(_, verifier) => verifier.last_result(),
ConnState::Server(_, verifier) => verifier.last_result(),
_ => X509_V_ERR_UNSPECIFIED as i64,
}
}

fn get_error(&mut self) -> c_int {
match &mut self.conn {
match self.conn_mut() {
Some(ref mut conn) => {
if let Err(e) = conn.process_new_packets() {
error::Error::from_rustls(e).raise();
Expand Down Expand Up @@ -775,13 +881,14 @@ impl Ssl {
}

fn handshake_state(&mut self) -> HandshakeState {
match &mut self.conn {
let mode = self.mode;
match self.conn_mut() {
Some(ref mut conn) => {
if conn.process_new_packets().is_err() {
return HandshakeState::Error;
}

match (&self.mode, conn.is_handshaking()) {
match (mode, conn.is_handshaking()) {
(ConnMode::Server, true) => HandshakeState::ServerAwaitingClientHello,
(ConnMode::Client, true) => HandshakeState::ClientAwaitingServerHello,
(ConnMode::Unknown, true) => HandshakeState::Before,
Expand Down Expand Up @@ -841,7 +948,7 @@ struct Want {
write: bool,
}

#[derive(PartialEq, Debug)]
#[derive(PartialEq, Debug, Clone, Copy)]
enum ConnMode {
Unknown,
Client,
Expand Down Expand Up @@ -895,6 +1002,10 @@ impl VerifyMode {
self.0 & VerifyMode::PEER == VerifyMode::PEER
}

pub fn server_must_attempt_client_auth(&self) -> bool {
self.0 & VerifyMode::PEER == VerifyMode::PEER
}

pub fn server_must_verify_client(&self) -> bool {
let bitmap = VerifyMode::PEER | VerifyMode::FAIL_IF_NO_PEER_CERT;
self.0 & bitmap == bitmap
Expand Down
Loading

0 comments on commit 0367fe2

Please sign in to comment.