diff --git a/Stormshield/stormshield_module/base.py b/Stormshield/stormshield_module/base.py index 7ffe3852..230eae6e 100644 --- a/Stormshield/stormshield_module/base.py +++ b/Stormshield/stormshield_module/base.py @@ -1,5 +1,6 @@ from functools import cached_property from posixpath import join as urljoin +from typing import Any import re import requests from requests import RequestException, Response @@ -17,19 +18,19 @@ class StormshieldAction(GenericAPIAction): endpoint: str @cached_property - def api_token(self): - return self.module.configuration["api_token"] + def api_token(self): # type: ignore + return self.module.configuration.get("api_token") @cached_property - def base_url(self): + def base_url(self) -> str: config_url = self.module.configuration["url"].rstrip("/") api_path = "rest/api/v1" return urljoin(config_url, api_path) - def get_headers(self): + def get_headers(self) -> dict[str, str]: return {"Authorization": f"Bearer {self.api_token}"} - def treat_failed_response(self, response: Response): + def treat_failed_response(self, response: Response) -> None: errors = { 401: "Authentication failed: Invalid API key provided.", 403: "Access denied: Insufficient permissions to access this resource.", @@ -43,7 +44,7 @@ def treat_failed_response(self, response: Response): if message: raise Exception(f"Error : {message}") - def get_url(self, arguments): + def get_url(self, arguments: dict[str, Any]) -> str: match = re.findall("{(.*?)}", self.endpoint) for replacement in match: self.endpoint = self.endpoint.replace(f"{{{replacement}}}", str(arguments.pop(replacement)), 1) @@ -51,7 +52,7 @@ def get_url(self, arguments): path = urljoin(self.base_url, self.endpoint.lstrip("/")) if self.query_parameters: - query_arguments: list = [] + query_arguments: list[str] = [] for k in self.query_parameters: if k in arguments: @@ -65,10 +66,10 @@ def get_url(self, arguments): return path - def get_response(self, url, body, headers) -> Response: + def get_response(self, url: str, body: dict[str, Any] | None, headers:dict[str, Any]) -> Response: return requests.request(self.verb, url, json=body, headers=headers, timeout=self.timeout) - def run(self, arguments) -> dict | None: + def run(self, arguments: dict[str, Any]) -> dict[str, Any] | None: headers = self.get_headers() url = self.get_url(arguments) body = self.get_body(arguments) diff --git a/Stormshield/stormshield_module/wait_task.py b/Stormshield/stormshield_module/wait_task.py index 9a160fef..522929d7 100644 --- a/Stormshield/stormshield_module/wait_task.py +++ b/Stormshield/stormshield_module/wait_task.py @@ -1,5 +1,6 @@ import requests from requests import Response +from typing import Any from stormshield_module.base import StormshieldAction from stormshield_module.exceptions import RemoteTaskExecutionFailedError @@ -10,7 +11,7 @@ class WaitForTaskCompletionAction(StormshieldAction): endpoint = "/agents/tasks/{task_id}" query_parameters: list[str] = [] - def get_response(self, url, body, headers) -> Response: + def get_response(self, url: str, body: dict[str, Any] | None, headers:dict[str, Any]) -> Response: result = requests.request(self.verb, url, json=body, headers=headers, timeout=self.timeout) execution_state = result.json()["status"] diff --git a/Stormshield/tests/test_wait_tasks.py b/Stormshield/tests/test_wait_tasks.py index dd463e36..9b7d219a 100644 --- a/Stormshield/tests/test_wait_tasks.py +++ b/Stormshield/tests/test_wait_tasks.py @@ -16,7 +16,7 @@ def test_integration_wait_task_with_CD(symphony_storage): action = WaitForTaskCompletionAction(data_path=symphony_storage) action.module.configuration = module_configuration - arguments = {"id": os.environ["STORMSHIELD_AGENT_ID"]} + arguments = {"task_id": os.environ["STORMSHIELD_AGENT_ID"]} response = action.run(arguments) @@ -64,7 +64,7 @@ def test_integration_wait_task_failed(symphony_storage, wait_task_failed_message json=wait_task_failed_message, ) - arguments = {"id": "foo"} + arguments = {"task_id": "foo"} with pytest.raises(Exception) as excinfo: action.run(arguments) @@ -86,7 +86,7 @@ def test_integration_wait_task_succeeded(symphony_storage, wait_task_succeded_me json=wait_task_succeded_message, ) - arguments = {"id": "foo"} + arguments = {"task_id": "foo"} response = action.run(arguments) assert response is not None