From 8689a0966c4894f0f5e7ffc468a8161e51c22b95 Mon Sep 17 00:00:00 2001 From: Kamforka Date: Thu, 26 Sep 2024 17:07:24 +0200 Subject: [PATCH] #340 - Add retry mechanism --- pyproject.toml | 2 +- thehive4py/client.py | 4 +++- thehive4py/session.py | 21 +++++++++++++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 68228988..99f373d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "thehive4py" description = "Python client for TheHive5" version = "2.0.0b6" requires-python = ">=3.8" -dependencies = ["requests>=2.27"] +dependencies = ["requests~=2.27"] readme = "README.md" keywords = ["thehive5", "api", "client"] license = { text = "MIT" } diff --git a/thehive4py/client.py b/thehive4py/client.py index 923c8e67..f731d633 100644 --- a/thehive4py/client.py +++ b/thehive4py/client.py @@ -17,7 +17,7 @@ from thehive4py.endpoints.custom_field import CustomFieldEndpoint from thehive4py.endpoints.observable_type import ObservableTypeEndpoint from thehive4py.endpoints.query import QueryEndpoint -from thehive4py.session import TheHiveSession +from thehive4py.session import DEFAULT_RETRY, RetryValue, TheHiveSession class TheHiveApi: @@ -29,6 +29,7 @@ def __init__( password: Optional[str] = None, organisation: Optional[str] = None, verify=None, + max_retries: RetryValue = DEFAULT_RETRY, ): self.session = TheHiveSession( url=url, @@ -36,6 +37,7 @@ def __init__( username=username, password=password, verify=verify, + max_retries=max_retries, ) self.session_organisation = organisation diff --git a/thehive4py/session.py b/thehive4py/session.py index cd0b2d26..e3e261d2 100644 --- a/thehive4py/session.py +++ b/thehive4py/session.py @@ -4,11 +4,24 @@ from typing import Any, Optional, Union import requests +import requests.adapters import requests.auth +from urllib3 import Retry from thehive4py.__version__ import __version__ from thehive4py.errors import TheHiveError +DEFAULT_RETRY = Retry( + total=5, + backoff_factor=1, + status_forcelist=[500, 502, 503, 504], + allowed_methods=["GET", "POST", "PUT", "PATCH", "DELETE"], + raise_on_status=False, +) + + +RetryValue = Union[Retry, int, None] + class SessionJSONEncoder(jsonlib.JSONEncoder): """Custom JSON encoder class for TheHive session.""" @@ -27,11 +40,13 @@ def __init__( username: Optional[str] = None, password: Optional[str] = None, verify=None, + max_retries: RetryValue = DEFAULT_RETRY, ): super().__init__() self.hive_url = self._sanitize_hive_url(url) self.verify = verify self.headers["User-Agent"] = f"thehive4py/{__version__}" + self._set_retries(max_retries=max_retries) if username and password: self.headers["Authorization"] = requests.auth._basic_auth_str( @@ -44,6 +59,12 @@ def __init__( "Either apikey or the username/password combination must be provided!" ) + def _set_retries(self, max_retries: RetryValue): + """Configure the session to retry.""" + retry_adapter = requests.adapters.HTTPAdapter(max_retries=max_retries) + self.mount("http://", retry_adapter) + self.mount("https://", retry_adapter) + def _sanitize_hive_url(self, hive_url: str) -> str: """Sanitize the base url for the client.""" if hive_url.endswith("/"):