Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add authentication to /query route #104

Merged
merged 18 commits into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions app/api/routers/query.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
"""Router for query path operations."""

from fastapi import APIRouter, Depends, Response, status
from fastapi import APIRouter, Depends, HTTPException, Response, status
from fastapi.security import OAuth2

from .. import crud
from .. import crud, security
from ..models import CombinedQueryResponse, QueryModel
from ..security import verify_token

# from fastapi.security import open_id_connect_url


router = APIRouter(prefix="/query", tags=["query"])

# Adapted from info in https://github.com/tiangolo/fastapi/discussions/9137#discussioncomment-5157382
oauth2_scheme = OAuth2(
flows={
"implicit": {
"authorizationUrl": "https://accounts.google.com/o/oauth2/auth",
}
},
# Don't automatically error out when request is not authenticated, to support optional authentication
auto_error=False,
)
# NOTE: Can also explicitly use OpenID Connect because Google supports it - results in the same behavior as the OAuth2 scheme above.
# openid_connect_scheme = open_id_connect_url.OpenIdConnect(
# openIdConnectUrl="https://accounts.google.com/.well-known/openid-configuration"
# )


# We use the Response parameter below to change the status code of the response while still being able to validate the returned data using the response model.
# (see https://fastapi.tiangolo.com/advanced/response-change-status-code/ for more info).
Expand All @@ -16,9 +36,33 @@
# example responses for different status codes in the OpenAPI docs (less relevant for now since there is only one response model).
@router.get("/", response_model=CombinedQueryResponse)
async def get_query(
response: Response, query: QueryModel = Depends(QueryModel)
response: Response,
query: QueryModel = Depends(QueryModel),
token: str | None = Depends(oauth2_scheme),
):
"""When a GET request is sent, return list of dicts corresponding to subject-level metadata aggregated by dataset."""
# NOTE: Currently, when the request is unauthenticated (missing or malformed authorization header -> missing token),
# the default response is a 403 Forbidden error.
# This doesn't fully align with HTTP status code conventions:
# - 401 Unauthorized should be used when the client lacks authentication credentials
# - 403 Forbidden should be used when the client has been authenticated but lacks the required permissions
# If we really care about returning a 401 Unauthorized error, we can use auto_error=False
# when creating the OAuth2 object and raise a custom HTTPException.
# See also https://github.com/tiangolo/fastapi/discussions/9130
# if not token:
# raise HTTPException(
# status_code=status.HTTP_401_UNAUTHORIZED,
# detail="Not authenticated",
# headers={"WWW-Authenticate": "Bearer"},
# )
if security.AUTH_ENABLED:
if token is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authenticated",
)
verify_token(token)

response_dict = await crud.get(
query.min_age,
query.max_age,
Expand Down
43 changes: 43 additions & 0 deletions app/api/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os

from fastapi import HTTPException, status
from fastapi.security.utils import get_authorization_scheme_param
from google.auth.exceptions import GoogleAuthError
from google.auth.transport import requests
from google.oauth2 import id_token

AUTH_ENABLED = os.environ.get("NB_ENABLE_AUTH", "True").lower() == "true"
CLIENT_ID = os.environ.get("NB_QUERY_CLIENT_ID", None)


def check_client_id():
"""Check if the CLIENT_ID environment variable is set."""
# By default, if CLIENT_ID is not provided to verify_oauth2_token,
# Google will simply skip verifying the audience claim of ID tokens.
# This however can be a security risk, so we mandate that CLIENT_ID is set.
if AUTH_ENABLED and CLIENT_ID is None:
raise ValueError(
"Authentication has been enabled (NB_ENABLE_AUTH) but the environment variable NB_QUERY_CLIENT_ID is not set. "
"Please set NB_QUERY_CLIENT_ID to the Neurobagel query tool client ID, to verify the audience claim of ID tokens."
alyssadai marked this conversation as resolved.
Show resolved Hide resolved
)


def verify_token(token: str):
"""Verify the Google ID token. Raise an HTTPException if the token is invalid."""
# Adapted from https://developers.google.com/identity/gsi/web/guides/verify-google-id-token#python
try:
# Extract the token from the "Bearer" scheme
# (See https://github.com/tiangolo/fastapi/blob/master/fastapi/security/oauth2.py#L473-L485)
# TODO: Check also if scheme of token is "Bearer"?
_, param = get_authorization_scheme_param(token)
id_info = id_token.verify_oauth2_token(
param, requests.Request(), CLIENT_ID
)
# TODO: Remove print statement - just for testing
print("Token verified: ", id_info)
alyssadai marked this conversation as resolved.
Show resolved Hide resolved
except (GoogleAuthError, ValueError) as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Invalid token: {exc}",
headers={"WWW-Authenticate": "Bearer"},
) from exc
2 changes: 2 additions & 0 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .api import utility as util
from .api.routers import attributes, nodes, query
from .api.security import check_client_id

