Skip to content

Commit

Permalink
feat: Updated amqp server payload support!
Browse files Browse the repository at this point in the history
http needs some additonal work and also the saving of payloads on db.
  • Loading branch information
amindadgar committed Oct 8, 2024
1 parent 13bfb0f commit 8132947
Show file tree
Hide file tree
Showing 8 changed files with 254 additions and 29 deletions.
30 changes: 19 additions & 11 deletions routers/amqp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from datetime import datetime

from faststream.rabbit.fastapi import Logger, RabbitRouter # type: ignore
from faststream.rabbit import RabbitBroker
from faststream.rabbit.schemas.queue import RabbitQueue
from pydantic import BaseModel
from schema import PayloadModel, InputModel, OutputModel
from schema import PayloadModel, ResponseModel
from tc_messageBroker.rabbit_mq.event import Event
from tc_messageBroker.rabbit_mq.queue import Queue
from utils.credentials import load_rabbitmq_credentials
from utils.persist_payload import PersistPayload
from worker.tasks import query_data_sources

rabbitmq_creds = load_rabbitmq_credentials()
Expand All @@ -21,33 +23,39 @@ class Payload(BaseModel):


@router.subscriber(queue=RabbitQueue(name=Queue.HIVEMIND, durable=True))
@router.publisher(queue=RabbitQueue(Queue.DISCORD_BOT, durable=True))
async def ask(payload: Payload, logger: Logger):
if payload.event == Event.HIVEMIND.INTERACTION_CREATED:
try:
question = payload.content.input.message
community_id = payload.content.input.community_id
question = payload.content.question.message
community_id = payload.content.communityId

logger.info(f"COMMUNITY_ID: {community_id} Received job")
response = query_data_sources(community_id=community_id, query=question)
logger.info(f"COMMUNITY_ID: {community_id} Job finished")

response_payload = PayloadModel(
input=InputModel(message=response, community_id=community_id),
output=OutputModel(destination=payload.content.output.destination),
communityId=community_id,
route=payload.content.route,
question=payload.content.question,
response=ResponseModel(message=response),
metadata=payload.content.metadata,
session_id=payload.content.session_id,
)
# dumping the whole payload of question & answer to db
persister = PersistPayload()
persister.persist(response_payload)

result = Payload(
event=payload.content.output.destination,
event=payload.content.route.destination.event,
date=str(datetime.now()),
content=response_payload.model_dump(),
)
return result
async with RabbitBroker(url=rabbitmq_creds["url"]) as broker:
await broker.publish(
message=result, queue=payload.content.route.destination.queue
)
except Exception as e:
logger.error(f"Errors While processing job! {e}")
else:
logger.error(
f"No more event available for {Queue.HIVEMIND} queue! "
f"Received event: `{payload.event}`"
f"No such `{payload.event}` event available for {Queue.HIVEMIND} queue!"
)
10 changes: 8 additions & 2 deletions routers/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@


class Payload(BaseModel):
query: str
question: str
response: str | None = None
community_id: str
# bot_given_info: dict[str, Any]

Expand All @@ -15,7 +16,12 @@ class Payload(BaseModel):

@router.post("/ask")
async def ask(payload: Payload):
task = ask_question_auto_search.delay(**payload.model_dump())
query = payload.question
community_id = payload.community_id
task = ask_question_auto_search.delay(
community_id=community_id,
query=query,
)
return {"id": task.id}


Expand Down
2 changes: 1 addition & 1 deletion schema/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .payload import FiltersModel, InputModel, OutputModel, PayloadModel
from .payload import PayloadModel, ResponseModel
32 changes: 18 additions & 14 deletions schema/payload.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
from pydantic import BaseModel


class InputModel(BaseModel):
message: str | None = None
community_id: str | None = None
class DestinationModel(BaseModel):
queue: str
event: str


class OutputModel(BaseModel):
destination: str | None = None
class RouteModel(BaseModel):
source: str
destination: DestinationModel | None


class FiltersModel(BaseModel):
username: list[str] | None = None
resource: str | None = None
dataSourceA: dict[str, list[str] | None] | None = None
class QuestionModel(BaseModel):
message: str
filters: dict | None


class ResponseModel(BaseModel):
message: str


class PayloadModel(BaseModel):
input: InputModel
output: OutputModel
metadata: dict
session_id: str | None = None
filters: FiltersModel | None = None
communityId: str
route: RouteModel
question: QuestionModel
response: ResponseModel
metadata: dict | None
99 changes: 99 additions & 0 deletions tests/integration/test_persist_payload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import unittest
from unittest.mock import patch
import mongomock

from schema import PayloadModel
from utils.persist_payload import PersistPayload


class TestPersistPayloadIntegration(unittest.TestCase):
"""Integration tests for the PersistPayload class."""

# Sample PayloadModel data
sample_payload_data = {
"communityId": "650be9f4e2c1234abcd12345",
"route": {
"source": "api-gateway",
"destination": {"queue": "data-processing", "event": "new-data"},
},
"question": {
"message": "What is the meaning of life?",
"filters": {"category": "philosophy"},
},
"response": {
"message": "The meaning of life is subjective and varies from person to person."
},
"metadata": {"timestamp": "2023-10-08T12:00:00"},
}

@patch("utils.mongo.MongoSingleton.get_instance")
def setUp(self, mock_mongo_instance):
"""Setup a mocked MongoDB client for testing."""
# Create a mock MongoDB client using `mongomock`
self.mock_client = mongomock.MongoClient()

# Mock the `get_client` method to return the mocked client
mock_instance = mock_mongo_instance.return_value
mock_instance.get_client.return_value = self.mock_client

