Skip to content

Commit

Permalink
Merge pull request #96 from TogetherCrew/feat/93-add-api-key
Browse files Browse the repository at this point in the history
fix: made API token validation async!
  • Loading branch information
amindadgar authored Nov 12, 2024
2 parents 7260171 + 42604d5 commit 4cd4702
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 17 deletions.
6 changes: 3 additions & 3 deletions routers/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from schema import HTTPPayload, QuestionModel, ResponseModel
from services.api_key import api_key_header
from services.api_key import get_api_key
from utils.persist_payload import PersistPayload
from worker.tasks import ask_question_auto_search

Expand All @@ -17,7 +17,7 @@ class RequestPayload(BaseModel):
router = APIRouter()


@router.post("/ask", dependencies=[Depends(api_key_header)])
@router.post("/ask", dependencies=[Depends(get_api_key)])
async def ask(payload: RequestPayload):
query = payload.question.message
community_id = payload.communityId
Expand All @@ -37,7 +37,7 @@ async def ask(payload: RequestPayload):
return {"id": task.id}


@router.get("/status", dependencies=[Depends(api_key_header)])
@router.get("/status", dependencies=[Depends(get_api_key)])
async def status(task_id: str):
task = AsyncResult(task_id)
if task.status == "SUCCESS":
Expand Down
23 changes: 21 additions & 2 deletions services/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,33 @@


async def get_api_key(api_key_header: str = Security(api_key_header)):
"""
Dependency function to validate API key
Parameters
-------------
api_key_header : str
the api key passed to the header
Raises
------
HTTPException
If API key is missing or invalid
Returns
-------
api_key_header : str
The validated API key
"""
validator = ValidateAPIKey()

if not api_key_header:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="No API key provided"
)

if not validator.validate(api_key_header):
valid = await validator.validate(api_key_header)
if not valid:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid API key")

return api_key_header
Expand All @@ -29,7 +48,7 @@ def __init__(self) -> None:
self.db = "hivemind"
self.tokens_collection = "tokens"

def validate(self, api_key: str) -> bool:
async def validate(self, api_key: str) -> bool:
"""
check if the api key is available in mongodb or not
Expand Down
58 changes: 46 additions & 12 deletions tests/integration/test_validate_token.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,50 @@
from unittest import TestCase
from unittest import IsolatedAsyncioTestCase

from services.api_key import ValidateAPIKey
from utils.mongo import MongoSingleton


class TestValidateToken(TestCase):
def setUp(self) -> None:
class TestValidateToken(IsolatedAsyncioTestCase):
async def asyncSetUp(self) -> None:
"""
Set up test case with a test database
"""
self.client = MongoSingleton.get_instance().get_client()
self.validator = ValidateAPIKey()

# changing the db so not to overlap with the right ones
# Using test database to avoid affecting production data
self.validator.db = "hivemind_test"
self.validator.tokens_collection = "tokens_test"

self.client.drop_database(self.validator.db)
# Clean start for each test
self.clean_database()

async def asyncTearDown(self) -> None:
"""
Clean up test database after each test
"""
self.clean_database()

def tearDown(self) -> None:
def clean_database(self) -> None:
"""
Helper method to clean the test database
"""
self.client.drop_database(self.validator.db)

def test_no_token_available(self):
async def test_no_token_available(self):
"""
Test validation when no tokens exist in database
"""
api_key = "1234"
valid = self.validator.validate(api_key)
valid = await self.validator.validate(api_key)

self.assertEqual(valid, False)

def test_no_matching_token_available(self):
async def test_no_matching_token_available(self):
"""
Test validation when tokens exist but none match
"""
# Insert test tokens - no await needed as this is synchronous
self.client[self.validator.db][self.validator.tokens_collection].insert_many(
[
{
Expand All @@ -44,13 +64,18 @@ def test_no_matching_token_available(self):
},
]
)

api_key = "1234"
valid = self.validator.validate(api_key)
valid = await self.validator.validate(api_key)

self.assertEqual(valid, False)

def test_single_token_available(self):
async def test_single_token_available(self):
"""
Test validation when matching token exists
"""
api_key = "1234"

self.client[self.validator.db][self.validator.tokens_collection].insert_many(
[
{
Expand All @@ -70,6 +95,15 @@ def test_single_token_available(self):
},
]
)
valid = self.validator.validate(api_key)

valid = await self.validator.validate(api_key)

self.assertEqual(valid, True)

async def test_validation_with_empty_api_key(self):
"""
Test validation with empty API key
"""
valid = await self.validator.validate("")

self.assertEqual(valid, False)

0 comments on commit 4cd4702

Please sign in to comment.