Skip to content

Commit

Permalink
Merge pull request #31 from vtuber-plan/dev
Browse files Browse the repository at this point in the history
V0.3.2
  • Loading branch information
jstzwj authored Oct 19, 2024
2 parents 72ea99c + c780ff1 commit 4be5ed1
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 73 deletions.
16 changes: 8 additions & 8 deletions olah/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

from typing import List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
import toml
import re
import fnmatch
Expand All @@ -24,14 +24,14 @@


class OlahRule(object):
def __init__(self) -> None:
self.repo = ""
self.type = "*"
self.allow = False
self.use_re = False
def __init__(self, repo: str = "", type: str = "*", allow: bool = False, use_re: bool = False) -> None:
self.repo = repo
self.type = type
self.allow = allow
self.use_re = use_re

@staticmethod
def from_dict(data) -> "OlahRule":
def from_dict(data: Dict[str, Any]) -> "OlahRule":
out = OlahRule()
if "repo" in data:
out.repo = data["repo"]
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(self) -> None:
self.rules: List[OlahRule] = []

@staticmethod
def from_list(data) -> "OlahRuleList":
def from_list(data: List[Dict[str, Any]]) -> "OlahRuleList":
out = OlahRuleList()
for item in data:
out.rules.append(OlahRule.from_dict(item))
Expand Down
11 changes: 8 additions & 3 deletions olah/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,25 @@ def error_repo_not_found() -> JSONResponse:

def error_page_not_found() -> JSONResponse:
return JSONResponse(
content={"error":"Sorry, we can't find the page you are looking for."},
content={"error": "Sorry, we can't find the page you are looking for."},
headers={
"x-error-code": "RepoNotFound",
"x-error-message": "Sorry, we can't find the page you are looking for.",
},
status_code=404,
)


def error_entry_not_found_branch(branch: str, path: str) -> Response:
return Response(
headers={
"x-error-code": "EntryNotFound",
"x-error-message": f"{path} does not exist on \"{branch}\"",
"x-error-message": f'{path} does not exist on "{branch}"',
},
status_code=404,
)


def error_entry_not_found() -> Response:
return Response(
headers={
Expand All @@ -48,6 +50,7 @@ def error_entry_not_found() -> Response:
status_code=404,
)


def error_revision_not_found(revision: str) -> Response:
return JSONResponse(
content={"error": f"Invalid rev id: {revision}"},
Expand All @@ -58,6 +61,7 @@ def error_revision_not_found(revision: str) -> Response:
status_code=404,
)


# Olah Custom Messages
def error_proxy_timeout() -> Response:
return Response(
Expand All @@ -68,11 +72,12 @@ def error_proxy_timeout() -> Response:
status_code=504,
)


def error_proxy_invalid_data() -> Response:
return Response(
headers={
"x-error-code": "ProxyInvalidData",
"x-error-message": "Proxy Invalid Data",
},
status_code=504,
)
)
5 changes: 4 additions & 1 deletion olah/mirror/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
# https://opensource.org/licenses/MIT.


from typing import Any, Dict


class RepoMeta(object):
def __init__(self) -> None:
self._id = None
Expand All @@ -25,7 +28,7 @@ def __init__(self) -> None:
self.siblings = None
self.createdAt = None

