Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: made API token validation async! #96

Merged
merged 3 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions routers/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from schema import HTTPPayload, QuestionModel, ResponseModel
from services.api_key import api_key_header
from services.api_key import get_api_key
from utils.persist_payload import PersistPayload
from worker.tasks import ask_question_auto_search

Expand All @@ -17,7 +17,7 @@ class RequestPayload(BaseModel):
router = APIRouter()


@router.post("/ask", dependencies=[Depends(api_key_header)])
@router.post("/ask", dependencies=[Depends(get_api_key)])
async def ask(payload: RequestPayload):
query = payload.question.message
community_id = payload.communityId
Expand All @@ -37,7 +37,7 @@ async def ask(payload: RequestPayload):
return {"id": task.id}


@router.get("/status", dependencies=[Depends(api_key_header)])
@router.get("/status", dependencies=[Depends(get_api_key)])
async def status(task_id: str):
task = AsyncResult(task_id)
if task.status == "SUCCESS":
Expand Down
23 changes: 21 additions & 2 deletions services/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,33 @@


async def get_api_key(api_key_header: str = Security(api_key_header)):
"""
Dependency function to validate API key

Parameters
-------------
api_key_header : str
the api key passed to the header

Raises
------
HTTPException
If API key is missing or invalid

Returns
-------
api_key_header : str
The validated API key
"""
validator = ValidateAPIKey()

if not api_key_header:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="No API key provided"
)

if not validator.validate(api_key_header):
valid = await validator.validate(api_key_header)
if not valid:
amindadgar marked this conversation as resolved.
Show resolved Hide resolved
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid API key")

return api_key_header
Expand All @@ -29,7 +48,7 @@ def __init__(self) -> None:
self.db = "hivemind"
self.tokens_collection = "tokens"

def validate(self, api_key: str) -> bool:
async def validate(self, api_key: str) -> bool:
"""
check if the api key is available in mongodb or not

Expand Down
58 changes: 46 additions & 12 deletions tests/integration/test_validate_token.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,50 @@
from unittest import TestCase
from unittest import IsolatedAsyncioTestCase

from services.api_key import ValidateAPIKey
from utils.mongo import MongoSingleton


class TestValidateToken(TestCase):
def setUp(self) -> None:
class TestValidateToken(IsolatedAsyncioTestCase):
async def asyncSetUp(self) -> None:
"""
Set up test case with a test database
"""
self.client = MongoSingleton.get_instance().get_client()
self.validator = ValidateAPIKey()

# changing the db so not to overlap with the right ones
# Using test database to avoid affecting production data
self.validator.db = "hivemind_test"
self.validator.tokens_collection = "tokens_test"

self.client.drop_database(self.validator.db)
# Clean start for each test
self.clean_database()

async def asyncTearDown(self) -> None:
"""
Clean up test database after each test
"""
self.clean_database()

def tearDown(self) -> None:
def clean_database(self) -> None:
"""
Helper method to clean the test database
"""
self.client.drop_database(self.validator.db)

amindadgar marked this conversation as resolved.
Show resolved Hide resolved
def test_no_token_available(self):
async def test_no_token_available(self):
"""
Test validation when no tokens exist in database
"""
api_key = "1234"
valid = self.validator.validate(api_key)
valid = await self.validator.validate(api_key)

self.assertEqual(valid, False)

def test_no_matching_token_available(self):
async def test_no_matching_token_available(self):
"""
Test validation when tokens exist but none match
"""
# Insert test tokens - no await needed as this is synchronous
self.client[self.validator.db][self.validator.tokens_collection].insert_many(
[
{
Expand All @@ -44,13 +64,18 @@ def test_no_matching_token_available(self):
},
]
)

api_key = "1234"
valid = self.validator.validate(api_key)
valid = await self.validator.validate(api_key)

self.assertEqual(valid, False)

def test_single_token_available(self):
async def test_single_token_available(self):
"""
Test validation when matching token exists
"""
api_key = "1234"

self.client[self.validator.db][self.validator.tokens_collection].insert_many(
[
{
Expand All @@ -70,6 +95,15 @@ def test_single_token_available(self):
},
]
)
valid = self.validator.validate(api_key)

valid = await self.validator.validate(api_key)

self.assertEqual(valid, True)

async def test_validation_with_empty_api_key(self):
"""
Test validation with empty API key
"""
valid = await self.validator.validate("")

self.assertEqual(valid, False)
Loading