diff --git a/routers/http.py b/routers/http.py index e0e69bd..f9afa7e 100644 --- a/routers/http.py +++ b/routers/http.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends from pydantic import BaseModel from schema import HTTPPayload, QuestionModel, ResponseModel -from services.api_key import api_key_header +from services.api_key import get_api_key from utils.persist_payload import PersistPayload from worker.tasks import ask_question_auto_search @@ -17,7 +17,7 @@ class RequestPayload(BaseModel): router = APIRouter() -@router.post("/ask", dependencies=[Depends(api_key_header)]) +@router.post("/ask", dependencies=[Depends(get_api_key)]) async def ask(payload: RequestPayload): query = payload.question.message community_id = payload.communityId @@ -37,7 +37,7 @@ async def ask(payload: RequestPayload): return {"id": task.id} -@router.get("/status", dependencies=[Depends(api_key_header)]) +@router.get("/status", dependencies=[Depends(get_api_key)]) async def status(task_id: str): task = AsyncResult(task_id) if task.status == "SUCCESS": diff --git a/services/api_key.py b/services/api_key.py index a7c60c2..9e522f2 100644 --- a/services/api_key.py +++ b/services/api_key.py @@ -10,6 +10,24 @@ async def get_api_key(api_key_header: str = Security(api_key_header)): + """ + Dependency function to validate API key + + Parameters + ------------- + api_key_header : str + the api key passed to the header + + Raises + ------ + HTTPException + If API key is missing or invalid + + Returns + ------- + api_key_header : str + The validated API key + """ validator = ValidateAPIKey() if not api_key_header: @@ -17,7 +35,8 @@ async def get_api_key(api_key_header: str = Security(api_key_header)): status_code=HTTP_401_UNAUTHORIZED, detail="No API key provided" ) - if not validator.validate(api_key_header): + valid = await validator.validate(api_key_header) + if not valid: raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid API key") return api_key_header @@ -29,7 +48,7 @@ def __init__(self) -> None: self.db = "hivemind" self.tokens_collection = "tokens" - def validate(self, api_key: str) -> bool: + async def validate(self, api_key: str) -> bool: """ check if the api key is available in mongodb or not diff --git a/tests/integration/test_validate_token.py b/tests/integration/test_validate_token.py index b4771c1..ccb121a 100644 --- a/tests/integration/test_validate_token.py +++ b/tests/integration/test_validate_token.py @@ -1,30 +1,50 @@ -from unittest import TestCase +from unittest import IsolatedAsyncioTestCase from services.api_key import ValidateAPIKey from utils.mongo import MongoSingleton -class TestValidateToken(TestCase): - def setUp(self) -> None: +class TestValidateToken(IsolatedAsyncioTestCase): + async def asyncSetUp(self) -> None: + """ + Set up test case with a test database + """ self.client = MongoSingleton.get_instance().get_client() self.validator = ValidateAPIKey() - # changing the db so not to overlap with the right ones + # Using test database to avoid affecting production data self.validator.db = "hivemind_test" self.validator.tokens_collection = "tokens_test" - self.client.drop_database(self.validator.db) + # Clean start for each test + self.clean_database() + + async def asyncTearDown(self) -> None: + """ + Clean up test database after each test + """ + self.clean_database() - def tearDown(self) -> None: + def clean_database(self) -> None: + """ + Helper method to clean the test database + """ self.client.drop_database(self.validator.db) - def test_no_token_available(self): + async def test_no_token_available(self): + """ + Test validation when no tokens exist in database + """ api_key = "1234" - valid = self.validator.validate(api_key) + valid = await self.validator.validate(api_key) self.assertEqual(valid, False) - def test_no_matching_token_available(self): + async def test_no_matching_token_available(self): + """ + Test validation when tokens exist but none match + """ + # Insert test tokens - no await needed as this is synchronous self.client[self.validator.db][self.validator.tokens_collection].insert_many( [ { @@ -44,13 +64,18 @@ def test_no_matching_token_available(self): }, ] ) + api_key = "1234" - valid = self.validator.validate(api_key) + valid = await self.validator.validate(api_key) self.assertEqual(valid, False) - def test_single_token_available(self): + async def test_single_token_available(self): + """ + Test validation when matching token exists + """ api_key = "1234" + self.client[self.validator.db][self.validator.tokens_collection].insert_many( [ { @@ -70,6 +95,15 @@ def test_single_token_available(self): }, ] ) - valid = self.validator.validate(api_key) + + valid = await self.validator.validate(api_key) self.assertEqual(valid, True) + + async def test_validation_with_empty_api_key(self): + """ + Test validation with empty API key + """ + valid = await self.validator.validate("") + + self.assertEqual(valid, False)