diff --git a/postgres/src/client.rs b/postgres/src/client.rs index 7c5ff34..3b2d4f9 100644 --- a/postgres/src/client.rs +++ b/postgres/src/client.rs @@ -3,7 +3,7 @@ use std::io::{self, Read}; use tokio_postgres::types::{ToSql, Type}; use tokio_postgres::Error; #[cfg(feature = "runtime")] -use tokio_postgres::{MakeTlsMode, Socket, TlsMode}; +use tokio_postgres::{MakeTlsConnect, Socket, TlsConnect}; #[cfg(feature = "runtime")] use crate::Config; @@ -15,10 +15,10 @@ impl Client { #[cfg(feature = "runtime")] pub fn connect(params: &str, tls_mode: T) -> Result where - T: MakeTlsMode + 'static + Send, - T::TlsMode: Send, + T: MakeTlsConnect + 'static + Send, + T::TlsConnect: Send, T::Stream: Send, - >::Future: Send, + >::Future: Send, { params.parse::()?.connect(tls_mode) } diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 4c6147b..dab5f53 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -4,7 +4,7 @@ use log::error; use std::path::Path; use std::str::FromStr; use std::time::Duration; -use tokio_postgres::{Error, MakeTlsMode, Socket, TargetSessionAttrs, TlsMode}; +use tokio_postgres::{Error, MakeTlsConnect, Socket, TargetSessionAttrs, TlsConnect}; use crate::{Client, RUNTIME}; @@ -94,10 +94,10 @@ impl Config { pub fn connect(&self, tls_mode: T) -> Result where - T: MakeTlsMode + 'static + Send, - T::TlsMode: Send, + T: MakeTlsConnect + 'static + Send, + T::TlsConnect: Send, T::Stream: Send, - >::Future: Send, + >::Future: Send, { let connect = self.0.connect(tls_mode); let (client, connection) = oneshot::spawn(connect, &RUNTIME.executor()).wait()?; diff --git a/tokio-postgres-native-tls/src/test.rs b/tokio-postgres-native-tls/src/test.rs index cbb1218..2843118 100644 --- a/tokio-postgres-native-tls/src/test.rs +++ b/tokio-postgres-native-tls/src/test.rs @@ -2,13 +2,13 @@ use futures::{Future, Stream}; use native_tls::{self, Certificate}; use tokio::net::TcpStream; use tokio::runtime::current_thread::Runtime; -use tokio_postgres::{self, PreferTls, RequireTls, TlsMode}; +use tokio_postgres::TlsConnect; use crate::TlsConnector; fn smoke_test(s: &str, tls: T) where - T: TlsMode, + T: TlsConnect, T::Stream: 'static, { let mut runtime = Runtime::new().unwrap(); @@ -44,8 +44,8 @@ fn require() { .build() .unwrap(); smoke_test( - "user=ssl_user dbname=postgres", - RequireTls(TlsConnector::with_connector(connector, "localhost")), + "user=ssl_user dbname=postgres sslmode=require", + TlsConnector::with_connector(connector, "localhost"), ); } @@ -59,7 +59,7 @@ fn prefer() { .unwrap(); smoke_test( "user=ssl_user dbname=postgres", - PreferTls(TlsConnector::with_connector(connector, "localhost")), + TlsConnector::with_connector(connector, "localhost"), ); } @@ -72,7 +72,7 @@ fn scram_user() { .build() .unwrap(); smoke_test( - "user=scram_user password=password dbname=postgres", - RequireTls(TlsConnector::with_connector(connector, "localhost")), + "user=scram_user password=password dbname=postgres sslmode=require", + TlsConnector::with_connector(connector, "localhost"), ); } diff --git a/tokio-postgres-openssl/src/test.rs b/tokio-postgres-openssl/src/test.rs index aa0e380..2dc336c 100644 --- a/tokio-postgres-openssl/src/test.rs +++ b/tokio-postgres-openssl/src/test.rs @@ -2,13 +2,13 @@ use futures::{Future, Stream}; use openssl::ssl::{SslConnector, SslMethod}; use tokio::net::TcpStream; use tokio::runtime::current_thread::Runtime; -use tokio_postgres::{self, PreferTls, RequireTls, TlsMode}; +use tokio_postgres::TlsConnect; use super::*; fn smoke_test(s: &str, tls: T) where - T: TlsMode, + T: TlsConnect, T::Stream: 'static, { let mut runtime = Runtime::new().unwrap(); @@ -41,8 +41,8 @@ fn require() { builder.set_ca_file("../test/server.crt").unwrap(); let ctx = builder.build(); smoke_test( - "user=ssl_user dbname=postgres", - RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")), + "user=ssl_user dbname=postgres sslmode=require", + TlsConnector::new(ctx.configure().unwrap(), "localhost"), ); } @@ -53,7 +53,7 @@ fn prefer() { let ctx = builder.build(); smoke_test( "user=ssl_user dbname=postgres", - PreferTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")), + TlsConnector::new(ctx.configure().unwrap(), "localhost"), ); } @@ -63,8 +63,8 @@ fn scram_user() { builder.set_ca_file("../test/server.crt").unwrap(); let ctx = builder.build(); smoke_test( - "user=scram_user password=password dbname=postgres", - RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")), + "user=scram_user password=password dbname=postgres sslmode=require", + TlsConnector::new(ctx.configure().unwrap(), "localhost"), ); } @@ -78,8 +78,8 @@ fn runtime() { let connector = MakeTlsConnector::new(builder.build()); let connect = tokio_postgres::connect( - "host=localhost port=5433 user=postgres", - RequireTls(connector), + "host=localhost port=5433 user=postgres sslmode=require", + connector, ); let (mut client, connection) = runtime.block_on(connect).unwrap(); let connection = connection.map_err(|e| panic!("{}", e)); diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 35a121e..d47a105 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -49,7 +49,6 @@ postgres-protocol = { version = "0.3.0", path = "../postgres-protocol" } state_machine_future = "0.1.7" tokio-codec = "0.1" tokio-io = "0.1" -void = "1.0" tokio-tcp = { version = "0.1", optional = true } futures-cpupool = { version = "0.1", optional = true } diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index f290e5f..b38c1a0 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -19,8 +19,8 @@ use tokio_io::{AsyncRead, AsyncWrite}; use crate::proto::ConnectFuture; use crate::proto::ConnectRawFuture; #[cfg(feature = "runtime")] -use crate::{Connect, MakeTlsMode, Socket}; -use crate::{ConnectRaw, Error, TlsMode}; +use crate::{Connect, MakeTlsConnect, Socket}; +use crate::{ConnectRaw, Error, TlsConnect}; /// Properties required of a session. #[cfg(feature = "runtime")] @@ -34,6 +34,17 @@ pub enum TargetSessionAttrs { __NonExhaustive, } +/// TLS configuration. +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum SslMode { + /// Do not use TLS. + Disable, + /// Attempt to connect with TLS but allow sessions without. + Prefer, + /// Require the use of TLS. + Require, +} + #[cfg(feature = "runtime")] #[derive(Debug, Clone, PartialEq)] pub(crate) enum Host { @@ -49,6 +60,7 @@ pub(crate) struct Inner { pub(crate) dbname: Option, pub(crate) options: Option, pub(crate) application_name: Option, + pub(crate) ssl_mode: SslMode, #[cfg(feature = "runtime")] pub(crate) host: Vec, #[cfg(feature = "runtime")] @@ -79,6 +91,8 @@ pub(crate) struct Inner { /// * `dbname` - The name of the database to connect to. Defaults to the username. /// * `options` - Command line options used to configure the server. /// * `application_name` - Sets the `application_name` parameter on the server. +/// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used +/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`. /// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting @@ -152,6 +166,7 @@ impl Config { dbname: None, options: None, application_name: None, + ssl_mode: SslMode::Prefer, #[cfg(feature = "runtime")] host: vec![], #[cfg(feature = "runtime")] @@ -204,6 +219,14 @@ impl Config { self } + /// Sets the SSL configuration. + /// + /// Defaults to `prefer`. + pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config { + Arc::make_mut(&mut self.0).ssl_mode = ssl_mode; + self + } + /// Adds a host to the configuration. /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix @@ -320,6 +343,15 @@ impl Config { "application_name" => { self.application_name(&value); } + "sslmode" => { + let mode = match value { + "disable" => SslMode::Disable, + "prefer" => SslMode::Prefer, + "require" => SslMode::Require, + _ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))), + }; + self.ssl_mode(mode); + } #[cfg(feature = "runtime")] "host" => { for host in value.split(',') { @@ -390,22 +422,22 @@ impl Config { /// /// Requires the `runtime` Cargo feature (enabled by default). #[cfg(feature = "runtime")] - pub fn connect(&self, make_tls_mode: T) -> Connect + pub fn connect(&self, tls: T) -> Connect where - T: MakeTlsMode, + T: MakeTlsConnect, { - Connect(ConnectFuture::new(make_tls_mode, Ok(self.clone()))) + Connect(ConnectFuture::new(tls, Ok(self.clone()))) } /// Connects to a PostgreSQL database over an arbitrary stream. /// /// All of the settings other than `user`, `password`, `dbname`, `options`, and `application` name are ignored. - pub fn connect_raw(&self, stream: S, tls_mode: T) -> ConnectRaw + pub fn connect_raw(&self, stream: S, tls: T) -> ConnectRaw where S: AsyncRead + AsyncWrite, - T: TlsMode, + T: TlsConnect, { - ConnectRaw(ConnectRawFuture::new(stream, tls_mode, self.clone(), None)) + ConnectRaw(ConnectRawFuture::new(stream, tls, self.clone(), None)) } } diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 6e4ee3c..0019337 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -127,11 +127,11 @@ fn next_portal() -> String { /// /// [`Config`]: ./Config.t.html #[cfg(feature = "runtime")] -pub fn connect(config: &str, tls_mode: T) -> Connect +pub fn connect(config: &str, tls: T) -> Connect where - T: MakeTlsMode, + T: MakeTlsConnect, { - Connect(proto::ConnectFuture::new(tls_mode, config.parse())) + Connect(proto::ConnectFuture::new(tls, config.parse())) } /// An asynchronous PostgreSQL client. @@ -250,7 +250,7 @@ impl Client { #[cfg(feature = "runtime")] pub fn cancel_query(&mut self, make_tls_mode: T) -> CancelQuery where - T: MakeTlsMode, + T: MakeTlsConnect, { CancelQuery(self.0.cancel_query(make_tls_mode)) } @@ -260,7 +260,7 @@ impl Client { pub fn cancel_query_raw(&mut self, stream: S, tls_mode: T) -> CancelQueryRaw where S: AsyncRead + AsyncWrite, - T: TlsMode, + T: TlsConnect, { CancelQueryRaw(self.0.cancel_query_raw(stream, tls_mode)) } @@ -291,11 +291,12 @@ impl Client { /// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has /// occurred, or because its associated `Client` has dropped and all outstanding work has completed. #[must_use = "futures do nothing unless polled"] -pub struct Connection(proto::Connection); +pub struct Connection(proto::Connection>); -impl Connection +impl Connection where S: AsyncRead + AsyncWrite, + T: AsyncRead + AsyncWrite, { /// Returns the value of a runtime parameter for this connection. pub fn parameter(&self, name: &str) -> Option<&str> { @@ -311,9 +312,10 @@ where } } -impl Future for Connection +impl Future for Connection where S: AsyncRead + AsyncWrite, + T: AsyncRead + AsyncWrite, { type Item = (); type Error = Error; @@ -342,12 +344,12 @@ pub enum AsyncMessage { pub struct CancelQueryRaw(proto::CancelQueryRawFuture) where S: AsyncRead + AsyncWrite, - T: TlsMode; + T: TlsConnect; impl Future for CancelQueryRaw where S: AsyncRead + AsyncWrite, - T: TlsMode, + T: TlsConnect, { type Item = (); type Error = Error; @@ -361,12 +363,12 @@ where #[must_use = "futures do nothing unless polled"] pub struct CancelQuery(proto::CancelQueryFuture) where - T: MakeTlsMode; + T: MakeTlsConnect; #[cfg(feature = "runtime")] impl Future for CancelQuery where - T: MakeTlsMode, + T: MakeTlsConnect, { type Item = (); type Error = Error; @@ -380,17 +382,17 @@ where pub struct ConnectRaw(proto::ConnectRawFuture) where S: AsyncRead + AsyncWrite, - T: TlsMode; + T: TlsConnect; impl Future for ConnectRaw where S: AsyncRead + AsyncWrite, - T: TlsMode, + T: TlsConnect, { - type Item = (Client, Connection); + type Item = (Client, Connection); type Error = Error; - fn poll(&mut self) -> Poll<(Client, Connection), Error> { + fn poll(&mut self) -> Poll<(Client, Connection), Error> { let (client, connection) = try_ready!(self.0.poll()); Ok(Async::Ready((Client(client), Connection(connection)))) @@ -401,17 +403,17 @@ where #[must_use = "futures do nothing unless polled"] pub struct Connect(proto::ConnectFuture) where - T: MakeTlsMode; + T: MakeTlsConnect; #[cfg(feature = "runtime")] impl Future for Connect where - T: MakeTlsMode, + T: MakeTlsConnect, { - type Item = (Client, Connection); + type Item = (Client, Connection); type Error = Error; - fn poll(&mut self) -> Poll<(Client, Connection), Error> { + fn poll(&mut self) -> Poll<(Client, Connection), Error> { let (client, connection) = try_ready!(self.0.poll()); Ok(Async::Ready((Client(client), Connection(connection)))) diff --git a/tokio-postgres/src/proto/cancel_query.rs b/tokio-postgres/src/proto/cancel_query.rs index 6148480..1a7377c 100644 --- a/tokio-postgres/src/proto/cancel_query.rs +++ b/tokio-postgres/src/proto/cancel_query.rs @@ -3,16 +3,16 @@ use state_machine_future::{transition, RentToOwn, StateMachineFuture}; use std::io; use crate::proto::{CancelQueryRawFuture, ConnectSocketFuture}; -use crate::{Config, Error, Host, MakeTlsMode, Socket}; +use crate::{Config, Error, Host, MakeTlsConnect, Socket, SslMode}; #[derive(StateMachineFuture)] pub enum CancelQuery where - T: MakeTlsMode, + T: MakeTlsConnect, { #[state_machine_future(start, transitions(ConnectingSocket))] Start { - make_tls_mode: T, + tls: T, idx: Option, config: Config, process_id: i32, @@ -21,13 +21,14 @@ where #[state_machine_future(transitions(Canceling))] ConnectingSocket { future: ConnectSocketFuture, - tls_mode: T::TlsMode, + mode: SslMode, + tls: T::TlsConnect, process_id: i32, secret_key: i32, }, #[state_machine_future(transitions(Finished))] Canceling { - future: CancelQueryRawFuture, + future: CancelQueryRawFuture, }, #[state_machine_future(ready)] Finished(()), @@ -37,7 +38,7 @@ where impl PollCancelQuery for CancelQuery where - T: MakeTlsMode, + T: MakeTlsConnect, { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll, Error> { let mut state = state.take(); @@ -52,14 +53,15 @@ where #[cfg(unix)] Host::Unix(_) => "", }; - let tls_mode = state - .make_tls_mode - .make_tls_mode(hostname) + let tls = state + .tls + .make_tls_connect(hostname) .map_err(|e| Error::tls(e.into()))?; transition!(ConnectingSocket { + mode: state.config.0.ssl_mode, future: ConnectSocketFuture::new(state.config, idx), - tls_mode, + tls, process_id: state.process_id, secret_key: state.secret_key, }) @@ -74,7 +76,8 @@ where transition!(Canceling { future: CancelQueryRawFuture::new( socket, - state.tls_mode, + state.mode, + state.tls, state.process_id, state.secret_key ), @@ -91,15 +94,15 @@ where impl CancelQueryFuture where - T: MakeTlsMode, + T: MakeTlsConnect, { pub fn new( - make_tls_mode: T, + tls: T, idx: Option, config: Config, process_id: i32, secret_key: i32, ) -> CancelQueryFuture { - CancelQuery::start(make_tls_mode, idx, config, process_id, secret_key) + CancelQuery::start(tls, idx, config, process_id, secret_key) } } diff --git a/tokio-postgres/src/proto/cancel_query_raw.rs b/tokio-postgres/src/proto/cancel_query_raw.rs index ae2aee4..522fe31 100644 --- a/tokio-postgres/src/proto/cancel_query_raw.rs +++ b/tokio-postgres/src/proto/cancel_query_raw.rs @@ -5,14 +5,14 @@ use tokio_io::io::{self, Flush, WriteAll}; use tokio_io::{AsyncRead, AsyncWrite}; use crate::error::Error; -use crate::proto::TlsFuture; -use crate::TlsMode; +use crate::proto::{MaybeTlsStream, TlsFuture}; +use crate::{SslMode, TlsConnect}; #[derive(StateMachineFuture)] pub enum CancelQueryRaw where S: AsyncRead + AsyncWrite, - T: TlsMode, + T: TlsConnect, { #[state_machine_future(start, transitions(SendingCancel))] Start { @@ -22,10 +22,12 @@ where }, #[state_machine_future(transitions(FlushingCancel))] SendingCancel { - future: WriteAll>, + future: WriteAll, Vec>, }, #[state_machine_future(transitions(Finished))] - FlushingCancel { future: Flush }, + FlushingCancel { + future: Flush>, + }, #[state_machine_future(ready)] Finished(()), #[state_machine_future(error)] @@ -35,7 +37,7 @@ where impl PollCancelQueryRaw for CancelQueryRaw where S: AsyncRead + AsyncWrite, - T: TlsMode, + T: TlsConnect, { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll, Error> { let (stream, _) = try_ready!(state.future.poll()); @@ -69,14 +71,15 @@ where impl CancelQueryRawFuture where S: AsyncRead + AsyncWrite, - T: TlsMode, + T: TlsConnect, { pub fn new( stream: S, - tls_mode: T, + mode: SslMode, + tls: T, process_id: i32, secret_key: i32, ) -> CancelQueryRawFuture { - CancelQueryRaw::start(TlsFuture::new(stream, tls_mode), process_id, secret_key) + CancelQueryRaw::start(TlsFuture::new(stream, mode, tls), process_id, secret_key) } } diff --git a/tokio-postgres/src/proto/client.rs b/tokio-postgres/src/proto/client.rs index 6708e52..a9e0d53 100644 --- a/tokio-postgres/src/proto/client.rs +++ b/tokio-postgres/src/proto/client.rs @@ -25,9 +25,9 @@ use crate::proto::statement::Statement; use crate::proto::CancelQueryFuture; use crate::proto::CancelQueryRawFuture; use crate::types::{IsNull, Oid, ToSql, Type}; -use crate::{Config, Error, TlsMode}; +use crate::{Config, Error, TlsConnect}; #[cfg(feature = "runtime")] -use crate::{MakeTlsMode, Socket}; +use crate::{MakeTlsConnect, Socket}; pub struct PendingRequest(Result<(RequestMessages, IdleGuard), Error>); @@ -247,7 +247,7 @@ impl Client { #[cfg(feature = "runtime")] pub fn cancel_query(&self, make_tls_mode: T) -> CancelQueryFuture where - T: MakeTlsMode, + T: MakeTlsConnect, { CancelQueryFuture::new( make_tls_mode, @@ -258,12 +258,18 @@ impl Client { ) } - pub fn cancel_query_raw(&self, stream: S, tls_mode: T) -> CancelQueryRawFuture + pub fn cancel_query_raw(&self, stream: S, mode: T) -> CancelQueryRawFuture where S: AsyncRead + AsyncWrite, - T: TlsMode, + T: TlsConnect, { - CancelQueryRawFuture::new(stream, tls_mode, self.0.process_id, self.0.secret_key) + CancelQueryRawFuture::new( + stream, + self.0.config.0.ssl_mode, + mode, + self.0.process_id, + self.0.secret_key, + ) } fn close(&self, ty: u8, name: &str) { diff --git a/tokio-postgres/src/proto/connect.rs b/tokio-postgres/src/proto/connect.rs index afc1d01..1bc3f48 100644 --- a/tokio-postgres/src/proto/connect.rs +++ b/tokio-postgres/src/proto/connect.rs @@ -1,35 +1,35 @@ use futures::{Async, Future, Poll}; use state_machine_future::{transition, RentToOwn, StateMachineFuture}; -use crate::proto::{Client, ConnectOnceFuture, Connection}; -use crate::{Config, Error, Host, MakeTlsMode, Socket}; +use crate::proto::{Client, ConnectOnceFuture, Connection, MaybeTlsStream}; +use crate::{Config, Error, Host, MakeTlsConnect, Socket}; #[derive(StateMachineFuture)] pub enum Connect where - T: MakeTlsMode, + T: MakeTlsConnect, { #[state_machine_future(start, transitions(Connecting))] Start { - make_tls_mode: T, + tls: T, config: Result, }, #[state_machine_future(transitions(Finished))] Connecting { - future: ConnectOnceFuture, + future: ConnectOnceFuture, idx: usize, - make_tls_mode: T, + tls: T, config: Config, }, #[state_machine_future(ready)] - Finished((Client, Connection)), + Finished((Client, Connection>)), #[state_machine_future(error)] Failed(Error), } impl PollConnect for Connect where - T: MakeTlsMode, + T: MakeTlsConnect, { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll, Error> { let mut state = state.take(); @@ -50,15 +50,15 @@ where #[cfg(unix)] Host::Unix(_) => "", }; - let tls_mode = state - .make_tls_mode - .make_tls_mode(hostname) + let tls = state + .tls + .make_tls_connect(hostname) .map_err(|e| Error::tls(e.into()))?; transition!(Connecting { - future: ConnectOnceFuture::new(0, tls_mode, config.clone()), + future: ConnectOnceFuture::new(0, tls, config.clone()), idx: 0, - make_tls_mode: state.make_tls_mode, + tls: state.tls, config, }) } @@ -84,13 +84,12 @@ where #[cfg(unix)] Host::Unix(_) => "", }; - let tls_mode = state - .make_tls_mode - .make_tls_mode(hostname) + let tls = state + .tls + .make_tls_connect(hostname) .map_err(|e| Error::tls(e.into()))?; - state.future = - ConnectOnceFuture::new(state.idx, tls_mode, state.config.clone()); + state.future = ConnectOnceFuture::new(state.idx, tls, state.config.clone()); } } } @@ -99,9 +98,9 @@ where impl ConnectFuture where - T: MakeTlsMode, + T: MakeTlsConnect, { - pub fn new(make_tls_mode: T, config: Result) -> ConnectFuture { - Connect::start(make_tls_mode, config) + pub fn new(tls: T, config: Result) -> ConnectFuture { + Connect::start(tls, config) } } diff --git a/tokio-postgres/src/proto/connect_once.rs b/tokio-postgres/src/proto/connect_once.rs index 2f73df1..750c9ba 100644 --- a/tokio-postgres/src/proto/connect_once.rs +++ b/tokio-postgres/src/proto/connect_once.rs @@ -4,25 +4,23 @@ use futures::{try_ready, Async, Future, Poll, Stream}; use state_machine_future::{transition, RentToOwn, StateMachineFuture}; use std::io; -use crate::proto::{Client, ConnectRawFuture, ConnectSocketFuture, Connection, SimpleQueryStream}; -use crate::{Config, Error, Socket, TargetSessionAttrs, TlsMode}; +use crate::proto::{ + Client, ConnectRawFuture, ConnectSocketFuture, Connection, MaybeTlsStream, SimpleQueryStream, +}; +use crate::{Config, Error, Socket, TargetSessionAttrs, TlsConnect}; #[derive(StateMachineFuture)] pub enum ConnectOnce where - T: TlsMode, + T: TlsConnect, { #[state_machine_future(start, transitions(ConnectingSocket))] - Start { - idx: usize, - tls_mode: T, - config: Config, - }, + Start { idx: usize, tls: T, config: Config }, #[state_machine_future(transitions(ConnectingRaw))] ConnectingSocket { future: ConnectSocketFuture, idx: usize, - tls_mode: T, + tls: T, config: Config, }, #[state_machine_future(transitions(CheckingSessionAttrs, Finished))] @@ -34,17 +32,17 @@ where CheckingSessionAttrs { stream: SimpleQueryStream, client: Client, - connection: Connection, + connection: Connection>, }, #[state_machine_future(ready)] - Finished((Client, Connection)), + Finished((Client, Connection>)), #[state_machine_future(error)] Failed(Error), } impl PollConnectOnce for ConnectOnce where - T: TlsMode, + T: TlsConnect, { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll, Error> { let state = state.take(); @@ -52,7 +50,7 @@ where transition!(ConnectingSocket { future: ConnectSocketFuture::new(state.config.clone(), state.idx), idx: state.idx, - tls_mode: state.tls_mode, + tls: state.tls, config: state.config, }) } @@ -65,7 +63,7 @@ where transition!(ConnectingRaw { target_session_attrs: state.config.0.target_session_attrs, - future: ConnectRawFuture::new(socket, state.tls_mode, state.config, Some(state.idx)), + future: ConnectRawFuture::new(socket, state.tls, state.config, Some(state.idx)), }) } @@ -111,9 +109,9 @@ where impl ConnectOnceFuture where - T: TlsMode, + T: TlsConnect, { - pub fn new(idx: usize, tls_mode: T, config: Config) -> ConnectOnceFuture { - ConnectOnce::start(idx, tls_mode, config) + pub fn new(idx: usize, tls: T, config: Config) -> ConnectOnceFuture { + ConnectOnce::start(idx, tls, config) } } diff --git a/tokio-postgres/src/proto/connect_raw.rs b/tokio-postgres/src/proto/connect_raw.rs index 23f1456..12efae3 100644 --- a/tokio-postgres/src/proto/connect_raw.rs +++ b/tokio-postgres/src/proto/connect_raw.rs @@ -11,14 +11,14 @@ use std::collections::HashMap; use tokio_codec::Framed; use tokio_io::{AsyncRead, AsyncWrite}; -use crate::proto::{Client, Connection, PostgresCodec, TlsFuture}; -use crate::{ChannelBinding, Config, Error, TlsMode}; +use crate::proto::{Client, Connection, MaybeTlsStream, PostgresCodec, TlsFuture}; +use crate::{ChannelBinding, Config, Error, TlsConnect}; #[derive(StateMachineFuture)] pub enum ConnectRaw where S: AsyncRead + AsyncWrite, - T: TlsMode, + T: TlsConnect, { #[state_machine_future(start, transitions(SendingStartup))] Start { @@ -28,47 +28,47 @@ where }, #[state_machine_future(transitions(ReadingAuth))] SendingStartup { - future: sink::Send>, + future: sink::Send, PostgresCodec>>, config: Config, idx: Option, channel_binding: ChannelBinding, }, #[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))] ReadingAuth { - stream: Framed, + stream: Framed, PostgresCodec>, config: Config, idx: Option, channel_binding: ChannelBinding, }, #[state_machine_future(transitions(ReadingAuthCompletion))] SendingPassword { - future: sink::Send>, + future: sink::Send, PostgresCodec>>, config: Config, idx: Option, }, #[state_machine_future(transitions(ReadingSasl))] SendingSasl { - future: sink::Send>, + future: sink::Send, PostgresCodec>>, scram: ScramSha256, config: Config, idx: Option, }, #[state_machine_future(transitions(SendingSasl, ReadingAuthCompletion))] ReadingSasl { - stream: Framed, + stream: Framed, PostgresCodec>, scram: ScramSha256, config: Config, idx: Option, }, #[state_machine_future(transitions(ReadingInfo))] ReadingAuthCompletion { - stream: Framed, + stream: Framed, PostgresCodec>, config: Config, idx: Option, }, #[state_machine_future(transitions(Finished))] ReadingInfo { - stream: Framed, + stream: Framed, PostgresCodec>, process_id: i32, secret_key: i32, parameters: HashMap, @@ -76,7 +76,7 @@ where idx: Option, }, #[state_machine_future(ready)] - Finished((Client, Connection)), + Finished((Client, Connection>)), #[state_machine_future(error)] Failed(Error), } @@ -84,7 +84,7 @@ where impl PollConnectRaw for ConnectRaw where S: AsyncRead + AsyncWrite, - T: TlsMode, + T: TlsConnect, { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll, Error> { let (stream, channel_binding) = try_ready!(state.future.poll()); @@ -377,14 +377,9 @@ where impl ConnectRawFuture where S: AsyncRead + AsyncWrite, - T: TlsMode, + T: TlsConnect, { - pub fn new( - stream: S, - tls_mode: T, - config: Config, - idx: Option, - ) -> ConnectRawFuture { - ConnectRaw::start(TlsFuture::new(stream, tls_mode), config, idx) + pub fn new(stream: S, tls: T, config: Config, idx: Option) -> ConnectRawFuture { + ConnectRaw::start(TlsFuture::new(stream, config.0.ssl_mode, tls), config, idx) } } diff --git a/tokio-postgres/src/proto/maybe_tls_stream.rs b/tokio-postgres/src/proto/maybe_tls_stream.rs new file mode 100644 index 0000000..928674f --- /dev/null +++ b/tokio-postgres/src/proto/maybe_tls_stream.rs @@ -0,0 +1,88 @@ +use bytes::{Buf, BufMut}; +use futures::Poll; +use std::io::{self, Read, Write}; +use tokio_io::{AsyncRead, AsyncWrite}; + +pub enum MaybeTlsStream { + Raw(T), + Tls(U), +} + +impl Read for MaybeTlsStream +where + T: Read, + U: Read, +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self { + MaybeTlsStream::Raw(s) => s.read(buf), + MaybeTlsStream::Tls(s) => s.read(buf), + } + } +} + +impl AsyncRead for MaybeTlsStream +where + T: AsyncRead, + U: AsyncRead, +{ + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + match self { + MaybeTlsStream::Raw(s) => s.prepare_uninitialized_buffer(buf), + MaybeTlsStream::Tls(s) => s.prepare_uninitialized_buffer(buf), + } + } + + fn read_buf(&mut self, buf: &mut B) -> Poll + where + B: BufMut, + { + match self { + MaybeTlsStream::Raw(s) => s.read_buf(buf), + MaybeTlsStream::Tls(s) => s.read_buf(buf), + } + } +} + +impl Write for MaybeTlsStream +where + T: Write, + U: Write, +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + match self { + MaybeTlsStream::Raw(s) => s.write(buf), + MaybeTlsStream::Tls(s) => s.write(buf), + } + } + + fn flush(&mut self) -> io::Result<()> { + match self { + MaybeTlsStream::Raw(s) => s.flush(), + MaybeTlsStream::Tls(s) => s.flush(), + } + } +} + +impl AsyncWrite for MaybeTlsStream +where + T: AsyncWrite, + U: AsyncWrite, +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + match self { + MaybeTlsStream::Raw(s) => s.shutdown(), + MaybeTlsStream::Tls(s) => s.shutdown(), + } + } + + fn write_buf(&mut self, buf: &mut B) -> Poll + where + B: Buf, + { + match self { + MaybeTlsStream::Raw(s) => s.write_buf(buf), + MaybeTlsStream::Tls(s) => s.write_buf(buf), + } + } +} diff --git a/tokio-postgres/src/proto/mod.rs b/tokio-postgres/src/proto/mod.rs index 796ded1..7667901 100644 --- a/tokio-postgres/src/proto/mod.rs +++ b/tokio-postgres/src/proto/mod.rs @@ -36,6 +36,7 @@ mod copy_in; mod copy_out; mod execute; mod idle; +mod maybe_tls_stream; mod portal; mod prepare; mod query; @@ -64,6 +65,7 @@ pub use crate::proto::connection::Connection; pub use crate::proto::copy_in::CopyInFuture; pub use crate::proto::copy_out::CopyOutStream; pub use crate::proto::execute::ExecuteFuture; +pub use crate::proto::maybe_tls_stream::MaybeTlsStream; pub use crate::proto::portal::Portal; pub use crate::proto::prepare::PrepareFuture; pub use crate::proto::query::QueryStream; diff --git a/tokio-postgres/src/proto/tls.rs b/tokio-postgres/src/proto/tls.rs index 64a2172..e4274a6 100644 --- a/tokio-postgres/src/proto/tls.rs +++ b/tokio-postgres/src/proto/tls.rs @@ -4,54 +4,65 @@ use state_machine_future::{transition, RentToOwn, StateMachineFuture}; use tokio_io::io::{self, ReadExact, WriteAll}; use tokio_io::{AsyncRead, AsyncWrite}; -use crate::{ChannelBinding, Error, TlsMode}; +use crate::proto::MaybeTlsStream; +use crate::tls::private::ForcePrivateApi; +use crate::{ChannelBinding, Error, SslMode, TlsConnect}; #[derive(StateMachineFuture)] pub enum Tls where - T: TlsMode, + T: TlsConnect, S: AsyncRead + AsyncWrite, { - #[state_machine_future(start, transitions(SendingTls, ConnectingTls))] - Start { stream: S, tls_mode: T }, + #[state_machine_future(start, transitions(SendingTls, Ready))] + Start { stream: S, mode: SslMode, tls: T }, #[state_machine_future(transitions(ReadingTls))] SendingTls { future: WriteAll>, - tls_mode: T, + mode: SslMode, + tls: T, }, - #[state_machine_future(transitions(ConnectingTls))] + #[state_machine_future(transitions(ConnectingTls, Ready))] ReadingTls { future: ReadExact, - tls_mode: T, + mode: SslMode, + tls: T, }, #[state_machine_future(transitions(Ready))] ConnectingTls { future: T::Future }, #[state_machine_future(ready)] - Ready((T::Stream, ChannelBinding)), + Ready((MaybeTlsStream, ChannelBinding)), #[state_machine_future(error)] Failed(Error), } impl PollTls for Tls where - T: TlsMode, + T: TlsConnect, S: AsyncRead + AsyncWrite, { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll, Error> { let state = state.take(); - if state.tls_mode.request_tls() { - let mut buf = vec![]; - frontend::ssl_request(&mut buf); + match state.mode { + SslMode::Disable => transition!(Ready(( + MaybeTlsStream::Raw(state.stream), + ChannelBinding::none() + ))), + SslMode::Prefer if !state.tls.can_connect(ForcePrivateApi) => transition!(Ready(( + MaybeTlsStream::Raw(state.stream), + ChannelBinding::none() + ))), + SslMode::Prefer | SslMode::Require => { + let mut buf = vec![]; + frontend::ssl_request(&mut buf); - transition!(SendingTls { - future: io::write_all(state.stream, buf), - tls_mode: state.tls_mode, - }) - } else { - transition!(ConnectingTls { - future: state.tls_mode.handle_tls(false, state.stream), - }) + transition!(SendingTls { + future: io::write_all(state.stream, buf), + mode: state.mode, + tls: state.tls, + }) + } } } @@ -62,7 +73,8 @@ where let state = state.take(); transition!(ReadingTls { future: io::read_exact(stream, [0]), - tls_mode: state.tls_mode, + mode: state.mode, + tls: state.tls, }) } @@ -72,26 +84,32 @@ where let (stream, buf) = try_ready!(state.future.poll().map_err(Error::io)); let state = state.take(); - let use_tls = buf[0] == b'S'; - transition!(ConnectingTls { - future: state.tls_mode.handle_tls(use_tls, stream) - }) + if buf[0] == b'S' { + transition!(ConnectingTls { + future: state.tls.connect(stream), + }) + } else if state.mode == SslMode::Require { + Err(Error::tls("server does not support TLS".into())) + } else { + transition!(Ready((MaybeTlsStream::Raw(stream), ChannelBinding::none()))) + } } fn poll_connecting_tls<'a>( state: &'a mut RentToOwn<'a, ConnectingTls>, ) -> Poll, Error> { - let t = try_ready!(state.future.poll().map_err(|e| Error::tls(e.into()))); - transition!(Ready(t)) + let (stream, channel_binding) = + try_ready!(state.future.poll().map_err(|e| Error::tls(e.into()))); + transition!(Ready((MaybeTlsStream::Tls(stream), channel_binding))) } } impl TlsFuture where - T: TlsMode, + T: TlsConnect, S: AsyncRead + AsyncWrite, { - pub fn new(stream: S, tls_mode: T) -> TlsFuture { - Tls::start(stream, tls_mode) + pub fn new(stream: S, mode: SslMode, tls: T) -> TlsFuture { + Tls::start(stream, mode, tls) } } diff --git a/tokio-postgres/src/tls.rs b/tokio-postgres/src/tls.rs index a5170dd..7eabfab 100644 --- a/tokio-postgres/src/tls.rs +++ b/tokio-postgres/src/tls.rs @@ -1,11 +1,13 @@ -use bytes::{Buf, BufMut}; use futures::future::{self, FutureResult}; -use futures::{try_ready, Async, Future, Poll}; +use futures::{Future, Poll}; use std::error::Error; use std::fmt; use std::io::{self, Read, Write}; use tokio_io::{AsyncRead, AsyncWrite}; -use void::Void; + +pub(crate) mod private { + pub struct ForcePrivateApi; +} pub struct ChannelBinding { pub(crate) tls_server_end_point: Option>, @@ -25,25 +27,6 @@ impl ChannelBinding { } } -#[cfg(feature = "runtime")] -pub trait MakeTlsMode { - type Stream: AsyncRead + AsyncWrite; - type TlsMode: TlsMode; - type Error: Into>; - - fn make_tls_mode(&mut self, domain: &str) -> Result; -} - -pub trait TlsMode { - type Stream: AsyncRead + AsyncWrite; - type Error: Into>; - type Future: Future; - - fn request_tls(&self) -> bool; - - fn handle_tls(self, use_tls: bool, stream: S) -> Self::Future; -} - #[cfg(feature = "runtime")] pub trait MakeTlsConnect { type Stream: AsyncRead + AsyncWrite; @@ -59,271 +42,74 @@ pub trait TlsConnect { type Future: Future; fn connect(self, stream: S) -> Self::Future; -} - -#[derive(Debug, Copy, Clone)] -pub struct NoTls; - -#[cfg(feature = "runtime")] -impl MakeTlsMode for NoTls -where - S: AsyncRead + AsyncWrite, -{ - type Stream = S; - type TlsMode = NoTls; - type Error = Void; - - fn make_tls_mode(&mut self, _: &str) -> Result { - Ok(NoTls) - } -} - -impl TlsMode for NoTls -where - S: AsyncRead + AsyncWrite, -{ - type Stream = S; - type Error = Void; - type Future = FutureResult<(S, ChannelBinding), Void>; - - fn request_tls(&self) -> bool { - false - } - - fn handle_tls(self, use_tls: bool, stream: S) -> FutureResult<(S, ChannelBinding), Void> { - debug_assert!(!use_tls); - future::ok((stream, ChannelBinding::none())) + #[doc(hidden)] + fn can_connect(&self, _: private::ForcePrivateApi) -> bool { + true } } #[derive(Debug, Copy, Clone)] -pub struct PreferTls(pub T); +pub struct NoTls; #[cfg(feature = "runtime")] -impl MakeTlsMode for PreferTls -where - T: MakeTlsConnect, - S: AsyncRead + AsyncWrite, -{ - type Stream = MaybeTlsStream; - type TlsMode = PreferTls; - type Error = T::Error; +impl MakeTlsConnect for NoTls where { + type Stream = NoTlsStream; + type TlsConnect = NoTls; + type Error = NoTlsError; - fn make_tls_mode(&mut self, domain: &str) -> Result, T::Error> { - self.0.make_tls_connect(domain).map(PreferTls) + fn make_tls_connect(&mut self, _: &str) -> Result { + Ok(NoTls) } } -impl TlsMode for PreferTls -where - T: TlsConnect, - S: AsyncRead + AsyncWrite, -{ - type Stream = MaybeTlsStream; - type Error = T::Error; - type Future = PreferTlsFuture; - - fn request_tls(&self) -> bool { - true - } - - fn handle_tls(self, use_tls: bool, stream: S) -> PreferTlsFuture { - let f = if use_tls { - PreferTlsFutureInner::Tls(self.0.connect(stream)) - } else { - PreferTlsFutureInner::Raw(Some(stream)) - }; +impl TlsConnect for NoTls { + type Stream = NoTlsStream; + type Error = NoTlsError; + type Future = FutureResult<(NoTlsStream, ChannelBinding), NoTlsError>; - PreferTlsFuture(f) + fn connect(self, _: S) -> FutureResult<(NoTlsStream, ChannelBinding), NoTlsError> { + future::err(NoTlsError(())) } -} - -enum PreferTlsFutureInner { - Tls(F), - Raw(Option), -} -pub struct PreferTlsFuture(PreferTlsFutureInner); - -impl Future for PreferTlsFuture -where - F: Future, -{ - type Item = (MaybeTlsStream, ChannelBinding); - type Error = F::Error; - - fn poll(&mut self) -> Poll<(MaybeTlsStream, ChannelBinding), F::Error> { - match &mut self.0 { - PreferTlsFutureInner::Tls(f) => { - let (stream, channel_binding) = try_ready!(f.poll()); - Ok(Async::Ready((MaybeTlsStream::Tls(stream), channel_binding))) - } - PreferTlsFutureInner::Raw(s) => Ok(Async::Ready(( - MaybeTlsStream::Raw(s.take().expect("future polled after completion")), - ChannelBinding::none(), - ))), - } + fn can_connect(&self, _: private::ForcePrivateApi) -> bool { + false } } -pub enum MaybeTlsStream { - Tls(T), - Raw(U), -} +pub enum NoTlsStream {} -impl Read for MaybeTlsStream -where - T: Read, - U: Read, -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self { - MaybeTlsStream::Tls(s) => s.read(buf), - MaybeTlsStream::Raw(s) => s.read(buf), - } +impl Read for NoTlsStream { + fn read(&mut self, _: &mut [u8]) -> io::Result { + match *self {} } } -impl AsyncRead for MaybeTlsStream -where - T: AsyncRead, - U: AsyncRead, -{ - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - match self { - MaybeTlsStream::Tls(s) => s.prepare_uninitialized_buffer(buf), - MaybeTlsStream::Raw(s) => s.prepare_uninitialized_buffer(buf), - } - } - - fn read_buf(&mut self, buf: &mut B) -> Poll - where - B: BufMut, - { - match self { - MaybeTlsStream::Tls(s) => s.read_buf(buf), - MaybeTlsStream::Raw(s) => s.read_buf(buf), - } - } -} +impl AsyncRead for NoTlsStream {} -impl Write for MaybeTlsStream -where - T: Write, - U: Write, -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - match self { - MaybeTlsStream::Tls(s) => s.write(buf), - MaybeTlsStream::Raw(s) => s.write(buf), - } +impl Write for NoTlsStream { + fn write(&mut self, _: &[u8]) -> io::Result { + match *self {} } fn flush(&mut self) -> io::Result<()> { - match self { - MaybeTlsStream::Tls(s) => s.flush(), - MaybeTlsStream::Raw(s) => s.flush(), - } + match *self {} } } -impl AsyncWrite for MaybeTlsStream -where - T: AsyncWrite, - U: AsyncWrite, -{ +impl AsyncWrite for NoTlsStream { fn shutdown(&mut self) -> Poll<(), io::Error> { - match self { - MaybeTlsStream::Tls(s) => s.shutdown(), - MaybeTlsStream::Raw(s) => s.shutdown(), - } - } - - fn write_buf(&mut self, buf: &mut B) -> Poll - where - B: Buf, - { - match self { - MaybeTlsStream::Tls(s) => s.write_buf(buf), - MaybeTlsStream::Raw(s) => s.write_buf(buf), - } - } -} - -#[derive(Debug, Copy, Clone)] -pub struct RequireTls(pub T); - -#[cfg(feature = "runtime")] -impl MakeTlsMode for RequireTls -where - T: MakeTlsConnect, -{ - type Stream = T::Stream; - type TlsMode = RequireTls; - type Error = T::Error; - - fn make_tls_mode(&mut self, domain: &str) -> Result, T::Error> { - self.0.make_tls_connect(domain).map(RequireTls) - } -} - -impl TlsMode for RequireTls -where - T: TlsConnect, -{ - type Stream = T::Stream; - type Error = Box; - type Future = RequireTlsFuture; - - fn request_tls(&self) -> bool { - true - } - - fn handle_tls(self, use_tls: bool, stream: S) -> RequireTlsFuture { - let f = if use_tls { - Ok(self.0.connect(stream)) - } else { - Err(TlsUnsupportedError(()).into()) - }; - - RequireTlsFuture { f: Some(f) } + match *self {} } } #[derive(Debug)] -pub struct TlsUnsupportedError(()); +pub struct NoTlsError(()); -impl fmt::Display for TlsUnsupportedError { +impl fmt::Display for NoTlsError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.write_str("TLS was required but not supported by the server") + fmt.write_str("no TLS implementation configured") } } -impl Error for TlsUnsupportedError {} - -pub struct RequireTlsFuture { - f: Option>>, -} - -impl Future for RequireTlsFuture -where - T: Future, - T::Error: Into>, -{ - type Item = T::Item; - type Error = Box; - - fn poll(&mut self) -> Poll> { - match self.f.take().expect("future polled after completion") { - Ok(mut f) => match f.poll().map_err(Into::into)? { - Async::Ready(r) => Ok(Async::Ready(r)), - Async::NotReady => { - self.f = Some(Ok(f)); - Ok(Async::NotReady) - } - }, - Err(e) => Err(e), - } - } -} +impl Error for NoTlsError {} diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index c770c9f..b7001dd 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -12,7 +12,7 @@ use tokio::runtime::current_thread::Runtime; use tokio::timer::Delay; use tokio_postgres::error::SqlState; use tokio_postgres::types::{Kind, Type}; -use tokio_postgres::{AsyncMessage, Client, Connection, NoTls}; +use tokio_postgres::{AsyncMessage, Client, Connection, NoTls, NoTlsStream}; mod parse; #[cfg(feature = "runtime")] @@ -21,7 +21,8 @@ mod types; fn connect( s: &str, -) -> impl Future), Error = tokio_postgres::Error> { +) -> impl Future), Error = tokio_postgres::Error> +{ let builder = s.parse::().unwrap(); TcpStream::connect(&"127.0.0.1:5433".parse().unwrap()) .map_err(|e| panic!("{}", e))