diff --git a/CHANGELOG.md b/CHANGELOG.md index 430a71c..3a1c520 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.11.2] - 2023-02-13 + +### Fixed + +- Fixes rate limiter in async connector. Make it more configurable + ## [1.11.1] - 2023-01-29 ### Fixed diff --git a/pyproject.toml b/pyproject.toml index 43037ce..bc89d7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "sekoia-automation-sdk" -version = "1.11.1" +version = "1.11.2" description = "SDK to create Sekoia.io playbook modules" license = "MIT" readme = "README.md" diff --git a/sekoia_automation/aio/connector.py b/sekoia_automation/aio/connector.py index a8a3152..dc328d6 100644 --- a/sekoia_automation/aio/connector.py +++ b/sekoia_automation/aio/connector.py @@ -1,5 +1,6 @@ """Contains connector with async version.""" +import os from abc import ABC from asyncio import AbstractEventLoop, get_event_loop from collections.abc import AsyncGenerator @@ -68,7 +69,7 @@ def set_rate_limiter(cls, rate_limiter: AsyncLimiter) -> None: cls._rate_limiter = rate_limiter @classmethod - def get_rate_limiter(cls) -> AsyncLimiter: + def get_rate_limiter(cls) -> AsyncLimiter | None: """ Get or initialize rate limiter. @@ -76,7 +77,13 @@ def get_rate_limiter(cls) -> AsyncLimiter: AsyncLimiter: """ if cls._rate_limiter is None: - cls._rate_limiter = AsyncLimiter(1, 1) + requests_limit = os.getenv("REQUESTS_PER_SECOND_TO_INTAKE") + + if requests_limit is not None and int(requests_limit) > 0: + cls._rate_limiter = AsyncLimiter(int(requests_limit), 1) + + if requests_limit is None: + cls._rate_limiter = AsyncLimiter(1, 1) return cls._rate_limiter @@ -92,7 +99,12 @@ async def session(cls) -> AsyncGenerator[ClientSession, None]: # pragma: no cov if cls._session is None: cls._session = ClientSession() - async with cls.get_rate_limiter(): + rate_limiter = cls.get_rate_limiter() + + if rate_limiter: + async with rate_limiter: + yield cls._session + else: yield cls._session async def push_data_to_intakes( diff --git a/tests/aio/test_connector.py b/tests/aio/test_connector.py index 14c9e2d..b9b1666 100644 --- a/tests/aio/test_connector.py +++ b/tests/aio/test_connector.py @@ -1,6 +1,7 @@ """Test async connector.""" -from unittest.mock import Mock, patch +import os +from unittest.mock import AsyncMock, Mock, patch from urllib.parse import urljoin import pytest @@ -221,3 +222,57 @@ async def test_async_connector_raise_error( except Exception as e: assert isinstance(e, RuntimeError) assert str(e) == expected_error + + +@pytest.mark.asyncio +async def test_session(): + async with AsyncConnector.session() as session: + assert session is not None + + +@pytest.mark.asyncio +async def test_session_reuses_existing_session(): + session_mock = Mock() + AsyncConnector._session = session_mock + + async with AsyncConnector.session() as session: + assert session == session_mock + + +@pytest.mark.asyncio +async def test_session_with_rate_limiter(): + mock_rate_limiter = AsyncMock() + AsyncConnector._rate_limiter = mock_rate_limiter + + async with AsyncConnector.session() as session: + assert session is not None + mock_rate_limiter.__aenter__.assert_called_once() + + +@pytest.mark.asyncio +async def test_session_with_rate_limiter_none(): + AsyncConnector._rate_limiter = None + + async with AsyncConnector.session() as session: + assert session is not None + assert AsyncConnector._rate_limiter.max_rate == 1 + + +@pytest.mark.asyncio +async def test_session_with_rate_limiter_from_env_variable(): + os.environ["REQUESTS_PER_SECOND_TO_INTAKE"] = str(100) + AsyncConnector._rate_limiter = None + + async with AsyncConnector.session() as session: + assert session is not None + assert AsyncConnector._rate_limiter.max_rate == 100 + + +@pytest.mark.asyncio +async def test_session_with_rate_limiter_from_env_variable_with_zero(): + os.environ["REQUESTS_PER_SECOND_TO_INTAKE"] = str(0) + AsyncConnector._rate_limiter = None + + async with AsyncConnector.session() as session: + assert session is not None + assert AsyncConnector._rate_limiter is None