Skip to content

Commit

Permalink
Fix mypy problem
Browse files Browse the repository at this point in the history
  • Loading branch information
TOUFIKIzakarya committed Dec 17, 2024
1 parent e253bf6 commit 2aaa62e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 13 deletions.
19 changes: 10 additions & 9 deletions Stormshield/stormshield_module/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import cached_property
from posixpath import join as urljoin
from typing import Any
import re
import requests
from requests import RequestException, Response
Expand All @@ -17,19 +18,19 @@ class StormshieldAction(GenericAPIAction):
endpoint: str

@cached_property
def api_token(self):
return self.module.configuration["api_token"]
def api_token(self): # type: ignore
return self.module.configuration.get("api_token")

@cached_property
def base_url(self):
def base_url(self) -> str:
config_url = self.module.configuration["url"].rstrip("/")
api_path = "rest/api/v1"
return urljoin(config_url, api_path)

def get_headers(self):
def get_headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self.api_token}"}

def treat_failed_response(self, response: Response):
def treat_failed_response(self, response: Response) -> None:
errors = {
401: "Authentication failed: Invalid API key provided.",
403: "Access denied: Insufficient permissions to access this resource.",
Expand All @@ -43,15 +44,15 @@ def treat_failed_response(self, response: Response):
if message:
raise Exception(f"Error : {message}")

def get_url(self, arguments):
def get_url(self, arguments: dict[str, Any]) -> str:
match = re.findall("{(.*?)}", self.endpoint)
for replacement in match:
self.endpoint = self.endpoint.replace(f"{{{replacement}}}", str(arguments.pop(replacement)), 1)

path = urljoin(self.base_url, self.endpoint.lstrip("/"))

if self.query_parameters:
query_arguments: list = []
query_arguments: list[str] = []

for k in self.query_parameters:
if k in arguments:
Expand All @@ -65,10 +66,10 @@ def get_url(self, arguments):

return path

def get_response(self, url, body, headers) -> Response:
def get_response(self, url: str, body: dict[str, Any] | None, headers:dict[str, Any]) -> Response:
return requests.request(self.verb, url, json=body, headers=headers, timeout=self.timeout)

def run(self, arguments) -> dict | None:
def run(self, arguments: dict[str, Any]) -> dict[str, Any] | None:
headers = self.get_headers()
url = self.get_url(arguments)
body = self.get_body(arguments)
Expand Down
3 changes: 2 additions & 1 deletion Stormshield/stormshield_module/wait_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import requests
from requests import Response
from typing import Any

from stormshield_module.base import StormshieldAction
from stormshield_module.exceptions import RemoteTaskExecutionFailedError
Expand All @@ -10,7 +11,7 @@ class WaitForTaskCompletionAction(StormshieldAction):
endpoint = "/agents/tasks/{task_id}"
query_parameters: list[str] = []

def get_response(self, url, body, headers) -> Response:
def get_response(self, url: str, body: dict[str, Any] | None, headers:dict[str, Any]) -> Response:
result = requests.request(self.verb, url, json=body, headers=headers, timeout=self.timeout)
execution_state = result.json()["status"]

Expand Down
6 changes: 3 additions & 3 deletions Stormshield/tests/test_wait_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_integration_wait_task_with_CD(symphony_storage):
action = WaitForTaskCompletionAction(data_path=symphony_storage)
action.module.configuration = module_configuration

arguments = {"id": os.environ["STORMSHIELD_AGENT_ID"]}
arguments = {"task_id": os.environ["STORMSHIELD_AGENT_ID"]}

response = action.run(arguments)

Expand Down Expand Up @@ -64,7 +64,7 @@ def test_integration_wait_task_failed(symphony_storage, wait_task_failed_message
json=wait_task_failed_message,
)

arguments = {"id": "foo"}
arguments = {"task_id": "foo"}

with pytest.raises(Exception) as excinfo:
action.run(arguments)
Expand All @@ -86,7 +86,7 @@ def test_integration_wait_task_succeeded(symphony_storage, wait_task_succeded_me
json=wait_task_succeded_message,
)

arguments = {"id": "foo"}
arguments = {"task_id": "foo"}
response = action.run(arguments)

assert response is not None
Expand Down

0 comments on commit 2aaa62e

Please sign in to comment.