def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
return {
"_id": self._id,
"id": self.id,
Expand Down
4 changes: 2 additions & 2 deletions olah/mirror/repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _get_description(self, commit: Commit) -> str:
readme = self._get_readme(commit)
return self._remove_card(readme)

def _get_tree_filepaths_recursive(self, tree, include_dir=False) -> List[str]:
def _get_tree_filepaths_recursive(self, tree: Tree, include_dir: bool = False) -> List[str]:
out_paths = []
for entry in tree:
if entry.type == "tree":
Expand All @@ -80,7 +80,7 @@ def _get_tree_filepaths_recursive(self, tree, include_dir=False) -> List[str]:
def _get_commit_filepaths_recursive(self, commit: Commit) -> List[str]:
return self._get_tree_filepaths_recursive(commit.tree)

def _get_path_info(self, entry: IndexObjUnion, expand: bool=False) -> Dict[str, Union[int, str]]:
def _get_path_info(self, entry: IndexObjUnion, expand: bool = False) -> Dict[str, Union[int, str]]:
lfs = False
if entry.type != "tree":
t = "file"
Expand Down
43 changes: 5 additions & 38 deletions olah/proxy/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
import os
from typing import Dict, List, Literal, Optional, Tuple
from fastapi import Request

from requests.structures import CaseInsensitiveDict
import httpx
import zlib
from starlette.datastructures import URL
from urllib.parse import urlparse, urljoin

from olah.constants import (
Expand Down Expand Up @@ -43,6 +39,7 @@
from olah.utils.rule_utils import check_cache_rules_hf
from olah.utils.file_utils import make_dirs
from olah.constants import CHUNK_SIZE, LFS_FILE_BLOCK, WORKER_API_TIMEOUT
from olah.utils.zip_utils import decompress_data


def get_block_info(pos: int, block_size: int, file_size: int) -> Tuple[int, int, int]:
Expand Down Expand Up @@ -141,40 +138,10 @@ async def _get_file_range_from_remote(
yield raw_chunk
chunk_bytes += len(raw_chunk)

# If result is compressed
if "content-encoding" in response.headers:
final_data = raw_data
algorithms = response.headers["content-encoding"].split(',')
for algo in algorithms:
algo = algo.strip().lower()
if algo == "gzip":
try:
final_data = zlib.decompress(raw_data, zlib.MAX_WBITS | 16) # 解压缩
except Exception as e:
print(f"Error decompressing gzip data: {e}")
elif algo == "compress":
print(f"Unsupported decompression algorithm: {algo}")
elif algo == "deflate":
try:
final_data = zlib.decompress(raw_data)
except Exception as e:
print(f"Error decompressing deflate data: {e}")
elif algo == "br":
try:
import brotli
final_data = brotli.decompress(raw_data)
except Exception as e:
print(f"Error decompressing Brotli data: {e}")
elif algo == "zstd":
try:
import zstandard
final_data = zstandard.ZstdDecompressor().decompress(raw_data)
except Exception as e:
print(f"Error decompressing Zstandard data: {e}")
else:
print(f"Unsupported compression algorithm: {algo}")
chunk_bytes = len(final_data)
yield final_data
if "content-encoding" in response.headers:
final_data = decompress_data(raw_data, response.headers.get("content-encoding", None))
chunk_bytes = len(final_data)
yield final_data
if "content-length" in response.headers:
if "content-encoding" in response.headers:
response_content_length = len(final_data)
Expand Down
10 changes: 5 additions & 5 deletions olah/proxy/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import shutil
import tempfile
from typing import Dict, Literal, Optional
from typing import Dict, Literal, Optional, AsyncGenerator, Union
from urllib.parse import urljoin
from fastapi import FastAPI, Request

Expand All @@ -20,7 +20,7 @@
from olah.utils.repo_utils import get_org_repo
from olah.utils.file_utils import make_dirs

async def _meta_cache_generator(save_path: str):
async def _meta_cache_generator(save_path: str) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]:
cache_rq = await read_cache_request(save_path)
yield cache_rq["headers"]
yield cache_rq["content"]
Expand All @@ -33,7 +33,7 @@ async def _meta_proxy_generator(
method: str,
allow_cache: bool,
save_path: str,
):
) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]:
async with httpx.AsyncClient(follow_redirects=True) as client:
content_chunks = []
async with client.stream(
Expand Down Expand Up @@ -61,7 +61,7 @@ async def _meta_proxy_generator(
save_path, response_status_code, response_headers, bytes(content)
)

# TODO: remove param `request`

async def meta_generator(
app: FastAPI,
repo_type: Literal["models", "datasets", "spaces"],
Expand All @@ -71,7 +71,7 @@ async def meta_generator(
override_cache: bool,
method: str,
authorization: Optional[str],
):
) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]:
headers = {}
if authorization is not None:
headers["authorization"] = authorization
Expand Down
8 changes: 4 additions & 4 deletions olah/proxy/pathsinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import json
import os
from typing import Dict, List, Literal, Optional
from typing import AsyncGenerator, Dict, List, Literal, Optional, Tuple, Union
from urllib.parse import quote, urljoin
from fastapi import FastAPI, Request

