Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vladyslav-huriev committed Oct 4, 2024
1 parent 00eb91c commit 8d59af2
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 78 deletions.
62 changes: 19 additions & 43 deletions sekoia_automation/aio/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio
import time
from abc import ABC
from asyncio import AbstractEventLoop, get_event_loop
from collections.abc import AsyncGenerator, Sequence
from contextlib import asynccontextmanager
from datetime import datetime
Expand All @@ -27,16 +26,13 @@ class AsyncConnector(Connector, ABC):

configuration: DefaultConnectorConfiguration

_event_loop: AbstractEventLoop

_session: ClientSession | None = None
_rate_limiter: AsyncLimiter | None = None

def __init__(
self,
module: Module | None = None,
data_path: Path | None = None,
event_loop: AbstractEventLoop | None = None,
*args,
**kwargs,
):
Expand All @@ -53,60 +49,31 @@ def __init__(
self.max_concurrency_tasks = kwargs.pop("max_concurrency_tasks", 1000)
super().__init__(module=module, data_path=data_path, *args, **kwargs)

self._event_loop = event_loop or get_event_loop()

@classmethod
def set_client_session(cls, session: ClientSession) -> None:
"""
Set client session.
Args:
session: ClientSession
"""
cls._session = session

@classmethod
def set_rate_limiter(cls, rate_limiter: AsyncLimiter) -> None:
"""
Set rate limiter.
Args:
rate_limiter:
"""
cls._rate_limiter = rate_limiter

@classmethod
def get_rate_limiter(cls) -> AsyncLimiter:
def get_rate_limiter(self) -> AsyncLimiter:
"""
Get or initialize rate limiter.
Returns:
AsyncLimiter:
"""
if cls._rate_limiter is None:
cls._rate_limiter = AsyncLimiter(1, 1)
if self._rate_limiter is None:
self._rate_limiter = AsyncLimiter(1, 1)

return cls._rate_limiter
return self._rate_limiter

@classmethod
@asynccontextmanager
async def session(cls) -> AsyncGenerator[ClientSession, None]: # pragma: no cover
async def session(self) -> AsyncGenerator[ClientSession, None]: # pragma: no cover
"""
Get or initialize client session if it is not initialized yet.
Returns:
ClientSession:
"""
if cls._session is None:
cls._session = ClientSession()
if self._session is None:
self._session = ClientSession()

async with cls.get_rate_limiter():
yield cls._session

async def async_close(self) -> None:
"""Close session."""
if self._session:
await self._session.close()
async with self.get_rate_limiter():
yield self._session

async def _async_send_chunk(
self, session: ClientSession, url: str, chunk_index: int, chunk: list[str]
Expand Down Expand Up @@ -245,8 +212,17 @@ async def async_run(self) -> None: # pragma: no cover
if self.frequency:
await asyncio.sleep(self.frequency)

def stop(self, *args, **kwargs):
"""
Stop the connector
"""
super().stop(*args, **kwargs)
loop = asyncio.get_event_loop()

if self._session:
loop.run_until_complete(self._session.close())

def run(self) -> None: # pragma: no cover
"""Runs Connector."""
loop = asyncio.get_event_loop()
loop.run_until_complete(self.async_run())
loop.run_until_complete(self.async_close())
3 changes: 0 additions & 3 deletions sekoia_automation/http/aio/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from aiohttp import ClientResponse, ClientResponseError, ClientSession
from aiohttp.web_response import Response
from aiolimiter import AsyncLimiter
from loguru import logger

from sekoia_automation.http.http_client import AbstractHttpClient, Method
from sekoia_automation.http.rate_limiter import RateLimiterConfig
Expand Down Expand Up @@ -192,8 +191,6 @@ async def request_retry(

for attempt in range(attempts):
try:
logger.debug("Attempt {0} to do {1} on {2}", attempt, method.value, url)

async with self.session() as session:
async with session.request(
method.value, url, *args, **kwargs
Expand Down
6 changes: 5 additions & 1 deletion sekoia_automation/http/aio/token_refresher.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ async def _schedule_token_refresh(self, expires_in: int) -> None:
Args:
expires_in: int
"""
await self.close()
if self._token_refresh_task:
self._token_refresh_task.cancel()

Check warning on line 122 in sekoia_automation/http/aio/token_refresher.py

View check run for this annotation

Codecov / codecov/patch

sekoia_automation/http/aio/token_refresher.py#L122

Added line #L122 was not covered by tests

async def _refresh() -> None:
await asyncio.sleep(expires_in)
Expand All @@ -133,6 +134,9 @@ async def close(self) -> None:
if self._token_refresh_task:
self._token_refresh_task.cancel()

if self._session:
await self._session.close()

@asynccontextmanager
async def with_access_token(self) -> AsyncGenerator[RefreshedTokenT, None]:
"""
Expand Down
30 changes: 1 addition & 29 deletions tests/aio/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,31 +49,6 @@ def async_connector(storage, mocked_trigger_logs, faker: Faker):
async_connector.stop()


@pytest.mark.asyncio
async def test_async_connector_rate_limiter(async_connector: DummyAsyncConnector):
"""
Test async connector rate limiter.
Args:
async_connector: DummyAsyncConnector
"""
other_instance = DummyAsyncConnector()
rate_limiter_mock = AsyncLimiter(max_rate=100)

assert async_connector._rate_limiter is None
assert other_instance._rate_limiter is None

assert async_connector.get_rate_limiter() == other_instance.get_rate_limiter()

async_connector.set_rate_limiter(rate_limiter_mock)

assert async_connector.get_rate_limiter() == other_instance.get_rate_limiter()
assert async_connector._rate_limiter == rate_limiter_mock

DummyAsyncConnector.set_rate_limiter(None)
DummyAsyncConnector.set_client_session(None)


@pytest.mark.asyncio
async def test_async_connector_client_session(async_connector: DummyAsyncConnector):
"""
Expand All @@ -89,7 +64,7 @@ async def test_async_connector_client_session(async_connector: DummyAsyncConnect

async with async_connector.session() as session_1:
async with other_instance.session() as session_2:
assert session_1 == session_2
assert session_1 != session_2

assert async_connector._rate_limiter is not None and isinstance(
async_connector._rate_limiter, AsyncLimiter
Expand All @@ -99,9 +74,6 @@ async def test_async_connector_client_session(async_connector: DummyAsyncConnect
other_instance._rate_limiter, AsyncLimiter
)

DummyAsyncConnector.set_rate_limiter(None)
other_instance.set_client_session(None)


@pytest.mark.asyncio
async def test_async_connector_push_single_event(
Expand Down
1 change: 0 additions & 1 deletion tests/connectors/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def configure_intake_url(config_storage):


class DummyConnector(Connector):

events: list[list[str]] | None = None

def set_events(self, events: list[list[str]]) -> None:
Expand Down
2 changes: 2 additions & 0 deletions tests/http/aio/examples/test_bearer_token_auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,5 @@ async def test_get_events_example_method(session_faker: Faker):
mocked_responses.get(request_url, status=402)
mocked_responses.get(request_url, status=200, payload=data)
assert await client.get_events_retry_example({"key": "value"}) == data

await client.close()
2 changes: 2 additions & 0 deletions tests/http/aio/examples/test_oauth_token_auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,5 @@ async def test_get_events_example_method(session_faker: Faker):
)

assert await client.get_events() == data

await client.http_client.close()
22 changes: 21 additions & 1 deletion tests/http/aio/test_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ async def test_simple_workflow_async_http_client(base_url: str, session_faker: F
assert response.status == 200
assert await response.json() == json.loads(data)

await client.close()


@pytest.mark.asyncio
async def test_rate_limited_workflow_async_http_client(
Expand Down Expand Up @@ -74,6 +76,8 @@ async def test_rate_limited_workflow_async_http_client(
assert response.status == 200
assert await response.json() == json.loads(data)

await client.close()

time_end = time.time()

assert time_end - time_start >= 2
Expand Down Expand Up @@ -112,6 +116,8 @@ async def test_rate_limited_workflow_async_http_client_1(
assert response.status == 200
assert await response.json() == json.loads(data)

await client.close()

time_end = time.time()

assert time_end - time_start >= 2
Expand Down Expand Up @@ -156,6 +162,8 @@ async def test_retry_workflow_get_async_http_client(
# As a result, the last response should be 412
assert response.status == status_3

await client.close()


@pytest.mark.asyncio
async def test_retry_workflow_post_async_http_client(
Expand Down Expand Up @@ -201,6 +209,8 @@ async def test_retry_workflow_post_async_http_client(
async with client.post(base_url, json=data) as response:
assert response.status == status_3

await client.close()


@pytest.mark.asyncio
async def test_retry_workflow_put_async_http_client(
Expand Down Expand Up @@ -246,6 +256,8 @@ async def test_retry_workflow_put_async_http_client(
async with client.put(base_url, json=data) as response:
assert response.status == status_3

await client.close()


@pytest.mark.asyncio
async def test_retry_workflow_head_async_http_client(
Expand Down Expand Up @@ -284,6 +296,8 @@ async def test_retry_workflow_head_async_http_client(
async with client.head(base_url) as response:
assert response.status == status_3

await client.close()


@pytest.mark.asyncio
async def test_retry_workflow_delete_async_http_client(
Expand Down Expand Up @@ -322,6 +336,8 @@ async def test_retry_workflow_delete_async_http_client(
async with client.delete(base_url) as response:
assert response.status == status_3

await client.close()


@pytest.mark.asyncio
async def test_retry_workflow_patch_async_http_client(
Expand Down Expand Up @@ -360,6 +376,8 @@ async def test_retry_workflow_patch_async_http_client(
async with client.patch(base_url) as response:
assert response.status == status_3

await client.close()


@pytest.mark.asyncio
async def test_complete_configurable_async_http_client(
Expand Down Expand Up @@ -415,4 +433,6 @@ async def test_complete_configurable_async_http_client(
assert response.status == status_1
end_time = time.time()

assert end_time - start_time >= 3
await client.close()

assert end_time - start_time >= 3

0 comments on commit 8d59af2

Please sign in to comment.