diff --git a/sekoia_automation/account_validator.py b/sekoia_automation/account_validator.py new file mode 100644 index 0000000..d982075 --- /dev/null +++ b/sekoia_automation/account_validator.py @@ -0,0 +1,50 @@ +from abc import abstractmethod +from pathlib import Path + +from sekoia_automation.module import Module, ModuleItem + + +class AccountValidator(ModuleItem): + CALLBACK_URL_FILE_NAME = "validation_callback_url" + + def __init__(self, module: Module | None = None, data_path: Path | None = None): + super().__init__(module, data_path) + self._error: str | None = None + + @abstractmethod + def validate(self) -> bool: + """To define in subclasses. Validates the configuration of the module. + + Returns: + bool: True if the module is valid, False otherwise + """ + + def error(self, message: str) -> None: + """Allow to set an error message explaining why the validation failed.""" + self._error = message + + def execute(self): + """Validates the account (module_configuration) of the module + and sends the result to Symphony.""" + self.set_task_as_running() + # Call the actual validation procedure + success = self.validate() + self.send_results(success) + + def set_task_as_running(self): + """Send a request to indicate the action started.""" + data = {"status": "running"} + response = self._send_request(data, verb="PATCH") + if self.module.has_secrets(): + secrets = { + k: v + for k, v in response.json()["module_configuration"]["value"].items() + if k in self.module.manifest_secrets() + } + self.module.set_secrets(secrets) + + def send_results(self, success: bool): + data = {"status": "finished", "results": {"success": success}} + if self._error: + data["error"] = self._error + self._send_request(data, verb="PATCH") diff --git a/sekoia_automation/module.py b/sekoia_automation/module.py index 5b3012e..35e669d 100644 --- a/sekoia_automation/module.py +++ b/sekoia_automation/module.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from functools import cached_property from pathlib import Path -from typing import Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, cast import requests import sentry_sdk @@ -25,6 +25,10 @@ get_as_model, ) +if TYPE_CHECKING: # pragma: no cover + from sekoia_automation.account_validator import AccountValidator + + LogLevelStr = Literal["fatal", "critical", "error", "warning", "info", "debug"] @@ -493,36 +497,3 @@ def stop_monitoring(self): """ Stops the background monitoring operations """ - - -class AccountValidator(ModuleItem): - CALLBACK_URL_FILE_NAME = "validation_callback_url" - - @abstractmethod - def validate(self) -> bool: - """To define in subclasses. Validates the configuration of the module. - - Returns: - bool: True if the module is valid, False otherwise - """ - - def execute(self): - """Validates the account (module_configuration) of the module - and sends the result to Symphony.""" - # Call the actual validation procedure - status = self.validate() - - # Return result of validation to Symphony ; ask for module's secrets if needed - data = {"validation_status": status, "need_secrets": self.module.has_secrets()} - - # Send request to Symphony - response = self._send_request(data, verb="PATCH") - - # Set module's secrets if needed - if self.module.has_secrets(): - secrets = { - k: v - for k, v in response.json()["module_configuration"]["value"].items() - if k in self.module.manifest_secrets() - } - self.module.set_secrets(secrets) diff --git a/tests/conftest.py b/tests/conftest.py index 2af1952..71ceb74 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ from sekoia_automation import config from sekoia_automation import storage as storage_module -from sekoia_automation.module import AccountValidator, Module +from sekoia_automation.module import Module from sekoia_automation.trigger import Trigger @@ -157,12 +157,3 @@ def session_faker(faker_locale: list[str], faker_seed: int) -> Faker: instance.seed_instance(seed=faker_seed) return instance - - -class MockAccountValidator(AccountValidator): - def __init__(self, mock_return_value: bool = True): - super().__init__() - self.mock_return_value = mock_return_value - - def validate(self): - return self.mock_return_value diff --git a/tests/test_account_validator.py b/tests/test_account_validator.py index 76b1565..b060dee 100644 --- a/tests/test_account_validator.py +++ b/tests/test_account_validator.py @@ -1,90 +1,90 @@ -from unittest.mock import call, patch +from unittest.mock import patch -import pytest -import requests import requests_mock -from tests.conftest import MockAccountValidator +from sekoia_automation.account_validator import AccountValidator +from sekoia_automation.module import Module + + +class MockAccountValidator(AccountValidator): + mock_return_value = True + + def validate(self): + if not self.mock_return_value: + self.error("Validation failed") + return self.mock_return_value def test_execute_success(): validator = MockAccountValidator() + validator.mock_return_value = True with ( patch.object( validator.module, "load_config", return_value="http://example.com/callback" - ) as mock_load_config, + ), requests_mock.Mocker() as mock_request, ): mock_request.patch("http://example.com/callback", status_code=200) validator.execute() - assert mock_load_config.call_args_list == [ - call(validator.CALLBACK_URL_FILE_NAME), - call(validator.TOKEN_FILE_NAME), - ] - assert mock_request.called + # Check the callback has been called + assert mock_request.call_count == 2 + assert mock_request.request_history[0].json() == {"status": "running"} assert mock_request.last_request.json() == { - "validation_status": True, - "need_secrets": False, + "results": {"success": True}, + "status": "finished", } def test_execute_failure(): - validator = MockAccountValidator(mock_return_value=False) + validator = MockAccountValidator() + validator.mock_return_value = False with ( patch.object( validator.module, "load_config", return_value="http://example.com/callback" - ) as mock_load_config, + ), requests_mock.Mocker() as mock_request, ): mock_request.patch("http://example.com/callback", status_code=200) validator.execute() - assert mock_load_config.call_args_list == [ - call(validator.CALLBACK_URL_FILE_NAME), - call(validator.TOKEN_FILE_NAME), - ] - assert mock_request.called + # Check the callback has been called + assert mock_request.call_count == 2 + assert mock_request.request_history[0].json() == {"status": "running"} assert mock_request.last_request.json() == { - "validation_status": False, - "need_secrets": False, + "error": "Validation failed", + "results": {"success": False}, + "status": "finished", } -def test_execute_request_failure(): - validator = MockAccountValidator() - - with ( - patch.object( - validator.module, "load_config", return_value="http://example.com/callback" - ) as mock_load_config, - requests_mock.Mocker() as mock_request, - ): - mock_request.patch("http://example.com/callback", status_code=500) - - with pytest.raises(requests.exceptions.HTTPError): - validator.execute() - - assert mock_load_config.call_args_list == [ - call(validator.CALLBACK_URL_FILE_NAME), - call(validator.TOKEN_FILE_NAME), - ] - assert mock_request.called - assert mock_request.last_request.json() == { - "validation_status": True, - "need_secrets": False, +def test_execute_with_secrets(): + module = Module() + module._manifest = { + "configuration": { + "$schema": "http://json-schema.org/draft-07/schema#", + "properties": { + "api_key": {"description": "SEKOIA.IO API key", "type": "string"}, + "base_url": { + "description": "SEKOIA.IO base URL (ex. https://api.sekoia.io)", + "type": "string", + }, + }, + "required": ["api_key"], + "secrets": ["api_key"], + "title": "SEKOIA.IO Configuration", + "type": "object", } - - -def test_retrieve_secrets(): - validator = MockAccountValidator() + } + module._configuration = {"base_url": "https://api.sekoia.io"} + validator = MockAccountValidator(module=module) + validator.mock_return_value = True with ( - patch.object(validator.module, "has_secrets", return_value=True), patch.object( validator.module, "load_config", return_value="http://example.com/callback" ), @@ -92,7 +92,21 @@ def test_retrieve_secrets(): ): mock_request.patch( "http://example.com/callback", - json={"module_configuration": {"value": {"secret_key": "secret_value"}}}, + status_code=200, + json={"module_configuration": {"value": {"api_key": "foo"}}}, ) validator.execute() + + # Check the configuration has been updated with the secrets + assert module.configuration == { + "api_key": "foo", + "base_url": "https://api.sekoia.io", + } + # Check the callback has been called + assert mock_request.call_count == 2 + assert mock_request.request_history[0].json() == {"status": "running"} + assert mock_request.last_request.json() == { + "results": {"success": True}, + "status": "finished", + } diff --git a/tests/test_module.py b/tests/test_module.py index 448da73..6ab6693 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -1,5 +1,5 @@ # natives -from unittest.mock import patch +from unittest.mock import Mock, patch # third parties import pytest @@ -10,7 +10,7 @@ from sekoia_automation.exceptions import CommandNotFoundError, ModuleConfigurationError from sekoia_automation.module import Module, ModuleItem from sekoia_automation.trigger import Trigger -from tests.conftest import DEFAULT_ARGUMENTS, MockAccountValidator +from tests.conftest import DEFAULT_ARGUMENTS def test_load_config_file_not_exists(): @@ -65,12 +65,10 @@ def test_register_no_command(): def test_register_account_validator(): module = Module() - - with patch.object(module, "register") as mock_register: - module.register_account_validator(MockAccountValidator) - mock_register.assert_called_once_with( - MockAccountValidator, "validate_module_configuration" - ) + validator = Mock() + validator.name = None + module.register_account_validator(validator) + assert module._items["validate_module_configuration"] == validator @patch.object(DummyTrigger, "execute")