Skip to content

Commit

Permalink
rm comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Mon-ius committed May 5, 2024
1 parent 6bc627f commit 3010536
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 211 deletions.
162 changes: 4 additions & 158 deletions hfd/src/api/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,59 +17,44 @@ use thiserror::Error;
use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
use tokio::sync::{AcquireError, Semaphore, TryAcquireError};

/// Current version (used in user-agent)
const VERSION: &str = env!("CARGO_PKG_VERSION");
/// Current name (used in user-agent)

const NAME: &str = env!("CARGO_PKG_NAME");

#[derive(Debug, Error)]
/// All errors the API can throw

pub enum ApiError {
/// Api expects certain header to be present in the results to derive some information
#[error("Header {0} is missing")]
MissingHeader(HeaderName),

/// The header exists, but the value is not conform to what the Api expects.
#[error("Header {0} is invalid")]
InvalidHeader(HeaderName),

/// The value cannot be used as a header during request header construction
#[error("Invalid header value {0}")]
InvalidHeaderValue(#[from] InvalidHeaderValue),

/// The header value is not valid utf-8
#[error("header value is not a string")]
ToStr(#[from] ToStrError),

/// Error in the request
#[error("request error: {0}")]
RequestError(#[from] ReqwestError),

/// Error parsing some range value
#[error("Cannot parse int")]
ParseIntError(#[from] ParseIntError),

/// I/O Error
#[error("I/O error {0}")]
IoError(#[from] std::io::Error),

/// We tried to download chunk too many times
#[error("Too many retries: {0}")]
TooManyRetries(Box<ApiError>),

/// Semaphore cannot be acquired
#[error("Try acquire: {0}")]
TryAcquireError(#[from] TryAcquireError),

/// Semaphore cannot be acquired
#[error("Acquire: {0}")]
AcquireError(#[from] AcquireError),
// /// Semaphore cannot be acquired
// #[error("Invalid Response: {0:?}")]
// InvalidResponse(Response),
}

/// Helper to create [`Api`] with all the options.
#[derive(Debug)]
pub struct ApiBuilder {
endpoint: String,
Expand All @@ -90,23 +75,11 @@ impl Default for ApiBuilder {
}

impl ApiBuilder {
/// Default api builder
/// ```
/// use hf_hub::api::tokio::ApiBuilder;
/// let api = ApiBuilder::new().build().unwrap();
/// ```
pub fn new() -> Self {
let cache = Cache::default();
Self::from_cache(cache)
}

/// From a given cache
/// ```
/// use hf_hub::{api::tokio::ApiBuilder, Cache};
/// let path = std::path::PathBuf::from("/tmp");
/// let cache = Cache::new(path);
/// let api = ApiBuilder::from_cache(cache).build().unwrap();
/// ```
pub fn from_cache(cache: Cache) -> Self {
let token = cache.token();

Expand All @@ -125,19 +98,16 @@ impl ApiBuilder {
}
}

/// Wether to show a progressbar
pub fn with_progress(mut self, progress: bool) -> Self {
self.progress = progress;
self
}

/// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`.
pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self {
self.cache = Cache::new(cache_dir);
self
}

/// Sets the token to be used in the API
pub fn with_token(mut self, token: Option<String>) -> Self {
self.token = token;
self
Expand All @@ -156,27 +126,21 @@ impl ApiBuilder {
Ok(headers)
}

/// Consumes the builder and builds the final [`Api`]
pub fn build(self) -> Result<Api, ApiError> {
let headers = self.build_headers()?;
let client = Client::builder().default_headers(headers.clone()).build()?;

// Policy: only follow relative redirects
// See: https://github.com/huggingface/huggingface_hub/blob/9c6af39cdce45b570f0b7f8fad2b311c96019804/src/huggingface_hub/file_download.py#L411
let relative_redirect_policy = Policy::custom(|attempt| {
// Follow redirects up to a maximum of 10.
if attempt.previous().len() > 10 {
return attempt.error("too many redirects");
}

if let Some(last) = attempt.previous().last() {
// If the url is not relative
if last.make_relative(attempt.url()).is_none() {
return attempt.stop();
}
}

// Follow redirect
attempt.follow()
});

Expand Down Expand Up @@ -206,9 +170,6 @@ struct Metadata {
size: usize,
}

/// The actual Api used to interact with the hub.
/// You can inspect repos with [`Api::info`]
/// or download files with [`Api::download`]
#[derive(Clone, Debug)]
pub struct Api {
endpoint: String,
Expand Down Expand Up @@ -239,9 +200,6 @@ fn make_relative(src: &Path, dst: &Path) -> PathBuf {
match (ita.next(), itb.next()) {
(Some(a), Some(b)) if a == b => (),
(some_a, _) => {
// Ignoring b, because 1 component is the filename
// for which we don't need to go back up for relative
// filename to work.
let mut new_path = PathBuf::new();
for _ in itb {
new_path.push(Component::ParentDir);
Expand Down Expand Up @@ -286,13 +244,10 @@ fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize {
}

impl Api {
/// Creates a default Api, for Api options See [`ApiBuilder`]
pub fn new() -> Result<Self, ApiError> {
ApiBuilder::new().build()
}

/// Get the underlying api client
/// Allows for lower level access
pub fn client(&self) -> &Client {
&self.client
}
Expand All @@ -316,16 +271,14 @@ impl Api {
.get(&header_etag)
.ok_or(ApiError::MissingHeader(header_etag))?,
};
// Cleaning extra quotes

let etag = etag.to_str()?.to_string().replace('"', "");
let commit_hash = headers
.get(&header_commit)
.ok_or(ApiError::MissingHeader(header_commit))?
.to_str()?
.to_string();

// The response was redirected o S3 most likely which will
// know about the size of the file
let response = if response.status().is_redirection() {
self.client
.get(headers.get(LOCATION).unwrap().to_str()?.to_string())
Expand Down Expand Up @@ -353,47 +306,23 @@ impl Api {
})
}

/// Creates a new handle [`ApiRepo`] which contains operations
/// on a particular [`Repo`]
pub fn repo(&self, repo: Repo) -> ApiRepo {
ApiRepo::new(self.clone(), repo)
}

/// Simple wrapper over
/// ```
/// # use hf_hub::{api::tokio::Api, Repo, RepoType};
/// # let model_id = "gpt2".to_string();
/// let api = Api::new().unwrap();
/// let api = api.repo(Repo::new(model_id, RepoType::Model));
/// ```
pub fn model(&self, model_id: String) -> ApiRepo {
self.repo(Repo::new(model_id, RepoType::Model))
}

/// Simple wrapper over
/// ```
/// # use hf_hub::{api::tokio::Api, Repo, RepoType};
/// # let model_id = "gpt2".to_string();
/// let api = Api::new().unwrap();
/// let api = api.repo(Repo::new(model_id, RepoType::Dataset));
/// ```
pub fn dataset(&self, model_id: String) -> ApiRepo {
self.repo(Repo::new(model_id, RepoType::Dataset))
}

/// Simple wrapper over
/// ```
/// # use hf_hub::{api::tokio::Api, Repo, RepoType};
/// # let model_id = "gpt2".to_string();
/// let api = Api::new().unwrap();
/// let api = api.repo(Repo::new(model_id, RepoType::Space));
/// ```
pub fn space(&self, model_id: String) -> ApiRepo {
self.repo(Repo::new(model_id, RepoType::Space))
}
}

/// Shorthand for accessing things within a particular repo
#[derive(Debug)]
pub struct ApiRepo {
api: Api,
Expand All @@ -407,13 +336,6 @@ impl ApiRepo {
}

impl ApiRepo {
/// Get the fully qualified URL of the remote filename
/// ```
/// # use hf_hub::api::tokio::Api;
/// let api = Api::new().unwrap();
/// let url = api.model("gpt2".to_string()).url("model.safetensors");
/// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors");
/// ```
pub fn url(&self, filename: &str) -> String {
let endpoint = &self.api.endpoint;
let revision = &self.repo.url_revision();
Expand All @@ -436,7 +358,6 @@ impl ApiRepo {
let parallel_failures_semaphore = Arc::new(Semaphore::new(self.api.parallel_failures));
let filename = self.api.cache.temp_path();

// Create the file and set everything properly
tokio::fs::File::create(&filename)
.await?
.set_len(length as u64)
Expand Down Expand Up @@ -482,7 +403,6 @@ impl ApiRepo {
}));
}

// Output the chained result
let results: Vec<Result<Result<(), ApiError>, tokio::task::JoinError>> =
futures::future::join_all(handles).await;
let results: Result<(), ApiError> = results.into_iter().flatten().collect();
Expand All @@ -500,7 +420,6 @@ impl ApiRepo {
start: usize,
stop: usize,
) -> Result<(), ApiError> {
// Process each socket concurrently.
let range = format!("bytes={start}-{stop}");
let mut file = tokio::fs::OpenOptions::new()
.write(true)
Expand All @@ -518,14 +437,6 @@ impl ApiRepo {
Ok(())
}

/// This will attempt the fetch the file locally first, then [`Api.download`]
/// if the file is not present.
/// ```no_run
/// # use hf_hub::api::tokio::Api;
/// # tokio_test::block_on(async {
/// let api = Api::new().unwrap();
/// let local_filename = api.model("gpt2".to_string()).get("model.safetensors").await.unwrap();
/// # })
pub async fn get(&self, filename: &str) -> Result<PathBuf, ApiError> {
if let Some(path) = self.api.cache.repo(self.repo.clone()).get(filename) {
Ok(path)
Expand All @@ -534,17 +445,6 @@ impl ApiRepo {
}
}

/// Downloads a remote file (if not already present) into the cache directory
/// to be used locally.
/// This functions require internet access to verify if new versions of the file
/// exist, even if a file is already on disk at location.
/// ```no_run
/// # use hf_hub::api::tokio::Api;
/// # tokio_test::block_on(async {
/// let api = Api::new().unwrap();
/// let local_filename = api.model("gpt2".to_string()).download("model.safetensors").await.unwrap();
/// # })
/// ```
pub async fn download(&self, filename: &str) -> Result<PathBuf, ApiError> {
let url = self.url(filename);
let metadata = self.api.metadata(&url).await?;
Expand All @@ -559,7 +459,7 @@ impl ApiRepo {
ProgressStyle::with_template(
"{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})",
)
.unwrap(), // .progress_chars("━ "),
.unwrap(),
);
let maxlength = 30;
let message = if filename.len() > maxlength {
Expand Down Expand Up @@ -589,67 +489,13 @@ impl ApiRepo {
Ok(pointer_path)
}

/// Get information about the Repo
/// ```
/// # use hf_hub::api::tokio::Api;
/// # tokio_test::block_on(async {
/// let api = Api::new().unwrap();
/// api.model("gpt2".to_string()).info();
/// # })
/// ```
pub async fn info(&self) -> Result<RepoInfo, ApiError> {
Ok(self.info_request().send().await?.json().await?)
}

/// Get the raw [`reqwest::RequestBuilder`] with the url and method already set
/// ```
/// # use hf_hub::api::tokio::Api;
/// # tokio_test::block_on(async {
/// let api = Api::new().unwrap();
/// api.model("gpt2".to_owned())
/// .info_request()
/// .query(&[("blobs", "true")])
/// .send()
/// .await;
/// # })
/// ```
pub fn info_request(&self) -> RequestBuilder {
let url = format!("{}/api/{}", self.api.endpoint, self.repo.api_url());
self.api.client.get(url)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::api::Siblings;
use hex_literal::hex;
use rand::distributions::Alphanumeric;
use serde_json::{json, Value};
use sha2::{Digest, Sha256};

struct TempDir {
path: PathBuf,
}

impl TempDir {
pub fn new() -> Self {
let s: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(7)
.map(char::from)
.collect();
let mut path = std::env::temp_dir();
path.push(s);
std::fs::create_dir(&path).unwrap();
Self { path }
}
}

impl Drop for TempDir {
fn drop(&mut self) {
std::fs::remove_dir_all(&self.path).unwrap();
}
}

}
Loading

0 comments on commit 3010536

Please sign in to comment.