# Initialize the class under test with the mocked MongoDB client
self.persist_payload = PersistPayload()

def test_persist_valid_payload(self):
"""Test persisting a valid PayloadModel into the database."""
# Create a PayloadModel instance from the sample data
payload = PayloadModel(**self.sample_payload_data)

# Call the `persist` method to store the payload in the mock database
self.persist_payload.persist(payload)

# Retrieve the persisted document from the mock database
persisted_data = self.mock_client["hivemind"]["messages"].find_one(
{"communityId": self.sample_payload_data["communityId"]}
)

# Check that the persisted document matches the original payload
self.assertIsNotNone(persisted_data)
self.assertEqual(
persisted_data["communityId"], self.sample_payload_data["communityId"]
)
self.assertEqual(
persisted_data["route"]["source"],
self.sample_payload_data["route"]["source"],
)
self.assertEqual(
persisted_data["question"]["message"],
self.sample_payload_data["question"]["message"],
)

def test_persist_with_invalid_payload(self):
"""Test that attempting to persist an invalid payload raises an exception."""
# Create an invalid PayloadModel by omitting required fields
invalid_payload_data = {
"communityId": self.sample_payload_data["communityId"],
"route": {}, # Invalid as required fields are missing
"question": {"message": ""},
"response": {"message": ""},
"metadata": None,
}

# Construct the PayloadModel (this will raise a validation error)
with self.assertRaises(ValueError):
PayloadModel(**invalid_payload_data)

def test_persist_handles_mongo_exception(self):
"""Test that MongoDB exceptions are properly handled and logged."""
# Create a valid PayloadModel instance
payload = PayloadModel(**self.sample_payload_data)

# Simulate a MongoDB exception during the insert operation
with patch.object(
self.mock_client["hivemind"]["messages"],
"insert_one",
side_effect=Exception("Database error"),
):
with self.assertLogs(level="ERROR") as log:
self.persist_payload.persist(payload)
self.assertIn(
"Failed to persist payload to database for community", log.output[0]
)
71 changes: 71 additions & 0 deletions tests/unit/test_payload_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from unittest import TestCase
from pydantic import ValidationError

from schema.payload import PayloadModel


class TestPayloadModel(TestCase):
"""Test suite for PayloadModel and its nested models."""

valid_community_id = "650be9f4e2c1234abcd12345"

# Helper function to create a valid payload dictionary
def get_valid_payload(self):
return {
"communityId": self.valid_community_id,
"route": {
"source": "some-source",
"destination": {"queue": "some-queue", "event": "some-event"},
},
"question": {
"message": "What is the best approach?",
"filters": {"category": "science"},
},
"response": {"message": "The best approach is using scientific methods."},
"metadata": {"timestamp": "2023-10-08T12:00:00"},
}

def test_valid_payload(self):
"""Test if a valid payload is correctly validated."""
payload = self.get_valid_payload()
validated_model = PayloadModel(**payload)
self.assertEqual(validated_model.communityId, payload["communityId"])
self.assertEqual(validated_model.route.source, payload["route"]["source"])
self.assertEqual(
validated_model.route.destination.queue,
payload["route"]["destination"]["queue"],
)

def test_missing_required_field(self):
"""Test if missing a required field raises a ValidationError."""
payload = self.get_valid_payload()
del payload["route"] # Remove a required field
with self.assertRaises(ValidationError):
PayloadModel(**payload)

def test_none_as_optional_fields(self):
"""Test if setting optional fields as None is valid."""
payload = self.get_valid_payload()
payload["route"]["destination"] = None # Set optional destination to None
payload["question"]["filters"] = None # Set optional filters to None
payload["metadata"] = None # Set optional metadata to None
validated_model = PayloadModel(**payload)
self.assertIsNone(validated_model.route.destination)
self.assertIsNone(validated_model.question.filters)
self.assertIsNone(validated_model.metadata)

def test_invalid_route(self):
"""Test if an invalid RouteModel within PayloadModel raises a ValidationError."""
payload = self.get_valid_payload()
payload["route"]["source"] = None # Invalid value for a required field
with self.assertRaises(ValidationError):
PayloadModel(**payload)

def test_empty_string_fields(self):
"""Test if fields with empty strings are allowed."""
payload = self.get_valid_payload()
payload["route"]["source"] = "" # Set an empty string
payload["question"]["message"] = "" # Set an empty string
validated_model = PayloadModel(**payload)
self.assertEqual(validated_model.route.source, "")
self.assertEqual(validated_model.question.message, "")
33 changes: 33 additions & 0 deletions utils/persist_payload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging

from utils.mongo import MongoSingleton
from schema import PayloadModel


class PersistPayload:
def __init__(self) -> None:
# the place we would save data in mongo
self.db = "hivemind"
self.collection = "internal_messages"
self.client = MongoSingleton.get_instance().get_client()

def persist(self, payload: PayloadModel) -> None:
"""
persist the payload within the database
Parameters
-----------
payload : schema.PayloadModel
the data payload to save on database
"""
community_id = payload.communityId
try:
self.client[self.db][self.collection].insert_one(payload.model_dump())
logging.info(
f"Payload for community id: {community_id} persisted successfully!"
)
except Exception as exp:
logging.error(
f"Failed to persist payload to database for community: {community_id}!"
f"Exception: {exp}"
)
6 changes: 5 additions & 1 deletion worker/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ def ask_question_auto_search(
query: str,
) -> str:
response = query_data_sources(community_id=community_id, query=query)
return response
return {
"community_id": community_id,
"question": query,
"response": response,
}


@task_prerun.connect
Expand Down

0 comments on commit 8132947

Please sign in to comment.