Skip to content

Commit

Permalink
Merge pull request #1124 from SEKOIA-IO/fix/AzureEventHubShutdown_exe…
Browse files Browse the repository at this point in the history
…cution

Fix: azure event hub shutdown execution
  • Loading branch information
squioc authored Oct 14, 2024
2 parents bd25c00 + c099b8e commit e7562ee
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 20 deletions.
6 changes: 6 additions & 0 deletions Azure/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

## 2024-10-14 - 2.5.6

### Fixed

- Fix the way to handle graceful shutdown with async tasks

## 2024-09-30 - 2.5.5

### Changed
Expand Down
52 changes: 37 additions & 15 deletions Azure/connectors/azure_eventhub.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import os
import signal
import time
from asyncio import Task
from datetime import datetime, timezone
from functools import cached_property
from threading import Event
from typing import Any, Optional, cast

import orjson
Expand Down Expand Up @@ -66,15 +67,38 @@ class AzureEventsHubTrigger(AsyncConnector):

configuration: AzureEventsHubConfiguration

# The maximum time to wait for new messages before closing the client
wait_timeout = 600

def __init__(self, *args: Any, **kwargs: Optional[Any]) -> None:
super().__init__(*args, **kwargs)
self._stop_event = Event()
self._consumption_max_wait_time = int(os.environ.get("CONSUMER_MAX_WAIT_TIME", "600"), 10)
self._current_task: Task[Any] | None = None

@cached_property
def client(self) -> Client:
return Client(self.configuration)

async def stop_current_task(self) -> None:
"""
Stop the current async task
"""
# if the current task is defined and not already cancelled
if self._current_task is not None and not self._current_task.cancelled():
# cancel the receiving task
self._current_task.cancel()

# clean the current task
self._current_task = None

async def shutdown(self) -> None:
"""
Shutdown the connector
"""
self.stop()
await self.stop_current_task()
self.log("Shutting down the trigger")

