From 9c27f92203bc663dbede748ed96220f0d057e6c5 Mon Sep 17 00:00:00 2001 From: Zanie Blue Date: Fri, 15 Mar 2024 12:07:38 -0500 Subject: [PATCH] Introduce a `BaseClient` for construction of canonical configured client (#2431) In preparation for support of https://github.com/astral-sh/uv/issues/2357 (see https://github.com/astral-sh/uv/pull/2434) --- crates/uv-client/src/base_client.rs | 188 ++++++++++++++++++ crates/uv-client/src/cached_client.rs | 10 +- crates/uv-client/src/flat_index.rs | 11 +- crates/uv-client/src/lib.rs | 2 + crates/uv-client/src/registry_client.rs | 126 +++--------- crates/uv-client/tests/netrc_auth.rs | 3 +- crates/uv-client/tests/user_agent_version.rs | 3 +- .../src/distribution_database.rs | 6 +- crates/uv-distribution/src/source/mod.rs | 6 +- 9 files changed, 240 insertions(+), 115 deletions(-) create mode 100644 crates/uv-client/src/base_client.rs diff --git a/crates/uv-client/src/base_client.rs b/crates/uv-client/src/base_client.rs new file mode 100644 index 000000000000..b0fab037f42e --- /dev/null +++ b/crates/uv-client/src/base_client.rs @@ -0,0 +1,188 @@ +use reqwest::{Client, ClientBuilder}; +use reqwest_middleware::ClientWithMiddleware; +use reqwest_retry::policies::ExponentialBackoff; +use reqwest_retry::RetryTransientMiddleware; +use std::env; +use std::fmt::Debug; +use std::ops::Deref; +use std::path::Path; +use tracing::debug; +use uv_auth::{AuthMiddleware, KeyringProvider}; +use uv_fs::Simplified; +use uv_version::version; +use uv_warnings::warn_user_once; + +use crate::middleware::OfflineMiddleware; +use crate::tls::Roots; +use crate::{tls, Connectivity}; + +/// A builder for an [`RegistryClient`]. +#[derive(Debug, Clone)] +pub struct BaseClientBuilder { + keyring_provider: KeyringProvider, + native_tls: bool, + retries: u32, + connectivity: Connectivity, + client: Option, +} + +impl BaseClientBuilder { + pub fn new() -> Self { + Self { + keyring_provider: KeyringProvider::default(), + native_tls: false, + connectivity: Connectivity::Online, + retries: 3, + client: None, + } + } +} + +impl BaseClientBuilder { + #[must_use] + pub fn keyring_provider(mut self, keyring_provider: KeyringProvider) -> Self { + self.keyring_provider = keyring_provider; + self + } + + #[must_use] + pub fn connectivity(mut self, connectivity: Connectivity) -> Self { + self.connectivity = connectivity; + self + } + + #[must_use] + pub fn retries(mut self, retries: u32) -> Self { + self.retries = retries; + self + } + + #[must_use] + pub fn native_tls(mut self, native_tls: bool) -> Self { + self.native_tls = native_tls; + self + } + + #[must_use] + pub fn client(mut self, client: Client) -> Self { + self.client = Some(client); + self + } + + pub fn build(self) -> BaseClient { + // Create user agent. + let user_agent_string = format!("uv/{}", version()); + + // Timeout options, matching https://doc.rust-lang.org/nightly/cargo/reference/config.html#httptimeout + // `UV_REQUEST_TIMEOUT` is provided for backwards compatibility with v0.1.6 + let default_timeout = 5 * 60; + let timeout = env::var("UV_HTTP_TIMEOUT") + .or_else(|_| env::var("UV_REQUEST_TIMEOUT")) + .or_else(|_| env::var("HTTP_TIMEOUT")) + .and_then(|value| { + value.parse::() + .or_else(|_| { + // On parse error, warn and use the default timeout + warn_user_once!("Ignoring invalid value from environment for UV_HTTP_TIMEOUT. Expected integer number of seconds, got \"{value}\"."); + Ok(default_timeout) + }) + }) + .unwrap_or(default_timeout); + debug!("Using registry request timeout of {}s", timeout); + + // Initialize the base client. + let client = self.client.unwrap_or_else(|| { + // Check for the presence of an `SSL_CERT_FILE`. + let ssl_cert_file_exists = env::var_os("SSL_CERT_FILE").is_some_and(|path| { + let path_exists = Path::new(&path).exists(); + if !path_exists { + warn_user_once!( + "Ignoring invalid `SSL_CERT_FILE`. File does not exist: {}.", + path.simplified_display() + ); + } + path_exists + }); + // Load the TLS configuration. + let tls = tls::load(if self.native_tls || ssl_cert_file_exists { + Roots::Native + } else { + Roots::Webpki + }) + .expect("Failed to load TLS configuration."); + + let client_core = ClientBuilder::new() + .user_agent(user_agent_string) + .pool_max_idle_per_host(20) + .timeout(std::time::Duration::from_secs(timeout)) + .use_preconfigured_tls(tls); + + client_core.build().expect("Failed to build HTTP client.") + }); + + // Wrap in any relevant middleware. + let client = match self.connectivity { + Connectivity::Online => { + let client = reqwest_middleware::ClientBuilder::new(client.clone()); + + // Initialize the retry strategy. + let retry_policy = + ExponentialBackoff::builder().build_with_max_retries(self.retries); + let retry_strategy = RetryTransientMiddleware::new_with_policy(retry_policy); + let client = client.with(retry_strategy); + + // Initialize the authentication middleware to set headers. + let client = client.with(AuthMiddleware::new(self.keyring_provider)); + + client.build() + } + Connectivity::Offline => reqwest_middleware::ClientBuilder::new(client.clone()) + .with(OfflineMiddleware) + .build(), + }; + + BaseClient { + connectivity: self.connectivity, + client, + timeout, + } + } +} + +/// A base client for HTTP requests +#[derive(Debug, Clone)] +pub struct BaseClient { + /// The underlying HTTP client. + client: ClientWithMiddleware, + /// The connectivity mode to use. + connectivity: Connectivity, + /// Configured client timeout, in seconds. + timeout: u64, +} + +impl BaseClient { + /// The underyling [`ClientWithMiddleware`]. + pub fn client(&self) -> ClientWithMiddleware { + self.client.clone() + } + + /// The configured client timeout, in seconds. + pub fn timeout(&self) -> u64 { + self.timeout + } + + /// The configured connectivity mode. + pub fn connectivity(&self) -> Connectivity { + self.connectivity + } +} + +// To avoid excessively verbose call chains, as the [`BaseClient`] is often nested within other client types. +impl Deref for BaseClient { + type Target = ClientWithMiddleware; + + /// Deference to the underlying [`ClientWithMiddleware`]. + fn deref(&self) -> &Self::Target { + &self.client + } +} diff --git a/crates/uv-client/src/cached_client.rs b/crates/uv-client/src/cached_client.rs index cbcebf12c0ee..a80726e8bb95 100644 --- a/crates/uv-client/src/cached_client.rs +++ b/crates/uv-client/src/cached_client.rs @@ -2,7 +2,6 @@ use std::{borrow::Cow, future::Future, path::Path}; use futures::FutureExt; use reqwest::{Request, Response}; -use reqwest_middleware::ClientWithMiddleware; use rkyv::util::AlignedVec; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; @@ -11,6 +10,7 @@ use tracing::{debug, info_span, instrument, trace, warn, Instrument}; use uv_cache::{CacheEntry, Freshness}; use uv_fs::write_atomic; +use crate::BaseClient; use crate::{ httpcache::{AfterResponse, BeforeRequest, CachePolicy, CachePolicyBuilder}, rkyvutil::OwnedArchive, @@ -158,15 +158,15 @@ impl From for CacheControl { /// Again unlike `http-cache`, the caller gets full control over the cache key with the assumption /// that it's a file. #[derive(Debug, Clone)] -pub struct CachedClient(ClientWithMiddleware); +pub struct CachedClient(BaseClient); impl CachedClient { - pub fn new(client: ClientWithMiddleware) -> Self { + pub fn new(client: BaseClient) -> Self { Self(client) } - /// The middleware is the retry strategy - pub fn uncached(&self) -> ClientWithMiddleware { + /// The base client + pub fn uncached(&self) -> BaseClient { self.0.clone() } diff --git a/crates/uv-client/src/flat_index.rs b/crates/uv-client/src/flat_index.rs index e827542fb9e0..a042f4f10bd4 100644 --- a/crates/uv-client/src/flat_index.rs +++ b/crates/uv-client/src/flat_index.rs @@ -143,10 +143,9 @@ impl<'a> FlatIndexClient<'a> { Connectivity::Offline => CacheControl::AllowStale, }; - let cached_client = self.client.cached_client(); - - let flat_index_request = cached_client - .uncached() + let flat_index_request = self + .client + .uncached_client() .get(url.clone()) .header("Accept-Encoding", "gzip") .header("Accept", "text/html") @@ -180,7 +179,9 @@ impl<'a> FlatIndexClient<'a> { .boxed() .instrument(info_span!("parse_flat_index_html", url = % url)) }; - let response = cached_client + let response = self + .client + .cached_client() .get_serde( flat_index_request, &cache_entry, diff --git a/crates/uv-client/src/lib.rs b/crates/uv-client/src/lib.rs index bf0a083f6d61..c74b92ac57db 100644 --- a/crates/uv-client/src/lib.rs +++ b/crates/uv-client/src/lib.rs @@ -1,3 +1,4 @@ +pub use base_client::BaseClient; pub use cached_client::{CacheControl, CachedClient, CachedClientError, DataWithCachePolicy}; pub use error::{BetterReqwestError, Error, ErrorKind}; pub use flat_index::{FlatDistributions, FlatIndex, FlatIndexClient, FlatIndexError}; @@ -7,6 +8,7 @@ pub use registry_client::{ }; pub use rkyvutil::OwnedArchive; +mod base_client; mod cached_client; mod error; mod flat_index; diff --git a/crates/uv-client/src/registry_client.rs b/crates/uv-client/src/registry_client.rs index 6930cc93fd82..23e500bb381c 100644 --- a/crates/uv-client/src/registry_client.rs +++ b/crates/uv-client/src/registry_client.rs @@ -1,5 +1,4 @@ use std::collections::BTreeMap; -use std::env; use std::fmt::Debug; use std::path::Path; use std::str::FromStr; @@ -7,13 +6,11 @@ use std::str::FromStr; use async_http_range_reader::AsyncHttpRangeReader; use futures::{FutureExt, TryStreamExt}; use http::HeaderMap; -use reqwest::{Client, ClientBuilder, Response, StatusCode}; -use reqwest_retry::policies::ExponentialBackoff; -use reqwest_retry::RetryTransientMiddleware; +use reqwest::{Client, Response, StatusCode}; use serde::{Deserialize, Serialize}; use tokio::io::AsyncReadExt; use tokio_util::compat::FuturesAsyncReadCompatExt; -use tracing::{debug, info_span, instrument, trace, warn, Instrument}; +use tracing::{info_span, instrument, trace, warn, Instrument}; use url::Url; use distribution_filename::{DistFilename, SourceDistFilename, WheelFilename}; @@ -21,20 +18,16 @@ use distribution_types::{BuiltDist, File, FileLocation, IndexUrl, IndexUrls, Nam use install_wheel_rs::metadata::{find_archive_dist_info, is_metadata_entry}; use pep440_rs::Version; use pypi_types::{Metadata23, SimpleJson}; -use uv_auth::{AuthMiddleware, KeyringProvider}; +use uv_auth::KeyringProvider; use uv_cache::{Cache, CacheBucket, WheelCache}; -use uv_fs::Simplified; use uv_normalize::PackageName; -use uv_version::version; -use uv_warnings::warn_user_once; +use crate::base_client::{BaseClient, BaseClientBuilder}; use crate::cached_client::CacheControl; use crate::html::SimpleHtml; -use crate::middleware::OfflineMiddleware; use crate::remote_metadata::wheel_metadata_from_remote_zip; use crate::rkyvutil::OwnedArchive; -use crate::tls::Roots; -use crate::{tls, CachedClient, CachedClientError, Error, ErrorKind}; +use crate::{CachedClient, CachedClientError, Error, ErrorKind}; /// A builder for an [`RegistryClient`]. #[derive(Debug, Clone)] @@ -106,76 +99,22 @@ impl RegistryClientBuilder { } pub fn build(self) -> RegistryClient { - // Create user agent. - let user_agent_string = format!("uv/{}", version()); - - // Timeout options, matching https://doc.rust-lang.org/nightly/cargo/reference/config.html#httptimeout - // `UV_REQUEST_TIMEOUT` is provided for backwards compatibility with v0.1.6 - let default_timeout = 5 * 60; - let timeout = env::var("UV_HTTP_TIMEOUT") - .or_else(|_| env::var("UV_REQUEST_TIMEOUT")) - .or_else(|_| env::var("HTTP_TIMEOUT")) - .and_then(|value| { - value.parse::() - .or_else(|_| { - // On parse error, warn and use the default timeout - warn_user_once!("Ignoring invalid value from environment for UV_HTTP_TIMEOUT. Expected integer number of seconds, got \"{value}\"."); - Ok(default_timeout) - }) - }) - .unwrap_or(default_timeout); - debug!("Using registry request timeout of {}s", timeout); - - // Initialize the base client. - let client = self.client.unwrap_or_else(|| { - // Check for the presence of an `SSL_CERT_FILE`. - let ssl_cert_file_exists = env::var_os("SSL_CERT_FILE").is_some_and(|path| { - let path_exists = Path::new(&path).exists(); - if !path_exists { - warn_user_once!( - "Ignoring invalid `SSL_CERT_FILE`. File does not exist: {}.", - path.simplified_display() - ); - } - path_exists - }); - // Load the TLS configuration. - let tls = tls::load(if self.native_tls || ssl_cert_file_exists { - Roots::Native - } else { - Roots::Webpki - }) - .expect("Failed to load TLS configuration."); - - let client_core = ClientBuilder::new() - .user_agent(user_agent_string) - .pool_max_idle_per_host(20) - .timeout(std::time::Duration::from_secs(timeout)) - .use_preconfigured_tls(tls); - - client_core.build().expect("Failed to build HTTP client.") - }); - - // Wrap in any relevant middleware. - let client = match self.connectivity { - Connectivity::Online => { - let client = reqwest_middleware::ClientBuilder::new(client.clone()); + // Build a base client + let mut builder = BaseClientBuilder::new(); - // Initialize the retry strategy. - let retry_policy = - ExponentialBackoff::builder().build_with_max_retries(self.retries); - let retry_strategy = RetryTransientMiddleware::new_with_policy(retry_policy); - let client = client.with(retry_strategy); + if let Some(client) = self.client { + builder = builder.client(client) + } - // Initialize the authentication middleware to set headers. - let client = client.with(AuthMiddleware::new(self.keyring_provider)); + let client = builder + .retries(self.retries) + .connectivity(self.connectivity) + .native_tls(self.native_tls) + .keyring_provider(self.keyring_provider) + .build(); - client.build() - } - Connectivity::Offline => reqwest_middleware::ClientBuilder::new(client.clone()) - .with(OfflineMiddleware) - .build(), - }; + let timeout = client.timeout(); + let connectivity = client.connectivity(); // Wrap in the cache middleware. let client = CachedClient::new(client); @@ -183,7 +122,7 @@ impl RegistryClientBuilder { RegistryClient { index_urls: self.index_urls, cache: self.cache, - connectivity: self.connectivity, + connectivity, client, timeout, } @@ -211,6 +150,11 @@ impl RegistryClient { &self.client } + /// Return the [`BaseClient`] used by this client. + pub fn uncached_client(&self) -> BaseClient { + self.client.uncached() + } + /// Return the [`Connectivity`] mode used by this client. pub fn connectivity(&self) -> Connectivity { self.connectivity @@ -306,8 +250,7 @@ impl RegistryClient { }; let simple_request = self - .client - .uncached() + .uncached_client() .get(url.clone()) .header("Accept-Encoding", "gzip") .header("Accept", MediaType::accepts()) @@ -356,7 +299,7 @@ impl RegistryClient { .instrument(info_span!("parse_simple_api", package = %package_name)) }; let result = self - .client + .cached_client() .get_cacheable( simple_request, &cache_entry, @@ -469,13 +412,12 @@ impl RegistryClient { }) }; let req = self - .client - .uncached() + .uncached_client() .get(url.clone()) .build() .map_err(ErrorKind::from)?; Ok(self - .client + .cached_client() .get_serde(req, &cache_entry, cache_control, response_callback) .await?) } else { @@ -509,8 +451,7 @@ impl RegistryClient { }; let req = self - .client - .uncached() + .uncached_client() .head(url.clone()) .header( "accept-encoding", @@ -530,7 +471,7 @@ impl RegistryClient { let read_metadata_range_request = |response: Response| { async { let mut reader = AsyncHttpRangeReader::from_head_response( - self.client.uncached(), + self.uncached_client().client(), response, headers, ) @@ -552,7 +493,7 @@ impl RegistryClient { }; let result = self - .client + .cached_client() .get_serde( req, &cache_entry, @@ -577,8 +518,7 @@ impl RegistryClient { // Create a request to stream the file. let req = self - .client - .uncached() + .uncached_client() .get(url.clone()) .header( // `reqwest` defaults to accepting compressed responses. @@ -603,7 +543,7 @@ impl RegistryClient { .instrument(info_span!("read_metadata_stream", wheel = %filename)) }; - self.client + self.cached_client() .get_serde(req, &cache_entry, cache_control, read_metadata_stream) .await .map_err(crate::Error::from) diff --git a/crates/uv-client/tests/netrc_auth.rs b/crates/uv-client/tests/netrc_auth.rs index d23103e40f0c..4678f176bc38 100644 --- a/crates/uv-client/tests/netrc_auth.rs +++ b/crates/uv-client/tests/netrc_auth.rs @@ -52,8 +52,7 @@ async fn test_client_with_netrc_credentials() -> Result<()> { // Send request to our dummy server let res = client - .cached_client() - .uncached() + .uncached_client() .get(format!("http://{addr}")) .send() .await?; diff --git a/crates/uv-client/tests/user_agent_version.rs b/crates/uv-client/tests/user_agent_version.rs index a366de68d959..68d7eefe6adf 100644 --- a/crates/uv-client/tests/user_agent_version.rs +++ b/crates/uv-client/tests/user_agent_version.rs @@ -44,8 +44,7 @@ async fn test_user_agent_has_version() -> Result<()> { // Send request to our dummy server let res = client - .cached_client() - .uncached() + .uncached_client() .get(format!("http://{addr}")) .send() .await?; diff --git a/crates/uv-distribution/src/distribution_database.rs b/crates/uv-distribution/src/distribution_database.rs index c9882b574c91..16d273280174 100644 --- a/crates/uv-distribution/src/distribution_database.rs +++ b/crates/uv-distribution/src/distribution_database.rs @@ -460,8 +460,7 @@ impl<'a, Context: BuildContext + Send + Sync> DistributionDatabase<'a, Context> let req = self .client - .cached_client() - .uncached() + .uncached_client() .get(url) .header( // `reqwest` defaults to accepting compressed responses. @@ -542,8 +541,7 @@ impl<'a, Context: BuildContext + Send + Sync> DistributionDatabase<'a, Context> let req = self .client - .cached_client() - .uncached() + .uncached_client() .get(url) .header( // `reqwest` defaults to accepting compressed responses. diff --git a/crates/uv-distribution/src/source/mod.rs b/crates/uv-distribution/src/source/mod.rs index 95d22457708c..8e9e9d6652a7 100644 --- a/crates/uv-distribution/src/source/mod.rs +++ b/crates/uv-distribution/src/source/mod.rs @@ -304,8 +304,7 @@ impl<'a, T: BuildContext> SourceDistCachedBuilder<'a, T> { }; let req = self .client - .cached_client() - .uncached() + .uncached_client() .get(url.clone()) .header( // `reqwest` defaults to accepting compressed responses. @@ -414,8 +413,7 @@ impl<'a, T: BuildContext> SourceDistCachedBuilder<'a, T> { }; let req = self .client - .cached_client() - .uncached() + .uncached_client() .get(url.clone()) .header( // `reqwest` defaults to accepting compressed responses.