Skip to content

Commit

Permalink
header auth field is None - bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Sep 6, 2024
1 parent 15a8466 commit e9777fa
Showing 1 changed file with 43 additions and 141 deletions.
184 changes: 43 additions & 141 deletions olah/proxy/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,125 +85,6 @@ def get_contiguous_ranges(
range_start_pos = end_pos
return ranges_and_cache_list

async def _file_full_header(
app,
save_path: str,
head_path: str,
client: httpx.AsyncClient,
method: str,
url: str,
headers: Dict[str, str],
allow_cache: bool,
) -> Tuple[int, Dict[str, str], bytes]:
assert method.lower() == "head"
if not app.app_settings.config.offline:
if os.path.exists(head_path):
cache_rq = await read_cache_request(head_path)
response_headers_dict = {
k.lower(): v for k, v in cache_rq["headers"].items()
}
if "location" in response_headers_dict:
parsed_url = urlparse(response_headers_dict["location"])
if len(parsed_url.netloc) != 0:
new_loc = urljoin(
app.app_settings.config.mirror_lfs_url_base(),
get_url_tail(response_headers_dict["location"]),
)
response_headers_dict["location"] = new_loc
return cache_rq["status_code"], response_headers_dict, cache_rq["content"]
else:
if "range" in headers:
headers.pop("range")
response = await client.request(
method=method,
url=url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
)
response_headers_dict = {k.lower(): v for k, v in response.headers.items()}
if allow_cache and method.lower() == "head":
if response.status_code == 200:
await write_cache_request(
head_path,
response.status_code,
response_headers_dict,
response.content,
)
elif response.status_code >= 300 and response.status_code <= 399:
from_url = urlparse(url)
parsed_url = urlparse(response.headers["location"])
if len(parsed_url.netloc) != 0:
new_loc = urljoin(
app.app_settings.config.mirror_lfs_url_base(),
get_url_tail(response.headers["location"]),
)
response_headers_dict["location"] = new_loc
# Redirect, add original location info
if check_url_has_param_name(
response_headers_dict["location"], ORIGINAL_LOC
):
raise Exception(f"Invalid field {ORIGINAL_LOC} in the url.")
else:
response_headers_dict["location"] = add_query_param(
response_headers_dict["location"],
ORIGINAL_LOC,
response.headers["location"],
)
await write_cache_request(
head_path,
response.status_code,
response_headers_dict,
response.content,
)
elif response.status_code == 403:
pass
elif response.status_code == 404:
pass
else:
raise Exception(
f"Unexpected HTTP status code {response.status_code}"
)
return response.status_code, response_headers_dict, response.content
else:
if os.path.exists(head_path):
cache_rq = await read_cache_request(head_path)
response_headers_dict = {
k.lower(): v for k, v in cache_rq["headers"].items()
}
else:
response_headers_dict = {}
cache_rq = {
"status_code": 200,
"headers": response_headers_dict,
"content": b"",
}

new_headers = {}
if "content-type" in response_headers_dict:
new_headers["content-type"] = response_headers_dict["content-type"]
if "content-length" in response_headers_dict:
new_headers["content-length"] = response_headers_dict["content-length"]
if HUGGINGFACE_HEADER_X_REPO_COMMIT.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = (
response_headers_dict.get(HUGGINGFACE_HEADER_X_REPO_COMMIT.lower(), "")
)
if HUGGINGFACE_HEADER_X_LINKED_ETAG.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_LINKED_ETAG.lower()] = (
response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_ETAG.lower(), "")
)
if HUGGINGFACE_HEADER_X_LINKED_SIZE.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE.lower()] = (
response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_SIZE.lower(), "")
)
if "etag" in response_headers_dict:
new_headers["etag"] = response_headers_dict["etag"]
if "location" in response_headers_dict:
new_headers["location"] = urljoin(
app.app_settings.config.mirror_lfs_url_base(),
get_url_tail(response_headers_dict["location"]),
)
return cache_rq["status_code"], new_headers, cache_rq["content"]


async def _get_file_range_from_cache(
cache_file: OlahCache, start_pos: int, end_pos: int
Expand Down Expand Up @@ -238,7 +119,8 @@ async def _get_file_range_from_remote(
end_pos: int,
):
headers = {}
headers["authorization"] = remote_info.headers.get("authorization", None)
if remote_info.headers.get("authorization", None) is not None:
headers["authorization"] = remote_info.headers.get("authorization", None)
headers["range"] = f"bytes={start_pos}-{end_pos - 1}"

chunk_bytes = 0
Expand Down Expand Up @@ -433,6 +315,32 @@ async def _file_chunk_head(
yield b""


async def _resource_etag(hf_url: str, authorization: Optional[str]=None, offline: bool = False) -> Optional[str]:
ret_etag = None
sha256_hash = hashlib.sha256()
sha256_hash.update(hf_url.encode("utf-8"))
content_hash = sha256_hash.hexdigest()
if offline:
ret_etag = f'"{content_hash[:32]}-10"'
else:
etag_headers = {}
if authorization is not None:
etag_headers["authorization"] = authorization
try:
async with httpx.AsyncClient() as client:
response = await client.request(
method="head",
url=hf_url,
headers=etag_headers,
timeout=WORKER_API_TIMEOUT,
)
if "etag" in response.headers:
ret_etag = response.headers["etag"]
else:
ret_etag = f'"{content_hash[:32]}-10"'
except httpx.TimeoutException:
ret_etag = None
return ret_etag
async def _file_realtime_stream(
app,
repo_type: Literal["models", "datasets", "spaces"],
Expand Down Expand Up @@ -531,28 +439,22 @@ async def _file_realtime_stream(
if commit is not None:
response_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = commit
# Create fake headers when offline mode
sha256_hash = hashlib.sha256()
sha256_hash.update(hf_url.encode("utf-8"))
content_hash = sha256_hash.hexdigest()
if app.app_settings.config.offline:
response_headers["etag"] = f'"{content_hash[:32]}-10"'
etag = await _resource_etag(
hf_url=hf_url,
authorization=request.headers.get("authorization", None),
offline=app.app_settings.config.offline,
)
response_headers["etag"] = etag

if etag is None:
error_response = error_proxy_timeout()
yield error_response.status_code
yield error_response.headers
yield error_response.body
return
else:
if method.lower() == "head":
async with httpx.AsyncClient() as client:
response = await client.request(
method="head",
url=hf_url,
headers={
"authorization": request.headers.get("authorization", None)
},
timeout=WORKER_API_TIMEOUT,
)
if "etag" in response.headers:
response_headers["etag"] = response.headers["etag"]
else:
response_headers["etag"] = f'"{content_hash[:32]}-10"'
yield 200
yield response_headers
yield 200
yield response_headers

async with httpx.AsyncClient() as client:
if method.lower() == "get":
Expand Down

0 comments on commit e9777fa

Please sign in to comment.