logger = logging.getLogger("nb-f-API")
stdout_handler = logging.StreamHandler()
Expand All @@ -26,6 +27,7 @@ async def lifespan(app: FastAPI):
"""
Collect and store locally defined and public node details for federation upon startup and clears the index upon shutdown.
"""
check_client_id()
await util.create_federation_node_index()
yield
util.FEDERATION_NODES.clear()
Expand Down
13 changes: 11 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
anyio==3.6.2
attrs==23.1.0
cachetools==5.3.3
certifi==2023.7.22
cfgv==3.3.1
charset-normalizer==3.3.2
click==8.1.3
colorama==0.4.6
coverage==7.0.0
distlib==0.3.6
exceptiongroup==1.0.4
fastapi==0.95.2
filelock==3.8.0
google-auth==2.32.0
h11==0.14.0
httpcore==0.16.2
httpx==0.23.1
Expand All @@ -22,23 +26,28 @@ orjson==3.8.6
packaging==21.3
pandas==1.5.2
platformdirs==2.5.4
pluggy==1.5.0
pluggy==1.0.0
pre-commit==2.20.0
pyasn1==0.6.0
pyasn1_modules==0.4.0
pydantic==1.10.2
pyparsing==3.0.9
pytest==8.2.1
pytest==7.2.0
pytest-asyncio==0.23.7
python-dateutil==2.8.2
pytz==2023.3.post1
PyYAML==6.0
referencing==0.31.1
requests==2.31.0
rfc3986==1.5.0
rpds-py==0.13.2
rsa==4.9
six==1.16.0
sniffio==1.3.0
starlette==0.27.0
toml==0.10.2
tomli==2.0.1
typing_extensions==4.4.0
urllib3==2.2.0
uvicorn==0.20.0
virtualenv==20.16.7
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ def test_app():
yield client


@pytest.fixture
def disable_auth(monkeypatch):
"""
Disable the authentication requirement for the API to skip startup checks
(for when the tested route does not require authentication).
"""
monkeypatch.setattr("app.api.security.AUTH_ENABLED", False)


@pytest.fixture(scope="function")
def set_valid_test_federation_nodes(monkeypatch):
"""Set two correctly formatted federation nodes for a test function (mocks the result of reading/parsing available public and local nodes on startup)."""
Expand Down
12 changes: 8 additions & 4 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
},
],
)
def test_nodes_discovery_endpoint(test_app, monkeypatch, local_nodes):
def test_nodes_discovery_endpoint(
test_app, monkeypatch, local_nodes, disable_auth
):
"""Test that a federation node index is correctly created from locally set and remote node lists."""

def mock_parse_nodes_as_dict(path):
Expand Down Expand Up @@ -59,7 +61,7 @@ def mock_httpx_get(**kwargs):


def test_failed_public_nodes_fetching_raises_warning(
test_app, monkeypatch, caplog
test_app, monkeypatch, disable_auth, caplog
):
"""Test that when request for remote list of public nodes fails, an informative warning is raised and the federation node index only includes local nodes."""

Expand Down Expand Up @@ -95,7 +97,7 @@ def mock_httpx_get(**kwargs):
assert warn_substr in caplog.text


def test_unset_local_nodes_raises_warning(test_app, monkeypatch):
def test_unset_local_nodes_raises_warning(test_app, monkeypatch, disable_auth):
"""Test that when no local nodes are set, an informative warning is raised and the federation node index only includes remote nodes."""

def mock_parse_nodes_as_dict(path):
Expand Down Expand Up @@ -166,7 +168,9 @@ def test_missing_local_nodes_file_does_not_raise_error(tmp_path):
assert util.parse_nodes_as_dict(expected_file_path) == {}


def test_no_available_nodes_raises_error(monkeypatch, test_app, caplog):
def test_no_available_nodes_raises_error(
monkeypatch, test_app, disable_auth, caplog
):
"""Test that when no local or remote nodes are available, an informative error is raised."""

