diff --git a/routers/http.py b/routers/http.py index d53e790..e0e69bd 100644 --- a/routers/http.py +++ b/routers/http.py @@ -1,9 +1,10 @@ import logging from celery.result import AsyncResult -from fastapi import APIRouter +from fastapi import APIRouter, Depends from pydantic import BaseModel from schema import HTTPPayload, QuestionModel, ResponseModel +from services.api_key import api_key_header from utils.persist_payload import PersistPayload from worker.tasks import ask_question_auto_search @@ -16,7 +17,7 @@ class RequestPayload(BaseModel): router = APIRouter() -@router.post("/ask") +@router.post("/ask", dependencies=[Depends(api_key_header)]) async def ask(payload: RequestPayload): query = payload.question.message community_id = payload.communityId @@ -36,7 +37,7 @@ async def ask(payload: RequestPayload): return {"id": task.id} -@router.get("/status") +@router.get("/status", dependencies=[Depends(api_key_header)]) async def status(task_id: str): task = AsyncResult(task_id) if task.status == "SUCCESS": diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/services/api_key.py b/services/api_key.py new file mode 100644 index 0000000..a7c60c2 --- /dev/null +++ b/services/api_key.py @@ -0,0 +1,51 @@ +from fastapi import HTTPException, Security +from fastapi.security.api_key import APIKeyHeader +from starlette.status import HTTP_401_UNAUTHORIZED +from utils.mongo import MongoSingleton + +# List of valid API keys - in production, this should be stored securely +API_KEY_NAME = "X-API-Key" + +api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) + + +async def get_api_key(api_key_header: str = Security(api_key_header)): + validator = ValidateAPIKey() + + if not api_key_header: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, detail="No API key provided" + ) + + if not validator.validate(api_key_header): + raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid API key") + + return api_key_header + + +class ValidateAPIKey: + def __init__(self) -> None: + self.client = MongoSingleton.get_instance().get_client() + self.db = "hivemind" + self.tokens_collection = "tokens" + + def validate(self, api_key: str) -> bool: + """ + check if the api key is available in mongodb or not + + Parameters + ------------ + api_key : str + the provided key to check in db + + Returns + --------- + valid : bool + if the key was available in mongo collection, then return True + else, the token is not valid and return False + """ + document = self.client[self.db][self.tokens_collection].find_one( + {"token": api_key} + ) + + return True if document else False diff --git a/tests/integration/test_validate_token.py b/tests/integration/test_validate_token.py new file mode 100644 index 0000000..b4771c1 --- /dev/null +++ b/tests/integration/test_validate_token.py @@ -0,0 +1,75 @@ +from unittest import TestCase + +from services.api_key import ValidateAPIKey +from utils.mongo import MongoSingleton + + +class TestValidateToken(TestCase): + def setUp(self) -> None: + self.client = MongoSingleton.get_instance().get_client() + self.validator = ValidateAPIKey() + + # changing the db so not to overlap with the right ones + self.validator.db = "hivemind_test" + self.validator.tokens_collection = "tokens_test" + + self.client.drop_database(self.validator.db) + + def tearDown(self) -> None: + self.client.drop_database(self.validator.db) + + def test_no_token_available(self): + api_key = "1234" + valid = self.validator.validate(api_key) + + self.assertEqual(valid, False) + + def test_no_matching_token_available(self): + self.client[self.validator.db][self.validator.tokens_collection].insert_many( + [ + { + "id": 1, + "token": "1111", + "options": {}, + }, + { + "id": 2, + "token": "2222", + "options": {}, + }, + { + "id": 3, + "token": "3333", + "options": {}, + }, + ] + ) + api_key = "1234" + valid = self.validator.validate(api_key) + + self.assertEqual(valid, False) + + def test_single_token_available(self): + api_key = "1234" + self.client[self.validator.db][self.validator.tokens_collection].insert_many( + [ + { + "id": 1, + "token": api_key, + "options": {}, + }, + { + "id": 2, + "token": "2222", + "options": {}, + }, + { + "id": 3, + "token": "3333", + "options": {}, + }, + ] + ) + valid = self.validator.validate(api_key) + + self.assertEqual(valid, True)