diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c45d6a..114e006 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.15.1] - 2024-10-03 + +### Fixed + +- Fix tests for async version of connector. + ## [1.15.0] - 2024-09-28 ### Changed diff --git a/pyproject.toml b/pyproject.toml index 454dee1..520dc79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "sekoia-automation-sdk" -version = "1.15.0" +version = "1.15.1" description = "SDK to create Sekoia.io playbook modules" license = "MIT" readme = "README.md" diff --git a/sekoia_automation/aio/connector.py b/sekoia_automation/aio/connector.py index d54052b..2129dd8 100644 --- a/sekoia_automation/aio/connector.py +++ b/sekoia_automation/aio/connector.py @@ -1,7 +1,7 @@ """Contains connector with async version.""" +import asyncio from abc import ABC -from asyncio import AbstractEventLoop, get_event_loop from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from datetime import datetime @@ -21,8 +21,6 @@ class AsyncConnector(Connector, ABC): configuration: DefaultConnectorConfiguration - _event_loop: AbstractEventLoop - _session: ClientSession | None = None _rate_limiter: AsyncLimiter | None = None @@ -30,7 +28,6 @@ def __init__( self, module: Module | None = None, data_path: Path | None = None, - event_loop: AbstractEventLoop | None = None, *args, **kwargs, ): @@ -47,55 +44,49 @@ def __init__( self.max_concurrency_tasks = kwargs.pop("max_concurrency_tasks", 1000) super().__init__(module=module, data_path=data_path, *args, **kwargs) - self._event_loop = event_loop or get_event_loop() - - @classmethod - def set_client_session(cls, session: ClientSession) -> None: + def set_client_session(self, session: ClientSession) -> None: """ Set client session. Args: session: ClientSession """ - cls._session = session + self._session = session - @classmethod - def set_rate_limiter(cls, rate_limiter: AsyncLimiter) -> None: + def set_rate_limiter(self, rate_limiter: AsyncLimiter) -> None: """ Set rate limiter. Args: rate_limiter: """ - cls._rate_limiter = rate_limiter + self._rate_limiter = rate_limiter - @classmethod - def get_rate_limiter(cls) -> AsyncLimiter: + def get_rate_limiter(self) -> AsyncLimiter: """ Get or initialize rate limiter. Returns: AsyncLimiter: """ - if cls._rate_limiter is None: - cls._rate_limiter = AsyncLimiter(1, 1) + if self._rate_limiter is None: + self._rate_limiter = AsyncLimiter(1, 1) - return cls._rate_limiter + return self._rate_limiter - @classmethod @asynccontextmanager - async def session(cls) -> AsyncGenerator[ClientSession, None]: # pragma: no cover + async def session(self) -> AsyncGenerator[ClientSession, None]: # pragma: no cover """ Get or initialize client session if it is not initialized yet. Returns: ClientSession: """ - if cls._session is None: - cls._session = ClientSession() + if self._session is None: + self._session = ClientSession() - async with cls.get_rate_limiter(): - yield cls._session + async with self.get_rate_limiter(): + yield self._session async def _async_send_chunk( self, session: ClientSession, url: str, chunk_index: int, chunk: list[str] @@ -169,3 +160,14 @@ async def push_data_to_intakes( result_ids.extend(ids) return result_ids + + def stop(self, *args, **kwargs): + """ + Stop the connector + """ + loop = asyncio.get_event_loop() + + if self._session: + loop.run_until_complete(self._session.close()) + + super().stop(*args, **kwargs) diff --git a/tests/aio/test_connector.py b/tests/aio/test_connector.py index 14c9e2d..dcbe67e 100644 --- a/tests/aio/test_connector.py +++ b/tests/aio/test_connector.py @@ -52,16 +52,13 @@ async def test_async_connector_rate_limiter(async_connector: DummyAsyncConnector assert async_connector._rate_limiter is None assert other_instance._rate_limiter is None - assert async_connector.get_rate_limiter() == other_instance.get_rate_limiter() + assert async_connector.get_rate_limiter() != other_instance.get_rate_limiter() async_connector.set_rate_limiter(rate_limiter_mock) - assert async_connector.get_rate_limiter() == other_instance.get_rate_limiter() + assert async_connector.get_rate_limiter() != other_instance.get_rate_limiter() assert async_connector._rate_limiter == rate_limiter_mock - DummyAsyncConnector.set_rate_limiter(None) - DummyAsyncConnector.set_client_session(None) - @pytest.mark.asyncio async def test_async_connector_client_session(async_connector: DummyAsyncConnector): @@ -78,7 +75,7 @@ async def test_async_connector_client_session(async_connector: DummyAsyncConnect async with async_connector.session() as session_1: async with other_instance.session() as session_2: - assert session_1 == session_2 + assert session_1 != session_2 assert async_connector._rate_limiter is not None and isinstance( async_connector._rate_limiter, AsyncLimiter @@ -88,9 +85,6 @@ async def test_async_connector_client_session(async_connector: DummyAsyncConnect other_instance._rate_limiter, AsyncLimiter ) - DummyAsyncConnector.set_rate_limiter(None) - other_instance.set_client_session(None) - @pytest.mark.asyncio async def test_async_connector_push_single_event(