def mock_parse_nodes_as_dict(path):
Expand Down
75 changes: 71 additions & 4 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,30 @@
from fastapi import status


@pytest.fixture()
def mock_token():
"""Create a mock token that is well-formed for testing purposes."""
return "Bearer foo"


@pytest.fixture()
def mock_verify_token():
"""Mock a successful token verification that does not raise any exceptions."""

def _verify_token(token):
return None

return _verify_token


@pytest.fixture()
def set_mock_verify_token(monkeypatch, mock_verify_token):
"""Set the verify_token function to a mock that does not raise any exceptions."""
monkeypatch.setattr(
"app.api.routers.query.verify_token", mock_verify_token
)


alyssadai marked this conversation as resolved.
Show resolved Hide resolved
@pytest.fixture()
def mocked_single_matching_dataset_result():
"""Valid aggregate query result for a single matching dataset."""
Expand All @@ -29,6 +53,8 @@ def test_partial_node_failure_responses_handled_gracefully(
test_app,
set_valid_test_federation_nodes,
mocked_single_matching_dataset_result,
mock_token,
set_mock_verify_token,
caplog,
):
"""
Expand All @@ -50,7 +76,10 @@ async def mock_httpx_get(self, **kwargs):

monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get)

response = test_app.get("/query/")
response = test_app.get(
"/query/",
headers={"Authorization": mock_token},
)
alyssadai marked this conversation as resolved.
Show resolved Hide resolved

assert response.status_code == status.HTTP_207_MULTI_STATUS
assert response.json() == {
Expand Down Expand Up @@ -104,6 +133,8 @@ def test_partial_node_request_failures_handled_gracefully(
test_app,
set_valid_test_federation_nodes,
mocked_single_matching_dataset_result,
mock_token,
set_mock_verify_token,
error_to_raise,
expected_node_message,
caplog,
Expand All @@ -123,7 +154,10 @@ async def mock_httpx_get(self, **kwargs):

monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get)

response = test_app.get("/query/")
response = test_app.get(
"/query/",
headers={"Authorization": mock_token},
)

assert response.status_code == status.HTTP_207_MULTI_STATUS

Expand Down Expand Up @@ -153,6 +187,8 @@ def test_all_nodes_failure_handled_gracefully(
monkeypatch,
test_app,
mock_failed_connection_httpx_get,
mock_token,
set_mock_verify_token,
set_valid_test_federation_nodes,
caplog,
):
Expand All @@ -164,7 +200,10 @@ def test_all_nodes_failure_handled_gracefully(
httpx.AsyncClient, "get", mock_failed_connection_httpx_get
)

response = test_app.get("/query/")
response = test_app.get(
"/query/",
headers={"Authorization": mock_token},
)

# We expect 3 logs here: one warning for each failed node, and one error for the overall failure
assert len(caplog.records) == 3
Expand All @@ -186,6 +225,8 @@ def test_all_nodes_success_handled_gracefully(
caplog,
set_valid_test_federation_nodes,
mocked_single_matching_dataset_result,
mock_token,
set_mock_verify_token,
):
"""
Test that when queries sent to all nodes succeed, the federation API response includes an overall success status and no errors.
Expand All @@ -201,7 +242,10 @@ async def mock_httpx_get(self, **kwargs):

monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get)

response = test_app.get("/query/")
response = test_app.get(
"/query/",
headers={"Authorization": mock_token},
)

assert response.status_code == status.HTTP_200_OK

Expand All @@ -210,3 +254,26 @@ async def mock_httpx_get(self, **kwargs):
assert response["errors"] == []
assert len(response["responses"]) == 2
assert "Requests to all nodes succeeded (2/2)" in caplog.text


def test_query_without_token_succeeds_when_auth_disabled(
monkeypatch,
test_app,
set_valid_test_federation_nodes,
mocked_single_matching_dataset_result,
disable_auth,
):
alyssadai marked this conversation as resolved.
Show resolved Hide resolved
"""
Test that when authentication is disabled, a federated query request without a token succeeds.
"""

async def mock_httpx_get(self, **kwargs):
return httpx.Response(
status_code=200, json=[mocked_single_matching_dataset_result]
)

monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get)

response = test_app.get("/query/")

assert response.status_code == status.HTTP_200_OK
Loading