From f19d0dc6d9cff366b97b3d8504d32983c4bd4ad3 Mon Sep 17 00:00:00 2001 From: "vladyslav.guriev" Date: Thu, 3 Oct 2024 20:25:27 +0300 Subject: [PATCH 1/7] Feature: Connectors workflow --- CHANGELOG.md | 9 + pyproject.toml | 2 +- sekoia_automation/aio/connector.py | 83 +++++- .../aio/helpers/http/http_client.py | 91 ------- .../aio/helpers/http/token_refresher.py | 150 ----------- sekoia_automation/connector/__init__.py | 163 ++++++++++-- sekoia_automation/http/aio/http_client.py | 26 +- sekoia_automation/http/aio/token_refresher.py | 4 +- sekoia_automation/http/http_client.py | 14 + .../helpers/http/test_http_client_session.py | 200 --------------- .../helpers/http/test_http_token_refresher.py | 241 ------------------ tests/aio/test_connector.py | 61 ++++- tests/connectors/test_connector.py | 38 ++- .../examples/test_oauth_token_auth_client.py | 134 ++++++++++ tests/http/aio/test_http_client.py | 13 +- 15 files changed, 500 insertions(+), 729 deletions(-) delete mode 100644 sekoia_automation/aio/helpers/http/http_client.py delete mode 100644 sekoia_automation/aio/helpers/http/token_refresher.py delete mode 100644 tests/aio/helpers/http/test_http_client_session.py delete mode 100644 tests/aio/helpers/http/test_http_token_refresher.py create mode 100644 tests/http/aio/examples/test_oauth_token_auth_client.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c45d6a..a7874b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.16.0] - 2024-10-02 + +### Changed + +- Improvements for AsyncConnector. +- Improvements for Async Http workflow +- Remove duplicated parts and make the code more uniform for async http workflow + + ## [1.15.0] - 2024-09-28 ### Changed diff --git a/pyproject.toml b/pyproject.toml index 454dee1..315df8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "sekoia-automation-sdk" -version = "1.15.0" +version = "1.16.0" description = "SDK to create Sekoia.io playbook modules" license = "MIT" readme = "README.md" diff --git a/sekoia_automation/aio/connector.py b/sekoia_automation/aio/connector.py index d54052b..e149a8a 100644 --- a/sekoia_automation/aio/connector.py +++ b/sekoia_automation/aio/connector.py @@ -1,8 +1,10 @@ """Contains connector with async version.""" +import asyncio +import time from abc import ABC from asyncio import AbstractEventLoop, get_event_loop -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Sequence from contextlib import asynccontextmanager from datetime import datetime from pathlib import Path @@ -12,7 +14,11 @@ from aiolimiter import AsyncLimiter from sekoia_automation.aio.helpers import limit_concurrency -from sekoia_automation.connector import Connector, DefaultConnectorConfiguration +from sekoia_automation.connector import ( + Connector, + DefaultConnectorConfiguration, + EventType, +) from sekoia_automation.module import Module @@ -139,7 +145,7 @@ async def _async_send_chunk( return events_ids async def push_data_to_intakes( - self, events: list[str] + self, events: Sequence[EventType] ) -> list[str]: # pragma: no cover """ Custom method to push events to intakes. @@ -150,7 +156,6 @@ async def push_data_to_intakes( Returns: list[str]: """ - self._last_events_time = datetime.utcnow() if intake_server := self.configuration.intake_server: batch_api = urljoin(intake_server, "batch") else: @@ -169,3 +174,73 @@ async def push_data_to_intakes( result_ids.extend(ids) return result_ids + + async def async_iterate( + self, + ) -> AsyncGenerator[tuple[list[EventType], datetime | None], None]: + """Iterate over events.""" + yield [], None # To avoid type checking error + + async def async_next_run(self) -> None: + processing_start = time.time() + + result_last_event_date: datetime | None = None + total_number_of_events = 0 + async for data in self.async_iterate(): + events, last_event_date = data + if last_event_date: + if ( + not result_last_event_date + or last_event_date > result_last_event_date + ): + result_last_event_date = last_event_date + + if events: + total_number_of_events += len(events) + await self.push_data_to_intakes(events) + + processing_end = time.time() + processing_time = processing_end - processing_start + + # Metric about processing time + self._forward_events_duration.labels( + intake_key=self.configuration.intake_key + ).observe(processing_time) + + # Metric about processing count + self._outcoming_events.labels(intake_key=self.configuration.intake_key).inc( + total_number_of_events + ) + + # Metric about events lag + if result_last_event_date: + lag = (datetime.utcnow() - result_last_event_date).total_seconds() + self._events_lag.labels(intake_key=self.configuration.intake_key).set(lag) + + # Compute the remaining sleeping time. + # If greater than 0 and no messages where fetched, pause the connector + delta_sleep = (self.frequency or 0) - processing_time + if total_number_of_events == 0 and delta_sleep > 0: + self.log(message=f"Next batch in the future. Waiting {delta_sleep} seconds") + + await asyncio.sleep(delta_sleep) + + # Put infinite arg only to have testing easier + async def async_run(self) -> None: # pragma: no cover + """Runs Connector.""" + while self.running: + try: + await self.async_next_run() + except Exception as e: + self.log_exception( + e, + message=f"Error while running connector {self.connector_name}", + ) + + if self.frequency: + await asyncio.sleep(self.frequency) + + def run(self) -> None: # pragma: no cover + """Runs Connector.""" + loop = asyncio.get_event_loop() + loop.run_until_complete(self.async_run()) diff --git a/sekoia_automation/aio/helpers/http/http_client.py b/sekoia_automation/aio/helpers/http/http_client.py deleted file mode 100644 index 0458403..0000000 --- a/sekoia_automation/aio/helpers/http/http_client.py +++ /dev/null @@ -1,91 +0,0 @@ -"""HttpClient with ratelimiter.""" - -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager - -from aiohttp import ClientSession -from aiolimiter import AsyncLimiter - - -class HttpClient: - """ - Http client with optional rate limiting. - - Example: - >>> from sekoia_automation.http.aio.http_client import HttpClient - >>> class CustomHttpClient(HttpClient): - >>> def __init__(self): - >>> super().__init__() - >>> - >>> async def load_data(self, url: str) -> str: - >>> async with self.session() as session: - >>> async with session.get(url) as response: - >>> return await response.text() - >>> - >>> client = CustomHttpClient() - >>> # If rate limiter is set, it will be used - >>> client.set_rate_limit(max_rate=10, time_period=60) - >>> # or - >>> client.set_rate_limiter(AsyncLimiter(max_rate=10, time_period=60)) - >>> - >>> result = await client.load_data("https://example.com") - """ - - _session: ClientSession | None = None - _rate_limiter: AsyncLimiter | None = None - - def __init__( - self, - max_rate: float | None = None, - time_period: float | None = None, - rate_limiter: AsyncLimiter | None = None, - ): - """ - Initialize HttpClient. - - Args: - max_rate: float | None - time_period: float | None - rate_limiter: AsyncLimiter | None - """ - if max_rate and time_period: - self.set_rate_limit(max_rate, time_period) # pragma: no cover - - if rate_limiter: - self.set_rate_limiter(rate_limiter) # pragma: no cover - - def set_rate_limit(self, max_rate: float, time_period: float = 60) -> None: - """ - Set rate limiter. - - Args: - max_rate: float - time_period: float - """ - self._rate_limiter = AsyncLimiter(max_rate=max_rate, time_period=time_period) - - def set_rate_limiter(self, rate_limiter: AsyncLimiter) -> None: # pragma: no cover - """ - Set rate limiter. - - Args: - rate_limiter: - """ - self._rate_limiter = rate_limiter - - @asynccontextmanager - async def session(self) -> AsyncGenerator[ClientSession, None]: - """ - Get configured session with rate limiter. - - Yields: - AsyncGenerator[ClientSession, None]: - """ - if self._session is None: - self._session = ClientSession() - - if self._rate_limiter: - async with self._rate_limiter: - yield self._session - else: - yield self._session diff --git a/sekoia_automation/aio/helpers/http/token_refresher.py b/sekoia_automation/aio/helpers/http/token_refresher.py deleted file mode 100644 index ef3cb30..0000000 --- a/sekoia_automation/aio/helpers/http/token_refresher.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Auth token refresher wrapper with token schema.""" - -import asyncio -import time -from asyncio import Task -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager -from typing import Generic, TypeVar - -from aiohttp import ClientSession -from pydantic import BaseModel -from pydantic.generics import GenericModel - -HttpTokenT = TypeVar("HttpTokenT", bound=BaseModel) - - -class RefreshedToken(GenericModel, Generic[HttpTokenT]): - """Model to work with auth token with additional info.""" - - token: HttpTokenT - created_at: int - ttl: int - - def is_valid(self) -> bool: - """ - Check if token is not expired yet and valid. - - Returns: - bool: - """ - return not self.is_expired() - - def is_expired(self) -> bool: - """ - Check if token is expired. - - Returns: - bool: - """ - return self.created_at + self.ttl < (time.time() - 1) - - -RefreshedTokenT = TypeVar("RefreshedTokenT", bound=RefreshedToken) - - -class GenericTokenRefresher(Generic[RefreshedTokenT]): - """ - Contains access token refresher logic. - - Example of usage: - >>> # Define schema for token response from server - >>> class HttpToken(BaseModel): - >>> access_token: str - >>> signature: str - >>> - >>> # Define TokenRefresher class with necessary logic - >>> class CustomTokenRefresher(GenericTokenRefresher): - >>> def __init__(self, client_id: str, client_secret: str, auth_url: str): - >>> super().__init__() - >>> ... - >>> def get_token(self) -> RefreshedToken[HttpToken]: - >>> ... - >>> - >>> token_refresher = GenericTokenRefresher[RefreshedToken[HttpToken]]() - >>> - >>> async with token_refresher.with_access_token() as access_token: - >>> print(access_token) - """ - - _session: ClientSession | None = None - - def __init__(self): - """Initialize GenericTokenRefresher.""" - - self._token: RefreshedTokenT | None = None - self._token_refresh_task: Task[None] | None = None - - def session(self) -> ClientSession: - """ - Initialize client session. - - Singleton client session to work with token refresh logic. - - Returns: - ClientSession: - """ - if not self._session: - self._session = ClientSession() - - return self._session - - async def get_token(self) -> RefreshedTokenT: - """ - Get new token logic. - - Main method to get new token. - - Returns: - RefreshedTokenT: instance of RefreshedToken - """ - raise NotImplementedError( - "You should implement `get_token` method in child class" - ) - - async def _refresh_token(self) -> None: - """ - Refresh token logic. - - Also triggers token refresh task. - """ - self._token = await self.get_token() - await self._schedule_token_refresh(self._token.ttl) - - async def _schedule_token_refresh(self, expires_in: int) -> None: - """ - Schedule token refresh. - - Args: - expires_in: int - """ - await self.close() - - async def _refresh() -> None: - await asyncio.sleep(expires_in) - await self._refresh_token() - - self._token_refresh_task = asyncio.create_task(_refresh()) - - async def close(self) -> None: - """ - Cancel token refresh task. - """ - if self._token_refresh_task: - self._token_refresh_task.cancel() - - @asynccontextmanager - async def with_access_token(self) -> AsyncGenerator[RefreshedTokenT, None]: - """ - Get access token. - - Yields: - RefreshedTokenT: - """ - if self._token is None: - await self._refresh_token() - - if not self._token: - raise ValueError("Token can not be initialized") - - yield self._token diff --git a/sekoia_automation/connector/__init__.py b/sekoia_automation/connector/__init__.py index 392530a..3a3330e 100644 --- a/sekoia_automation/connector/__init__.py +++ b/sekoia_automation/connector/__init__.py @@ -1,9 +1,11 @@ +import time import uuid from abc import ABC from collections.abc import Generator, Sequence from concurrent.futures import ThreadPoolExecutor from concurrent.futures import wait as wait_futures -from datetime import datetime, time +from datetime import datetime +from datetime import time as datetime_time from functools import cached_property from os.path import join as urljoin from typing import Any @@ -11,28 +13,22 @@ import orjson import requests import sentry_sdk +from prometheus_client import Counter, Gauge, Histogram from pydantic import BaseModel from requests import Response -from tenacity import ( - Retrying, - stop_after_delay, - wait_exponential, -) +from tenacity import Retrying, stop_after_delay, wait_exponential from sekoia_automation.constants import CHUNK_BYTES_MAX_SIZE, EVENT_BYTES_MAX_SIZE -from sekoia_automation.exceptions import ( - TriggerConfigurationError, -) +from sekoia_automation.exceptions import TriggerConfigurationError from sekoia_automation.trigger import Trigger -from sekoia_automation.utils import ( - get_annotation_for, - get_as_model, -) +from sekoia_automation.utils import get_annotation_for, get_as_model # Connector are a kind of trigger that fetch events from remote sources. # We should add the content of push_events_to_intakes # so that we are able to send events directly from connectors +EventType = dict[str, Any] | str | BaseModel + class DefaultConnectorConfiguration(BaseModel): intake_server: str | None = None @@ -46,6 +42,46 @@ class Connector(Trigger, ABC): # Required for Pydantic to correctly type the configuration object configuration: DefaultConnectorConfiguration + _prometheus_namespace = "symphony_module_common" + + _outcoming_events = Counter( + name="forwarded_events", + documentation="Number of events forwarded to Sekoia.io", + namespace=_prometheus_namespace, + labelnames=["intake_key"], + ) + + _forward_events_duration = Histogram( + name="forward_events_duration", + documentation="Duration to collect and forward events from eventhub", + namespace=_prometheus_namespace, + labelnames=["intake_key"], + ) + + _discarded_events = Counter( + name="discarded_events", + documentation="Number of events discarded from the collect", + namespace=_prometheus_namespace, + labelnames=["intake_key"], + ) + + _events_lag = Gauge( + name="events_lags", + documentation="The delay, in seconds, from the date of the last event", + namespace=_prometheus_namespace, + labelnames=["intake_key"], + ) + + @property + def connector_name(self) -> str: + """ + Get connector name. + + Returns: + str: + """ + return self.__class__.__name__ + @property # type: ignore[override, no-redef] def configuration(self) -> DefaultConnectorConfiguration: if self._configuration is None: @@ -150,8 +186,18 @@ def _send_chunk( self.log(message=message, level="error") self.log_exception(ex, message=message) + @property + def frequency(self) -> int: + """ + Get frequency. + + Returns: + int: + """ + return 0 + def push_events_to_intakes( - self, events: list[str], sync: bool = False + self, events: list[EventType], sync: bool = False ) -> list[str]: """ Push events to intakes. @@ -247,7 +293,9 @@ def send_records( remove_directory=True, ) - def _chunk_events(self, events: Sequence) -> Generator[list[Any], None, None]: + def _chunk_events( + self, events: Sequence[EventType] + ) -> Generator[list[Any], None, None]: """ Group events by chunk. @@ -263,20 +311,29 @@ def _chunk_events(self, events: Sequence) -> Generator[list[Any], None, None]: # iter over the events for event in events: - if len(event) > EVENT_BYTES_MAX_SIZE: + result_event = str(event) + + if isinstance(event, BaseModel): + result_event = orjson.dumps(event.dict()).decode("utf-8") + elif isinstance(event, dict): + result_event = orjson.dumps(event).decode("utf-8") + + event_len = len(result_event) + + if event_len > EVENT_BYTES_MAX_SIZE: nb_discarded_events += 1 continue # if the chunk is full - if chunk_bytes + len(event) > CHUNK_BYTES_MAX_SIZE: + if chunk_bytes + event_len > CHUNK_BYTES_MAX_SIZE: # yield the current chunk and create a new one yield chunk chunk = [] chunk_bytes = 0 # add the event to the current chunk - chunk.append(event) - chunk_bytes += len(event) + chunk.append(result_event) + chunk_bytes += event_len # if the last chunk is not empty if len(chunk) > 0: @@ -285,12 +342,16 @@ def _chunk_events(self, events: Sequence) -> Generator[list[Any], None, None]: # if events were discarded, log it if nb_discarded_events > 0: + self._discarded_events.labels(intake_key=self.configuration.intake_key).inc( + nb_discarded_events + ) + self.log( message=f"{nb_discarded_events} too long events " "were discarded (length > 250kb)" ) - def forward_events(self, events) -> None: + def forward_events(self, events: Sequence[EventType]) -> None: try: chunks = self._chunk_events(events) _name = self.name or "" # mypy complains about NoneType in annotation @@ -298,7 +359,67 @@ def forward_events(self, events) -> None: self.log(message=f"Forwarding {len(records)} records", level="info") self.send_records( records=list(records), - event_name=f"{_name.lower().replace(' ', '-')}_{time()!s}", + event_name=f"{_name.lower().replace(' ', '-')}_{datetime_time()!s}", ) except Exception as ex: self.log_exception(ex, message="Failed to forward events") + + def iterate(self) -> Generator[tuple[list[EventType], datetime | None], None]: + """Iterate over events.""" + yield [], None + + def next_run(self) -> None: + processing_start = time.time() + + result_last_event_date: datetime | None = None + total_number_of_events = 0 + for data in self.iterate(): + events, last_event_date = data + if last_event_date: + if ( + not result_last_event_date + or last_event_date > result_last_event_date + ): + result_last_event_date = last_event_date + + if events: + total_number_of_events += len(events) + self.push_events_to_intakes(events) + + processing_end = time.time() + processing_time = processing_end - processing_start + + # Metric about processing time + self._forward_events_duration.labels( + intake_key=self.configuration.intake_key + ).observe(processing_time) + + # Metric about processing count + self._outcoming_events.labels(intake_key=self.configuration.intake_key).inc( + total_number_of_events + ) + + # Metric about events lag + if result_last_event_date: + lag = (datetime.utcnow() - result_last_event_date).total_seconds() + self._events_lag.labels(intake_key=self.configuration.intake_key).set(lag) + + # Compute the remaining sleeping time. + # If greater than 0 and no messages where fetched, pause the connector + delta_sleep = self.frequency - processing_time + if total_number_of_events == 0 and delta_sleep > 0: + self.log(message=f"Next batch in the future. Waiting {delta_sleep} seconds") + + time.sleep(delta_sleep) + + def run(self) -> None: # pragma: no cover + while self.running: + try: + self.next_run() + except Exception as e: + self.log_exception( + e, + message=f"Error while running connector {self.connector_name}", + ) + + time.sleep(self.frequency) diff --git a/sekoia_automation/http/aio/http_client.py b/sekoia_automation/http/aio/http_client.py index a472874..b3cf23c 100644 --- a/sekoia_automation/http/aio/http_client.py +++ b/sekoia_automation/http/aio/http_client.py @@ -8,8 +8,9 @@ from aiohttp import ClientResponse, ClientResponseError, ClientSession from aiohttp.web_response import Response from aiolimiter import AsyncLimiter +from loguru import logger -from sekoia_automation.http.http_client import AbstractHttpClient +from sekoia_automation.http.http_client import AbstractHttpClient, Method from sekoia_automation.http.rate_limiter import RateLimiterConfig from sekoia_automation.http.retry import RetryPolicy @@ -69,7 +70,7 @@ async def get( Returns: ClientResponse: """ - async with self.request_retry("GET", url, *args, **kwargs) as result: + async with self.request_retry(Method.GET, url, *args, **kwargs) as result: yield result @asynccontextmanager @@ -87,7 +88,7 @@ async def post( Returns: ClientResponse: """ - async with self.request_retry("POST", url, *args, **kwargs) as result: + async with self.request_retry(Method.POST, url, *args, **kwargs) as result: yield result @asynccontextmanager @@ -105,7 +106,7 @@ async def put( Returns: ClientResponse: """ - async with self.request_retry("PUT", url, *args, **kwargs) as response: + async with self.request_retry(Method.PUT, url, *args, **kwargs) as response: yield response @asynccontextmanager @@ -123,7 +124,7 @@ async def delete( Returns: ClientResponse: """ - async with self.request_retry("DELETE", url, *args, **kwargs) as response: + async with self.request_retry(Method.DELETE, url, *args, **kwargs) as response: yield response @asynccontextmanager @@ -141,7 +142,7 @@ async def patch( Returns: ClientResponse: """ - async with self.request_retry("PATCH", url, *args, **kwargs) as response: + async with self.request_retry(Method.PATCH, url, *args, **kwargs) as response: yield response @asynccontextmanager @@ -159,12 +160,17 @@ async def head( Returns: ClientResponse: """ - async with self.request_retry("HEAD", url, *args, **kwargs) as response: + async with self.request_retry(Method.HEAD, url, *args, **kwargs) as response: yield response + async def close(self) -> None: # pragma: no cover + """Close the session if it exists.""" + if self._session: + await self._session.close() + @asynccontextmanager async def request_retry( - self, method: str, url: str, *args: Any, **kwargs: Any | None + self, method: Method, url: str, *args: Any, **kwargs: Any | None ) -> AsyncGenerator[ClientResponse, None]: """ Request callable. @@ -186,9 +192,11 @@ async def request_retry( for attempt in range(attempts): try: + logger.debug("Attempt {0} to do {1} on {2}", attempt, method.value, url) + async with self.session() as session: async with session.request( - method, url, *args, **kwargs + method.value, url, *args, **kwargs ) as response: if ( self._retry_policy is not None diff --git a/sekoia_automation/http/aio/token_refresher.py b/sekoia_automation/http/aio/token_refresher.py index ef3cb30..42efa7e 100644 --- a/sekoia_automation/http/aio/token_refresher.py +++ b/sekoia_automation/http/aio/token_refresher.py @@ -37,7 +37,7 @@ def is_expired(self) -> bool: Returns: bool: """ - return self.created_at + self.ttl < (time.time() - 1) + return (self.created_at + self.ttl) < (int(time.time()) - 1) RefreshedTokenT = TypeVar("RefreshedTokenT", bound=RefreshedToken) @@ -141,7 +141,7 @@ async def with_access_token(self) -> AsyncGenerator[RefreshedTokenT, None]: Yields: RefreshedTokenT: """ - if self._token is None: + if self._token is None or not self._token.is_valid(): await self._refresh_token() if not self._token: diff --git a/sekoia_automation/http/http_client.py b/sekoia_automation/http/http_client.py index 99dd8dc..c26b54d 100644 --- a/sekoia_automation/http/http_client.py +++ b/sekoia_automation/http/http_client.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Generic, TypeVar from sekoia_automation.http.rate_limiter import RateLimiterConfig @@ -6,6 +7,19 @@ TResult = TypeVar("TResult") +class Method(Enum): + """ + Enum for http methods. + """ + + GET = "GET" + POST = "POST" + PUT = "PUT" + PATCH = "PATCH" + DELETE = "DELETE" + HEAD = "HEAD" + + class AbstractHttpClient(Generic[TResult]): """ Abstract class for http client. diff --git a/tests/aio/helpers/http/test_http_client_session.py b/tests/aio/helpers/http/test_http_client_session.py deleted file mode 100644 index 8436dfc..0000000 --- a/tests/aio/helpers/http/test_http_client_session.py +++ /dev/null @@ -1,200 +0,0 @@ -"""Tests for sekoia_automation.helpers.aio.http.http_client.""" - -import time - -import pytest -from aioresponses import aioresponses -from pydantic import BaseModel - -from sekoia_automation.aio.helpers.http.http_client import HttpClient -from sekoia_automation.aio.helpers.http.token_refresher import ( - GenericTokenRefresher, - RefreshedToken, -) - - -class TokenResponse(BaseModel): - """Test implementation of token response.""" - - access_token: str - - -class TokenRefresher(GenericTokenRefresher): - """Test implementation of GenericTokenRefresher.""" - - async def get_token(self) -> RefreshedToken[TokenResponse]: - """ - Test implementation of get_token. - - Returns: - RefreshedToken[TokenResponse]: - """ - async with self.session().post(url=self.auth_url, json={}) as response: - response_data = await response.json() - - return RefreshedToken( - token=TokenResponse(**response_data), - created_at=int(time.time()), - ttl=3600, - ) - - def __init__(self, client_id: str, client_secret: str, auth_url: str): - """Initialize TokenRefresher.""" - super().__init__() - self.client_id = client_id - self.client_secret = client_secret - self.auth_url = auth_url - - -class CustomHttpClient(HttpClient): - """Complete test implementation of HttpClient with TokenRefresher.""" - - def __init__( - self, - client_id: str, - client_secret: str, - auth_url: str, - base_url: str, - ) -> None: - """Initialize CustomHttpClient.""" - super().__init__() - self.base_url = base_url - self.token_refresher = TokenRefresher( - client_id=client_id, - client_secret=client_secret, - auth_url=auth_url, - ) - - async def get_test_data(self, url: str) -> dict[str, str]: - """ - Test method to get data from server with authentication. - - Args: - url: str - - Returns: - dict[str, str]: - """ - async with self.token_refresher.with_access_token() as access_token: - async with self.session() as session: - async with session.get( - url=url, - headers={ - "Authorization": f"Bearer {access_token.token.access_token}" - }, - ) as response: - return await response.json() - - -@pytest.fixture -def auth_url(session_faker): - """ - Fixture to initialize auth_url. - - Returns: - str: - """ - return session_faker.uri() - - -@pytest.fixture -def base_url(session_faker): - """ - Fixture to initialize base_url. - - Returns: - str: - """ - return session_faker.uri() - - -@pytest.fixture -def http_client(session_faker, auth_url, base_url): - """ - Fixture to initialize HttpClient. - - Returns: - CustomHttpClient: - """ - return CustomHttpClient( - client_id=session_faker.word(), - client_secret=session_faker.word(), - auth_url=auth_url, - base_url=base_url, - ) - - -@pytest.mark.asyncio -async def test_http_client_get_data(session_faker, http_client, base_url, auth_url): - """ - Test http_client get data. - - Args: - session_faker: Faker - http_client: CustomHttpClient - base_url: str - auth_url: str - """ - token_response = TokenResponse(access_token=session_faker.word()) - - get_test_data_response = { - session_faker.word(): session_faker.word(), - session_faker.word(): session_faker.word(), - } - - with aioresponses() as mocked_responses: - mocked_responses.post(auth_url, payload=token_response.dict()) - mocked_responses.get(f"{base_url}/test", payload=get_test_data_response) - - test_data = await http_client.get_test_data(url=f"{base_url}/test") - - assert test_data == get_test_data_response - - await http_client.token_refresher.close() - await http_client.token_refresher._session.close() - await http_client._session.close() - - -@pytest.mark.asyncio -async def test_http_client_get_data_async_limiter( - session_faker, - http_client, - base_url, - auth_url, -): - """ - Test http_client get data with async_limiter. - - Args: - session_faker: Faker - http_client: CustomHttpClient - base_url: str - auth_url: str - """ - token_response = TokenResponse(access_token=session_faker.word()) - - # 1 request per 3 seconds - http_client.set_rate_limit(1, 3) - - get_test_data_response = { - session_faker.word(): session_faker.word(), - session_faker.word(): session_faker.word(), - } - - with aioresponses() as mocked_responses: - start_query_time = time.time() - mocked_responses.post(auth_url, payload=token_response.dict()) - mocked_responses.get(f"{base_url}/test", payload=get_test_data_response) - await http_client.get_test_data(url=f"{base_url}/test") - - mocked_responses.post(auth_url, payload=token_response.dict()) - mocked_responses.get(f"{base_url}/test", payload=get_test_data_response) - await http_client.get_test_data(url=f"{base_url}/test") - - end_query_time = time.time() - - assert int(end_query_time - start_query_time) == 3 - - await http_client.token_refresher.close() - await http_client.token_refresher._session.close() - await http_client._session.close() diff --git a/tests/aio/helpers/http/test_http_token_refresher.py b/tests/aio/helpers/http/test_http_token_refresher.py deleted file mode 100644 index fbb3fd2..0000000 --- a/tests/aio/helpers/http/test_http_token_refresher.py +++ /dev/null @@ -1,241 +0,0 @@ -"""Tests for `sekoia_automation.aio.helpers.http.token_refresher`.""" - -import asyncio -import time -from asyncio import Lock -from typing import ClassVar - -import pytest -from aioresponses import aioresponses -from pydantic import BaseModel - -from sekoia_automation.aio.helpers.http.token_refresher import ( - GenericTokenRefresher, - RefreshedToken, -) - - -class CustomTokenResponse(BaseModel): - """Test implementation of token response.""" - - access_token: str - signature: str - random_field: str - - -class CustomTokenRefresher(GenericTokenRefresher): - """ - Test implementation of GenericTokenRefresher. - - Contains some additional feature like default ttl and singleton impl. - """ - - _instances: ClassVar[dict[str, "CustomTokenRefresher"]] = {} - _locks: ClassVar[dict[str, Lock]] = {} - - def __init__( - self, - client_id: str, - client_secret: str, - auth_url: str, - ttl: int = 60, - ) -> None: - """ - Initialize CustomTokenRefresher. - - Args: - client_id: str - client_secret: str - auth_url: str - ttl: int - """ - super().__init__() - self.client_id = client_id - self.client_secret = client_secret - self.auth_url = auth_url - self.ttl = ttl - - @classmethod - async def get_instance( - cls, - client_id: str, - client_secret: str, - auth_url: str, - ) -> "CustomTokenRefresher": - """ - Get instance of CustomTokenRefresher. - - Totally safe to use in async environment. Use lock to prevent multiple - instances creation. Get instance from cls._instances if it already exists - based on client_id, client_secret and auth_url. - - Args: - client_id: str - client_secret: str - auth_url: str - - Returns: - CustomTokenRefresher: - """ - refresher_unique_key = str((client_id, client_secret, auth_url)) - if not cls._locks.get(refresher_unique_key): - cls._locks[refresher_unique_key] = asyncio.Lock() - - if not cls._instances.get(refresher_unique_key): - async with cls._locks[refresher_unique_key]: - if not cls._instances.get(refresher_unique_key): - cls._instances[refresher_unique_key] = CustomTokenRefresher( - client_id, - client_secret, - auth_url, - ) - - return cls._instances[refresher_unique_key] - - async def get_token(self) -> RefreshedToken[CustomTokenResponse]: - """ - Get token from server test implementation. - - Returns: - RefreshedToken[CustomTokenResponse]: - """ - - async with self.session().post(url=self.auth_url, json={}) as response: - response_data = await response.json() - ttl = self.ttl - if (response_data.get("expires_in") or 0) > 0: - ttl = int(response_data.get("expires_in")) - - return RefreshedToken( - token=CustomTokenResponse(**response_data), - created_at=int(time.time()), - ttl=ttl, - ) - - -@pytest.mark.asyncio -async def test_token_refresher_1(session_faker): - """ - Test token refresher with ttl from class. - - Args: - session_faker: Faker - """ - token_response = CustomTokenResponse( - access_token=session_faker.word(), - signature=session_faker.word(), - random_field=session_faker.word(), - ) - - client_id = session_faker.word() - client_secret = session_faker.word() - auth_url = session_faker.uri() - - token_refresher = await CustomTokenRefresher.get_instance( - client_id, - client_secret, - auth_url, - ) - - assert token_refresher == await CustomTokenRefresher.get_instance( - client_id, - client_secret, - auth_url, - ) - - with aioresponses() as mocked_responses: - mocked_responses.post(auth_url, payload=token_response.dict()) - await token_refresher._refresh_token() - - assert token_refresher._token is not None - assert token_refresher._token.token.access_token == token_response.access_token - assert token_refresher._token.token.signature == token_response.signature - assert token_refresher._token.token.random_field == token_response.random_field - assert token_refresher._token.ttl == 60 - - await token_refresher._session.close() - await token_refresher.close() - - -@pytest.mark.asyncio -async def test_token_refresher_2(session_faker): - """ - Test token refresher with ttl from server response. - - Args: - session_faker: Faker - """ - token_response = { - "access_token": session_faker.word(), - "signature": session_faker.word(), - "random_field": session_faker.word(), - "expires_in": session_faker.pyint(), - } - - client_id = session_faker.word() - client_secret = session_faker.word() - auth_url = session_faker.uri() - - with aioresponses() as mocked_responses: - token_refresher = CustomTokenRefresher( - client_id, - client_secret, - auth_url, - ) - - mocked_responses.post(auth_url, payload=token_response) - await token_refresher._refresh_token() - - assert token_refresher._token is not None - assert token_refresher._token.token.access_token == token_response.get( - "access_token" - ) - assert token_refresher._token.token.signature == token_response.get("signature") - assert token_refresher._token.token.random_field == token_response.get( - "random_field" - ) - assert token_refresher._token.ttl == token_response.get("expires_in") - - await token_refresher._session.close() - await token_refresher.close() - - -@pytest.mark.asyncio -async def test_token_refresher_with_token(session_faker): - """ - Test token refresher with ttl from server response. - - Args: - session_faker: Faker - """ - token_response = { - "access_token": session_faker.word(), - "signature": session_faker.word(), - "random_field": session_faker.word(), - "expires_in": session_faker.pyint(), - } - - client_id = session_faker.word() - client_secret = session_faker.word() - auth_url = session_faker.uri() - - with aioresponses() as mocked_responses: - token_refresher = CustomTokenRefresher( - client_id, - client_secret, - auth_url, - ) - - mocked_responses.post(auth_url, payload=token_response) - async with token_refresher.with_access_token() as generated_token: - assert generated_token.token.access_token == token_response.get( - "access_token" - ) - assert generated_token.token.signature == token_response.get("signature") - assert generated_token.token.random_field == token_response.get( - "random_field" - ) - assert generated_token.ttl == token_response.get("expires_in") - - await token_refresher._session.close() - await token_refresher.close() diff --git a/tests/aio/test_connector.py b/tests/aio/test_connector.py index 14c9e2d..489ab9f 100644 --- a/tests/aio/test_connector.py +++ b/tests/aio/test_connector.py @@ -1,5 +1,7 @@ """Test async connector.""" +from collections.abc import AsyncGenerator +from datetime import datetime from unittest.mock import Mock, patch from urllib.parse import urljoin @@ -15,8 +17,17 @@ class DummyAsyncConnector(AsyncConnector): trigger_activation: str | None = None - def run(self): - raise NotImplementedError + events: list[list[str]] | None = None + + def set_events(self, events: list[list[str]]) -> None: + self.events = events + + async def iterate(self) -> AsyncGenerator[tuple[list[str], datetime | None], None]: + if self.events is None: + raise RuntimeError("Events are not set") + + for event in self.events: + yield event, None @pytest.fixture @@ -221,3 +232,49 @@ async def test_async_connector_raise_error( except Exception as e: assert isinstance(e, RuntimeError) assert str(e) == expected_error + + +@pytest.mark.asyncio +async def test_async_connector_async_next_run( + async_connector: DummyAsyncConnector, faker: Faker +): + """ + Test async connector push events. + + Args: + async_connector: DummyAsyncConnector + faker: Faker + """ + single_event_id = faker.uuid4() + + # We expect 3 chunks of events + test_events = [ + [faker.uuid4(), faker.uuid4()], + [faker.uuid4(), faker.uuid4(), faker.uuid4()], + [faker.uuid4(), faker.uuid4(), faker.uuid4(), faker.uuid4()], + ] + + async_connector.set_events(test_events) + + request_url = urljoin(async_connector.configuration.intake_server, "batch") + + with aioresponses() as mocked_responses: + mocked_responses.post( + request_url, + status=200, + payload={"received": True, "event_ids": [single_event_id]}, + ) + + mocked_responses.post( + request_url, + status=200, + payload={"received": True, "event_ids": [single_event_id]}, + ) + + mocked_responses.post( + request_url, + status=200, + payload={"received": True, "event_ids": [single_event_id]}, + ) + + await async_connector.async_next_run() diff --git a/tests/connectors/test_connector.py b/tests/connectors/test_connector.py index b7fd5cb..93c9b16 100644 --- a/tests/connectors/test_connector.py +++ b/tests/connectors/test_connector.py @@ -1,3 +1,5 @@ +from collections.abc import Generator +from datetime import datetime from unittest.mock import Mock, PropertyMock, patch import pytest @@ -20,8 +22,18 @@ def configure_intake_url(config_storage): class DummyConnector(Connector): - def run(self): - raise NotImplementedError + + events: list[list[str]] | None = None + + def set_events(self, events: list[list[str]]) -> None: + self.events = events + + def iterate(self) -> Generator[tuple[list[str], datetime | None], None, None]: + if self.events is None: + raise RuntimeError("Events are not set") + + for data in self.events: + yield data, None @pytest.fixture @@ -280,3 +292,25 @@ def test_connector_configuration_file_not_found(test_connector): Trigger, "configuration", new_callable=PropertyMock, return_value=config ): assert test_connector.configuration == config + + +def test_connector_next_run(faker, test_connector, requests_mock): + requests_mock.post( + "https://intake.sekoia.io/batch", + [ + {"json": {"event_ids": ["001", "002"]}}, + ], + ) + + events = [ + [faker.word(), faker.word()], + [faker.word(), faker.word(), faker.word()], + [faker.word(), faker.word(), faker.word(), faker.word()], + ] + + test_connector.set_events(events) + + test_connector.next_run() + + # because we expect 3 iterations + assert len(requests_mock.request_history) == 3 diff --git a/tests/http/aio/examples/test_oauth_token_auth_client.py b/tests/http/aio/examples/test_oauth_token_auth_client.py new file mode 100644 index 0000000..508c6ca --- /dev/null +++ b/tests/http/aio/examples/test_oauth_token_auth_client.py @@ -0,0 +1,134 @@ +"""Example implementation with tests for AsyncOauthTokenAuthClient.""" + +import json +import time +from typing import Any + +import pytest +from aioresponses import aioresponses +from faker import Faker +from pydantic.main import BaseModel + +from sekoia_automation.http.aio.http_client import AsyncHttpClient +from sekoia_automation.http.aio.token_refresher import ( + GenericTokenRefresher, + RefreshedToken, +) + + +class DummyOAuthResponse(BaseModel): + access_token: str + + +class SampleTokenRefresher(GenericTokenRefresher): + def __init__(self, auth_url: str, client_id: str, client_secret: str) -> None: + super().__init__() + + self.auth_url = auth_url + self.client_id = client_id + self.client_secret = client_secret + + async def get_token(self) -> RefreshedToken[DummyOAuthResponse]: + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + } + + async with self.session().post(self.auth_url, json=data) as response: + result = await response.json() + + return RefreshedToken( + token=DummyOAuthResponse(**result), + created_at=int(time.time()), + ttl=3600, + ) + + +class AsyncOauthClientExample: + def __init__( + self, + base_url: str, + token_refresher: SampleTokenRefresher, + http_client: AsyncHttpClient, + ) -> None: + self.base_url = base_url + self.http_client = http_client + self.token_refresher = token_refresher + + @classmethod + async def instance( + cls, + client_id: str, + client_secret: str, + oauth_url: str, + base_url: str, + max_retries: int | None = None, + backoff_factor: float | None = None, + status_forcelist: list[int] | None = None, + max_rate: float | None = None, + time_period: float | None = None, + ) -> "AsyncOauthClientExample": + token_refresher = SampleTokenRefresher(oauth_url, client_id, client_secret) + http_client = AsyncHttpClient.create( + max_retries, backoff_factor, status_forcelist, max_rate, time_period + ) + + return cls(base_url, token_refresher, http_client) + + async def get_events(self) -> list[dict[str, Any]]: + async with self.token_refresher.with_access_token() as token: + headers = { + "Authorization": f"Bearer {token.token.access_token}", + } + + async with self.http_client.get( + f"{self.base_url}/test/events", headers=headers + ) as response: + return await response.json() + + +@pytest.mark.asyncio +async def test_get_events_example_method(session_faker: Faker): + """ + Test get_events_example_base_method. + + Args: + session_faker: Faker + """ + base_url = str(session_faker.uri()) + auth_url = str(base_url + "/auth") + client_id = session_faker.word() + client_secret = session_faker.word() + test_token = session_faker.word() + + client = await AsyncOauthClientExample.instance( + client_id=client_id, + client_secret=client_secret, + oauth_url=auth_url, + base_url=base_url, + max_retries=3, + backoff_factor=0.1, + status_forcelist=[400], + ) + + data = json.loads( + session_faker.json( + data_columns={"test": ["name", "name", "name"]}, + num_rows=10, + ) + ) + + with aioresponses() as mocked_responses: + request_url = f"{base_url}/test/events" + + auth_headers = { + "Authorization": f"Bearer {test_token}", + } + + mocked_responses.post(auth_url, payload={"access_token": test_token}) + mocked_responses.get(request_url, headers=auth_headers, status=400) + mocked_responses.get( + request_url, headers=auth_headers, payload=data, status=200 + ) + + assert await client.get_events() == data diff --git a/tests/http/aio/test_http_client.py b/tests/http/aio/test_http_client.py index dcbc46d..023499a 100644 --- a/tests/http/aio/test_http_client.py +++ b/tests/http/aio/test_http_client.py @@ -8,6 +8,7 @@ from faker import Faker from sekoia_automation.http.aio.http_client import AsyncHttpClient +from sekoia_automation.http.http_client import Method @pytest.mark.asyncio @@ -143,7 +144,7 @@ async def test_retry_workflow_get_async_http_client( mocked_responses.get(url=base_url, status=status_2) mocked_responses.get(url=base_url, status=status_3) - async with client.request_retry("GET", base_url) as response: + async with client.request_retry(Method.GET, base_url) as response: # As a result, the last response should be 412 assert response.status == status_3 @@ -190,7 +191,7 @@ async def test_retry_workflow_post_async_http_client( mocked_responses.post(url=base_url, payload=data, status=status_2) mocked_responses.post(url=base_url, payload=data, status=status_3) - async with client.request_retry("POST", base_url, json=data) as response: + async with client.request_retry(Method.POST, base_url, json=data) as response: assert response.status == status_3 mocked_responses.post(url=base_url, payload=data, status=status_1) @@ -235,7 +236,7 @@ async def test_retry_workflow_put_async_http_client( mocked_responses.put(url=base_url, payload=data, status=status_2) mocked_responses.put(url=base_url, payload=data, status=status_3) - async with client.request_retry("PUT", base_url, json=data) as response: + async with client.request_retry(Method.PUT, base_url, json=data) as response: assert response.status == status_3 mocked_responses.put(url=base_url, payload=data, status=status_1) @@ -273,7 +274,7 @@ async def test_retry_workflow_head_async_http_client( mocked_responses.head(url=base_url, status=status_2) mocked_responses.head(url=base_url, status=status_3) - async with client.request_retry("HEAD", base_url) as response: + async with client.request_retry(Method.HEAD, base_url) as response: assert response.status == status_3 mocked_responses.head(url=base_url, status=status_1) @@ -311,7 +312,7 @@ async def test_retry_workflow_delete_async_http_client( mocked_responses.delete(url=base_url, status=status_2) mocked_responses.delete(url=base_url, status=status_3) - async with client.request_retry("delete", base_url) as response: + async with client.request_retry(Method.DELETE, base_url) as response: assert response.status == status_3 mocked_responses.delete(url=base_url, status=status_1) @@ -349,7 +350,7 @@ async def test_retry_workflow_patch_async_http_client( mocked_responses.patch(url=base_url, status=status_2) mocked_responses.patch(url=base_url, status=status_3) - async with client.request_retry("PATCH", base_url) as response: + async with client.request_retry(Method.PATCH, base_url) as response: assert response.status == status_3 mocked_responses.patch(url=base_url, status=status_1) From 03f1bb77a2ab5dce5022f93c61664895c2482bc5 Mon Sep 17 00:00:00 2001 From: "vladyslav.guriev" Date: Fri, 4 Oct 2024 09:57:46 +0300 Subject: [PATCH 2/7] fix linting --- sekoia_automation/connector/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sekoia_automation/connector/__init__.py b/sekoia_automation/connector/__init__.py index 3a3330e..b69d53f 100644 --- a/sekoia_automation/connector/__init__.py +++ b/sekoia_automation/connector/__init__.py @@ -8,7 +8,7 @@ from datetime import time as datetime_time from functools import cached_property from os.path import join as urljoin -from typing import Any +from typing import Any, TypeAlias import orjson import requests @@ -27,7 +27,7 @@ # We should add the content of push_events_to_intakes # so that we are able to send events directly from connectors -EventType = dict[str, Any] | str | BaseModel +EventType: TypeAlias = dict[str, Any] | str | BaseModel class DefaultConnectorConfiguration(BaseModel): From 00eb91c09070a43aa9da8318aee20a90f9e325a4 Mon Sep 17 00:00:00 2001 From: "vladyslav.guriev" Date: Fri, 4 Oct 2024 10:23:51 +0300 Subject: [PATCH 3/7] fix linting --- sekoia_automation/aio/connector.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sekoia_automation/aio/connector.py b/sekoia_automation/aio/connector.py index e149a8a..110b20e 100644 --- a/sekoia_automation/aio/connector.py +++ b/sekoia_automation/aio/connector.py @@ -103,6 +103,11 @@ async def session(cls) -> AsyncGenerator[ClientSession, None]: # pragma: no cov async with cls.get_rate_limiter(): yield cls._session + async def async_close(self) -> None: + """Close session.""" + if self._session: + await self._session.close() + async def _async_send_chunk( self, session: ClientSession, url: str, chunk_index: int, chunk: list[str] ) -> list[str]: @@ -244,3 +249,4 @@ def run(self) -> None: # pragma: no cover """Runs Connector.""" loop = asyncio.get_event_loop() loop.run_until_complete(self.async_run()) + loop.run_until_complete(self.async_close()) From 8d59af27ce68e5036fa136ef0c31972e2bb96701 Mon Sep 17 00:00:00 2001 From: "vladyslav.guriev" Date: Fri, 4 Oct 2024 12:25:33 +0300 Subject: [PATCH 4/7] Fix tests --- sekoia_automation/aio/connector.py | 62 ++++++------------- sekoia_automation/http/aio/http_client.py | 3 - sekoia_automation/http/aio/token_refresher.py | 6 +- tests/aio/test_connector.py | 30 +-------- tests/connectors/test_connector.py | 1 - .../examples/test_bearer_token_auth_client.py | 2 + .../examples/test_oauth_token_auth_client.py | 2 + tests/http/aio/test_http_client.py | 22 ++++++- 8 files changed, 50 insertions(+), 78 deletions(-) diff --git a/sekoia_automation/aio/connector.py b/sekoia_automation/aio/connector.py index 110b20e..e8fc3ec 100644 --- a/sekoia_automation/aio/connector.py +++ b/sekoia_automation/aio/connector.py @@ -3,7 +3,6 @@ import asyncio import time from abc import ABC -from asyncio import AbstractEventLoop, get_event_loop from collections.abc import AsyncGenerator, Sequence from contextlib import asynccontextmanager from datetime import datetime @@ -27,8 +26,6 @@ class AsyncConnector(Connector, ABC): configuration: DefaultConnectorConfiguration - _event_loop: AbstractEventLoop - _session: ClientSession | None = None _rate_limiter: AsyncLimiter | None = None @@ -36,7 +33,6 @@ def __init__( self, module: Module | None = None, data_path: Path | None = None, - event_loop: AbstractEventLoop | None = None, *args, **kwargs, ): @@ -53,60 +49,31 @@ def __init__( self.max_concurrency_tasks = kwargs.pop("max_concurrency_tasks", 1000) super().__init__(module=module, data_path=data_path, *args, **kwargs) - self._event_loop = event_loop or get_event_loop() - - @classmethod - def set_client_session(cls, session: ClientSession) -> None: - """ - Set client session. - - Args: - session: ClientSession - """ - cls._session = session - - @classmethod - def set_rate_limiter(cls, rate_limiter: AsyncLimiter) -> None: - """ - Set rate limiter. - - Args: - rate_limiter: - """ - cls._rate_limiter = rate_limiter - - @classmethod - def get_rate_limiter(cls) -> AsyncLimiter: + def get_rate_limiter(self) -> AsyncLimiter: """ Get or initialize rate limiter. Returns: AsyncLimiter: """ - if cls._rate_limiter is None: - cls._rate_limiter = AsyncLimiter(1, 1) + if self._rate_limiter is None: + self._rate_limiter = AsyncLimiter(1, 1) - return cls._rate_limiter + return self._rate_limiter - @classmethod @asynccontextmanager - async def session(cls) -> AsyncGenerator[ClientSession, None]: # pragma: no cover + async def session(self) -> AsyncGenerator[ClientSession, None]: # pragma: no cover """ Get or initialize client session if it is not initialized yet. Returns: ClientSession: """ - if cls._session is None: - cls._session = ClientSession() + if self._session is None: + self._session = ClientSession() - async with cls.get_rate_limiter(): - yield cls._session - - async def async_close(self) -> None: - """Close session.""" - if self._session: - await self._session.close() + async with self.get_rate_limiter(): + yield self._session async def _async_send_chunk( self, session: ClientSession, url: str, chunk_index: int, chunk: list[str] @@ -245,8 +212,17 @@ async def async_run(self) -> None: # pragma: no cover if self.frequency: await asyncio.sleep(self.frequency) + def stop(self, *args, **kwargs): + """ + Stop the connector + """ + super().stop(*args, **kwargs) + loop = asyncio.get_event_loop() + + if self._session: + loop.run_until_complete(self._session.close()) + def run(self) -> None: # pragma: no cover """Runs Connector.""" loop = asyncio.get_event_loop() loop.run_until_complete(self.async_run()) - loop.run_until_complete(self.async_close()) diff --git a/sekoia_automation/http/aio/http_client.py b/sekoia_automation/http/aio/http_client.py index b3cf23c..d71b248 100644 --- a/sekoia_automation/http/aio/http_client.py +++ b/sekoia_automation/http/aio/http_client.py @@ -8,7 +8,6 @@ from aiohttp import ClientResponse, ClientResponseError, ClientSession from aiohttp.web_response import Response from aiolimiter import AsyncLimiter -from loguru import logger from sekoia_automation.http.http_client import AbstractHttpClient, Method from sekoia_automation.http.rate_limiter import RateLimiterConfig @@ -192,8 +191,6 @@ async def request_retry( for attempt in range(attempts): try: - logger.debug("Attempt {0} to do {1} on {2}", attempt, method.value, url) - async with self.session() as session: async with session.request( method.value, url, *args, **kwargs diff --git a/sekoia_automation/http/aio/token_refresher.py b/sekoia_automation/http/aio/token_refresher.py index 42efa7e..74d935b 100644 --- a/sekoia_automation/http/aio/token_refresher.py +++ b/sekoia_automation/http/aio/token_refresher.py @@ -118,7 +118,8 @@ async def _schedule_token_refresh(self, expires_in: int) -> None: Args: expires_in: int """ - await self.close() + if self._token_refresh_task: + self._token_refresh_task.cancel() async def _refresh() -> None: await asyncio.sleep(expires_in) @@ -133,6 +134,9 @@ async def close(self) -> None: if self._token_refresh_task: self._token_refresh_task.cancel() + if self._session: + await self._session.close() + @asynccontextmanager async def with_access_token(self) -> AsyncGenerator[RefreshedTokenT, None]: """ diff --git a/tests/aio/test_connector.py b/tests/aio/test_connector.py index 489ab9f..b3526ae 100644 --- a/tests/aio/test_connector.py +++ b/tests/aio/test_connector.py @@ -49,31 +49,6 @@ def async_connector(storage, mocked_trigger_logs, faker: Faker): async_connector.stop() -@pytest.mark.asyncio -async def test_async_connector_rate_limiter(async_connector: DummyAsyncConnector): - """ - Test async connector rate limiter. - - Args: - async_connector: DummyAsyncConnector - """ - other_instance = DummyAsyncConnector() - rate_limiter_mock = AsyncLimiter(max_rate=100) - - assert async_connector._rate_limiter is None - assert other_instance._rate_limiter is None - - assert async_connector.get_rate_limiter() == other_instance.get_rate_limiter() - - async_connector.set_rate_limiter(rate_limiter_mock) - - assert async_connector.get_rate_limiter() == other_instance.get_rate_limiter() - assert async_connector._rate_limiter == rate_limiter_mock - - DummyAsyncConnector.set_rate_limiter(None) - DummyAsyncConnector.set_client_session(None) - - @pytest.mark.asyncio async def test_async_connector_client_session(async_connector: DummyAsyncConnector): """ @@ -89,7 +64,7 @@ async def test_async_connector_client_session(async_connector: DummyAsyncConnect async with async_connector.session() as session_1: async with other_instance.session() as session_2: - assert session_1 == session_2 + assert session_1 != session_2 assert async_connector._rate_limiter is not None and isinstance( async_connector._rate_limiter, AsyncLimiter @@ -99,9 +74,6 @@ async def test_async_connector_client_session(async_connector: DummyAsyncConnect other_instance._rate_limiter, AsyncLimiter ) - DummyAsyncConnector.set_rate_limiter(None) - other_instance.set_client_session(None) - @pytest.mark.asyncio async def test_async_connector_push_single_event( diff --git a/tests/connectors/test_connector.py b/tests/connectors/test_connector.py index 93c9b16..abd01f4 100644 --- a/tests/connectors/test_connector.py +++ b/tests/connectors/test_connector.py @@ -22,7 +22,6 @@ def configure_intake_url(config_storage): class DummyConnector(Connector): - events: list[list[str]] | None = None def set_events(self, events: list[list[str]]) -> None: diff --git a/tests/http/aio/examples/test_bearer_token_auth_client.py b/tests/http/aio/examples/test_bearer_token_auth_client.py index 75148d6..161fdb1 100644 --- a/tests/http/aio/examples/test_bearer_token_auth_client.py +++ b/tests/http/aio/examples/test_bearer_token_auth_client.py @@ -123,3 +123,5 @@ async def test_get_events_example_method(session_faker: Faker): mocked_responses.get(request_url, status=402) mocked_responses.get(request_url, status=200, payload=data) assert await client.get_events_retry_example({"key": "value"}) == data + + await client.close() diff --git a/tests/http/aio/examples/test_oauth_token_auth_client.py b/tests/http/aio/examples/test_oauth_token_auth_client.py index 508c6ca..5d6f541 100644 --- a/tests/http/aio/examples/test_oauth_token_auth_client.py +++ b/tests/http/aio/examples/test_oauth_token_auth_client.py @@ -132,3 +132,5 @@ async def test_get_events_example_method(session_faker: Faker): ) assert await client.get_events() == data + + await client.http_client.close() diff --git a/tests/http/aio/test_http_client.py b/tests/http/aio/test_http_client.py index 023499a..5a535d1 100644 --- a/tests/http/aio/test_http_client.py +++ b/tests/http/aio/test_http_client.py @@ -40,6 +40,8 @@ async def test_simple_workflow_async_http_client(base_url: str, session_faker: F assert response.status == 200 assert await response.json() == json.loads(data) + await client.close() + @pytest.mark.asyncio async def test_rate_limited_workflow_async_http_client( @@ -74,6 +76,8 @@ async def test_rate_limited_workflow_async_http_client( assert response.status == 200 assert await response.json() == json.loads(data) + await client.close() + time_end = time.time() assert time_end - time_start >= 2 @@ -112,6 +116,8 @@ async def test_rate_limited_workflow_async_http_client_1( assert response.status == 200 assert await response.json() == json.loads(data) + await client.close() + time_end = time.time() assert time_end - time_start >= 2 @@ -156,6 +162,8 @@ async def test_retry_workflow_get_async_http_client( # As a result, the last response should be 412 assert response.status == status_3 + await client.close() + @pytest.mark.asyncio async def test_retry_workflow_post_async_http_client( @@ -201,6 +209,8 @@ async def test_retry_workflow_post_async_http_client( async with client.post(base_url, json=data) as response: assert response.status == status_3 + await client.close() + @pytest.mark.asyncio async def test_retry_workflow_put_async_http_client( @@ -246,6 +256,8 @@ async def test_retry_workflow_put_async_http_client( async with client.put(base_url, json=data) as response: assert response.status == status_3 + await client.close() + @pytest.mark.asyncio async def test_retry_workflow_head_async_http_client( @@ -284,6 +296,8 @@ async def test_retry_workflow_head_async_http_client( async with client.head(base_url) as response: assert response.status == status_3 + await client.close() + @pytest.mark.asyncio async def test_retry_workflow_delete_async_http_client( @@ -322,6 +336,8 @@ async def test_retry_workflow_delete_async_http_client( async with client.delete(base_url) as response: assert response.status == status_3 + await client.close() + @pytest.mark.asyncio async def test_retry_workflow_patch_async_http_client( @@ -360,6 +376,8 @@ async def test_retry_workflow_patch_async_http_client( async with client.patch(base_url) as response: assert response.status == status_3 + await client.close() + @pytest.mark.asyncio async def test_complete_configurable_async_http_client( @@ -415,4 +433,6 @@ async def test_complete_configurable_async_http_client( assert response.status == status_1 end_time = time.time() - assert end_time - start_time >= 3 + await client.close() + + assert end_time - start_time >= 3 From 8c78ab970202d45b315909d05136439e1e2a71c8 Mon Sep 17 00:00:00 2001 From: "vladyslav.guriev" Date: Wed, 16 Oct 2024 19:52:34 +0300 Subject: [PATCH 5/7] Fix comments --- sekoia_automation/aio/connector.py | 14 +++------- sekoia_automation/connector/__init__.py | 35 +++---------------------- sekoia_automation/connector/metrics.py | 33 +++++++++++++++++++++++ 3 files changed, 39 insertions(+), 43 deletions(-) create mode 100644 sekoia_automation/connector/metrics.py diff --git a/sekoia_automation/aio/connector.py b/sekoia_automation/aio/connector.py index f42be37..c9c462c 100644 --- a/sekoia_automation/aio/connector.py +++ b/sekoia_automation/aio/connector.py @@ -212,18 +212,10 @@ async def async_run(self) -> None: # pragma: no cover if self.frequency: await asyncio.sleep(self.frequency) - def stop(self, *args, **kwargs): - """ - Stop the connector - """ - loop = asyncio.get_event_loop() - - if self._session: - loop.run_until_complete(self._session.close()) - - super().stop(*args, **kwargs) - def run(self) -> None: # pragma: no cover """Runs Connector.""" loop = asyncio.get_event_loop() loop.run_until_complete(self.async_run()) + + if self._session: + loop.run_until_complete(self._session.close()) diff --git a/sekoia_automation/connector/__init__.py b/sekoia_automation/connector/__init__.py index b69d53f..379524d 100644 --- a/sekoia_automation/connector/__init__.py +++ b/sekoia_automation/connector/__init__.py @@ -13,11 +13,11 @@ import orjson import requests import sentry_sdk -from prometheus_client import Counter, Gauge, Histogram from pydantic import BaseModel from requests import Response from tenacity import Retrying, stop_after_delay, wait_exponential +from sekoia_automation.connector.metrics import MetricsMixin from sekoia_automation.constants import CHUNK_BYTES_MAX_SIZE, EVENT_BYTES_MAX_SIZE from sekoia_automation.exceptions import TriggerConfigurationError from sekoia_automation.trigger import Trigger @@ -35,43 +35,13 @@ class DefaultConnectorConfiguration(BaseModel): intake_key: str -class Connector(Trigger, ABC): +class Connector(Trigger, MetricsMixin, ABC): CONNECTOR_CONFIGURATION_FILE_NAME = "connector_configuration" seconds_without_events = 3600 * 6 # Required for Pydantic to correctly type the configuration object configuration: DefaultConnectorConfiguration - _prometheus_namespace = "symphony_module_common" - - _outcoming_events = Counter( - name="forwarded_events", - documentation="Number of events forwarded to Sekoia.io", - namespace=_prometheus_namespace, - labelnames=["intake_key"], - ) - - _forward_events_duration = Histogram( - name="forward_events_duration", - documentation="Duration to collect and forward events from eventhub", - namespace=_prometheus_namespace, - labelnames=["intake_key"], - ) - - _discarded_events = Counter( - name="discarded_events", - documentation="Number of events discarded from the collect", - namespace=_prometheus_namespace, - labelnames=["intake_key"], - ) - - _events_lag = Gauge( - name="events_lags", - documentation="The delay, in seconds, from the date of the last event", - namespace=_prometheus_namespace, - labelnames=["intake_key"], - ) - @property def connector_name(self) -> str: """ @@ -315,6 +285,7 @@ def _chunk_events( if isinstance(event, BaseModel): result_event = orjson.dumps(event.dict()).decode("utf-8") + elif isinstance(event, dict): result_event = orjson.dumps(event).decode("utf-8") diff --git a/sekoia_automation/connector/metrics.py b/sekoia_automation/connector/metrics.py new file mode 100644 index 0000000..60c2503 --- /dev/null +++ b/sekoia_automation/connector/metrics.py @@ -0,0 +1,33 @@ +from prometheus_client import Counter, Gauge, Histogram + + +class MetricsMixin: + _prometheus_namespace = "symphony_module_common" + + _outcoming_events: Counter = Counter( + name="forwarded_events", + documentation="Number of events forwarded to Sekoia.io", + namespace=_prometheus_namespace, + labelnames=["intake_key"], + ) + + _forward_events_duration: Histogram = Histogram( + name="forward_events_duration", + documentation="Duration to collect and forward events from eventhub", + namespace=_prometheus_namespace, + labelnames=["intake_key"], + ) + + _discarded_events: Counter = Counter( + name="discarded_events", + documentation="Number of events discarded from the collect", + namespace=_prometheus_namespace, + labelnames=["intake_key"], + ) + + _events_lag: Gauge = Gauge( + name="events_lags", + documentation="The delay, in seconds, from the date of the last event", + namespace=_prometheus_namespace, + labelnames=["intake_key"], + ) From f4a6a2fa5f4e65729b888b74caacd52f072f0b87 Mon Sep 17 00:00:00 2001 From: vg-svitla Date: Thu, 28 Nov 2024 10:27:41 +0200 Subject: [PATCH 6/7] Fix comments --- sekoia_automation/aio/connector.py | 34 +++-- sekoia_automation/connector/__init__.py | 17 +-- sekoia_automation/connector/metrics.py | 161 ++++++++++++++++++++---- tests/aio/test_connector.py | 21 ++-- 4 files changed, 179 insertions(+), 54 deletions(-) diff --git a/sekoia_automation/aio/connector.py b/sekoia_automation/aio/connector.py index 52a12b0..f400aae 100644 --- a/sekoia_automation/aio/connector.py +++ b/sekoia_automation/aio/connector.py @@ -155,14 +155,14 @@ async def async_iterate( self, ) -> AsyncGenerator[tuple[list[EventType], datetime | None], None]: """Iterate over events.""" - yield [], None # To avoid type checking error + raise NotImplementedError # To avoid type checking error async def async_next_run(self) -> None: processing_start = time.time() result_last_event_date: datetime | None = None total_number_of_events = 0 - async for data in self.async_iterate(): + async for data in self.async_iterate(): # type: ignore events, last_event_date = data if last_event_date: if ( @@ -179,19 +179,20 @@ async def async_next_run(self) -> None: processing_time = processing_end - processing_start # Metric about processing time - self._forward_events_duration.labels( - intake_key=self.configuration.intake_key - ).observe(processing_time) + self.put_forward_events_duration( + intake_key=self.configuration.intake_key, + duration=processing_time, + ) # Metric about processing count - self._outcoming_events.labels(intake_key=self.configuration.intake_key).inc( - total_number_of_events + self.put_forwarded_events( + intake_key=self.configuration.intake_key, count=total_number_of_events ) # Metric about events lag if result_last_event_date: lag = (datetime.utcnow() - result_last_event_date).total_seconds() - self._events_lag.labels(intake_key=self.configuration.intake_key).set(lag) + self.put_events_lag(intake_key=self.configuration.intake_key, lag=lag) # Compute the remaining sleeping time. # If greater than 0 and no messages where fetched, pause the connector @@ -201,6 +202,15 @@ async def async_next_run(self) -> None: await asyncio.sleep(delta_sleep) + async def on_shutdown(self) -> None: + """ + Called when connector is finishing processing. + + Can be used for some resources cleanup. + + Basically it emits shutdown event. + """ + # Put infinite arg only to have testing easier async def async_run(self) -> None: # pragma: no cover """Runs Connector.""" @@ -216,10 +226,12 @@ async def async_run(self) -> None: # pragma: no cover if self.frequency: await asyncio.sleep(self.frequency) + if self._session: + await self._session.close() + + await self.on_shutdown() + def run(self) -> None: # pragma: no cover """Runs Connector.""" loop = asyncio.get_event_loop() loop.run_until_complete(self.async_run()) - - if self._session: - loop.run_until_complete(self._session.close()) diff --git a/sekoia_automation/connector/__init__.py b/sekoia_automation/connector/__init__.py index 379524d..2770ba9 100644 --- a/sekoia_automation/connector/__init__.py +++ b/sekoia_automation/connector/__init__.py @@ -313,8 +313,8 @@ def _chunk_events( # if events were discarded, log it if nb_discarded_events > 0: - self._discarded_events.labels(intake_key=self.configuration.intake_key).inc( - nb_discarded_events + self.put_discarded_events( + intake_key=self.configuration.intake_key, count=nb_discarded_events ) self.log( @@ -361,19 +361,20 @@ def next_run(self) -> None: processing_time = processing_end - processing_start # Metric about processing time - self._forward_events_duration.labels( - intake_key=self.configuration.intake_key - ).observe(processing_time) + self.put_forward_events_duration( + intake_key=self.configuration.intake_key, + duration=processing_time, + ) # Metric about processing count - self._outcoming_events.labels(intake_key=self.configuration.intake_key).inc( - total_number_of_events + self.put_forwarded_events( + intake_key=self.configuration.intake_key, count=total_number_of_events ) # Metric about events lag if result_last_event_date: lag = (datetime.utcnow() - result_last_event_date).total_seconds() - self._events_lag.labels(intake_key=self.configuration.intake_key).set(lag) + self.put_events_lag(intake_key=self.configuration.intake_key, lag=lag) # Compute the remaining sleeping time. # If greater than 0 and no messages where fetched, pause the connector diff --git a/sekoia_automation/connector/metrics.py b/sekoia_automation/connector/metrics.py index 60c2503..f4950c0 100644 --- a/sekoia_automation/connector/metrics.py +++ b/sekoia_automation/connector/metrics.py @@ -1,33 +1,140 @@ +from functools import cached_property + from prometheus_client import Counter, Gauge, Histogram class MetricsMixin: _prometheus_namespace = "symphony_module_common" - _outcoming_events: Counter = Counter( - name="forwarded_events", - documentation="Number of events forwarded to Sekoia.io", - namespace=_prometheus_namespace, - labelnames=["intake_key"], - ) - - _forward_events_duration: Histogram = Histogram( - name="forward_events_duration", - documentation="Duration to collect and forward events from eventhub", - namespace=_prometheus_namespace, - labelnames=["intake_key"], - ) - - _discarded_events: Counter = Counter( - name="discarded_events", - documentation="Number of events discarded from the collect", - namespace=_prometheus_namespace, - labelnames=["intake_key"], - ) - - _events_lag: Gauge = Gauge( - name="events_lags", - documentation="The delay, in seconds, from the date of the last event", - namespace=_prometheus_namespace, - labelnames=["intake_key"], - ) + _forwarded_events: Counter | None = None + _forward_events_duration: Histogram | None = None + _discarded_events: Counter | None = None + _events_lag: Gauge | None = None + + @cached_property + def forwarded_events_counter(self) -> Counter | None: + """ + Get forwarded events counter. + + Returns: + Counter | None: + """ + if self._forwarded_events is None: + try: + self._forwarded_events = Counter( + name="forwarded_events", + documentation="Number of events forwarded to Sekoia.io", + namespace=self._prometheus_namespace, + labelnames=["intake_key"], + ) + except Exception: + return None + + return self._forwarded_events + + @cached_property + def forward_events_duration(self) -> Histogram | None: + """ + Get forward events duration. + + Returns: + Histogram | None: + """ + if self._forward_events_duration is None: + try: + self._forward_events_duration = Histogram( + name="forward_events_duration", + documentation="Duration of the forward events", + namespace=self._prometheus_namespace, + labelnames=["intake_key"], + ) + except Exception: + return None + + return self._forward_events_duration + + @cached_property + def discarded_events_counter(self) -> Counter | None: + """ + Get discarded events counter. + + Returns: + Counter | None: + """ + if self._discarded_events is None: + try: + self._discarded_events = Counter( + name="discarded_events", + documentation="Number of events discarded from the collect", + namespace=self._prometheus_namespace, + labelnames=["intake_key"], + ) + except Exception: + return None + + return self._discarded_events + + @cached_property + def events_lag(self) -> Gauge | None: + """ + Get events lag gauge. + + Returns: + Gauge | None: + """ + if self._events_lag is None: + try: + self._events_lag = Gauge( + name="events_lags", + documentation="The delay (seconds) from the date of the last event", + namespace=self._prometheus_namespace, + labelnames=["intake_key"], + ) + except Exception: + return None + + return self._events_lag + + def put_forward_events_duration(self, intake_key: str, duration: float) -> None: + """ + Put forwarded events duration. + + Args: + intake_key: str + duration: float + """ + if self.forward_events_duration: + self.forward_events_duration.labels(intake_key=intake_key).observe(duration) + + def put_discarded_events(self, intake_key: str, count: int) -> None: + """ + Put discarded events. + + Args: + intake_key: str + count: int + """ + if self.discarded_events_counter: + self.discarded_events_counter.labels(intake_key=intake_key).inc(count) + + def put_forwarded_events(self, intake_key: str, count: int) -> None: + """ + Put forwarded events. + + Args: + intake_key: str + count: int + """ + if self.forwarded_events_counter: + self.forwarded_events_counter.labels(intake_key=intake_key).inc(count) + + def put_events_lag(self, intake_key: str, lag: float) -> None: + """ + Put events lag. + + Args: + intake_key: str + lag: float + """ + if self.events_lag: + self.events_lag.labels(intake_key=intake_key).set(lag) diff --git a/tests/aio/test_connector.py b/tests/aio/test_connector.py index 6cc2d97..65c033f 100644 --- a/tests/aio/test_connector.py +++ b/tests/aio/test_connector.py @@ -2,8 +2,8 @@ from collections.abc import AsyncGenerator from datetime import datetime -from unittest.mock import Mock, patch from posixpath import join as urljoin +from unittest.mock import Mock, patch import pytest from aiolimiter import AsyncLimiter @@ -22,7 +22,9 @@ class DummyAsyncConnector(AsyncConnector): def set_events(self, events: list[list[str]]) -> None: self.events = events - async def iterate(self) -> AsyncGenerator[tuple[list[str], datetime | None], None]: + async def async_iterate( + self, + ) -> AsyncGenerator[tuple[list[str], datetime | None], None]: if self.events is None: raise RuntimeError("Events are not set") @@ -251,15 +253,18 @@ async def test_async_connector_async_next_run( await async_connector.async_next_run() + @pytest.mark.parametrize( - 'base_url,expected_batchapi_url', + "base_url,expected_batchapi_url", [ - ('http://intake.fake.url/', 'http://intake.fake.url/batch'), - ('http://fake.url/intake/', 'http://fake.url/intake/batch'), - ('http://fake.url/intake', 'http://fake.url/intake/batch'), - ] + ("http://intake.fake.url/", "http://intake.fake.url/batch"), + ("http://fake.url/intake/", "http://fake.url/intake/batch"), + ("http://fake.url/intake", "http://fake.url/intake/batch"), + ], ) -def test_async_connector_batchapi_url(storage, mocked_trigger_logs, base_url: str, expected_batchapi_url: str): +def test_async_connector_batchapi_url( + storage, mocked_trigger_logs, base_url: str, expected_batchapi_url: str +): with patch("sentry_sdk.set_tag"): async_connector = DummyAsyncConnector(data_path=storage) From 6c652966f08ab8de9d3de78e4c46fd6668b828b4 Mon Sep 17 00:00:00 2001 From: vg-svitla Date: Thu, 28 Nov 2024 15:30:38 +0200 Subject: [PATCH 7/7] Update changelog and fix version --- CHANGELOG.md | 2 ++ pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 97b7500..202925f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## 1.19.0 - 2024-11-28 + ### Changed - Improvements for AsyncConnector. diff --git a/pyproject.toml b/pyproject.toml index 80a73bf..0946168 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "sekoia-automation-sdk" -version = "1.18.0" +version = "1.19.0" description = "SDK to create Sekoia.io playbook modules" license = "MIT" readme = "README.md"