Skip to content

Commit

Permalink
Make Client optional in requirements-txt (astral-sh#2229)
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh authored Mar 6, 2024
1 parent 2b2de0d commit 3ca7776
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 109 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 7 additions & 4 deletions crates/requirements-txt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ workspace = true
[dependencies]
pep440_rs = { path = "../pep440-rs", features = ["rkyv", "serde"] }
pep508_rs = { path = "../pep508-rs", features = ["rkyv", "serde", "non-pep508-extensions"] }
uv-cache = { path = "../uv-cache" }
uv-client = { path = "../uv-client" }
uv-fs = { path = "../uv-fs" }
uv-normalize = { path = "../uv-normalize" }
Expand All @@ -25,11 +24,10 @@ async-recursion = { workspace = true }
fs-err = { workspace = true }
once_cell = { workspace = true }
regex = { workspace = true }
reqwest = { workspace = true }
reqwest-middleware = { workspace = true }
reqwest = { workspace = true, optional = true }
reqwest-middleware = { workspace = true, optional = true }
serde = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
unscanny = { workspace = true }
url = { workspace = true }
Expand All @@ -43,3 +41,8 @@ itertools = { version = "0.12.1" }
serde_json = { version = "1.0.114" }
tempfile = { version = "3.9.0" }
test-case = { version = "3.3.1" }
tokio = { version = "1.35.1", features = ["macros"] }

[features]
default = []
reqwest = ["dep:reqwest", "dep:reqwest-middleware"]
150 changes: 82 additions & 68 deletions crates/requirements-txt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,20 @@ use std::fmt::{Display, Formatter};
use std::io;
use std::path::{Path, PathBuf};

use async_recursion::async_recursion;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use unscanny::{Pattern, Scanner};
use url::Url;
use uv_client::RegistryClient;
use uv_warnings::warn_user;

use async_recursion::async_recursion;
use pep508_rs::{
expand_path_vars, split_scheme, Extras, Pep508Error, Pep508ErrorSource, Requirement, Scheme,
VerbatimUrl,
};
use uv_client::RegistryClient;
use uv_fs::{normalize_url_path, Simplified};
use uv_normalize::ExtraName;
use uv_warnings::warn_user;

/// We emit one of those for each requirements.txt entry
enum RequirementsTxtStatement {
Expand Down Expand Up @@ -326,14 +326,38 @@ impl RequirementsTxt {
pub async fn parse(
requirements_txt: impl AsRef<Path>,
working_dir: impl AsRef<Path>,
client: &RegistryClient,
client: Option<&RegistryClient>,
) -> Result<Self, RequirementsTxtFileError> {
let requirements_txt = requirements_txt.as_ref();
let working_dir = working_dir.as_ref();

let content =
if requirements_txt.starts_with("http://") | requirements_txt.starts_with("https://") {
read_url_to_string(&requirements_txt, client).await
#[cfg(not(feature = "reqwest"))]
{
return Err(RequirementsTxtFileError {
file: requirements_txt.to_path_buf(),
error: RequirementsTxtParserError::IO(io::Error::new(
io::ErrorKind::InvalidInput,
"Remote file not supported without `reqwest` feature",
)),
});
}

#[cfg(feature = "reqwest")]
{
let Some(client) = client else {
return Err(RequirementsTxtFileError {
file: requirements_txt.to_path_buf(),
error: RequirementsTxtParserError::IO(io::Error::new(
io::ErrorKind::InvalidInput,
"No client provided for remote file",
)),
});
};

read_url_to_string(&requirements_txt, client).await
}
} else {
uv_fs::read_to_string(&requirements_txt)
.await
Expand Down Expand Up @@ -372,7 +396,7 @@ impl RequirementsTxt {
content: &str,
working_dir: &Path,
requirements_dir: &Path,
client: &RegistryClient,
client: Option<&'async_recursion RegistryClient>,
) -> Result<Self, RequirementsTxtParserError> {
let mut s = Scanner::new(content);

Expand Down Expand Up @@ -794,6 +818,7 @@ fn parse_value<'a, T>(
}

/// Fetch the contents of a URL and return them as a string.
#[cfg(feature = "reqwest")]
async fn read_url_to_string(
path: impl AsRef<Path>,
client: &RegistryClient,
Expand Down Expand Up @@ -859,10 +884,11 @@ pub enum RequirementsTxtParserError {
start: usize,
end: usize,
},
Reqwest(reqwest_middleware::Error),
NonUnicodeUrl {
url: PathBuf,
},
#[cfg(feature = "reqwest")]
Reqwest(reqwest_middleware::Error),
}

impl RequirementsTxtParserError {
Expand Down Expand Up @@ -910,8 +936,9 @@ impl RequirementsTxtParserError {
start: start + offset,
end: end + offset,
},
Self::Reqwest(err) => Self::Reqwest(err),
Self::NonUnicodeUrl { url } => Self::NonUnicodeUrl { url },
#[cfg(feature = "reqwest")]
Self::Reqwest(err) => Self::Reqwest(err),
}
}
}
Expand Down Expand Up @@ -954,16 +981,17 @@ impl Display for RequirementsTxtParserError {
Self::Subfile { start, .. } => {
write!(f, "Error parsing included file at position {start}")
}
Self::Reqwest(err) => {
write!(f, "Error while accessing remote requirements file {err}")
}
Self::NonUnicodeUrl { url } => {
write!(
f,
"Remote requirements URL contains non-unicode characters: {}",
url.display(),
)
}
#[cfg(feature = "reqwest")]
Self::Reqwest(err) => {
write!(f, "Error while accessing remote requirements file {err}")
}
}
}
}
Expand All @@ -981,8 +1009,9 @@ impl std::error::Error for RequirementsTxtParserError {
Self::Pep508 { source, .. } => Some(source),
Self::Subfile { source, .. } => Some(source.as_ref()),
Self::Parser { .. } => None,
Self::Reqwest(err) => err.source(),
Self::NonUnicodeUrl { .. } => None,
#[cfg(feature = "reqwest")]
Self::Reqwest(err) => err.source(),
}
}
}
Expand Down Expand Up @@ -1058,19 +1087,19 @@ impl Display for RequirementsTxtFileError {
self.file.simplified_display(),
)
}
RequirementsTxtParserError::Reqwest(err) => {
RequirementsTxtParserError::NonUnicodeUrl { url } => {
write!(
f,
"Error while accessing remote requirements file {}: {err}",
self.file.simplified_display(),
"Remote requirements URL contains non-unicode characters: {}",
url.display(),
)
}

RequirementsTxtParserError::NonUnicodeUrl { url } => {
#[cfg(feature = "reqwest")]
RequirementsTxtParserError::Reqwest(err) => {
write!(
f,
"Remote requirements URL contains non-unicode characters: {}",
url.display(),
"Error while accessing remote requirements file {}: {err}",
self.file.simplified_display(),
)
}
}
Expand All @@ -1089,6 +1118,7 @@ impl From<io::Error> for RequirementsTxtParserError {
}
}

