Skip to content

Commit

Permalink
[FIX] Disable redirect slashes and remove trailing slashes from routes (
Browse files Browse the repository at this point in the history
#109)

* update fastapi for global disabling of redirect_slashes
- see fastapi/fastapi#3432

* remove trailing slash from route definitions

* update trailing slashes for routes in tests

* minor fix - update context of test requiring auth to be enabled

* test handling of trailing slashes in routes

* trust proxy headers from all remote IPs

* test root with and without trailing slash
  • Loading branch information
alyssadai authored Aug 2, 2024
1 parent 57132d9 commit 188acc1
Show file tree
Hide file tree
Showing 12 changed files with 118 additions and 57 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ COPY ./app /usr/src/app
# NB_API_PORT, representing the port on which the API will be exposed,
# is an environment variable that will always have a default value of 8000 when building the image
# but can be overridden when running the container.
ENTRYPOINT uvicorn app.main:app --proxy-headers --host 0.0.0.0 --port ${NB_API_PORT:-8000}
ENTRYPOINT uvicorn app.main:app --proxy-headers --forwarded-allow-ips=* --host 0.0.0.0 --port ${NB_API_PORT:-8000}
2 changes: 1 addition & 1 deletion app/api/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def get(
params["image_modal"] = image_modal

tasks = [
util.send_get_request(node_url + "query/", params)
util.send_get_request(node_url + "query", params)
for node_url in node_urls
]
responses = await asyncio.gather(*tasks, return_exceptions=True)
Expand Down
2 changes: 1 addition & 1 deletion app/api/routers/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
router = APIRouter(prefix="/nodes", tags=["nodes"])


@router.get("/")
@router.get("")
async def get_nodes():
"""Returns a dict of all available nodes apis where key is node URL and value is node name."""
return [
Expand Down
2 changes: 1 addition & 1 deletion app/api/routers/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# TODO: if our response model for fully successful vs. not fully successful responses grows more complex in the future,
# consider additionally using https://fastapi.tiangolo.com/advanced/additional-responses/#additional-response-with-model to document
# example responses for different status codes in the OpenAPI docs (less relevant for now since there is only one response model).
@router.get("/", response_model=CombinedQueryResponse)
@router.get("", response_model=CombinedQueryResponse)
async def get_query(
response: Response,
query: QueryModel = Depends(QueryModel),
Expand Down
1 change: 1 addition & 0 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ async def lifespan(app: FastAPI):
docs_url=None,
redoc_url=None,
lifespan=lifespan,
redirect_slashes=False,
)

app.add_middleware(
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ colorama==0.4.6
coverage==7.0.0
distlib==0.3.6
exceptiongroup==1.0.4
fastapi==0.95.2
fastapi==0.110.1
filelock==3.8.0
google-auth==2.32.0
h11==0.14.0
Expand Down Expand Up @@ -44,10 +44,10 @@ rpds-py==0.13.2
rsa==4.9
six==1.16.0
sniffio==1.3.0
starlette==0.27.0
starlette==0.37.2
toml==0.10.2
tomli==2.0.1
typing_extensions==4.4.0
typing_extensions==4.12.2
urllib3==2.2.0
uvicorn==0.20.0
virtualenv==20.16.7
26 changes: 25 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ def test_app():
yield client


@pytest.fixture
@pytest.fixture()
def enable_auth(monkeypatch):
"""Enable the authentication requirement for the API."""
monkeypatch.setattr("app.api.security.AUTH_ENABLED", True)


@pytest.fixture()
def disable_auth(monkeypatch):
"""
Disable the authentication requirement for the API to skip startup checks
Expand Down Expand Up @@ -62,3 +68,21 @@ async def _mock_httpx_get_with_connect_error(self, **kwargs):
raise httpx.ConnectError("Some connection error")

return _mock_httpx_get_with_connect_error


@pytest.fixture()
def mocked_single_matching_dataset_result():
"""Valid aggregate query result for a single matching dataset."""
return {
"dataset_uuid": "http://neurobagel.org/vocab/12345",
"dataset_name": "QPN",
"dataset_portal_uri": "https://rpq-qpn.ca/en/researchers-section/databases/",
"dataset_total_subjects": 200,
"num_matching_subjects": 5,
"records_protected": True,
"subject_data": "protected",
"image_modals": [
"http://purl.org/nidash/nidm#T1Weighted",
"http://purl.org/nidash/nidm#T2Weighted",
],
}
16 changes: 0 additions & 16 deletions tests/test_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,6 @@
from fastapi import status


def test_root(test_app, set_valid_test_federation_nodes):
"""Given a GET request to the root endpoint, Check for 200 status and expected content."""

response = test_app.get("/")

assert response.status_code == status.HTTP_200_OK
assert all(
substring in response.text
for substring in [
"Welcome to",
"Neurobagel",
'<a href="/docs">documentation</a>',
]
)


def test_partially_failed_terms_fetching_handled_gracefully(
test_app, monkeypatch, set_valid_test_federation_nodes, caplog
):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def mock_httpx_get(**kwargs):
monkeypatch.setattr(httpx, "get", mock_httpx_get)

with test_app:
response = test_app.get("/nodes/")
response = test_app.get("/nodes")
assert util.FEDERATION_NODES == {
"https://firstpublicnode.org/": "First Public Node",
"https://secondpublicnode.org/": "Second Public Node",
Expand Down Expand Up @@ -77,7 +77,7 @@ def mock_httpx_get(**kwargs):
monkeypatch.setattr(httpx, "get", mock_httpx_get)

with test_app:
response = test_app.get("/nodes/")
response = test_app.get("/nodes")
assert util.FEDERATION_NODES == {
"https://mylocalnode.org/": "Local Node"
}
Expand Down Expand Up @@ -123,7 +123,7 @@ def mock_httpx_get(**kwargs):

with pytest.warns(UserWarning) as w:
with test_app:
response = test_app.get("/nodes/")
response = test_app.get("/nodes")
assert util.FEDERATION_NODES == {
"https://firstpublicnode.org/": "First Public Node",
"https://secondpublicnode.org/": "Second Public Node",
Expand Down
32 changes: 7 additions & 25 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,6 @@ def mock_token():
return "Bearer foo"


@pytest.fixture()
def mocked_single_matching_dataset_result():
"""Valid aggregate query result for a single matching dataset."""
return {
"dataset_uuid": "http://neurobagel.org/vocab/12345",
"dataset_name": "QPN",
"dataset_portal_uri": "https://rpq-qpn.ca/en/researchers-section/databases/",
"dataset_total_subjects": 200,
"num_matching_subjects": 5,
"records_protected": True,
"subject_data": "protected",
"image_modals": [
"http://purl.org/nidash/nidm#T1Weighted",
"http://purl.org/nidash/nidm#T2Weighted",
],
}


def test_partial_node_failure_responses_handled_gracefully(
monkeypatch,
test_app,
Expand All @@ -47,7 +29,7 @@ def test_partial_node_failure_responses_handled_gracefully(
async def mock_httpx_get(self, **kwargs):
# The self parameter is necessary to match the signature of the method being mocked,
# which is a class method of the httpx.AsyncClient class (see https://www.python-httpx.org/api/#asyncclient).
if kwargs["url"] == "https://firstpublicnode.org/query/":
if kwargs["url"] == "https://firstpublicnode.org/query":
return httpx.Response(
status_code=200, json=[mocked_single_matching_dataset_result]
)
Expand All @@ -59,7 +41,7 @@ async def mock_httpx_get(self, **kwargs):
monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get)

response = test_app.get(
"/query/",
"/query",
headers={"Authorization": mock_token},
)

Expand Down Expand Up @@ -127,7 +109,7 @@ def test_partial_node_request_failures_handled_gracefully(
"""

async def mock_httpx_get(self, **kwargs):
if kwargs["url"] == "https://firstpublicnode.org/query/":
if kwargs["url"] == "https://firstpublicnode.org/query":
return httpx.Response(
status_code=200, json=[mocked_single_matching_dataset_result]
)
Expand All @@ -137,7 +119,7 @@ async def mock_httpx_get(self, **kwargs):
monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get)

response = test_app.get(
"/query/",
"/query",
headers={"Authorization": mock_token},
)

Expand Down Expand Up @@ -183,7 +165,7 @@ def test_all_nodes_failure_handled_gracefully(
)

response = test_app.get(
"/query/",
"/query",
headers={"Authorization": mock_token},
)

Expand Down Expand Up @@ -225,7 +207,7 @@ async def mock_httpx_get(self, **kwargs):
monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get)

response = test_app.get(
"/query/",
"/query",
headers={"Authorization": mock_token},
)

Expand Down Expand Up @@ -256,6 +238,6 @@ async def mock_httpx_get(self, **kwargs):

monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get)

response = test_app.get("/query/")
response = test_app.get("/query")

assert response.status_code == status.HTTP_200_OK
63 changes: 63 additions & 0 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import httpx
import pytest
from fastapi import status


@pytest.mark.parametrize(
"root_path",
["/", ""],
)
def test_root(test_app, set_valid_test_federation_nodes, root_path):
"""Given a GET request to the root endpoint, Check for 200 status and expected content."""

response = test_app.get(root_path, follow_redirects=False)

assert response.status_code == status.HTTP_200_OK
assert all(
substring in response.text
for substring in [
"Welcome to",
"Neurobagel",
'<a href="/docs">documentation</a>',
]
)


@pytest.mark.parametrize(
"valid_route",
["/query", "/query?min_age=20", "/nodes"],
)
def test_request_without_trailing_slash_not_redirected(
test_app,
monkeypatch,
set_valid_test_federation_nodes,
mocked_single_matching_dataset_result,
disable_auth,
valid_route,
):
"""Test that a request to a route without a / is not redirected to have a trailing slash."""

async def mock_httpx_get(self, **kwargs):
return httpx.Response(
status_code=200, json=[mocked_single_matching_dataset_result]
)

monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get)

