From a11ab9b554b866ac48f345542ea70528a46b6d23 Mon Sep 17 00:00:00 2001 From: Raphael Cohen Date: Mon, 26 Feb 2024 11:34:13 +0100 Subject: [PATCH] feat(AWS): Async --- AWS/connectors/__init__.py | 73 ++++++++++++++++++++++------------- AWS/connectors/s3/__init__.py | 5 +-- 2 files changed, 48 insertions(+), 30 deletions(-) diff --git a/AWS/connectors/__init__.py b/AWS/connectors/__init__.py index 98e735824..5caa1a9f9 100644 --- a/AWS/connectors/__init__.py +++ b/AWS/connectors/__init__.py @@ -4,6 +4,7 @@ import time from abc import ABCMeta from functools import cached_property +from typing import Any, Callable from pydantic import BaseModel, Field from sekoia_automation.aio.connector import AsyncConnector @@ -74,34 +75,52 @@ async def next_batch(self) -> tuple[list[str], list[int]]: def run(self) -> None: # pragma: no cover """Run the connector.""" + loop = asyncio.get_event_loop() + loop.run_until_complete(self.async_run()) + + async def async_run(self) -> None: + """Run the connector.""" + background_tasks = set() while self.running: try: - loop = asyncio.get_event_loop() - - while self.running: - processing_start = time.time() - - batch_result: tuple[list[str], list[int]] = loop.run_until_complete(self.next_batch()) - message_ids, messages_timestamp = batch_result - - # Identify delay between message timestamp ( when it was pushed to sqs ) - # and current timestamp ( when it was processed ) - processing_end = time.time() - for message_timestamp in messages_timestamp: - EVENTS_LAG.labels(intake_key=self.configuration.intake_key).set( - processing_end - (message_timestamp / 1000) - ) - - OUTCOMING_EVENTS.labels(intake_key=self.configuration.intake_key).inc(len(message_ids)) - FORWARD_EVENTS_DURATION.labels(intake_key=self.configuration.intake_key).observe( - processing_end - processing_start - ) - - if len(message_ids) > 0: - self.log(message="Pushed {0} records".format(len(message_ids)), level="info") - else: - self.log(message="No records to forward", level="info") - time.sleep(self.configuration.frequency) - + processing_start = time.time() + result = await self.next_batch() + records, messages_timestamp = result + if records: + task = asyncio.create_task(self.push_data_to_intakes(events=records)) + background_tasks.add(task) + task.add_done_callback( + background_tasks.discard + ) # Remove the task from the one that must be awaited when exiting + task.add_done_callback(self.push_data_to_intakes_callback(processing_start, messages_timestamp)) + else: + self.log(message="No records to forward", level="info") + await asyncio.sleep(self.configuration.frequency) except Exception as e: self.log_exception(e) + + # Wait for all logs to be pushed before exiting + await asyncio.gather(*background_tasks, return_exceptions=True) + + def push_data_to_intakes_callback( + self, processing_start: float, messages_timestamp: list[int] + ) -> Callable[[asyncio.Task[Any]], None]: + """Callback to remove the task from the background tasks set.""" + + def callback(task: asyncio.Task[Any]) -> None: + """Callback to remove the task from the background tasks set.""" + message_ids = task.result() + processing_end = time.time() + for message_timestamp in messages_timestamp: + EVENTS_LAG.labels(intake_key=self.configuration.intake_key).set( + processing_end - (message_timestamp / 1000) + ) + + OUTCOMING_EVENTS.labels(intake_key=self.configuration.intake_key).inc(len(message_ids)) + FORWARD_EVENTS_DURATION.labels(intake_key=self.configuration.intake_key).observe( + processing_end - processing_start + ) + if len(message_ids) > 0: + self.log(message="Pushed {0} records".format(len(message_ids)), level="info") + + return callback diff --git a/AWS/connectors/s3/__init__.py b/AWS/connectors/s3/__init__.py index 6532cde74..03d3a0335 100644 --- a/AWS/connectors/s3/__init__.py +++ b/AWS/connectors/s3/__init__.py @@ -3,6 +3,7 @@ from abc import ABCMeta from functools import cached_property from gzip import decompress +from typing import Any import orjson @@ -147,6 +148,4 @@ async def next_batch(self, previous_processing_end: float | None = None) -> tupl if len(records) >= self.configuration.records_in_queue_per_batch or not records: continue_receiving = False - result = await self.push_data_to_intakes(events=records) - - return result, timestamps_to_log + return records, timestamps_to_log