Skip to content

Commit

Permalink
Merge pull request #95 from TogetherCrew/feat/93-add-api-key
Browse files Browse the repository at this point in the history
Feat/93 add api key
  • Loading branch information
amindadgar authored Nov 12, 2024
2 parents 80afc66 + b0dd8bb commit 7260171
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 3 deletions.
7 changes: 4 additions & 3 deletions routers/http.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging

from celery.result import AsyncResult
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from schema import HTTPPayload, QuestionModel, ResponseModel
from services.api_key import api_key_header
from utils.persist_payload import PersistPayload
from worker.tasks import ask_question_auto_search

Expand All @@ -16,7 +17,7 @@ class RequestPayload(BaseModel):
router = APIRouter()


@router.post("/ask")
@router.post("/ask", dependencies=[Depends(api_key_header)])
async def ask(payload: RequestPayload):
query = payload.question.message
community_id = payload.communityId
Expand All @@ -36,7 +37,7 @@ async def ask(payload: RequestPayload):
return {"id": task.id}


@router.get("/status")
@router.get("/status", dependencies=[Depends(api_key_header)])
async def status(task_id: str):
task = AsyncResult(task_id)
if task.status == "SUCCESS":
Expand Down
Empty file added services/__init__.py
Empty file.
51 changes: 51 additions & 0 deletions services/api_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from fastapi import HTTPException, Security
from fastapi.security.api_key import APIKeyHeader
from starlette.status import HTTP_401_UNAUTHORIZED
from utils.mongo import MongoSingleton

# List of valid API keys - in production, this should be stored securely
API_KEY_NAME = "X-API-Key"

api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)


async def get_api_key(api_key_header: str = Security(api_key_header)):
validator = ValidateAPIKey()

if not api_key_header:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="No API key provided"
)

if not validator.validate(api_key_header):
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid API key")

return api_key_header


class ValidateAPIKey:
def __init__(self) -> None:
self.client = MongoSingleton.get_instance().get_client()
self.db = "hivemind"
self.tokens_collection = "tokens"

def validate(self, api_key: str) -> bool:
"""
check if the api key is available in mongodb or not
Parameters
------------
api_key : str
the provided key to check in db
Returns
---------
valid : bool
if the key was available in mongo collection, then return True
else, the token is not valid and return False
"""
document = self.client[self.db][self.tokens_collection].find_one(
{"token": api_key}
)

return True if document else False
75 changes: 75 additions & 0 deletions tests/integration/test_validate_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from unittest import TestCase

from services.api_key import ValidateAPIKey
from utils.mongo import MongoSingleton


class TestValidateToken(TestCase):
def setUp(self) -> None:
self.client = MongoSingleton.get_instance().get_client()
self.validator = ValidateAPIKey()

# changing the db so not to overlap with the right ones
self.validator.db = "hivemind_test"
self.validator.tokens_collection = "tokens_test"

self.client.drop_database(self.validator.db)

def tearDown(self) -> None:
self.client.drop_database(self.validator.db)

def test_no_token_available(self):
api_key = "1234"
valid = self.validator.validate(api_key)

self.assertEqual(valid, False)

def test_no_matching_token_available(self):
self.client[self.validator.db][self.validator.tokens_collection].insert_many(
[
{
"id": 1,
"token": "1111",
"options": {},
},
{
"id": 2,
"token": "2222",
"options": {},
},
{
"id": 3,
"token": "3333",
"options": {},
},
]
)
api_key = "1234"
valid = self.validator.validate(api_key)

self.assertEqual(valid, False)

def test_single_token_available(self):
api_key = "1234"
self.client[self.validator.db][self.validator.tokens_collection].insert_many(
[
{
"id": 1,
"token": api_key,
"options": {},
},
{
"id": 2,
"token": "2222",
"options": {},
},
{
"id": 3,
"token": "3333",
"options": {},
},
]
)
valid = self.validator.validate(api_key)

self.assertEqual(valid, True)

0 comments on commit 7260171

Please sign in to comment.