-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Updated amqp server payload support!
http needs some additonal work and also the saving of payloads on db.
- Loading branch information
1 parent
13bfb0f
commit 8132947
Showing
8 changed files
with
254 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .payload import FiltersModel, InputModel, OutputModel, PayloadModel | ||
from .payload import PayloadModel, ResponseModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,28 @@ | ||
from pydantic import BaseModel | ||
|
||
|
||
class InputModel(BaseModel): | ||
message: str | None = None | ||
community_id: str | None = None | ||
class DestinationModel(BaseModel): | ||
queue: str | ||
event: str | ||
|
||
|
||
class OutputModel(BaseModel): | ||
destination: str | None = None | ||
class RouteModel(BaseModel): | ||
source: str | ||
destination: DestinationModel | None | ||
|
||
|
||
class FiltersModel(BaseModel): | ||
username: list[str] | None = None | ||
resource: str | None = None | ||
dataSourceA: dict[str, list[str] | None] | None = None | ||
class QuestionModel(BaseModel): | ||
message: str | ||
filters: dict | None | ||
|
||
|
||
class ResponseModel(BaseModel): | ||
message: str | ||
|
||
|
||
class PayloadModel(BaseModel): | ||
input: InputModel | ||
output: OutputModel | ||
metadata: dict | ||
session_id: str | None = None | ||
filters: FiltersModel | None = None | ||
communityId: str | ||
route: RouteModel | ||
question: QuestionModel | ||
response: ResponseModel | ||
metadata: dict | None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import unittest | ||
from unittest.mock import patch | ||
import mongomock | ||
|
||
from schema import PayloadModel | ||
from utils.persist_payload import PersistPayload | ||
|
||
|
||
class TestPersistPayloadIntegration(unittest.TestCase): | ||
"""Integration tests for the PersistPayload class.""" | ||
|
||
# Sample PayloadModel data | ||
sample_payload_data = { | ||
"communityId": "650be9f4e2c1234abcd12345", | ||
"route": { | ||
"source": "api-gateway", | ||
"destination": {"queue": "data-processing", "event": "new-data"}, | ||
}, | ||
"question": { | ||
"message": "What is the meaning of life?", | ||
"filters": {"category": "philosophy"}, | ||
}, | ||
"response": { | ||
"message": "The meaning of life is subjective and varies from person to person." | ||
}, | ||
"metadata": {"timestamp": "2023-10-08T12:00:00"}, | ||
} | ||
|
||
@patch("utils.mongo.MongoSingleton.get_instance") | ||
def setUp(self, mock_mongo_instance): | ||
"""Setup a mocked MongoDB client for testing.""" | ||
# Create a mock MongoDB client using `mongomock` | ||
self.mock_client = mongomock.MongoClient() | ||
|
||
# Mock the `get_client` method to return the mocked client | ||
mock_instance = mock_mongo_instance.return_value | ||
mock_instance.get_client.return_value = self.mock_client | ||
|
||
# Initialize the class under test with the mocked MongoDB client | ||
self.persist_payload = PersistPayload() | ||
|
||
def test_persist_valid_payload(self): | ||
"""Test persisting a valid PayloadModel into the database.""" | ||
# Create a PayloadModel instance from the sample data | ||
payload = PayloadModel(**self.sample_payload_data) | ||
|
||
# Call the `persist` method to store the payload in the mock database | ||
self.persist_payload.persist(payload) | ||
|
||
# Retrieve the persisted document from the mock database | ||
persisted_data = self.mock_client["hivemind"]["messages"].find_one( | ||
{"communityId": self.sample_payload_data["communityId"]} | ||
) | ||
|
||
# Check that the persisted document matches the original payload | ||
self.assertIsNotNone(persisted_data) | ||
self.assertEqual( | ||
persisted_data["communityId"], self.sample_payload_data["communityId"] | ||
) | ||
self.assertEqual( | ||
persisted_data["route"]["source"], | ||
self.sample_payload_data["route"]["source"], | ||
) | ||
self.assertEqual( | ||
persisted_data["question"]["message"], | ||
self.sample_payload_data["question"]["message"], | ||
) | ||
|
||
def test_persist_with_invalid_payload(self): | ||
"""Test that attempting to persist an invalid payload raises an exception.""" | ||
# Create an invalid PayloadModel by omitting required fields | ||
invalid_payload_data = { | ||
"communityId": self.sample_payload_data["communityId"], | ||
"route": {}, # Invalid as required fields are missing | ||
"question": {"message": ""}, | ||
"response": {"message": ""}, | ||
"metadata": None, | ||
} | ||
|
||
# Construct the PayloadModel (this will raise a validation error) | ||
with self.assertRaises(ValueError): | ||
PayloadModel(**invalid_payload_data) | ||
|
||
def test_persist_handles_mongo_exception(self): | ||
"""Test that MongoDB exceptions are properly handled and logged.""" | ||
# Create a valid PayloadModel instance | ||
payload = PayloadModel(**self.sample_payload_data) | ||
|
||
# Simulate a MongoDB exception during the insert operation | ||
with patch.object( | ||
self.mock_client["hivemind"]["messages"], | ||
"insert_one", | ||
side_effect=Exception("Database error"), | ||
): | ||
with self.assertLogs(level="ERROR") as log: | ||
self.persist_payload.persist(payload) | ||
self.assertIn( | ||
"Failed to persist payload to database for community", log.output[0] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from unittest import TestCase | ||
from pydantic import ValidationError | ||
|
||
from schema.payload import PayloadModel | ||
|
||
|
||
class TestPayloadModel(TestCase): | ||
"""Test suite for PayloadModel and its nested models.""" | ||
|
||
valid_community_id = "650be9f4e2c1234abcd12345" | ||
|
||
# Helper function to create a valid payload dictionary | ||
def get_valid_payload(self): | ||
return { | ||
"communityId": self.valid_community_id, | ||
"route": { | ||
"source": "some-source", | ||
"destination": {"queue": "some-queue", "event": "some-event"}, | ||
}, | ||
"question": { | ||
"message": "What is the best approach?", | ||
"filters": {"category": "science"}, | ||
}, | ||
"response": {"message": "The best approach is using scientific methods."}, | ||
"metadata": {"timestamp": "2023-10-08T12:00:00"}, | ||
} | ||
|
||
def test_valid_payload(self): | ||
"""Test if a valid payload is correctly validated.""" | ||
payload = self.get_valid_payload() | ||
validated_model = PayloadModel(**payload) | ||
self.assertEqual(validated_model.communityId, payload["communityId"]) | ||
self.assertEqual(validated_model.route.source, payload["route"]["source"]) | ||
self.assertEqual( | ||
validated_model.route.destination.queue, | ||
payload["route"]["destination"]["queue"], | ||
) | ||
|
||
def test_missing_required_field(self): | ||
"""Test if missing a required field raises a ValidationError.""" | ||
payload = self.get_valid_payload() | ||
del payload["route"] # Remove a required field | ||
with self.assertRaises(ValidationError): | ||
PayloadModel(**payload) | ||
|
||
def test_none_as_optional_fields(self): | ||
"""Test if setting optional fields as None is valid.""" | ||
payload = self.get_valid_payload() | ||
payload["route"]["destination"] = None # Set optional destination to None | ||
payload["question"]["filters"] = None # Set optional filters to None | ||
payload["metadata"] = None # Set optional metadata to None | ||
validated_model = PayloadModel(**payload) | ||
self.assertIsNone(validated_model.route.destination) | ||
self.assertIsNone(validated_model.question.filters) | ||
self.assertIsNone(validated_model.metadata) | ||
|
||
def test_invalid_route(self): | ||
"""Test if an invalid RouteModel within PayloadModel raises a ValidationError.""" | ||
payload = self.get_valid_payload() | ||
payload["route"]["source"] = None # Invalid value for a required field | ||
with self.assertRaises(ValidationError): | ||
PayloadModel(**payload) | ||
|
||
def test_empty_string_fields(self): | ||
"""Test if fields with empty strings are allowed.""" | ||
payload = self.get_valid_payload() | ||
payload["route"]["source"] = "" # Set an empty string | ||
payload["question"]["message"] = "" # Set an empty string | ||
validated_model = PayloadModel(**payload) | ||
self.assertEqual(validated_model.route.source, "") | ||
self.assertEqual(validated_model.question.message, "") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import logging | ||
|
||
from utils.mongo import MongoSingleton | ||
from schema import PayloadModel | ||
|
||
|
||
class PersistPayload: | ||
def __init__(self) -> None: | ||
# the place we would save data in mongo | ||
self.db = "hivemind" | ||
self.collection = "internal_messages" | ||
self.client = MongoSingleton.get_instance().get_client() | ||
|
||
def persist(self, payload: PayloadModel) -> None: | ||
""" | ||
persist the payload within the database | ||
Parameters | ||
----------- | ||
payload : schema.PayloadModel | ||
the data payload to save on database | ||
""" | ||
community_id = payload.communityId | ||
try: | ||
self.client[self.db][self.collection].insert_one(payload.model_dump()) | ||
logging.info( | ||
f"Payload for community id: {community_id} persisted successfully!" | ||
) | ||
except Exception as exp: | ||
logging.error( | ||
f"Failed to persist payload to database for community: {community_id}!" | ||
f"Exception: {exp}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters