Skip to content

Commit

Permalink
Add socket keepalive and timeout options
Browse files Browse the repository at this point in the history
  • Loading branch information
benashford committed Feb 12, 2024
1 parent 245cc34 commit c841a5a
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 32 deletions.
6 changes: 3 additions & 3 deletions examples/monitor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2017-2022 Ben Ashford
* Copyright 2017-2024 Ben Ashford
*
* Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
* http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
Expand All @@ -21,12 +21,12 @@ async fn main() {
.unwrap_or_else(|| "127.0.0.1".to_string());

#[cfg(not(feature = "tls"))]
let mut connection = client::connect(&addr, 6379)
let mut connection = client::connect(&addr, 6379, None, None)
.await
.expect("Cannot connect to Redis");

#[cfg(feature = "tls")]
let mut connection = client::connect_tls(&addr, 6379)
let mut connection = client::connect_tls(&addr, 6379, None, None)
.await
.expect("Cannot connect to Redis");

Expand Down
20 changes: 18 additions & 2 deletions src/client/builder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2021 Ben Ashford
* Copyright 2020-2024 Ben Ashford
*
* Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
* http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
Expand All @@ -8,7 +8,7 @@
* except according to those terms.
*/

use std::sync::Arc;
use std::{sync::Arc, time::Duration};

use crate::error;

Expand All @@ -21,6 +21,8 @@ pub struct ConnectionBuilder {
pub(crate) password: Option<Arc<str>>,
#[cfg(feature = "tls")]
pub(crate) tls: bool,
pub(crate) socket_keepalive: Option<Duration>,
pub(crate) socket_timeout: Option<Duration>,
}

impl ConnectionBuilder {
Expand All @@ -32,6 +34,8 @@ impl ConnectionBuilder {
password: None,
#[cfg(feature = "tls")]
tls: false,
socket_keepalive: None,
socket_timeout: None,
})
}

Expand All @@ -52,4 +56,16 @@ impl ConnectionBuilder {
self.tls = true;
self
}

/// Set the socket keepalive duration
pub fn socket_keepalive(&mut self, duration: Duration) -> &mut Self {
self.socket_keepalive = Some(duration);
self
}