Expand All @@ -20,7 +20,7 @@
from olah.utils.file_utils import make_dirs


async def _pathsinfo_cache(save_path: str):
async def _pathsinfo_cache(save_path: str) -> Tuple[int, Dict[str, str], bytes]:
cache_rq = await read_cache_request(save_path)
return cache_rq["status_code"], cache_rq["headers"], cache_rq["content"]

Expand All @@ -33,7 +33,7 @@ async def _pathsinfo_proxy(
path: str,
allow_cache: bool,
save_path: str,
):
) -> Tuple[int, Dict[str, str], bytes]:
headers = {k: v for k, v in headers.items()}
if "content-length" in headers:
headers.pop("content-length")
Expand Down Expand Up @@ -67,7 +67,7 @@ async def pathsinfo_generator(
override_cache: bool,
method: str,
authorization: Optional[str],
):
) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]:
headers = {}
if authorization is not None:
headers["authorization"] = authorization
Expand Down
8 changes: 4 additions & 4 deletions olah/proxy/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# https://opensource.org/licenses/MIT.

import os
from typing import Dict, Literal, Mapping, Optional
from typing import Dict, Literal, Mapping, Optional, AsyncGenerator, Union
from urllib.parse import urljoin
from fastapi import FastAPI, Request

Expand All @@ -19,7 +19,7 @@
from olah.utils.file_utils import make_dirs


async def _tree_cache_generator(save_path: str):
async def _tree_cache_generator(save_path: str) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]:
cache_rq = await read_cache_request(save_path)
yield cache_rq["status_code"]
yield cache_rq["headers"]
Expand All @@ -33,7 +33,7 @@ async def _tree_proxy_generator(
params: Mapping[str, str],
allow_cache: bool,
save_path: str,
):
) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]:
async with httpx.AsyncClient(follow_redirects=True) as client:
content_chunks = []
async with client.stream(
Expand Down Expand Up @@ -77,7 +77,7 @@ async def tree_generator(
override_cache: bool,
method: str,
authorization: Optional[str],
):
) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]:
headers = {}
if authorization is not None:
headers["authorization"] = authorization
Expand Down
9 changes: 8 additions & 1 deletion olah/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from olah.proxy.tree import tree_generator
from olah.utils.disk_utils import convert_bytes_to_human_readable, convert_to_bytes, get_folder_size, sort_files_by_access_time, sort_files_by_modify_time, sort_files_by_size
from olah.utils.url_utils import clean_path
from olah.utils.zip_utils import decompress_data

BASE_SETTINGS = False
if not BASE_SETTINGS:
Expand Down Expand Up @@ -709,10 +710,16 @@ async def whoami_v2(request: Request):
headers=new_headers,
timeout=10,
)
# final_content = decompress_data(response.headers.get("content-encoding", None))
response_headers = {k.lower(): v for k, v in response.headers.items()}
if "content-encoding" in response_headers:
response_headers.pop("content-encoding")
if "content-length" in response_headers:
response_headers.pop("content-length")
return Response(
content=response.content,
status_code=response.status_code,
headers=response.headers,
headers=response_headers,
)


Expand Down
7 changes: 5 additions & 2 deletions olah/utils/repo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,15 @@ async def get_newest_commit_hf(
return await get_newest_commit_hf_offline(app, repo_type, org, repo)
try:
async with httpx.AsyncClient() as client:
response = await client.get(url, headers={"authorization": authorization}, timeout=WORKER_API_TIMEOUT)
headers = {}
if authorization is not None:
headers["authorization"] = authorization
response = await client.get(url, headers=headers, timeout=WORKER_API_TIMEOUT)
if response.status_code != 200:
return await get_newest_commit_hf_offline(app, repo_type, org, repo)
obj = json.loads(response.text)
return obj.get("sha", None)
except:
except httpx.TimeoutException as e:
return await get_newest_commit_hf_offline(app, repo_type, org, repo)


Expand Down
Loading

0 comments on commit 4be5ed1

Please sign in to comment.