diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index ec6b2c01..0b119231 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -1,5 +1,6 @@ """Api middleware.""" +import contextlib import re import typing from http.client import HTTP_PORT, HTTPS_PORT @@ -44,6 +45,10 @@ def __init__( ) +_PROTO_HEADER_REGEX = re.compile(r"proto=(?Phttp(s)?)") +_HOST_HEADER_REGEX = re.compile(r"host=(?P[\w.-]+)(:(?P\d{1,5}))?") + + class ProxyHeaderMiddleware: """Account for forwarding headers when deriving base URL. @@ -92,25 +97,20 @@ def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str]: if forwarded := self._get_header_value_by_name(scope, "forwarded"): for proxy in forwarded.split(","): - if (proto_expr := re.search(r"proto=(?Phttp(s)?)", proxy)) and ( - host_expr := re.search( - r"host=(?P[\w.-]+)(:(?P\d{1,5}))?", proxy - ) + if (proto_expr := _PROTO_HEADER_REGEX.search(proxy)) and ( + host_expr := _HOST_HEADER_REGEX.search(proxy) ): - proto = proto_expr.groupdict()["proto"] - domain = host_expr.groupdict()["host"] - port_str = host_expr.groupdict().get("port", None) + proto = proto_expr.group("proto") + domain = host_expr.group("host") + port_str = host_expr.group("port") # None if not present in the match else: domain = self._get_header_value_by_name(scope, "x-forwarded-host", domain) proto = self._get_header_value_by_name(scope, "x-forwarded-proto", proto) port_str = self._get_header_value_by_name(scope, "x-forwarded-port", port) - try: + with contextlib.suppress(ValueError): # ignore ports that are not valid integers port = int(port_str) if port_str is not None else port - except ValueError: - # ignore ports that are not valid integers - pass return (proto, domain, port)