Skip to content

Commit

Permalink
Move the TLS mode into config
Browse files Browse the repository at this point in the history
  • Loading branch information
sfackler committed Jan 13, 2019
1 parent dfc614b commit 2d3b9bb
Show file tree
Hide file tree
Showing 18 changed files with 356 additions and 424 deletions.
8 changes: 4 additions & 4 deletions postgres/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::io::{self, Read};
use tokio_postgres::types::{ToSql, Type};
use tokio_postgres::Error;
#[cfg(feature = "runtime")]
use tokio_postgres::{MakeTlsMode, Socket, TlsMode};
use tokio_postgres::{MakeTlsConnect, Socket, TlsConnect};

#[cfg(feature = "runtime")]
use crate::Config;
Expand All @@ -15,10 +15,10 @@ impl Client {
#[cfg(feature = "runtime")]
pub fn connect<T>(params: &str, tls_mode: T) -> Result<Client, Error>
where
T: MakeTlsMode<Socket> + 'static + Send,
T::TlsMode: Send,
T: MakeTlsConnect<Socket> + 'static + Send,
T::TlsConnect: Send,
T::Stream: Send,
<T::TlsMode as TlsMode<Socket>>::Future: Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
params.parse::<Config>()?.connect(tls_mode)
}
Expand Down
8 changes: 4 additions & 4 deletions postgres/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use log::error;
use std::path::Path;
use std::str::FromStr;
use std::time::Duration;
use tokio_postgres::{Error, MakeTlsMode, Socket, TargetSessionAttrs, TlsMode};
use tokio_postgres::{Error, MakeTlsConnect, Socket, TargetSessionAttrs, TlsConnect};

use crate::{Client, RUNTIME};

