diff --git a/antareslauncher/remote_environnement/ssh_connection.py b/antareslauncher/remote_environnement/ssh_connection.py index 5109eb8..389151a 100644 --- a/antareslauncher/remote_environnement/ssh_connection.py +++ b/antareslauncher/remote_environnement/ssh_connection.py @@ -1,5 +1,6 @@ import contextlib import fnmatch +import functools import logging import socket import stat @@ -20,6 +21,48 @@ DIRECTORY_NOT_FOUND_ERROR = "Directory not found error" +def retry( + exception: t.Type[Exception], + *exceptions: t.Type[Exception], + delay_sec: float = 5, + max_retry: int = 5, + msg_fmt: str = "Retrying in {delay_sec} seconds...", +): + """ + Decorator to retry a function call if it raises an exception. + + Args: + exception: The exception to catch. + exceptions: Additional exceptions to catch. + delay_sec: The delay (in seconds) between each retry. + max_retry: The maximum number of retries. + msg_fmt: The message to display when retrying, with the following format keys: + - delay_sec: The delay (in seconds) between each retry. + - remaining: The number of remaining retries. + + Returns: + The decorated function. + """ + + def decorator(func): # type: ignore + @functools.wraps(func) + def wrapper(*args, **kwargs): # type: ignore + for attempt in range(max_retry): + try: + return func(*args, **kwargs) + except (exception, *exceptions): + logger = logging.getLogger(__name__) + remaining = max_retry - attempt - 1 + logger.warning(msg_fmt.format(delay_sec=delay_sec, remaining=remaining)) + time.sleep(delay_sec) + # Last attempt + return func(*args, **kwargs) + + return wrapper + + return decorator + + class SshConnectionError(Exception): """ SSH Connection Error @@ -27,7 +70,7 @@ class SshConnectionError(Exception): class InvalidConfigError(SshConnectionError): - def __init__(self, config, msg=""): + def __init__(self, config: t.Mapping[str, t.Any], msg: str = ""): err_msg = f"Invalid configuration error {config}" if msg: err_msg += f": {msg}" @@ -112,7 +155,7 @@ def __str__(self) -> str: return f"{self.msg:<20} ETA: {eta}s [{rate:.0%}]" return f"{self.msg:<20} ETA: ??? [{rate:.0%}]" - def accumulate(self): + def accumulate(self) -> None: """ Accumulates the quantity transferred by the previous transfer and the current transfer. @@ -157,7 +200,7 @@ def __init__(self, config: t.Mapping[str, t.Any]): self.initialize_home_dir() self.logger.info(f"Connection created with host = {self.host} and username = {self.username}") - def _init_public_key(self, key_file_name, key_password): + def _init_public_key(self, key_file_name: str, key_password: str) -> bool: """Initialises self.private_key Args: @@ -253,7 +296,7 @@ def ssh_client(self) -> t.Generator[paramiko.SSHClient, None, None]: finally: client.close() - def execute_command(self, command: str) -> t.Tuple[str, str]: + def execute_command(self, command: str) -> t.Tuple[t.Optional[str], str]: """ Runs an SSH command with a retry logic. @@ -267,37 +310,45 @@ def execute_command(self, command: str) -> t.Tuple[str, str]: Returns: output: The standard output of the command - error: The standard error of the command """ output = None - error = "" - amount_of_retries = 5 - time_to_sleep = 5 - - for attempt in range(amount_of_retries): - if attempt != 0: - self.logger.info(f"An SSH Error occurred, so the command {command} did not succeed. The command will " - f"be re-executed {amount_of_retries - attempt} times until it succeeds.") - time.sleep(time_to_sleep) - try: - output, error = self._exec_command(command) - break - except (socket.timeout, paramiko.SSHException, ConnectionFailedException) as e: - error = self._handle_exception(command, e) + + try: + output, error = self._exec_command(command) + except socket.timeout: + error = f"SSH command timed out: [{command}]" + except paramiko.SSHException as e: + error = f"SSH command failed to execute [{command}]: {e}" + except ConnectionFailedException as e: + error = f"SSH connection failed: {e}" + + if error: + self.logger.error(error) + return output, error + @retry( + socket.timeout, + paramiko.SSHException, + ConnectionFailedException, + delay_sec=5, + max_retry=5, + msg_fmt=( + "An SSH Error occurred, so the command did not succeed." + " The command will be re-executed {remaining} times until it succeeds." + " Retrying in {delay_sec} seconds..." + ), + ) def _exec_command(self, command: str) -> t.Tuple[str, str]: """ Executes a command on the remote host. - Puts stderr and stdout in self.ssh_error and self.ssh_output respectively Args: command: String containing the command that will be executed through the ssh connection Returns: output: The standard output of the command - error: The standard error of the command """ with self.ssh_client() as client: @@ -309,32 +360,11 @@ def _exec_command(self, command: str) -> t.Tuple[str, str]: self.logger.info(f"SSH command stderr:\n{textwrap.indent(error, 'SSH ERROR> ')}") return output, error - def _handle_exception(self, command: str, exception) -> str: - """ - Handles SSH Exceptions that can occur. - Logs an appropriate error message according to the raised Exception - - Args: - command: String containing the command that will be executed through the ssh connection - - Returns: - The error message - """ - if isinstance(exception, socket.timeout): - error_msg = f"SSH command timed out: [{command}]" - elif isinstance(exception, paramiko.SSHException): - error_msg = f"SSH command failed to execute [{command}]: {exception}" - else: - error_msg = f"SSH connection failed: {exception}" - self.logger.error(error_msg) - return error_msg - def upload_file(self, src: str, dst: str): """Uploads a file to a remote server via sftp protocol Args: src: Local file to upload - dst: Remote directory where the file will be uploaded Returns: @@ -358,7 +388,7 @@ def upload_file(self, src: str, dst: str): result_flag = False return result_flag - def download_file(self, src: str, dst: str): + def download_file(self, src: str, dst: str) -> bool: """Downloads a file from a remote server via sftp protocol Args: @@ -633,7 +663,7 @@ def remove_dir(self, dir_path): result_flag = False return result_flag - def test_connection(self): + def test_connection(self) -> bool: try: with self.ssh_client(): return True