diff --git a/olah/configs.py b/olah/configs.py index c979fae..77d2d96 100644 --- a/olah/configs.py +++ b/olah/configs.py @@ -5,7 +5,7 @@ # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. -from typing import List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union import toml import re import fnmatch @@ -24,14 +24,14 @@ class OlahRule(object): - def __init__(self) -> None: - self.repo = "" - self.type = "*" - self.allow = False - self.use_re = False + def __init__(self, repo: str = "", type: str = "*", allow: bool = False, use_re: bool = False) -> None: + self.repo = repo + self.type = type + self.allow = allow + self.use_re = use_re @staticmethod - def from_dict(data) -> "OlahRule": + def from_dict(data: Dict[str, Any]) -> "OlahRule": out = OlahRule() if "repo" in data: out.repo = data["repo"] @@ -59,7 +59,7 @@ def __init__(self) -> None: self.rules: List[OlahRule] = [] @staticmethod - def from_list(data) -> "OlahRuleList": + def from_list(data: List[Dict[str, Any]]) -> "OlahRuleList": out = OlahRuleList() for item in data: out.rules.append(OlahRule.from_dict(item)) diff --git a/olah/errors.py b/olah/errors.py index 7bb8b24..2d741c3 100644 --- a/olah/errors.py +++ b/olah/errors.py @@ -22,7 +22,7 @@ def error_repo_not_found() -> JSONResponse: def error_page_not_found() -> JSONResponse: return JSONResponse( - content={"error":"Sorry, we can't find the page you are looking for."}, + content={"error": "Sorry, we can't find the page you are looking for."}, headers={ "x-error-code": "RepoNotFound", "x-error-message": "Sorry, we can't find the page you are looking for.", @@ -30,15 +30,17 @@ def error_page_not_found() -> JSONResponse: status_code=404, ) + def error_entry_not_found_branch(branch: str, path: str) -> Response: return Response( headers={ "x-error-code": "EntryNotFound", - "x-error-message": f"{path} does not exist on \"{branch}\"", + "x-error-message": f'{path} does not exist on "{branch}"', }, status_code=404, ) + def error_entry_not_found() -> Response: return Response( headers={ @@ -48,6 +50,7 @@ def error_entry_not_found() -> Response: status_code=404, ) + def error_revision_not_found(revision: str) -> Response: return JSONResponse( content={"error": f"Invalid rev id: {revision}"}, @@ -58,6 +61,7 @@ def error_revision_not_found(revision: str) -> Response: status_code=404, ) + # Olah Custom Messages def error_proxy_timeout() -> Response: return Response( @@ -68,6 +72,7 @@ def error_proxy_timeout() -> Response: status_code=504, ) + def error_proxy_invalid_data() -> Response: return Response( headers={ @@ -75,4 +80,4 @@ def error_proxy_invalid_data() -> Response: "x-error-message": "Proxy Invalid Data", }, status_code=504, - ) \ No newline at end of file + ) diff --git a/olah/mirror/meta.py b/olah/mirror/meta.py index 5c03ba8..c009222 100644 --- a/olah/mirror/meta.py +++ b/olah/mirror/meta.py @@ -6,6 +6,9 @@ # https://opensource.org/licenses/MIT. +from typing import Any, Dict + + class RepoMeta(object): def __init__(self) -> None: self._id = None @@ -25,7 +28,7 @@ def __init__(self) -> None: self.siblings = None self.createdAt = None - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "_id": self._id, "id": self.id, diff --git a/olah/mirror/repos.py b/olah/mirror/repos.py index 5ada704..0122dfd 100644 --- a/olah/mirror/repos.py +++ b/olah/mirror/repos.py @@ -66,7 +66,7 @@ def _get_description(self, commit: Commit) -> str: readme = self._get_readme(commit) return self._remove_card(readme) - def _get_tree_filepaths_recursive(self, tree, include_dir=False) -> List[str]: + def _get_tree_filepaths_recursive(self, tree: Tree, include_dir: bool = False) -> List[str]: out_paths = [] for entry in tree: if entry.type == "tree": @@ -80,7 +80,7 @@ def _get_tree_filepaths_recursive(self, tree, include_dir=False) -> List[str]: def _get_commit_filepaths_recursive(self, commit: Commit) -> List[str]: return self._get_tree_filepaths_recursive(commit.tree) - def _get_path_info(self, entry: IndexObjUnion, expand: bool=False) -> Dict[str, Union[int, str]]: + def _get_path_info(self, entry: IndexObjUnion, expand: bool = False) -> Dict[str, Union[int, str]]: lfs = False if entry.type != "tree": t = "file" diff --git a/olah/proxy/files.py b/olah/proxy/files.py index 9da6d91..21caab2 100644 --- a/olah/proxy/files.py +++ b/olah/proxy/files.py @@ -10,11 +10,7 @@ import os from typing import Dict, List, Literal, Optional, Tuple from fastapi import Request - -from requests.structures import CaseInsensitiveDict import httpx -import zlib -from starlette.datastructures import URL from urllib.parse import urlparse, urljoin from olah.constants import ( @@ -43,6 +39,7 @@ from olah.utils.rule_utils import check_cache_rules_hf from olah.utils.file_utils import make_dirs from olah.constants import CHUNK_SIZE, LFS_FILE_BLOCK, WORKER_API_TIMEOUT +from olah.utils.zip_utils import decompress_data def get_block_info(pos: int, block_size: int, file_size: int) -> Tuple[int, int, int]: @@ -141,40 +138,10 @@ async def _get_file_range_from_remote( yield raw_chunk chunk_bytes += len(raw_chunk) - # If result is compressed - if "content-encoding" in response.headers: - final_data = raw_data - algorithms = response.headers["content-encoding"].split(',') - for algo in algorithms: - algo = algo.strip().lower() - if algo == "gzip": - try: - final_data = zlib.decompress(raw_data, zlib.MAX_WBITS | 16) # 解压缩 - except Exception as e: - print(f"Error decompressing gzip data: {e}") - elif algo == "compress": - print(f"Unsupported decompression algorithm: {algo}") - elif algo == "deflate": - try: - final_data = zlib.decompress(raw_data) - except Exception as e: - print(f"Error decompressing deflate data: {e}") - elif algo == "br": - try: - import brotli - final_data = brotli.decompress(raw_data) - except Exception as e: - print(f"Error decompressing Brotli data: {e}") - elif algo == "zstd": - try: - import zstandard - final_data = zstandard.ZstdDecompressor().decompress(raw_data) - except Exception as e: - print(f"Error decompressing Zstandard data: {e}") - else: - print(f"Unsupported compression algorithm: {algo}") - chunk_bytes = len(final_data) - yield final_data + if "content-encoding" in response.headers: + final_data = decompress_data(raw_data, response.headers.get("content-encoding", None)) + chunk_bytes = len(final_data) + yield final_data if "content-length" in response.headers: if "content-encoding" in response.headers: response_content_length = len(final_data) diff --git a/olah/proxy/meta.py b/olah/proxy/meta.py index 6f0f904..bcb6a7d 100644 --- a/olah/proxy/meta.py +++ b/olah/proxy/meta.py @@ -8,7 +8,7 @@ import os import shutil import tempfile -from typing import Dict, Literal, Optional +from typing import Dict, Literal, Optional, AsyncGenerator, Union from urllib.parse import urljoin from fastapi import FastAPI, Request @@ -20,7 +20,7 @@ from olah.utils.repo_utils import get_org_repo from olah.utils.file_utils import make_dirs -async def _meta_cache_generator(save_path: str): +async def _meta_cache_generator(save_path: str) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]: cache_rq = await read_cache_request(save_path) yield cache_rq["headers"] yield cache_rq["content"] @@ -33,7 +33,7 @@ async def _meta_proxy_generator( method: str, allow_cache: bool, save_path: str, -): +) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]: async with httpx.AsyncClient(follow_redirects=True) as client: content_chunks = [] async with client.stream( @@ -61,7 +61,7 @@ async def _meta_proxy_generator( save_path, response_status_code, response_headers, bytes(content) ) -# TODO: remove param `request` + async def meta_generator( app: FastAPI, repo_type: Literal["models", "datasets", "spaces"], @@ -71,7 +71,7 @@ async def meta_generator( override_cache: bool, method: str, authorization: Optional[str], -): +) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]: headers = {} if authorization is not None: headers["authorization"] = authorization diff --git a/olah/proxy/pathsinfo.py b/olah/proxy/pathsinfo.py index 1e9115e..c0c7e23 100644 --- a/olah/proxy/pathsinfo.py +++ b/olah/proxy/pathsinfo.py @@ -7,7 +7,7 @@ import json import os -from typing import Dict, List, Literal, Optional +from typing import AsyncGenerator, Dict, List, Literal, Optional, Tuple, Union from urllib.parse import quote, urljoin from fastapi import FastAPI, Request @@ -20,7 +20,7 @@ from olah.utils.file_utils import make_dirs -async def _pathsinfo_cache(save_path: str): +async def _pathsinfo_cache(save_path: str) -> Tuple[int, Dict[str, str], bytes]: cache_rq = await read_cache_request(save_path) return cache_rq["status_code"], cache_rq["headers"], cache_rq["content"] @@ -33,7 +33,7 @@ async def _pathsinfo_proxy( path: str, allow_cache: bool, save_path: str, -): +) -> Tuple[int, Dict[str, str], bytes]: headers = {k: v for k, v in headers.items()} if "content-length" in headers: headers.pop("content-length") @@ -67,7 +67,7 @@ async def pathsinfo_generator( override_cache: bool, method: str, authorization: Optional[str], -): +) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]: headers = {} if authorization is not None: headers["authorization"] = authorization diff --git a/olah/proxy/tree.py b/olah/proxy/tree.py index 22e4467..92eb338 100644 --- a/olah/proxy/tree.py +++ b/olah/proxy/tree.py @@ -6,7 +6,7 @@ # https://opensource.org/licenses/MIT. import os -from typing import Dict, Literal, Mapping, Optional +from typing import Dict, Literal, Mapping, Optional, AsyncGenerator, Union from urllib.parse import urljoin from fastapi import FastAPI, Request @@ -19,7 +19,7 @@ from olah.utils.file_utils import make_dirs -async def _tree_cache_generator(save_path: str): +async def _tree_cache_generator(save_path: str) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]: cache_rq = await read_cache_request(save_path) yield cache_rq["status_code"] yield cache_rq["headers"] @@ -33,7 +33,7 @@ async def _tree_proxy_generator( params: Mapping[str, str], allow_cache: bool, save_path: str, -): +) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]: async with httpx.AsyncClient(follow_redirects=True) as client: content_chunks = [] async with client.stream( @@ -77,7 +77,7 @@ async def tree_generator( override_cache: bool, method: str, authorization: Optional[str], -): +) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]: headers = {} if authorization is not None: headers["authorization"] = authorization diff --git a/olah/server.py b/olah/server.py index 1a4f80a..70f9c23 100644 --- a/olah/server.py +++ b/olah/server.py @@ -32,6 +32,7 @@ from olah.proxy.tree import tree_generator from olah.utils.disk_utils import convert_bytes_to_human_readable, convert_to_bytes, get_folder_size, sort_files_by_access_time, sort_files_by_modify_time, sort_files_by_size from olah.utils.url_utils import clean_path +from olah.utils.zip_utils import decompress_data BASE_SETTINGS = False if not BASE_SETTINGS: @@ -709,10 +710,16 @@ async def whoami_v2(request: Request): headers=new_headers, timeout=10, ) + # final_content = decompress_data(response.headers.get("content-encoding", None)) + response_headers = {k.lower(): v for k, v in response.headers.items()} + if "content-encoding" in response_headers: + response_headers.pop("content-encoding") + if "content-length" in response_headers: + response_headers.pop("content-length") return Response( content=response.content, status_code=response.status_code, - headers=response.headers, + headers=response_headers, ) diff --git a/olah/utils/repo_utils.py b/olah/utils/repo_utils.py index 0f7f784..d055f15 100644 --- a/olah/utils/repo_utils.py +++ b/olah/utils/repo_utils.py @@ -189,12 +189,15 @@ async def get_newest_commit_hf( return await get_newest_commit_hf_offline(app, repo_type, org, repo) try: async with httpx.AsyncClient() as client: - response = await client.get(url, headers={"authorization": authorization}, timeout=WORKER_API_TIMEOUT) + headers = {} + if authorization is not None: + headers["authorization"] = authorization + response = await client.get(url, headers=headers, timeout=WORKER_API_TIMEOUT) if response.status_code != 200: return await get_newest_commit_hf_offline(app, repo_type, org, repo) obj = json.loads(response.text) return obj.get("sha", None) - except: + except httpx.TimeoutException as e: return await get_newest_commit_hf_offline(app, repo_type, org, repo) diff --git a/olah/utils/zip_utils.py b/olah/utils/zip_utils.py new file mode 100644 index 0000000..3476149 --- /dev/null +++ b/olah/utils/zip_utils.py @@ -0,0 +1,42 @@ + + +from typing import Optional +import zlib + + +def decompress_data(raw_data: bytes, content_encoding: Optional[str]): + # If result is compressed + if content_encoding is not None: + final_data = raw_data + algorithms = content_encoding.split(',') + for algo in algorithms: + algo = algo.strip().lower() + if algo == "gzip": + try: + final_data = zlib.decompress(raw_data, zlib.MAX_WBITS | 16) # 解压缩 + except Exception as e: + print(f"Error decompressing gzip data: {e}") + elif algo == "compress": + print(f"Unsupported decompression algorithm: {algo}") + elif algo == "deflate": + try: + final_data = zlib.decompress(raw_data) + except Exception as e: + print(f"Error decompressing deflate data: {e}") + elif algo == "br": + try: + import brotli + final_data = brotli.decompress(raw_data) + except Exception as e: + print(f"Error decompressing Brotli data: {e}") + elif algo == "zstd": + try: + import zstandard + final_data = zstandard.ZstdDecompressor().decompress(raw_data) + except Exception as e: + print(f"Error decompressing Zstandard data: {e}") + else: + print(f"Unsupported compression algorithm: {algo}") + return final_data + else: + return raw_data \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6ce8df3..f86d33a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "olah" -version = "0.3.1" +version = "0.3.2" description = "Self-hosted lightweight huggingface mirror." readme = "README.md" requires-python = ">=3.8" @@ -19,7 +19,7 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["black==24.4.2", "pylint==3.2.5", "pytest==8.2.2"] +dev = ["black==24.10.0", "pylint==3.3.1", "pytest==8.3.3"] [project.urls] "Homepage" = "https://github.com/vtuber-plan/olah" diff --git a/requirements.txt b/requirements.txt index 52f3e72..a83e52e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,12 @@ -fastapi==0.111.0 +fastapi==0.115.2 fastapi-utils==0.7.0 GitPython==3.1.43 httpx==0.27.0 pydantic==2.8.2 pydantic-settings==2.4.0 toml==0.10.2 -huggingface_hub==0.23.4 -pytest==8.2.2 +huggingface_hub==0.26.0 +pytest==8.3.3 cachetools==5.4.0 PyYAML==6.0.1 tenacity==8.5.0