From c382c33eb690bcca21b232dbef8d7ff615e725e4 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Fri, 19 Jul 2024 04:01:50 +0800 Subject: [PATCH] offline mode bug fix --- README.md | 8 +- README_zh.md | 16 ++- olah/files.py | 38 ++++-- olah/server.py | 269 ++++++++++++++++++++++++--------------- olah/utils/logging.py | 17 ++- olah/utils/olah_cache.py | 61 ++++++--- olah/utils/url_utils.py | 11 +- pyproject.toml | 4 +- requirements.txt | 4 +- 9 files changed, 286 insertions(+), 142 deletions(-) diff --git a/README.md b/README.md index 3322abe..556705a 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ -# olah -Olah is self-hosted lightweight huggingface mirror service. `Olah` means `hello` in Hilichurlian. +

Olah

+ +

+Self-hosted Lightweight Huggingface Mirror Service + +Olah is a self-hosted lightweight huggingface mirror service. `Olah` means `hello` in Hilichurlian. Olah implemented the `mirroring` feature for huggingface resources, rather than just a simple `reverse proxy`. Olah does not immediately mirror the entire huggingface website but mirrors the resources at the file block level when users download them (or we can say cache them). diff --git a/README_zh.md b/README_zh.md index 552033e..95e06bd 100644 --- a/README_zh.md +++ b/README_zh.md @@ -1,4 +1,9 @@ -# olah +

Olah

+ + +