#[cfg(feature = "reqwest")]
impl From<reqwest_middleware::Error> for RequirementsTxtParserError {
fn from(err: reqwest_middleware::Error) -> Self {
Self::Reqwest(err)
Expand Down Expand Up @@ -1147,7 +1177,7 @@ mod test {
use tempfile::tempdir;
use test_case::test_case;
use unscanny::Scanner;
use uv_client::{RegistryClient, RegistryClientBuilder};

use uv_fs::Simplified;

use crate::{calculate_row_column, EditableRequirement, RequirementsTxt};
Expand All @@ -1156,12 +1186,6 @@ mod test {
PathBuf::from("./test-data")
}

fn registry_client() -> RegistryClient {
RegistryClientBuilder::new(uv_cache::Cache::temp().unwrap())
.connectivity(uv_client::Connectivity::Online)
.build()
}

#[test_case(Path::new("basic.txt"))]
#[test_case(Path::new("constraints-a.txt"))]
#[test_case(Path::new("constraints-b.txt"))]
Expand All @@ -1177,7 +1201,7 @@ mod test {
let working_dir = workspace_test_data_dir().join("requirements-txt");
let requirements_txt = working_dir.join(path);

let actual = RequirementsTxt::parse(requirements_txt, &working_dir, &registry_client())
let actual = RequirementsTxt::parse(requirements_txt, &working_dir, None)
.await
.unwrap();

Expand Down Expand Up @@ -1221,7 +1245,7 @@ mod test {
let requirements_txt = temp_dir.path().join(path);
fs::write(&requirements_txt, contents).unwrap();

let actual = RequirementsTxt::parse(&requirements_txt, &working_dir, &registry_client())
let actual = RequirementsTxt::parse(&requirements_txt, &working_dir, None)
.await
.unwrap();

Expand All @@ -1238,10 +1262,9 @@ mod test {
-r missing.txt
"})?;

let error =
RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), &registry_client())
.await
.unwrap_err();
let error = RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), None)
.await
.unwrap_err();
let errors = anyhow::Error::new(error)
.chain()
// The last error is operating-system specific.
Expand Down Expand Up @@ -1276,10 +1299,9 @@ mod test {
numpy[ö]==1.29
"})?;

let error =
RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), &registry_client())
.await
.unwrap_err();
let error = RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), None)
.await
.unwrap_err();
let errors = anyhow::Error::new(error).chain().join("\n");

