diff --git a/routers/amqp.py b/routers/amqp.py index 89aed78..22b4f45 100644 --- a/routers/amqp.py +++ b/routers/amqp.py @@ -1,12 +1,14 @@ from datetime import datetime from faststream.rabbit.fastapi import Logger, RabbitRouter # type: ignore +from faststream.rabbit import RabbitBroker from faststream.rabbit.schemas.queue import RabbitQueue from pydantic import BaseModel -from schema import PayloadModel, InputModel, OutputModel +from schema import PayloadModel, ResponseModel from tc_messageBroker.rabbit_mq.event import Event from tc_messageBroker.rabbit_mq.queue import Queue from utils.credentials import load_rabbitmq_credentials +from utils.persist_payload import PersistPayload from worker.tasks import query_data_sources rabbitmq_creds = load_rabbitmq_credentials() @@ -21,33 +23,39 @@ class Payload(BaseModel): @router.subscriber(queue=RabbitQueue(name=Queue.HIVEMIND, durable=True)) -@router.publisher(queue=RabbitQueue(Queue.DISCORD_BOT, durable=True)) async def ask(payload: Payload, logger: Logger): if payload.event == Event.HIVEMIND.INTERACTION_CREATED: try: - question = payload.content.input.message - community_id = payload.content.input.community_id + question = payload.content.question.message + community_id = payload.content.communityId logger.info(f"COMMUNITY_ID: {community_id} Received job") response = query_data_sources(community_id=community_id, query=question) logger.info(f"COMMUNITY_ID: {community_id} Job finished") response_payload = PayloadModel( - input=InputModel(message=response, community_id=community_id), - output=OutputModel(destination=payload.content.output.destination), + communityId=community_id, + route=payload.content.route, + question=payload.content.question, + response=ResponseModel(message=response), metadata=payload.content.metadata, - session_id=payload.content.session_id, ) + # dumping the whole payload of question & answer to db + persister = PersistPayload() + persister.persist(response_payload) + result = Payload( - event=payload.content.output.destination, + event=payload.content.route.destination.event, date=str(datetime.now()), content=response_payload.model_dump(), ) - return result + async with RabbitBroker(url=rabbitmq_creds["url"]) as broker: + await broker.publish( + message=result, queue=payload.content.route.destination.queue + ) except Exception as e: logger.error(f"Errors While processing job! {e}") else: logger.error( - f"No more event available for {Queue.HIVEMIND} queue! " - f"Received event: `{payload.event}`" + f"No such `{payload.event}` event available for {Queue.HIVEMIND} queue!" ) diff --git a/routers/http.py b/routers/http.py index 86cebda..fe39ea8 100644 --- a/routers/http.py +++ b/routers/http.py @@ -5,7 +5,8 @@ class Payload(BaseModel): - query: str + question: str + response: str | None = None community_id: str # bot_given_info: dict[str, Any] @@ -15,7 +16,12 @@ class Payload(BaseModel): @router.post("/ask") async def ask(payload: Payload): - task = ask_question_auto_search.delay(**payload.model_dump()) + query = payload.question + community_id = payload.community_id + task = ask_question_auto_search.delay( + community_id=community_id, + query=query, + ) return {"id": task.id} diff --git a/schema/__init__.py b/schema/__init__.py index f617898..e8b27a0 100644 --- a/schema/__init__.py +++ b/schema/__init__.py @@ -1 +1 @@ -from .payload import FiltersModel, InputModel, OutputModel, PayloadModel +from .payload import PayloadModel, ResponseModel diff --git a/schema/payload.py b/schema/payload.py index 7f88ce3..a63b612 100644 --- a/schema/payload.py +++ b/schema/payload.py @@ -1,24 +1,28 @@ from pydantic import BaseModel -class InputModel(BaseModel): - message: str | None = None - community_id: str | None = None +class DestinationModel(BaseModel): + queue: str + event: str -class OutputModel(BaseModel): - destination: str | None = None +class RouteModel(BaseModel): + source: str + destination: DestinationModel | None -class FiltersModel(BaseModel): - username: list[str] | None = None - resource: str | None = None - dataSourceA: dict[str, list[str] | None] | None = None +class QuestionModel(BaseModel): + message: str + filters: dict | None + + +class ResponseModel(BaseModel): + message: str class PayloadModel(BaseModel): - input: InputModel - output: OutputModel - metadata: dict - session_id: str | None = None - filters: FiltersModel | None = None + communityId: str + route: RouteModel + question: QuestionModel + response: ResponseModel + metadata: dict | None diff --git a/tests/integration/test_persist_payload.py b/tests/integration/test_persist_payload.py new file mode 100644 index 0000000..c5478a2 --- /dev/null +++ b/tests/integration/test_persist_payload.py @@ -0,0 +1,99 @@ +import unittest +from unittest.mock import patch +import mongomock + +from schema import PayloadModel +from utils.persist_payload import PersistPayload + + +class TestPersistPayloadIntegration(unittest.TestCase): + """Integration tests for the PersistPayload class.""" + + # Sample PayloadModel data + sample_payload_data = { + "communityId": "650be9f4e2c1234abcd12345", + "route": { + "source": "api-gateway", + "destination": {"queue": "data-processing", "event": "new-data"}, + }, + "question": { + "message": "What is the meaning of life?", + "filters": {"category": "philosophy"}, + }, + "response": { + "message": "The meaning of life is subjective and varies from person to person." + }, + "metadata": {"timestamp": "2023-10-08T12:00:00"}, + } + + @patch("utils.mongo.MongoSingleton.get_instance") + def setUp(self, mock_mongo_instance): + """Setup a mocked MongoDB client for testing.""" + # Create a mock MongoDB client using `mongomock` + self.mock_client = mongomock.MongoClient() + + # Mock the `get_client` method to return the mocked client + mock_instance = mock_mongo_instance.return_value + mock_instance.get_client.return_value = self.mock_client + + # Initialize the class under test with the mocked MongoDB client + self.persist_payload = PersistPayload() + + def test_persist_valid_payload(self): + """Test persisting a valid PayloadModel into the database.""" + # Create a PayloadModel instance from the sample data + payload = PayloadModel(**self.sample_payload_data) + + # Call the `persist` method to store the payload in the mock database + self.persist_payload.persist(payload) + + # Retrieve the persisted document from the mock database + persisted_data = self.mock_client["hivemind"]["messages"].find_one( + {"communityId": self.sample_payload_data["communityId"]} + ) + + # Check that the persisted document matches the original payload + self.assertIsNotNone(persisted_data) + self.assertEqual( + persisted_data["communityId"], self.sample_payload_data["communityId"] + ) + self.assertEqual( + persisted_data["route"]["source"], + self.sample_payload_data["route"]["source"], + ) + self.assertEqual( + persisted_data["question"]["message"], + self.sample_payload_data["question"]["message"], + ) + + def test_persist_with_invalid_payload(self): + """Test that attempting to persist an invalid payload raises an exception.""" + # Create an invalid PayloadModel by omitting required fields + invalid_payload_data = { + "communityId": self.sample_payload_data["communityId"], + "route": {}, # Invalid as required fields are missing + "question": {"message": ""}, + "response": {"message": ""}, + "metadata": None, + } + + # Construct the PayloadModel (this will raise a validation error) + with self.assertRaises(ValueError): + PayloadModel(**invalid_payload_data) + + def test_persist_handles_mongo_exception(self): + """Test that MongoDB exceptions are properly handled and logged.""" + # Create a valid PayloadModel instance + payload = PayloadModel(**self.sample_payload_data) + + # Simulate a MongoDB exception during the insert operation + with patch.object( + self.mock_client["hivemind"]["messages"], + "insert_one", + side_effect=Exception("Database error"), + ): + with self.assertLogs(level="ERROR") as log: + self.persist_payload.persist(payload) + self.assertIn( + "Failed to persist payload to database for community", log.output[0] + ) diff --git a/tests/unit/test_payload_schema.py b/tests/unit/test_payload_schema.py new file mode 100644 index 0000000..c714a8d --- /dev/null +++ b/tests/unit/test_payload_schema.py @@ -0,0 +1,71 @@ +from unittest import TestCase +from pydantic import ValidationError + +from schema.payload import PayloadModel + + +class TestPayloadModel(TestCase): + """Test suite for PayloadModel and its nested models.""" + + valid_community_id = "650be9f4e2c1234abcd12345" + + # Helper function to create a valid payload dictionary + def get_valid_payload(self): + return { + "communityId": self.valid_community_id, + "route": { + "source": "some-source", + "destination": {"queue": "some-queue", "event": "some-event"}, + }, + "question": { + "message": "What is the best approach?", + "filters": {"category": "science"}, + }, + "response": {"message": "The best approach is using scientific methods."}, + "metadata": {"timestamp": "2023-10-08T12:00:00"}, + } + + def test_valid_payload(self): + """Test if a valid payload is correctly validated.""" + payload = self.get_valid_payload() + validated_model = PayloadModel(**payload) + self.assertEqual(validated_model.communityId, payload["communityId"]) + self.assertEqual(validated_model.route.source, payload["route"]["source"]) + self.assertEqual( + validated_model.route.destination.queue, + payload["route"]["destination"]["queue"], + ) + + def test_missing_required_field(self): + """Test if missing a required field raises a ValidationError.""" + payload = self.get_valid_payload() + del payload["route"] # Remove a required field + with self.assertRaises(ValidationError): + PayloadModel(**payload) + + def test_none_as_optional_fields(self): + """Test if setting optional fields as None is valid.""" + payload = self.get_valid_payload() + payload["route"]["destination"] = None # Set optional destination to None + payload["question"]["filters"] = None # Set optional filters to None + payload["metadata"] = None # Set optional metadata to None + validated_model = PayloadModel(**payload) + self.assertIsNone(validated_model.route.destination) + self.assertIsNone(validated_model.question.filters) + self.assertIsNone(validated_model.metadata) + + def test_invalid_route(self): + """Test if an invalid RouteModel within PayloadModel raises a ValidationError.""" + payload = self.get_valid_payload() + payload["route"]["source"] = None # Invalid value for a required field + with self.assertRaises(ValidationError): + PayloadModel(**payload) + + def test_empty_string_fields(self): + """Test if fields with empty strings are allowed.""" + payload = self.get_valid_payload() + payload["route"]["source"] = "" # Set an empty string + payload["question"]["message"] = "" # Set an empty string + validated_model = PayloadModel(**payload) + self.assertEqual(validated_model.route.source, "") + self.assertEqual(validated_model.question.message, "") diff --git a/utils/persist_payload.py b/utils/persist_payload.py new file mode 100644 index 0000000..e4f666d --- /dev/null +++ b/utils/persist_payload.py @@ -0,0 +1,33 @@ +import logging + +from utils.mongo import MongoSingleton +from schema import PayloadModel + + +class PersistPayload: + def __init__(self) -> None: + # the place we would save data in mongo + self.db = "hivemind" + self.collection = "internal_messages" + self.client = MongoSingleton.get_instance().get_client() + + def persist(self, payload: PayloadModel) -> None: + """ + persist the payload within the database + + Parameters + ----------- + payload : schema.PayloadModel + the data payload to save on database + """ + community_id = payload.communityId + try: + self.client[self.db][self.collection].insert_one(payload.model_dump()) + logging.info( + f"Payload for community id: {community_id} persisted successfully!" + ) + except Exception as exp: + logging.error( + f"Failed to persist payload to database for community: {community_id}!" + f"Exception: {exp}" + ) diff --git a/worker/tasks.py b/worker/tasks.py index 55d1a5b..78c4693 100644 --- a/worker/tasks.py +++ b/worker/tasks.py @@ -120,7 +120,11 @@ def ask_question_auto_search( query: str, ) -> str: response = query_data_sources(community_id=community_id, query=query) - return response + return { + "community_id": community_id, + "question": query, + "response": response, + } @task_prerun.connect