Skip to content

Commit

Permalink
check hf timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jul 18, 2024
1 parent c382c33 commit 7933abe
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 46 deletions.
70 changes: 26 additions & 44 deletions olah/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import hashlib
import json
import os
from typing import Dict, Literal, Optional
from typing import Dict, Literal, Optional, Tuple
from fastapi import Request

from requests.structures import CaseInsensitiveDict
Expand Down Expand Up @@ -37,7 +37,7 @@ async def _file_full_header(
url: str,
headers: Dict[str, str],
allow_cache: bool,
):
) -> Tuple[int, Dict[str, str], bytes]:
if os.path.exists(head_path):
with open(head_path, "r", encoding="utf-8") as f:
response_headers = json.loads(f.read())
Expand All @@ -53,9 +53,21 @@ async def _file_full_header(
timeout=WORKER_API_TIMEOUT,
)
response_headers_dict = {k.lower(): v for k, v in response.headers.items()}
if allow_cache and method.lower() == "head" and response.status_code == 200:
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(response_headers_dict, ensure_ascii=False))
if allow_cache and method.lower() == "head":
if response.status_code == 200:
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(response_headers_dict, ensure_ascii=False))
elif response.status_code >= 300 and response.status_code <= 399:
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(response_headers_dict, ensure_ascii=False))
from_url = urlparse(url)
parsed_url = urlparse(response.headers["location"])
if len(parsed_url.netloc) != 0:
new_loc = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response.headers["location"]))
response_headers_dict["location"] = new_loc
else:
raise Exception(f"Unexpected HTTP status code {response.status_code}")
return response.status_code, response_headers_dict, response.content
else:
response_headers_dict = {}

Expand All @@ -72,7 +84,9 @@ async def _file_full_header(
new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_SIZE.lower(), "")
if "etag" in response_headers_dict:
new_headers["etag"] = response_headers_dict["etag"]
return new_headers
if "location" in response_headers_dict:
new_headers["location"] = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response_headers_dict["location"]))
return 200, new_headers, b""

async def _get_file_block_from_cache(cache_file: OlahCache, block_index: int):
raw_block = cache_file.read_block(block_index)
Expand Down Expand Up @@ -208,47 +222,10 @@ async def _file_realtime_stream(
hf_url = urljoin(app.app_settings.config.hf_lfs_url_base(), get_url_tail(url))
else:
hf_url = url

# Handle Redirection
if not app.app_settings.config.offline:
async with httpx.AsyncClient() as client:
response = await client.request(
method="HEAD",
url=hf_url,
headers=request_headers,
timeout=WORKER_API_TIMEOUT,
)

if response.status_code >= 300 and response.status_code <= 399:
from_url = urlparse(url)
parsed_url = urlparse(response.headers["location"])
new_headers = {k.lower():v for k, v in response.headers.items()}
if len(parsed_url.netloc) != 0:
new_loc = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(response.headers["location"]))
new_headers["location"] = new_loc

if allow_cache:
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(new_headers, ensure_ascii=False))

yield response.status_code
yield new_headers
yield response.content
return
else:
if os.path.exists(head_path):
with open(head_path, "r", encoding="utf-8") as f:
head_content = json.loads(f.read())

if "location" in head_content:
yield 302
yield head_content
yield b""
return

async with httpx.AsyncClient() as client:
# redirect_loc = await _get_redirected_url(client, method, url, request_headers)
head_info = await _file_full_header(
status_code, head_info, content = await _file_full_header(
app=app,
save_path=save_path,
head_path=head_path,
Expand All @@ -258,6 +235,11 @@ async def _file_realtime_stream(
headers=request_headers,
allow_cache=allow_cache,
)
if status_code != 200:
yield status_code
yield head_info
yield content
return
file_size = int(head_info["content-length"])
response_headers = {k: v for k,v in head_info.items()}
if "range" in request_headers:
Expand Down
51 changes: 49 additions & 2 deletions olah/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

from contextlib import asynccontextmanager
import os
import argparse
import traceback
from typing import Annotated, Optional, Union
from urllib.parse import urljoin
from fastapi import FastAPI, Header, Request
from fastapi.responses import HTMLResponse, StreamingResponse, Response
from fastapi_utils.tasks import repeat_every
import httpx
from pydantic import BaseSettings
from olah.configs import OlahConfig
Expand All @@ -22,14 +24,59 @@

from olah.utils.logging import build_logger

app = FastAPI(debug=False)
# ======================
# Utilities
# ======================
async def check_connection(url: str) -> bool:
try:
async with httpx.AsyncClient() as client:
response = await client.request(
method="HEAD",
url=url,
timeout=10,
)
if response.status_code != 200:
return False
else:
return True
except httpx.TimeoutException:
return False


@repeat_every(seconds=60)
async def check_hf_connection() -> None:
if app.app_settings.config.offline:
return
hf_online_status = await check_connection(
"https://huggingface.co/datasets/Salesforce/wikitext/resolve/main/.gitattributes"
)
if not hf_online_status:
logger.info(
"Cannot reach Huggingface Official Site. Trying to connect hf-mirror."
)
hf_mirror_online_status = await check_connection(
"https://hf-mirror.com/datasets/Salesforce/wikitext/resolve/main/.gitattributes"
)
if not hf_online_status and not hf_mirror_online_status:
logger.error("Failed to reach Huggingface Official Site.")
logger.error("Failed to reach hf-mirror Site.")


@asynccontextmanager
async def lifespan(app: FastAPI):
await check_hf_connection()
yield

# ======================
# Application
# ======================
app = FastAPI(lifespan=lifespan, debug=False)

class AppSettings(BaseSettings):
# The address of the model controller.
config: OlahConfig = OlahConfig()
repos_path: str = "./repos"


# ======================
# API Hooks
# ======================
Expand Down

0 comments on commit 7933abe

Please sign in to comment.