let requirement_txt =
Expand Down Expand Up @@ -1310,10 +1332,9 @@ mod test {
-e http://localhost:8080/
"})?;

let error =
RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), &registry_client())
.await
.unwrap_err();
let error = RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), None)
.await
.unwrap_err();
let errors = anyhow::Error::new(error).chain().join("\n");

let requirement_txt =
Expand All @@ -1339,10 +1360,9 @@ mod test {
-e black[,abcdef]
"})?;

let error =
RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), &registry_client())
.await
.unwrap_err();
let error = RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), None)
.await
.unwrap_err();
let errors = anyhow::Error::new(error).chain().join("\n");

let requirement_txt =
Expand Down Expand Up @@ -1370,10 +1390,9 @@ mod test {
--index-url 123
"})?;

let error =
RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), &registry_client())
.await
.unwrap_err();
let error = RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), None)
.await
.unwrap_err();
let errors = anyhow::Error::new(error).chain().join("\n");

let requirement_txt =
Expand Down Expand Up @@ -1407,10 +1426,9 @@ mod test {
file.txt
"})?;

let error =
RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), &registry_client())
.await
.unwrap_err();
let error = RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), None)
.await
.unwrap_err();
let errors = anyhow::Error::new(error).chain().join("\n");

let requirement_txt =
Expand Down Expand Up @@ -1451,10 +1469,9 @@ mod test {
-r subdir/child.txt
"})?;

let requirements =
RequirementsTxt::parse(parent_txt.path(), temp_dir.path(), &registry_client())
.await
.unwrap();
let requirements = RequirementsTxt::parse(parent_txt.path(), temp_dir.path(), None)
.await
.unwrap();
insta::assert_debug_snapshot!(requirements, @r###"
RequirementsTxt {
requirements: [
Expand Down Expand Up @@ -1504,10 +1521,9 @@ mod test {
--no-index
"})?;

let requirements =
RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), &registry_client())
.await
.unwrap();
let requirements = RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), None)
.await
.unwrap();

insta::assert_debug_snapshot!(requirements, @r###"
RequirementsTxt {
Expand Down Expand Up @@ -1565,10 +1581,9 @@ mod test {
--index-url https://fake.pypi.org/simple
"})?;

let error =
RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), &registry_client())
.await
.unwrap_err();
let error = RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), None)
.await
.unwrap_err();
let errors = anyhow::Error::new(error).chain().join("\n");

let requirement_txt =
Expand Down Expand Up @@ -1616,10 +1631,9 @@ mod test {
tqdm
"})?;

let error =
RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), &registry_client())
.await
.unwrap_err();
let error = RequirementsTxt::parse(requirements_txt.path(), temp_dir.path(), None)
.await
.unwrap_err();
let errors = anyhow::Error::new(error).chain().join("\n");

let requirement_txt =
Expand Down
Loading

0 comments on commit 3ca7776

Please sign in to comment.