diff --git a/AWS/connectors/__init__.py b/AWS/connectors/__init__.py index 98e735824..10e30b5dd 100644 --- a/AWS/connectors/__init__.py +++ b/AWS/connectors/__init__.py @@ -4,6 +4,7 @@ import time from abc import ABCMeta from functools import cached_property +from typing import Callable from pydantic import BaseModel, Field from sekoia_automation.aio.connector import AsyncConnector @@ -74,34 +75,48 @@ async def next_batch(self) -> tuple[list[str], list[int]]: def run(self) -> None: # pragma: no cover """Run the connector.""" + loop = asyncio.get_event_loop() + loop.run_until_complete(self.async_run()) + + async def async_run(self) -> None: + """Run the connector.""" + background_tasks = set() while self.running: try: - loop = asyncio.get_event_loop() - - while self.running: - processing_start = time.time() - - batch_result: tuple[list[str], list[int]] = loop.run_until_complete(self.next_batch()) - message_ids, messages_timestamp = batch_result - - # Identify delay between message timestamp ( when it was pushed to sqs ) - # and current timestamp ( when it was processed ) - processing_end = time.time() - for message_timestamp in messages_timestamp: - EVENTS_LAG.labels(intake_key=self.configuration.intake_key).set( - processing_end - (message_timestamp / 1000) - ) - - OUTCOMING_EVENTS.labels(intake_key=self.configuration.intake_key).inc(len(message_ids)) - FORWARD_EVENTS_DURATION.labels(intake_key=self.configuration.intake_key).observe( - processing_end - processing_start - ) - - if len(message_ids) > 0: - self.log(message="Pushed {0} records".format(len(message_ids)), level="info") - else: - self.log(message="No records to forward", level="info") - time.sleep(self.configuration.frequency) - + processing_start = time.time() + result: tuple[list[dict], list[int]] = await self.next_batch() + records, messages_timestamp = result + if records: + task = asyncio.create_task(self.push_data_to_intakes(events=records)) + background_tasks.add(task) + task.add_done_callback(background_tasks.discard) + task.add_done_callback(self.push_data_to_intakes_callback(processing_start, messages_timestamp)) + else: + self.log(message="No records to forward", level="info") + await asyncio.sleep(self.configuration.frequency) except Exception as e: self.log_exception(e) + + # Wait for all logs to be pushed before exiting + await asyncio.gather(*background_tasks) + + def push_data_to_intakes_callback(self, processing_start: time.time, messages_timestamp: list[int]) -> Callable[[asyncio.Task], None]: + """Callback to remove the task from the background tasks set.""" + + def callback(task: asyncio.Task) -> None: + """Callback to remove the task from the background tasks set.""" + message_ids = task.result() + processing_end = time.time() + for message_timestamp in messages_timestamp: + EVENTS_LAG.labels(intake_key=self.configuration.intake_key).set( + processing_end - (message_timestamp / 1000) + ) + + OUTCOMING_EVENTS.labels(intake_key=self.configuration.intake_key).inc(len(message_ids)) + FORWARD_EVENTS_DURATION.labels(intake_key=self.configuration.intake_key).observe( + processing_end - processing_start + ) + if len(message_ids) > 0: + self.log(message="Pushed {0} records".format(len(message_ids)), level="info") + + return callback diff --git a/AWS/connectors/s3/__init__.py b/AWS/connectors/s3/__init__.py index 6532cde74..069c3c4a0 100644 --- a/AWS/connectors/s3/__init__.py +++ b/AWS/connectors/s3/__init__.py @@ -89,7 +89,7 @@ def decompress_content(data: bytes) -> bytes: return data - async def next_batch(self, previous_processing_end: float | None = None) -> tuple[list[str], list[int]]: + async def next_batch(self, previous_processing_end: float | None = None) -> tuple[list[dict], list[int]]: """ Get next batch of messages. @@ -147,6 +147,4 @@ async def next_batch(self, previous_processing_end: float | None = None) -> tupl if len(records) >= self.configuration.records_in_queue_per_batch or not records: continue_receiving = False - result = await self.push_data_to_intakes(events=records) - - return result, timestamps_to_log + return records, timestamps_to_log