response = test_app.get(valid_route, follow_redirects=False)
assert response.status_code == status.HTTP_200_OK


@pytest.mark.parametrize(
"invalid_route",
["/query/", "/query/?min_age=20", "/nodes/", "/attributes/nb:SomeClass/"],
)
def test_request_including_trailing_slash_fails(
test_app, disable_auth, invalid_route
):
"""
Test that a request to routes including a trailing slash, where none is expected,
is *not* redirected to exclude the slash, and returns a 404.
"""
response = test_app.get(invalid_route)
assert response.status_code == status.HTTP_404_NOT_FOUND
17 changes: 12 additions & 5 deletions tests/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@


def test_missing_client_id_raises_error_when_auth_enabled(
monkeypatch, test_app
monkeypatch, test_app, enable_auth
):
"""Test that a missing client ID raises an error on startup when authentication is enabled."""
# We're using what should be default values of CLIENT_ID and AUTH_ENABLED here
# (if the corresponding environment variables are unset),
# but we set the values explicitly here for clarity
monkeypatch.setattr("app.api.security.CLIENT_ID", None)
monkeypatch.setattr("app.api.security.AUTH_ENABLED", True)

with pytest.raises(ValueError) as exc_info:
with test_app:
Expand Down Expand Up @@ -52,12 +51,20 @@ def test_invalid_token_raises_error(invalid_token):
[{}, {"Authorization": ""}, {"badheader": "badvalue"}],
)
def test_query_with_malformed_auth_header_fails(
test_app, set_mock_verify_token, invalid_auth_header
test_app,
set_mock_verify_token,
enable_auth,
invalid_auth_header,
monkeypatch,
):
"""Test that a request to the /query route with a missing or malformed authorization header, fails ."""
"""
Test that when authentication is enabled, a request to the /query route with a
missing or malformed authorization header fails.
"""
monkeypatch.setattr("app.api.security.CLIENT_ID", "foo.id")

response = test_app.get(
"/query/",
"/query",
headers=invalid_auth_header,
)

Expand Down

0 comments on commit 188acc1

Please sign in to comment.