async def handle_messages(self, partition_context: PartitionContext, messages: list[EventData]) -> None:
"""
Handle new messages
Expand All @@ -83,7 +107,7 @@ async def handle_messages(self, partition_context: PartitionContext, messages: l
# got messages, we forward them
await self.forward_events(messages)

else:
else: # pragma: no cover
# We reached the max_wait_time, close the current client
self.log(
message=(
Expand Down Expand Up @@ -142,7 +166,7 @@ async def forward_events(self, messages: list[EventData]) -> None:
FORWARD_EVENTS_DURATION.labels(intake_key=self.configuration.intake_key).observe(time.time() - start)

enqueued_times = [message.enqueued_time for message in messages if message.enqueued_time is not None]
if len(enqueued_times) > 0:
if len(enqueued_times) > 0: # pragma: no cover
now = datetime.now(timezone.utc)
messages_age = [int((now - enqueued_time).total_seconds()) for enqueued_time in enqueued_times]

Expand Down Expand Up @@ -179,17 +203,14 @@ async def receive_events(self) -> None:

async def async_run(self) -> None:
while self.running:
task = asyncio.create_task(self.receive_events())
self._current_task = asyncio.create_task(self.receive_events())

try:
# Allow the task to run for the specified duration (10 minutes)
await asyncio.sleep(600)
# Allow the task to run for the specified duration (10 minutes default)
await asyncio.sleep(self.wait_timeout)

# Cancel the receiving task after the duration
task.cancel()

# Wait for the task to handle the cancellation
await task
# Stop the receiving task after the duration
await self.stop_current_task()

except asyncio.CancelledError:
self.log(message="Receiving task was cancelled", level="info")
Expand All @@ -198,11 +219,12 @@ async def async_run(self) -> None:
# Ensure the client is closed properly
await self.client.close()

# Sleep for a short period before starting the next cycle
await asyncio.sleep(5) # Adjust if necessary

def run(self) -> None: # pragma: no cover
self.log(message="Azure EventHub Trigger has started", level="info")

loop = asyncio.get_event_loop()
loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(self.shutdown()))
loop.add_signal_handler(signal.SIGINT, lambda: loop.create_task(self.shutdown()))
loop.run_until_complete(self.async_run())

self.log(message="Azure EventHub Trigger has stopped", level="info")
4 changes: 2 additions & 2 deletions Azure/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
"name": "Microsoft Azure",
"uuid": "525eecc0-9eee-484d-92bd-039117cf4dac",
"slug": "azure",
"version": "2.5.5",
"version": "2.5.6",
"categories": [
"Cloud Providers"
]
}
}
116 changes: 113 additions & 3 deletions Azure/tests/connector/test_azure_eventhub.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import asyncio
import os
import time
from multiprocessing import Process
from shutil import rmtree
from tempfile import mkdtemp
from threading import Thread
from unittest.mock import AsyncMock, MagicMock, Mock

import pytest
from azure.eventhub import EventData
from sekoia_automation import constants
from sekoia_automation.module import Module

from connectors.azure_eventhub import AzureEventsHubConfiguration, AzureEventsHubTrigger, Client

Expand All @@ -28,11 +34,18 @@ def test_forward_next_batches_integration(symphony_storage):
trigger.push_events_to_intakes = Mock()
trigger.log_exception = Mock()
trigger.log = Mock()
thread = Thread(target=trigger.run)

loop = asyncio.new_event_loop()

def run_trigger(trigger, loop):
asyncio.set_event_loop(loop)

trigger.run()

thread = Thread(target=run_trigger, args=(trigger, loop))
thread.start()
time.sleep(30)
trigger.client.close()
trigger.stop()
trigger.shutdown(9, loop)
calls = [call.kwargs["events"] for call in trigger.push_events_to_intakes.call_args_list]

assert len(calls) > 0
Expand Down Expand Up @@ -127,3 +140,100 @@ async def test_client_close():

fake_client.close.assert_awaited_once()
assert client._client is None


@pytest.fixture
def data_storage():
original_storage = constants.DATA_STORAGE
constants.DATA_STORAGE = mkdtemp()

yield constants.DATA_STORAGE

rmtree(constants.DATA_STORAGE)
constants.DATA_STORAGE = original_storage


class AzureEventsHubTestTriggerQuick(AzureEventsHubTrigger):
# Override wait timeout to speed up the test execution and mock receive_events method
# In this case, we will sleep for 2 second and execute for 1 second
# So no task cancellation will happen
wait_timeout = 2
execution_time = 1

async def receive_events(self) -> None:
await asyncio.sleep(self.execution_time)


class AzureEventsHubTestTriggerSlow(AzureEventsHubTestTriggerQuick):
# Override wait timeout to speed up the test execution and mock receive_events method
# In this case, we will sleep for 1 second and execute for 2 second
# So task cancellation will happen each iteration
wait_timeout = 1
execution_time = 2


def create_and_run_connector(data_storage, is_quick: bool = True) -> None:
"""
This function is used to test the AzureEventsHubTrigger run method and receiving signals to stop the connector.
We should run it in separate process!
Args:
data_storage:
is_quick:
"""
module = Module()

connector = AzureEventsHubTestTriggerQuick(module=module, data_path=data_storage)
if not is_quick:
connector = AzureEventsHubTestTriggerSlow(module=module, data_path=data_storage)

connector.configuration = AzureEventsHubConfiguration.parse_obj(
{
"chunk_size": 1,
"hub_connection_string": "hub_connection_string",
"hub_name": "hub_name",
"hub_consumer_group": "hub_consumer_group",
"storage_connection_string": "storage_connection_string",
"storage_container_name": "storage_container_name",
"intake_key": "",
}
)

connector.run()


def test_azure_eventhub_handling_stop_event_quick(data_storage):
start_execution_time = time.time()
process = Process(target=create_and_run_connector, args=(data_storage,))
process.start()

# So we can say that 1 iteration will take 2 seconds for quick connector
# 1 second for execution and 2 seconds for waiting the result
# Lets try to stop the connector after 5 seconds, so it should finish the execution not waiting for the next iteration
# So the total execution time should be less or equal to 6 seconds + 1 second for
# the possible timing difference calculation
time.sleep(5)
process.terminate()
process.join()
finish_execution_time = time.time()

assert finish_execution_time - start_execution_time <= 7


def test_azure_eventhub_handling_stop_event_slow(data_storage):
start_execution_time = time.time()
process = Process(target=create_and_run_connector, args=(data_storage, False))
process.start()

# So we can say that 1 iteration will take 2 seconds for quick connector
# 2 second for execution and 1 seconds for waiting the result
# Lets try to stop the connector after the same 4 seconds, so behaviour should be more-less the same.
# So the total execution time should be less or equal to 4 seconds + 1 second for
# the possible timing difference calculation
time.sleep(4)
process.terminate()
process.join()
finish_execution_time = time.time()

assert finish_execution_time - start_execution_time <= 5

0 comments on commit e7562ee

Please sign in to comment.