diff --git a/src/client/legacy/client.rs b/src/client/legacy/client.rs index 8562584..2deafa9 100644 --- a/src/client/legacy/client.rs +++ b/src/client/legacy/client.rs @@ -57,7 +57,7 @@ pub struct Error { kind: ErrorKind, source: Option>, #[cfg(any(feature = "http1", feature = "http2"))] - connect_info: Option, + connect_info: Option, } #[derive(Debug)] @@ -71,6 +71,34 @@ enum ErrorKind { SendRequest, } +/// Extra information about a failed connection. +pub struct ErroredConnectInfo { + conn_info: Connected, + is_reused: bool, +} + +impl ErroredConnectInfo { + /// Determines if the connected transport is to an HTTP proxy. + pub fn is_proxied(&self) -> bool { + self.conn_info.is_proxied() + } + + /// Copies the extra connection information into an `Extensions` map. + pub fn get_extras(&self, extensions: &mut http::Extensions) { + self.conn_info.get_extras(extensions); + } + + /// Determines if the connected transport negotiated HTTP/2 as its next protocol. + pub fn is_negotiated_h2(&self) -> bool { + self.conn_info.is_negotiated_h2() + } + + /// Determines if the connection is a reused one from the connection pool. + pub fn is_reused(&self) -> bool { + self.is_reused + } +} + macro_rules! e { ($kind:ident) => { Error { @@ -282,7 +310,7 @@ where if req.version() == Version::HTTP_2 { warn!("Connection is HTTP/1, but request requires HTTP/2"); return Err(TrySendError::Nope( - e!(UserUnsupportedVersion).with_connect_info(pooled.conn_info.clone()), + e!(UserUnsupportedVersion).with_connect_info(&pooled), )); } @@ -317,14 +345,12 @@ where Err(mut err) => { return if let Some(req) = err.take_message() { Err(TrySendError::Retryable { - error: e!(Canceled, err.into_error()) - .with_connect_info(pooled.conn_info.clone()), + error: e!(Canceled, err.into_error()).with_connect_info(&pooled), req, }) } else { Err(TrySendError::Nope( - e!(SendRequest, err.into_error()) - .with_connect_info(pooled.conn_info.clone()), + e!(SendRequest, err.into_error()).with_connect_info(&pooled), )) } } @@ -1619,14 +1645,20 @@ impl Error { /// Returns the info of the client connection on which this error occurred. #[cfg(any(feature = "http1", feature = "http2"))] - pub fn connect_info(&self) -> Option<&Connected> { + pub fn connect_info(&self) -> Option<&ErroredConnectInfo> { self.connect_info.as_ref() } #[cfg(any(feature = "http1", feature = "http2"))] - fn with_connect_info(self, connect_info: Connected) -> Self { + fn with_connect_info(self, pooled: &pool::Pooled, PoolKey>) -> Self + where + B: Send + 'static, + { Self { - connect_info: Some(connect_info), + connect_info: Some(ErroredConnectInfo { + conn_info: pooled.conn_info.clone(), + is_reused: pooled.is_reused(), + }), ..self } } diff --git a/tests/legacy_client.rs b/tests/legacy_client.rs index 0f11d77..9135bc6 100644 --- a/tests/legacy_client.rs +++ b/tests/legacy_client.rs @@ -20,7 +20,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use hyper::body::Bytes; use hyper::body::Frame; use hyper::Request; -use hyper_util::client::legacy::connect::{capture_connection, HttpConnector}; +use hyper_util::client::legacy::connect::{capture_connection, HttpConnector, HttpInfo}; use hyper_util::client::legacy::Client; use hyper_util::rt::{TokioExecutor, TokioIo}; @@ -978,3 +978,91 @@ fn connection_poisoning() { assert_eq!(num_conns.load(Ordering::SeqCst), 2); assert_eq!(num_requests.load(Ordering::SeqCst), 5); } + +#[cfg(not(miri))] +#[tokio::test] +async fn connect_info_on_error() { + let client = Client::builder(TokioExecutor::new()).build(HttpConnector::new()); + + // srv1 accepts one connection, and cancel it after reading the second request. + let tcp1 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr1 = tcp1.local_addr().unwrap(); + let srv1 = tokio::spawn(async move { + let (mut sock, _addr) = tcp1.accept().await.unwrap(); + let mut buf = [0; 4096]; + sock.read(&mut buf).await.expect("read 1"); + let body = Bytes::from("Hello, world!"); + sock.write_all( + format!("HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", body.len()).as_bytes(), + ) + .await + .expect("write header"); + sock.write_all(&body).await.expect("write body"); + + sock.read(&mut buf).await.expect("read 2"); + drop(sock); + }); + + // Makes a first request to srv1, which should succeed. + { + let req = Request::builder() + .uri(format!("http://{addr1}")) + .body(Empty::::new()) + .unwrap(); + let res = client.request(req).await.unwrap(); + let http_info = res.extensions().get::().unwrap(); + assert_eq!(http_info.remote_addr(), addr1); + let res_body = String::from_utf8(res.collect().await.unwrap().to_bytes().into()).unwrap(); + assert_eq!(res_body, "Hello, world!"); + } + + // Makes a second request to srv1, which should use the same connection and fail. + { + let req = Request::builder() + .uri(format!("http://{addr1}")) + .body(Empty::::new()) + .unwrap(); + let err = client.request(req).await.unwrap_err(); + let conn_info = err.connect_info().unwrap(); + assert!(!conn_info.is_proxied()); + assert!(!conn_info.is_negotiated_h2()); + assert!(conn_info.is_reused()); + + let mut exts = http::Extensions::new(); + conn_info.get_extras(&mut exts); + let http_info = exts.get::().unwrap(); + assert_eq!(http_info.remote_addr(), addr1); + } + + srv1.await.unwrap(); + + // srv2 accepts one connection, reads a request, and immediately closes it. + let tcp2 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr2 = tcp2.local_addr().unwrap(); + let srv2 = tokio::spawn(async move { + let (mut sock, _addr) = tcp2.accept().await.unwrap(); + let mut buf = [0; 4096]; + sock.read(&mut buf).await.expect("read"); + drop(sock); + }); + + // Makes a first request to srv2, which should use a fresh connection and fail. + { + let req = Request::builder() + .uri(format!("http://{addr2}")) + .body(Empty::::new()) + .unwrap(); + let err = client.request(req).await.unwrap_err(); + let conn_info = err.connect_info().unwrap(); + assert!(!conn_info.is_proxied()); + assert!(!conn_info.is_negotiated_h2()); + assert!(!conn_info.is_reused()); + + let mut exts = http::Extensions::new(); + conn_info.get_extras(&mut exts); + let http_info = exts.get::().unwrap(); + assert_eq!(http_info.remote_addr(), addr2); + } + + srv2.await.unwrap(); +}