Skip to content

Commit

Permalink
Support multiple proxy servers in Forwarded header parsing (#782)
Browse files Browse the repository at this point in the history
* Support multiple proxy servers in Forwarded header parsing

* Update CHANGELOG

* use regex and simplify

* use integer regex

* Use compiled regexes

---------

Co-authored-by: vincentsarago <[email protected]>
  • Loading branch information
lukasbindreiter and vincentsarago authored Jan 17, 2025
1 parent 48c4dea commit 2a72400
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 22 deletions.
6 changes: 5 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

## [Unreleased]

## Changed
### Changed

* use `string` type instead of python `datetime.datetime` for datetime parameter in `BaseSearchGetRequest`, `ItemCollectionUri` and `BaseCollectionSearchGetRequest` GET models
* rename `filter` to `filter_expr` for `FilterExtensionGetRequest` and `FilterExtensionPostRequest` attributes to avoid conflict with python filter method

### Fixed

* Support multiple proxy servers in the `forwarded` header in `ProxyHeaderMiddleware` ([#782](https://github.com/stac-utils/stac-fastapi/pull/782))

## [3.0.5] - 2025-01-10

### Removed
Expand Down
41 changes: 20 additions & 21 deletions stac_fastapi/api/stac_fastapi/api/middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Api middleware."""

import contextlib
import re
import typing
from http.client import HTTP_PORT, HTTPS_PORT
Expand Down Expand Up @@ -44,6 +45,10 @@ def __init__(
)


_PROTO_HEADER_REGEX = re.compile(r"proto=(?P<proto>http(s)?)")
_HOST_HEADER_REGEX = re.compile(r"host=(?P<host>[\w.-]+)(:(?P<port>\d{1,5}))?")


class ProxyHeaderMiddleware:
"""Account for forwarding headers when deriving base URL.
Expand All @@ -68,11 +73,13 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
proto == "https" and port != HTTPS_PORT
):
port_suffix = f":{port}"

scope["headers"] = self._replace_header_value_by_name(
scope,
"host",
f"{domain}{port_suffix}",
)

await self.app(scope, receive, send)

def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str]:
Expand All @@ -87,31 +94,23 @@ def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str]:
else:
domain = header_host_parts[0]
port = None
forwarded = self._get_header_value_by_name(scope, "forwarded")
if forwarded is not None:
parts = forwarded.split(";")
for part in parts:
if len(part) > 0 and re.search("=", part):
key, value = part.split("=")
if key == "proto":
proto = value
elif key == "host":
host_parts = value.split(":")
domain = host_parts[0]
try:
port = int(host_parts[1]) if len(host_parts) == 2 else None
except ValueError:
# ignore ports that are not valid integers
pass

if forwarded := self._get_header_value_by_name(scope, "forwarded"):
for proxy in forwarded.split(","):
if (proto_expr := _PROTO_HEADER_REGEX.search(proxy)) and (
host_expr := _HOST_HEADER_REGEX.search(proxy)
):
proto = proto_expr.group("proto")
domain = host_expr.group("host")
port_str = host_expr.group("port") # None if not present in the match

else:
domain = self._get_header_value_by_name(scope, "x-forwarded-host", domain)
proto = self._get_header_value_by_name(scope, "x-forwarded-proto", proto)
port_str = self._get_header_value_by_name(scope, "x-forwarded-port", port)
try:
port = int(port_str) if port_str is not None else None
except ValueError:
# ignore ports that are not valid integers
pass

with contextlib.suppress(ValueError): # ignore ports that are not valid integers
port = int(port_str) if port_str is not None else port

return (proto, domain, port)

Expand Down
28 changes: 28 additions & 0 deletions stac_fastapi/api/tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,34 @@ def test_replace_header_value_by_name(
},
("https", "test", 1234),
),
(
{
"scheme": "http",
"server": ["testserver", 80],
"headers": [
(
b"forwarded",
# two proxy servers added an entry, we want to use the last one
b"proto=https;host=test:1234,proto=https;host=second-server:1111",
)
],
},
("https", "second-server", 1111),
),
(
{
"scheme": "http",
"server": ["testserver", 80],
"headers": [
(
b"forwarded",
# check when host and port are inverted
b"host=test:1234;proto=https",
)
],
},
("https", "test", 1234),
),
],
)
def test_get_forwarded_url_parts(
Expand Down

0 comments on commit 2a72400

Please sign in to comment.