diff --git a/faster_sam/dependencies/events.py b/faster_sam/dependencies/events.py index 256f6bec..f78bbade 100644 --- a/faster_sam/dependencies/events.py +++ b/faster_sam/dependencies/events.py @@ -1,8 +1,13 @@ +import hashlib from datetime import datetime, timezone -from typing import Any, Dict +from typing import Any, Callable, Dict, Type from uuid import uuid4 +import uuid from fastapi import Request +from pydantic import BaseModel + +from faster_sam.protocols import IntoSQSInfo async def apigateway_proxy(request: Request) -> Dict[str, Any]: @@ -31,3 +36,34 @@ async def apigateway_proxy(request: Request) -> Dict[str, Any]: }, } return event + + +def sqs(schema: Type[BaseModel]) -> Callable[[BaseModel], Dict[str, Any]]: + def dep(message: schema) -> Dict[str, Any]: + assert isinstance(message, IntoSQSInfo) + + info = message.into() + event = { + "Records": [ + { + "messageId": info.id, + "receiptHandle": str(uuid.uuid4()), + "body": info.body, + "attributes": { + "ApproximateReceiveCount": info.receive_count, + "SentTimestamp": info.sent_timestamp, + "SenderId": str(uuid.uuid4()), + "ApproximateFirstReceiveTimestamp": info.sent_timestamp, + }, + "messageAttributes": info.message_attributes, + "md5OfBody": hashlib.md5(info.body.encode()).hexdigest(), + "eventSource": "aws:sqs", + "eventSourceARN": info.source_arn, + "awsRegion": None, + }, + ] + } + + return event + + return dep diff --git a/faster_sam/protocols.py b/faster_sam/protocols.py new file mode 100644 index 00000000..32b5da21 --- /dev/null +++ b/faster_sam/protocols.py @@ -0,0 +1,8 @@ +from typing import Protocol, runtime_checkable + +from faster_sam.schemas import SQSInfo + + +@runtime_checkable +class IntoSQSInfo(Protocol): + def into(self) -> SQSInfo: ... diff --git a/faster_sam/schemas.py b/faster_sam/schemas.py new file mode 100644 index 00000000..6bba0137 --- /dev/null +++ b/faster_sam/schemas.py @@ -0,0 +1,41 @@ +from datetime import datetime +from typing import Dict, Optional +from pydantic import BaseModel, Base64UrlStr, Field + + +class SQSInfo(BaseModel): + id: str + body: str + receive_count: int + sent_timestamp: int + source_arn: str + message_attributes: Optional[Dict[str, str]] = Field(default=None) + + +class PubSubMessage(BaseModel): + data: Base64UrlStr + messageId: str + publishTime: datetime + attributes: Optional[Dict[str, str]] = Field(default=None) + + +class PubSubEnvelope(BaseModel): + message: PubSubMessage + subscription: str + deliveryAttempt: int + + def into(self) -> SQSInfo: + milliseconds = 1000 + + publish_time = int(self.message.publishTime.timestamp() * milliseconds) + + topic_name = self.subscription.rsplit("/", maxsplit=1)[-1] + source_arn = f"arn:aws:sqs:::{topic_name}" + return SQSInfo( + id=self.message.messageId, + body=self.message.data, + receive_count=self.deliveryAttempt, + sent_timestamp=publish_time, + message_attributes=self.message.attributes, + source_arn=source_arn, + ) diff --git a/tests/test_dependencies_events.py b/tests/test_dependencies_events.py index f1cd4962..14221274 100644 --- a/tests/test_dependencies_events.py +++ b/tests/test_dependencies_events.py @@ -1,8 +1,13 @@ +import base64 import unittest +from datetime import datetime, timezone +from unittest.mock import patch +import uuid from fastapi import FastAPI, Request from faster_sam.dependencies import events +from faster_sam.schemas import PubSubEnvelope def build_request(): @@ -50,3 +55,55 @@ async def test_event(self): self.assertEqual(event["requestContext"]["path"], "/ping/pong") self.assertEqual(event["requestContext"]["httpMethod"], "GET") self.assertEqual(event["requestContext"]["protocol"], "HTTP/1.1") + + +class TestSQS(unittest.TestCase): + async def test_event(self): + data = { + "message": { + "data": "aGVsbG8=", + "attributes": {"foo": "bar"}, + "messageId": "10519041647717348", + "publishTime": "2024-02-22T15:45:31.346Z", + }, + "subscription": "projects/foo/subscriptions/bar", + "deliveryAttempt": 1, + } + + pubsub_envelope = PubSubEnvelope(**data) + + sender_id = uuid.uuid4() + + with patch("uuid.uuid4", return_value=sender_id): + SQSEvent = events.sqs(PubSubEnvelope) + + event = SQSEvent(pubsub_envelope) + + parsed_datetime = datetime.strptime(data["message"]["publishTime"], "%Y-%m-%dT%H:%M:%S.%fZ") + parsed_datetime_utc = parsed_datetime.replace(tzinfo=timezone.utc) + timestamp_milliseconds = int(parsed_datetime_utc.timestamp() * 1000) + + self.assertIsInstance(event, dict) + record = event["Records"][0] + self.assertEqual(record["messageId"], data["message"]["messageId"]) + self.assertEqual(record["body"], base64.b64decode(data["message"]["data"]).decode("utf-8")) + self.assertEqual(record["attributes"]["ApproximateReceiveCount"], data["deliveryAttempt"]) + self.assertEqual( + record["attributes"]["SentTimestamp"], + timestamp_milliseconds, + ) + self.assertEqual(record["attributes"]["SenderId"], str(sender_id)) + self.assertEqual( + record["attributes"]["ApproximateFirstReceiveTimestamp"], + timestamp_milliseconds, + ) + self.assertEqual(record["messageAttributes"], data["message"]["attributes"]) + self.assertEqual( + record["md5OfBody"], + data["message"]["data"], + ) + self.assertEqual(record["eventSource"], "aws:sqs") + self.assertEqual( + record["eventSourceARN"], + f"arn:aws:sqs:::{data['subscription'].rsplit('/', maxsplit=1)[-1]}", + )