Expand Down Expand Up @@ -94,10 +94,10 @@ impl Config {

pub fn connect<T>(&self, tls_mode: T) -> Result<Client, Error>
where
T: MakeTlsMode<Socket> + 'static + Send,
T::TlsMode: Send,
T: MakeTlsConnect<Socket> + 'static + Send,
T::TlsConnect: Send,
T::Stream: Send,
<T::TlsMode as TlsMode<Socket>>::Future: Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
let connect = self.0.connect(tls_mode);
let (client, connection) = oneshot::spawn(connect, &RUNTIME.executor()).wait()?;
Expand Down
14 changes: 7 additions & 7 deletions tokio-postgres-native-tls/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use futures::{Future, Stream};
use native_tls::{self, Certificate};
use tokio::net::TcpStream;
use tokio::runtime::current_thread::Runtime;
use tokio_postgres::{self, PreferTls, RequireTls, TlsMode};
use tokio_postgres::TlsConnect;

use crate::TlsConnector;

fn smoke_test<T>(s: &str, tls: T)
where
T: TlsMode<TcpStream>,
T: TlsConnect<TcpStream>,
T::Stream: 'static,
{
let mut runtime = Runtime::new().unwrap();
Expand Down Expand Up @@ -44,8 +44,8 @@ fn require() {
.build()
.unwrap();
smoke_test(
"user=ssl_user dbname=postgres",
RequireTls(TlsConnector::with_connector(connector, "localhost")),
"user=ssl_user dbname=postgres sslmode=require",
TlsConnector::with_connector(connector, "localhost"),
);
}

Expand All @@ -59,7 +59,7 @@ fn prefer() {
.unwrap();
smoke_test(
"user=ssl_user dbname=postgres",
PreferTls(TlsConnector::with_connector(connector, "localhost")),
TlsConnector::with_connector(connector, "localhost"),
);
}

Expand All @@ -72,7 +72,7 @@ fn scram_user() {
.build()
.unwrap();
smoke_test(
"user=scram_user password=password dbname=postgres",
RequireTls(TlsConnector::with_connector(connector, "localhost")),
"user=scram_user password=password dbname=postgres sslmode=require",
TlsConnector::with_connector(connector, "localhost"),
);
}
18 changes: 9 additions & 9 deletions tokio-postgres-openssl/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use futures::{Future, Stream};
use openssl::ssl::{SslConnector, SslMethod};
use tokio::net::TcpStream;
use tokio::runtime::current_thread::Runtime;
use tokio_postgres::{self, PreferTls, RequireTls, TlsMode};
use tokio_postgres::TlsConnect;

use super::*;

fn smoke_test<T>(s: &str, tls: T)
where
T: TlsMode<TcpStream>,
T: TlsConnect<TcpStream>,
T::Stream: 'static,
{
let mut runtime = Runtime::new().unwrap();
Expand Down Expand Up @@ -41,8 +41,8 @@ fn require() {
builder.set_ca_file("../test/server.crt").unwrap();
let ctx = builder.build();
smoke_test(
"user=ssl_user dbname=postgres",
RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
"user=ssl_user dbname=postgres sslmode=require",
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
);
}

Expand All @@ -53,7 +53,7 @@ fn prefer() {
let ctx = builder.build();
smoke_test(
"user=ssl_user dbname=postgres",
PreferTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
);
}

Expand All @@ -63,8 +63,8 @@ fn scram_user() {
builder.set_ca_file("../test/server.crt").unwrap();
let ctx = builder.build();
smoke_test(
"user=scram_user password=password dbname=postgres",
RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
"user=scram_user password=password dbname=postgres sslmode=require",
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
);
}

Expand All @@ -78,8 +78,8 @@ fn runtime() {
let connector = MakeTlsConnector::new(builder.build());

let connect = tokio_postgres::connect(
"host=localhost port=5433 user=postgres",
RequireTls(connector),
"host=localhost port=5433 user=postgres sslmode=require",
connector,
);
let (mut client, connection) = runtime.block_on(connect).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
Expand Down
1 change: 0 additions & 1 deletion tokio-postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ postgres-protocol = { version = "0.3.0", path = "../postgres-protocol" }
state_machine_future = "0.1.7"
tokio-codec = "0.1"
tokio-io = "0.1"
void = "1.0"

tokio-tcp = { version = "0.1", optional = true }
futures-cpupool = { version = "0.1", optional = true }
Expand Down
48 changes: 40 additions & 8 deletions tokio-postgres/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto::ConnectFuture;
use crate::proto::ConnectRawFuture;
#[cfg(feature = "runtime")]
use crate::{Connect, MakeTlsMode, Socket};
use crate::{ConnectRaw, Error, TlsMode};
use crate::{Connect, MakeTlsConnect, Socket};
use crate::{ConnectRaw, Error, TlsConnect};

/// Properties required of a session.
#[cfg(feature = "runtime")]
Expand All @@ -34,6 +34,17 @@ pub enum TargetSessionAttrs {
__NonExhaustive,
}

/// TLS configuration.
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum SslMode {
/// Do not use TLS.
Disable,
/// Attempt to connect with TLS but allow sessions without.
Prefer,
/// Require the use of TLS.
Require,
}

#[cfg(feature = "runtime")]
#[derive(Debug, Clone, PartialEq)]
pub(crate) enum Host {
Expand All @@ -49,6 +60,7 @@ pub(crate) struct Inner {
pub(crate) dbname: Option<String>,
pub(crate) options: Option<String>,
pub(crate) application_name: Option<String>,
pub(crate) ssl_mode: SslMode,
#[cfg(feature = "runtime")]
pub(crate) host: Vec<Host>,
#[cfg(feature = "runtime")]
Expand Down Expand Up @@ -79,6 +91,8 @@ pub(crate) struct Inner {
/// * `dbname` - The name of the database to connect to. Defaults to the username.
/// * `options` - Command line options used to configure the server.
/// * `application_name` - Sets the `application_name` parameter on the server.
/// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used
/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`.
/// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the
/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts
/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting
Expand Down Expand Up @@ -152,6 +166,7 @@ impl Config {
dbname: None,
options: None,
application_name: None,
ssl_mode: SslMode::Prefer,
#[cfg(feature = "runtime")]
host: vec![],
#[cfg(feature = "runtime")]
Expand Down Expand Up @@ -204,6 +219,14 @@ impl Config {
self
}

/// Sets the SSL configuration.
///
/// Defaults to `prefer`.
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
Arc::make_mut(&mut self.0).ssl_mode = ssl_mode;
self
}

/// Adds a host to the configuration.
///
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
Expand Down Expand Up @@ -320,6 +343,15 @@ impl Config {
"application_name" => {
self.application_name(&value);
}
"sslmode" => {
let mode = match value {
"disable" => SslMode::Disable,
"prefer" => SslMode::Prefer,
"require" => SslMode::Require,
_ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))),
};
self.ssl_mode(mode);
}
#[cfg(feature = "runtime")]
"host" => {
for host in value.split(',') {
Expand Down Expand Up @@ -390,22 +422,22 @@ impl Config {
///
/// Requires the `runtime` Cargo feature (enabled by default).
#[cfg(feature = "runtime")]
pub fn connect<T>(&self, make_tls_mode: T) -> Connect<T>
pub fn connect<T>(&self, tls: T) -> Connect<T>
where
T: MakeTlsMode<Socket>,
T: MakeTlsConnect<Socket>,
{
Connect(ConnectFuture::new(make_tls_mode, Ok(self.clone())))
Connect(ConnectFuture::new(tls, Ok(self.clone())))
}

/// Connects to a PostgreSQL database over an arbitrary stream.
///
/// All of the settings other than `user`, `password`, `dbname`, `options`, and `application` name are ignored.
pub fn connect_raw<S, T>(&self, stream: S, tls_mode: T) -> ConnectRaw<S, T>
pub fn connect_raw<S, T>(&self, stream: S, tls: T) -> ConnectRaw<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
T: TlsConnect<S>,
{
ConnectRaw(ConnectRawFuture::new(stream, tls_mode, self.clone(), None))
ConnectRaw(ConnectRawFuture::new(stream, tls, self.clone(), None))
}
}

Expand Down
Loading

0 comments on commit 2d3b9bb

Please sign in to comment.