/// Set the socket timeout duration
pub fn socket_timeout(&mut self, duration: Duration) -> &mut Self {
self.socket_timeout = Some(duration);
self
}
}
56 changes: 38 additions & 18 deletions src/client/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ impl AsyncRead for RespConnectionInner {

pub type RespConnection = Framed<RespConnectionInner, RespCodec>;

const DEFAULT_KEEPALIVE_DURATION: Duration = Duration::from_secs(5);

/// Connect to a Redis server and return a Future that resolves to a
/// `RespConnection` for reading and writing asynchronously.
///
Expand All @@ -114,22 +112,32 @@ const DEFAULT_KEEPALIVE_DURATION: Duration = Duration::from_secs(5);
///
/// But since most Redis usages involve issue commands that result in one
/// single result, this library also implements `paired_connect`.
pub async fn connect(host: &str, port: u16) -> Result<RespConnection, error::Error> {
pub async fn connect(
host: &str,
port: u16,
socket_keepalive: Option<Duration>,
socket_timeout: Option<Duration>,
) -> Result<RespConnection, error::Error> {
let tcp_stream = TcpStream::connect((host, port)).await?;
apply_keepalive(&tcp_stream, DEFAULT_KEEPALIVE_DURATION)?;
apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?;
Ok(RespCodec.framed(RespConnectionInner::Plain { stream: tcp_stream }))
}

#[cfg(feature = "with-rustls")]
pub async fn connect_tls(host: &str, port: u16) -> Result<RespConnection, error::Error> {
pub async fn connect_tls(
host: &str,
port: u16,
socket_keepalive: Option<Duration>,
socket_timeout: Option<Duration>,
) -> Result<RespConnection, error::Error> {
use std::sync::Arc;
use tokio_rustls::{
rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore},
TlsConnector,
};

let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
Expand All @@ -149,7 +157,7 @@ pub async fn connect_tls(host: &str, port: u16) -> Result<RespConnection, error:
error::ConnectionReason::ConnectionFailed,
))?;
let tcp_stream = TcpStream::connect(addr).await?;
apply_keepalive(&tcp_stream, DEFAULT_KEEPALIVE_DURATION)?;
apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?;

let stream = connector
.connect(
Expand Down Expand Up @@ -185,15 +193,17 @@ pub async fn connect_with_auth(
username: Option<&str>,
password: Option<&str>,
#[allow(unused_variables)] tls: bool,
socket_keepalive: Option<Duration>,
socket_timeout: Option<Duration>,
) -> Result<RespConnection, error::Error> {
#[cfg(feature = "tls")]
let mut connection = if tls {
connect_tls(host, port).await?
connect_tls(host, port, socket_keepalive, socket_timeout).await?
} else {
connect(host, port).await?
connect(host, port, socket_keepalive, socket_timeout).await?
};
#[cfg(not(feature = "tls"))]
let mut connection = connect(host, port).await?;
let mut connection = connect(host, port, socket_keepalive, socket_timeout).await?;

if let Some(password) = password {
let mut auth = resp_array!["AUTH"];
Expand Down Expand Up @@ -223,15 +233,25 @@ pub async fn connect_with_auth(
}

/// Apply a custom keep-alive value to the connection
fn apply_keepalive(stream: &TcpStream, interval: Duration) -> Result<(), error::Error> {
fn apply_keepalive_and_timeouts(
stream: &TcpStream,
socket_keepalive: Option<Duration>,
socket_timeout: Option<Duration>,
) -> Result<(), error::Error> {
let sock_ref = socket2::SockRef::from(stream);

let keep_alive = socket2::TcpKeepalive::new()
.with_time(interval)
.with_interval(interval)
.with_retries(1);
if let Some(interval) = socket_keepalive {
let keep_alive = socket2::TcpKeepalive::new()
.with_time(interval)
.with_interval(interval)
.with_retries(1);
sock_ref.set_tcp_keepalive(&keep_alive)?;
}

sock_ref.set_tcp_keepalive(&keep_alive)?;
if let Some(timeout) = socket_timeout {
sock_ref.set_read_timeout(Some(timeout))?;
sock_ref.set_write_timeout(Some(timeout))?;
}

Ok(())
}
Expand All @@ -247,7 +267,7 @@ mod test {

#[tokio::test]
async fn can_connect() {
let mut connection = super::connect("127.0.0.1", 6379)
let mut connection = super::connect("127.0.0.1", 6379, None, None)
.await
.expect("Cannot connect");
connection
Expand All @@ -266,7 +286,7 @@ mod test {

#[tokio::test]
async fn complex_test() {
let mut connection = super::connect("127.0.0.1", 6379)
let mut connection = super::connect("127.0.0.1", 6379, None, None)
.await
.expect("Cannot connect");
let mut ops = Vec::new();
Expand Down
29 changes: 26 additions & 3 deletions src/client/paired.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2017-2021 Ben Ashford
* Copyright 2017-2024 Ben Ashford
*
* Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
* http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
Expand All @@ -15,6 +15,7 @@ use std::mem;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use futures_channel::{mpsc, oneshot};
use futures_sink::Sink;
Expand Down Expand Up @@ -210,10 +211,21 @@ async fn inner_conn_fn(
username: Option<Arc<str>>,
password: Option<Arc<str>>,
tls: bool,
socket_keepalive: Option<Duration>,
socket_timeout: Option<Duration>,
) -> Result<mpsc::UnboundedSender<SendPayload>, error::Error> {
let username = username.as_ref().map(|u| u.as_ref());
let password = password.as_ref().map(|p| p.as_ref());
let connection = connect_with_auth(&host, port, username, password, tls).await?;
let connection = connect_with_auth(
&host,
port,
username,
password,
tls,
socket_keepalive,
socket_timeout,
)
.await?;
let (out_tx, out_rx) = mpsc::unbounded();
let paired_connection_inner = PairedConnectionInner::new(connection, out_rx);
tokio::spawn(paired_connection_inner);
Expand All @@ -236,8 +248,19 @@ impl ConnectionBuilder {
#[cfg(not(feature = "tls"))]
let tls = false;

let socket_keepalive = self.socket_keepalive;
let socket_timeout = self.socket_timeout;

let conn_fn = move || {
let con_f = inner_conn_fn(host.clone(), port, username.clone(), password.clone(), tls);
let con_f = inner_conn_fn(
host.clone(),
port,
username.clone(),
password.clone(),
tls,
socket_keepalive,
socket_timeout,
);
Box::pin(con_f) as Pin<Box<dyn Future<Output = Result<_, error::Error>> + Send + Sync>>
};

Expand Down
30 changes: 26 additions & 4 deletions src/client/pubsub/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2017-2023 Ben Ashford
* Copyright 2017-2024 Ben Ashford
*
* Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
* http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
Expand All @@ -14,6 +14,7 @@ use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use futures_channel::{mpsc, oneshot};
use futures_util::{
Expand Down Expand Up @@ -59,11 +60,22 @@ async fn inner_conn_fn(
username: Option<Arc<str>>,
password: Option<Arc<str>>,
tls: bool,
socket_keepalive: Option<Duration>,
socket_timeout: Option<Duration>,
) -> Result<mpsc::UnboundedSender<PubsubEvent>, error::Error> {
let username = username.as_deref();
let password = password.as_deref();

let connection = connect_with_auth(&host, port, username, password, tls).await?;
let connection = connect_with_auth(
&host,
port,
username,
password,
tls,
socket_keepalive,
socket_timeout,
)
.await?;
let (out_tx, out_rx) = mpsc::unbounded();
tokio::spawn(async {
match PubsubConnectionInner::new(connection, out_rx).await {
Expand All @@ -87,13 +99,23 @@ impl ConnectionBuilder {
let host = self.host.clone();
let port = self.port;

let socket_keepalive = self.socket_keepalive;
let socket_timeout = self.socket_timeout;

let reconnecting_f = reconnect(
|con: &mpsc::UnboundedSender<PubsubEvent>, act| {
con.unbounded_send(act).map_err(|e| e.into())
},
move || {
let con_f =
inner_conn_fn(host.clone(), port, username.clone(), password.clone(), tls);
let con_f = inner_conn_fn(
host.clone(),
port,
username.clone(),
password.clone(),
tls,
socket_keepalive,
socket_timeout,
);
Box::pin(con_f)
},
);
Expand Down
4 changes: 2 additions & 2 deletions src/resp.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2017-2023 Ben Ashford
* Copyright 2017-2024 Ben Ashford
*
* Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
* http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
Expand Down Expand Up @@ -767,7 +767,7 @@ mod tests {

let vals = vec![String::from("a"), String::from("b")];
#[allow(clippy::needless_borrow)]
let resp_object = resp_array!["RPUSH", "xyz"].append(&vals);
let resp_object = resp_array!["RPUSH", "xyz"].append(vals);
let bytes = obj_to_bytes(resp_object);
assert_eq!(
&b"*4\r\n$5\r\nRPUSH\r\n$3\r\nxyz\r\n$1\r\na\r\n$1\r\nb\r\n"[..],
Expand Down

0 comments on commit c841a5a

Please sign in to comment.