+自托管的轻量级HuggingFace镜像服务 + Olah是一种自托管的轻量级HuggingFace镜像服务。`Olah`来源于丘丘人语,在丘丘人语中意味着`你好`。 Olah真正地实现了huggingface资源的`镜像`功能,而不仅仅是一个简单的`反向代理`。 Olah并不会立刻对huggingface全站进行镜像,而是在用户下载的同时在文件块级别对资源进行镜像(或者我们可以说是缓存)。 @@ -106,6 +111,15 @@ python -m olah.server --host localhost --port 8090 --repos-path ./hf_mirrors **注意,不同版本之间的缓存数据不能迁移,请删除缓存文件夹后再进行olah的升级** +## 更多配置 + +更多配置可以通过配置文件进行控制,通过命令参数传入`configs.toml`以设置配置文件路径: +```bash +python -m olah.server -c configs.toml +``` + +完整的配置文件内容见[assets/full_configs.toml](https://github.com/vtuber-plan/olah/blob/main/assets/full_configs.toml) + ## 许可证 olah采用MIT许可证发布。 diff --git a/olah/files.py b/olah/files.py index b85c486..63ddb28 100644 --- a/olah/files.py +++ b/olah/files.py @@ -176,16 +176,19 @@ async def _file_chunk_head( allow_cache: bool, file_size: int, ): - async with client.stream( - method=method, - url=url, - headers=headers, - timeout=WORKER_API_TIMEOUT, - ) as response: - async for raw_chunk in response.aiter_raw(): - if not raw_chunk: - continue - yield raw_chunk + if not app.app_settings.config.offline: + async with client.stream( + method=method, + url=url, + headers=headers, + timeout=WORKER_API_TIMEOUT, + ) as response: + async for raw_chunk in response.aiter_raw(): + if not raw_chunk: + continue + yield raw_chunk + else: + yield b"" async def _file_realtime_stream( @@ -206,6 +209,7 @@ async def _file_realtime_stream( else: hf_url = url + # Handle Redirection if not app.app_settings.config.offline: async with httpx.AsyncClient() as client: response = await client.request( @@ -222,11 +226,25 @@ async def _file_realtime_stream( if len(parsed_url.netloc) != 0: new_loc = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response.headers["location"])) new_headers["location"] = new_loc + + if allow_cache: + with open(head_path, "w", encoding="utf-8") as f: + f.write(json.dumps(new_headers, ensure_ascii=False)) yield response.status_code yield new_headers yield response.content return + else: + if os.path.exists(head_path): + with open(head_path, "r", encoding="utf-8") as f: + head_content = json.loads(f.read()) + + if "location" in head_content: + yield 302 + yield head_content + yield b"" + return async with httpx.AsyncClient() as client: # redirect_loc = await _get_redirected_url(client, method, url, request_headers) diff --git a/olah/server.py b/olah/server.py index 813b1a9..929b0ae 100644 --- a/olah/server.py +++ b/olah/server.py @@ -1,12 +1,13 @@ # coding=utf-8 # Copyright 2024 XiaHan -# +# # Use of this source code is governed by an MIT-style # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. import os import argparse +import traceback from typing import Annotated, Optional, Union from urllib.parse import urljoin from fastapi import FastAPI, Header, Request @@ -28,35 +29,27 @@ class AppSettings(BaseSettings): config: OlahConfig = OlahConfig() repos_path: str = "./repos" + # ====================== # API Hooks # ====================== -@app.get("/api/{repo_type}/{org_repo}") -async def meta_proxy(repo_type: str, org_repo: str, request: Request): - org, repo = parse_org_repo(org_repo) - if org is None and repo is None: - return Response(content="This repository is not accessible.", status_code=404) - if not await check_proxy_rules_hf(app, repo_type, org, repo): - return Response(content="This repository is forbidden by the mirror.", status_code=403) - - try: - if not await check_commit_hf(app, repo_type, org, repo, None): - return Response(content="This repository is not accessible.", status_code=404) - new_commit = await get_newest_commit_hf(app, repo_type, org, repo) - generator = meta_generator(app, repo_type, org, repo, new_commit, request) - headers = await generator.__anext__() - return StreamingResponse(generator, headers=headers) - except httpx.ConnectTimeout: - return Response(status_code=504) - -@app.get("/api/{repo_type}/{org}/{repo}/revision/{commit}") -async def meta_proxy_commit2(repo_type: str, org: str, repo: str, commit: str, request: Request): +async def meta_proxy_common(repo_type: str, org: str, repo: str, commit: str, request: Request) -> Response: if not await check_proxy_rules_hf(app, repo_type, org, repo): - return Response(content="This repository is forbidden by the mirror. ", status_code=403) + return Response( + content="This repository is forbidden by the mirror. ", status_code=403 + ) try: - if not await check_commit_hf(app, repo_type, org, repo, commit): - return Response(content="This repository is not accessible. ", status_code=404) + if not app.app_settings.config.offline and not await check_commit_hf( + app, repo_type, org, repo, commit + ): + return Response( + content="This repository is not accessible. ", status_code=404 + ) commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) + if commit_sha is None: + return Response( + content="This repository is not accessible. ", status_code=404 + ) # if branch name and online mode, refresh branch info if commit_sha != commit and not app.app_settings.config.offline: @@ -66,73 +59,122 @@ async def meta_proxy_commit2(repo_type: str, org: str, repo: str, commit: str, r headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) except httpx.ConnectTimeout: + traceback.print_exc() return Response(status_code=504) -@app.get("/api/{repo_type}/{org_repo}/revision/{commit}") -async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request: Request): + +@app.get("/api/{repo_type}/{org_repo}") +async def meta_proxy(repo_type: str, org_repo: str, request: Request): org, repo = parse_org_repo(org_repo) if org is None and repo is None: return Response(content="This repository is not accessible.", status_code=404) + if not app.app_settings.config.offline: + new_commit = await get_newest_commit_hf(app, repo_type, org, repo) + else: + new_commit = "main" + return await meta_proxy_common( + repo_type=repo_type, org=org, repo=repo, commit=new_commit, request=request + ) - if not await check_proxy_rules_hf(app, repo_type, org, repo): - return Response(content="This repository is forbidden by the mirror. ", status_code=403) - try: - if not await check_commit_hf(app, repo_type, org, repo, commit): - return Response(content="This repository is not accessible. ", status_code=404) - commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) - # if branch name and online mode, refresh branch info - if commit_sha != commit and not app.app_settings.config.offline: - await meta_proxy_cache(app, repo_type, org, repo, commit, request) +@app.get("/api/{repo_type}/{org}/{repo}/revision/{commit}") +async def meta_proxy_commit2( + repo_type: str, org: str, repo: str, commit: str, request: Request +): + return await meta_proxy_common( + repo_type=repo_type, org=org, repo=repo, commit=commit, request=request + ) - generator = meta_generator(app, repo_type, org, repo, commit_sha, request) - headers = await generator.__anext__() - return StreamingResponse(generator, headers=headers) - except httpx.ConnectTimeout: - return Response(status_code=504) + +@app.get("/api/{repo_type}/{org_repo}/revision/{commit}") +async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request: Request): + org, repo = parse_org_repo(org_repo) + if org is None and repo is None: + return Response(content="This repository is not accessible.", status_code=404) + + return await meta_proxy_common( + repo_type=repo_type, org=org, repo=repo, commit=commit, request=request + ) # ====================== # File Head Hooks # ====================== -@app.head("/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path:path}") -async def file_head3(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str): +async def file_head_common( + repo_type: str, org: str, repo: str, commit: str, file_path: str, request: Request +) -> Response: if not await check_proxy_rules_hf(app, repo_type, org, repo): - return Response(content="This repository is forbidden by the mirror. ", status_code=403) + return Response( + content="This repository is forbidden by the mirror. ", status_code=403 + ) try: - if not await check_commit_hf(app, repo_type, org, repo, commit): - return Response(content="This repository is not accessible. ", status_code=404) + if not app.app_settings.config.offline and not await check_commit_hf( + app, repo_type, org, repo, commit + ): + return Response( + content="This repository is not accessible. ", status_code=404 + ) commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) - generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path=file_path, method="HEAD", request=request) + if commit_sha is None: + return Response( + content="This repository is not accessible. ", status_code=404 + ) + generator = await file_get_generator( + app, + repo_type, + org, + repo, + commit_sha, + file_path=file_path, + method="HEAD", + request=request, + ) status_code = await generator.__anext__() headers = await generator.__anext__() return StreamingResponse(generator, headers=headers, status_code=status_code) except httpx.ConnectTimeout: - return Response(status_code=504) + traceback.print_exc() + return Response(status_code=504) + + +@app.head("/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path:path}") +async def file_head3( + repo_type: str, org: str, repo: str, commit: str, file_path: str, request: Request +): + return await file_head_common( + repo_type=repo_type, + org=org, + repo=repo, + commit=commit, + file_path=file_path, + request=request, + ) + @app.head("/{org_or_repo_type}/{repo_name}/resolve/{commit}/{file_path:path}") -async def file_head2(org_or_repo_type: str, repo_name: str, commit: str, file_path: str, request: Request): +async def file_head2( + org_or_repo_type: str, repo_name: str, commit: str, file_path: str, request: Request +): if org_or_repo_type in ["models", "datasets", "spaces"]: repo_type: str = org_or_repo_type org, repo = parse_org_repo(repo_name) if org is None and repo is None: - return Response(content="This repository is not accessible.", status_code=404) + return Response( + content="This repository is not accessible.", status_code=404 + ) else: repo_type: str = "models" org, repo = org_or_repo_type, repo_name - if not await check_proxy_rules_hf(app, repo_type, org, repo): - return Response(content="This repository is forbidden by the mirror. ", status_code=403) - try: - if org is not None and not await check_commit_hf(app, repo_type, org, repo, commit): - return Response(content="This repository is not accessible. ", status_code=404) - commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) - generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path=file_path, method="HEAD", request=request) - status_code = await generator.__anext__() - headers = await generator.__anext__() - return StreamingResponse(generator, headers=headers, status_code=status_code) - except httpx.ConnectTimeout: - return Response(status_code=504) + return await file_head_common( + repo_type=repo_type, + org=org, + repo=repo, + commit=commit, + file_path=file_path, + request=request, + ) + @app.head("/{org_repo}/resolve/{commit}/{file_path:path}") async def file_head(org_repo: str, commit: str, file_path: str, request: Request): @@ -140,20 +182,15 @@ async def file_head(org_repo: str, commit: str, file_path: str, request: Request org, repo = parse_org_repo(org_repo) if org is None and repo is None: return Response(content="This repository is not accessible.", status_code=404) + return await file_head_common( + repo_type=repo_type, + org=org, + repo=repo, + commit=commit, + file_path=file_path, + request=request, + ) - if not await check_proxy_rules_hf(app, repo_type, org, repo): - return Response(content="This repository is forbidden by the mirror. ", status_code=403) - try: - if org is not None and not await check_commit_hf(app, repo_type, org, repo, commit): - return Response(content="This repository is not accessible. ", status_code=404) - - commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) - generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path=file_path, method="HEAD", request=request) - status_code = await generator.__anext__() - headers = await generator.__anext__() - return StreamingResponse(generator, headers=headers, status_code=status_code) - except httpx.ConnectTimeout: - return Response(status_code=504) @app.head("/{org_repo}/{hash_file}") @app.head("/{repo_type}/{org_repo}/{hash_file}") @@ -173,24 +210,56 @@ async def cdn_file_head(org_repo: str, hash_file: str, request: Request, repo_ty except httpx.ConnectTimeout: return Response(status_code=504) + # ====================== # File Hooks # ====================== -@app.get("/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path:path}") -async def file_get3(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str): +async def file_get_common( + repo_type: str, org: str, repo: str, commit: str, file_path: str, request: Request +) -> Response: if not await check_proxy_rules_hf(app, repo_type, org, repo): - return Response(content="This repository is forbidden by the mirror. ", status_code=403) + return Response( + content="This repository is forbidden by the mirror. ", status_code=403 + ) try: - if not await check_commit_hf(app, repo_type, org, repo, commit): - return Response(content="This repository is not accessible. ", status_code=404) + if not app.app_settings.config.offline and not await check_commit_hf(app, repo_type, org, repo, commit): + return Response( + content="This repository is not accessible. ", status_code=404 + ) commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) - generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path=file_path, method="GET", request=request) + if commit_sha is None: + return Response( + content="This repository is not accessible. ", status_code=404 + ) + generator = await file_get_generator( + app, + repo_type, + org, + repo, + commit_sha, + file_path=file_path, + method="GET", + request=request, + ) status_code = await generator.__anext__() headers = await generator.__anext__() return StreamingResponse(generator, headers=headers, status_code=status_code) except httpx.ConnectTimeout: + traceback.print_exc() return Response(status_code=504) + +@app.get("/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path:path}") +async def file_get3(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str): + return await file_get_common( + repo_type=repo_type, + org=org, + repo=repo, + commit=commit, + file_path=file_path, + request=request, + ) + @app.get("/{org_or_repo_type}/{repo_name}/resolve/{commit}/{file_path:path}") async def file_get2(org_or_repo_type: str, repo_name: str, commit: str, file_path: str, request: Request): if org_or_repo_type in ["models", "datasets", "spaces"]: @@ -202,18 +271,14 @@ async def file_get2(org_or_repo_type: str, repo_name: str, commit: str, file_pat repo_type: str = "models" org, repo = org_or_repo_type, repo_name - if not await check_proxy_rules_hf(app, repo_type, org, repo): - return Response(content="This repository is forbidden by the mirror. ", status_code=403) - try: - if org is not None and not await check_commit_hf(app, repo_type, org, repo, commit): - return Response(content="This repository is not accessible. ", status_code=404) - commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) - generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path=file_path, method="GET", request=request) - status_code = await generator.__anext__() - headers = await generator.__anext__() - return StreamingResponse(generator, headers=headers, status_code=status_code) - except httpx.ConnectTimeout: - return Response(status_code=504) + return await file_get_common( + repo_type=repo_type, + org=org, + repo=repo, + commit=commit, + file_path=file_path, + request=request, + ) @app.get("/{org_repo}/resolve/{commit}/{file_path:path}") async def file_get(org_repo: str, commit: str, file_path: str, request: Request): @@ -222,18 +287,14 @@ async def file_get(org_repo: str, commit: str, file_path: str, request: Request) if org is None and repo is None: return Response(content="This repository is not accessible.", status_code=404) - if not await check_proxy_rules_hf(app, repo_type, org, repo): - return Response(content="This repository is forbidden by the mirror. ", status_code=403) - try: - if not await check_commit_hf(app, repo_type, org, repo, commit): - return Response(content="This repository is not accessible. ", status_code=404) - commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) - generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path=file_path, method="GET", request=request) - status_code = await generator.__anext__() - headers = await generator.__anext__() - return StreamingResponse(generator, headers=headers, status_code=status_code) - except httpx.ConnectTimeout: - return Response(status_code=504) + return await file_get_common( + repo_type=repo_type, + org=org, + repo=repo, + commit=commit, + file_path=file_path, + request=request, + ) @app.get("/{org_repo}/{hash_file}") @app.get("/{repo_type}/{org_repo}/{hash_file}") diff --git a/olah/utils/logging.py b/olah/utils/logging.py index d1904ed..8742b0e 100644 --- a/olah/utils/logging.py +++ b/olah/utils/logging.py @@ -11,6 +11,7 @@ import logging.handlers import os import platform +import re import sys from typing import AsyncGenerator, Generator import warnings @@ -29,6 +30,15 @@ handler = None +# Define a custom formatter without color codes +class NoColorFormatter(logging.Formatter): + color_pattern = re.compile(r'\x1b[^m]*m') # Regex pattern to match color codes + + def format(self, record): + message = super().format(record) + # Remove color codes from the log message + message = self.color_pattern.sub('', message) + return message def build_logger(logger_name, logger_filename, logger_dir=DEFAULT_LOGGER_DIR) -> logging.Logger: global handler @@ -38,6 +48,11 @@ def build_logger(logger_name, logger_filename, logger_dir=DEFAULT_LOGGER_DIR) -> datefmt="%Y-%m-%d %H:%M:%S", ) + nocolor_formatter = NoColorFormatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + # Set the format of root handlers if logging.getLogger().handlers is None or len(logging.getLogger().handlers) == 0: if sys.version_info[1] >= 9: @@ -74,7 +89,7 @@ def build_logger(logger_name, logger_filename, logger_dir=DEFAULT_LOGGER_DIR) -> handler = logging.handlers.TimedRotatingFileHandler( filename, when="H", utc=True, encoding="utf-8" ) - handler.setFormatter(formatter) + handler.setFormatter(nocolor_formatter) handler.namer = lambda name: name.replace(".log", "") + ".log" for name, item in logging.root.manager.loggerDict.items(): diff --git a/olah/utils/olah_cache.py b/olah/utils/olah_cache.py index 0924506..39e0d3b 100644 --- a/olah/utils/olah_cache.py +++ b/olah/utils/olah_cache.py @@ -8,7 +8,7 @@ import os import struct import threading -from typing import BinaryIO, Optional +from typing import BinaryIO, Dict, Optional from .bitset import Bitset CURRENT_OLAH_CACHE_VERSION = 8 @@ -109,8 +109,13 @@ def __init__(self, path: str, block_size: int = DEFAULT_BLOCK_SIZE) -> None: self.path: Optional[str] = path self.header: Optional[OlahCacheHeader] = None self.is_open: bool = False + + # Lock + self._header_lock = threading.Lock() - self.header_lock = threading.Lock() + # Cache + self._blocks_read_cache: Dict[int, bytes] = {} + self._prefech_blocks: int = 16 self.open(path, block_size=block_size) @@ -123,12 +128,12 @@ def open(self, path: str, block_size: int = DEFAULT_BLOCK_SIZE): if self.is_open: raise Exception("This file has been open.") if os.path.exists(path): - with self.header_lock: + with self._header_lock: with open(path, "rb") as f: f.seek(0) self.header = OlahCacheHeader.read(f) else: - with self.header_lock: + with self._header_lock: # Create new file with open(path, "wb") as f: f.seek(0) @@ -149,48 +154,57 @@ def close(self): self.path = None self.header = None + self._blocks_read_cache.clear() + self.is_open = False def _flush_header(self): - with self.header_lock: + with self._header_lock: with open(self.path, "rb+") as f: f.seek(0) self.header.write(f) def _get_file_size(self) -> int: - with self.header_lock: + with self._header_lock: file_size = self.header.file_size return file_size def _get_block_number(self) -> int: - with self.header_lock: + with self._header_lock: block_number = self.header.block_number return block_number def _get_block_size(self) -> int: - with self.header_lock: + with self._header_lock: block_size = self.header.block_size return block_size def _get_header_size(self) -> int: - with self.header_lock: + with self._header_lock: header_size = self.header.get_header_size() return header_size def _resize_header(self, block_num: int, file_size: int): - with self.header_lock: + with self._header_lock: self.header._block_number = block_num self.header._file_size = file_size self.header._valid_header() def _set_header_block(self, block_index: int): - with self.header_lock: + with self._header_lock: self.header.block_mask.set(block_index) def _test_header_block(self, block_index: int): - with self.header_lock: + with self._header_lock: result = self.header.block_mask.test(block_index) return result + + def _pad_block(self, raw_block: bytes): + if len(raw_block) < self._get_block_size(): + block = raw_block + b"\x00" * (self._get_block_size() - len(raw_block)) + else: + block = raw_block + return block def flush(self): if not self.is_open: @@ -206,6 +220,10 @@ def read_block(self, block_index: int) -> Optional[bytes]: if block_index >= self._get_block_number(): raise Exception("Invalid block index.") + + # Check Cache + if block_index in self._blocks_read_cache: + return self._blocks_read_cache[block_index] if not self.has_block(block_index=block_index): return None @@ -214,10 +232,17 @@ def read_block(self, block_index: int) -> Optional[bytes]: with open(self.path, "rb") as f: f.seek(offset) raw_block = f.read(self._get_block_size()) - if len(raw_block) < self._get_block_size(): - block = raw_block + b"\x00" * (self._get_block_size() - len(raw_block)) - else: - block = raw_block + # Prefetch blocks + for block_offset in range(1, self._prefech_blocks + 1): + if block_index + block_offset >= self._get_block_number(): + break + if not self.has_block(block_index=block_index): + self._blocks_read_cache[block_index + block_offset] = None + else: + prefetch_raw_block = f.read(self._get_block_size()) + self._blocks_read_cache[block_index + block_offset] = self._pad_block(prefetch_raw_block) + + block = self._pad_block(raw_block) return block def write_block(self, block_index: int, block_bytes: bytes) -> None: @@ -241,6 +266,10 @@ def write_block(self, block_index: int, block_bytes: bytes) -> None: self._set_header_block(block_index) self._flush_header() + + # Clear Cache + if block_index in self._blocks_read_cache: + del self._blocks_read_cache[block_index] def _resize_file_size(self, file_size: int): if not self.is_open: diff --git a/olah/utils/url_utils.py b/olah/utils/url_utils.py index 2227122..85d8452 100644 --- a/olah/utils/url_utils.py +++ b/olah/utils/url_utils.py @@ -73,11 +73,12 @@ async def get_newest_commit_hf(app, repo_type: Optional[Literal["models", "datas async def get_commit_hf_offline(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str, commit: str) -> str: repos_path = app.app_settings.repos_path save_path = get_meta_save_path(repos_path, repo_type, org, repo, commit) - - with open(save_path, "r", encoding="utf-8") as f: - obj = json.loads(f.read()) - - return obj["sha"] + if os.path.exists(save_path): + with open(save_path, "r", encoding="utf-8") as f: + obj = json.loads(f.read()) + return obj["sha"] + else: + return None async def get_commit_hf(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str, commit: str) -> str: org_repo = get_org_repo(org, repo) diff --git a/pyproject.toml b/pyproject.toml index 2f3e246..0f8342b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,8 +13,8 @@ classifiers = [ "License :: OSI Approved :: MIT License", ] dependencies = [ - "fastapi", "httpx", "numpy", "pydantic<=2.8.2", "requests", "toml", - "rich>=10.0.0", "shortuuid", "uvicorn", "tenacity>=8.2.2", "pytz" + "fastapi", "fastapi-utils", "httpx", "numpy", "pydantic<=2.8.2", "requests", "toml", + "rich>=10.0.0", "shortuuid", "uvicorn", "tenacity>=8.2.2", "pytz", "cachetools" ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index 177b67a..dc90e65 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ fastapi==0.111.0 +fastapi-utils==0.7.0 httpx==0.27.0 pydantic==2.8.2 toml==0.10.2 huggingface_hub==0.23.4 -pytest==8.2.2 \ No newline at end of file +pytest==8.2.2 +cachetools==5.4.0 \ No newline at end of file