Skip to content

Commit

Permalink
refactor(ssh-connection): use a "retry" decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
laurent-laporte-pro committed Dec 19, 2023
1 parent 5f84889 commit c321b55
Showing 1 changed file with 74 additions and 44 deletions.
118 changes: 74 additions & 44 deletions antareslauncher/remote_environnement/ssh_connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import fnmatch
import functools
import logging
import socket
import stat
Expand All @@ -20,14 +21,56 @@
DIRECTORY_NOT_FOUND_ERROR = "Directory not found error"


def retry(
exception: t.Type[Exception],
*exceptions: t.Type[Exception],
delay_sec: float = 5,
max_retry: int = 5,
msg_fmt: str = "Retrying in {delay_sec} seconds...",
):
"""
Decorator to retry a function call if it raises an exception.
Args:
exception: The exception to catch.
exceptions: Additional exceptions to catch.
delay_sec: The delay (in seconds) between each retry.
max_retry: The maximum number of retries.
msg_fmt: The message to display when retrying, with the following format keys:
- delay_sec: The delay (in seconds) between each retry.
- remaining: The number of remaining retries.
Returns:
The decorated function.
"""

def decorator(func): # type: ignore
@functools.wraps(func)
def wrapper(*args, **kwargs): # type: ignore
for attempt in range(max_retry):
try:
return func(*args, **kwargs)
except (exception, *exceptions):
logger = logging.getLogger(__name__)
remaining = max_retry - attempt - 1
logger.warning(msg_fmt.format(delay_sec=delay_sec, remaining=remaining))
time.sleep(delay_sec)
# Last attempt
return func(*args, **kwargs)

return wrapper

return decorator


class SshConnectionError(Exception):
"""
SSH Connection Error
"""


class InvalidConfigError(SshConnectionError):
def __init__(self, config, msg=""):
def __init__(self, config: t.Mapping[str, t.Any], msg: str = ""):
err_msg = f"Invalid configuration error {config}"
if msg:
err_msg += f": {msg}"
Expand Down Expand Up @@ -112,7 +155,7 @@ def __str__(self) -> str:
return f"{self.msg:<20} ETA: {eta}s [{rate:.0%}]"
return f"{self.msg:<20} ETA: ??? [{rate:.0%}]"

def accumulate(self):
def accumulate(self) -> None:
"""
Accumulates the quantity transferred by the previous transfer and
the current transfer.
Expand Down Expand Up @@ -157,7 +200,7 @@ def __init__(self, config: t.Mapping[str, t.Any]):
self.initialize_home_dir()
self.logger.info(f"Connection created with host = {self.host} and username = {self.username}")

def _init_public_key(self, key_file_name, key_password):
def _init_public_key(self, key_file_name: str, key_password: str) -> bool:
"""Initialises self.private_key
Args:
Expand Down Expand Up @@ -253,7 +296,7 @@ def ssh_client(self) -> t.Generator[paramiko.SSHClient, None, None]:
finally:
client.close()

def execute_command(self, command: str) -> t.Tuple[str, str]:
def execute_command(self, command: str) -> t.Tuple[t.Optional[str], str]:
"""
Runs an SSH command with a retry logic.
Expand All @@ -267,37 +310,45 @@ def execute_command(self, command: str) -> t.Tuple[str, str]:
Returns:
output: The standard output of the command
error: The standard error of the command
"""
output = None
error = ""
amount_of_retries = 5
time_to_sleep = 5

for attempt in range(amount_of_retries):
if attempt != 0:
self.logger.info(f"An SSH Error occurred, so the command {command} did not succeed. The command will "
f"be re-executed {amount_of_retries - attempt} times until it succeeds.")
time.sleep(time_to_sleep)
try:
output, error = self._exec_command(command)
break
except (socket.timeout, paramiko.SSHException, ConnectionFailedException) as e:
error = self._handle_exception(command, e)

try:
output, error = self._exec_command(command)
except socket.timeout:
error = f"SSH command timed out: [{command}]"
except paramiko.SSHException as e:
error = f"SSH command failed to execute [{command}]: {e}"
except ConnectionFailedException as e:
error = f"SSH connection failed: {e}"

if error:
self.logger.error(error)

return output, error

@retry(
socket.timeout,
paramiko.SSHException,
ConnectionFailedException,
delay_sec=5,
max_retry=5,
msg_fmt=(
"An SSH Error occurred, so the command did not succeed."
" The command will be re-executed {remaining} times until it succeeds."
" Retrying in {delay_sec} seconds..."
),
)
def _exec_command(self, command: str) -> t.Tuple[str, str]:
"""
Executes a command on the remote host.
Puts stderr and stdout in self.ssh_error and self.ssh_output respectively
Args:
command: String containing the command that will be executed through the ssh connection
Returns:
output: The standard output of the command
error: The standard error of the command
"""
with self.ssh_client() as client:
Expand All @@ -309,32 +360,11 @@ def _exec_command(self, command: str) -> t.Tuple[str, str]:
self.logger.info(f"SSH command stderr:\n{textwrap.indent(error, 'SSH ERROR> ')}")
return output, error

def _handle_exception(self, command: str, exception) -> str:
"""
Handles SSH Exceptions that can occur.
Logs an appropriate error message according to the raised Exception
Args:
command: String containing the command that will be executed through the ssh connection
Returns:
The error message
"""
if isinstance(exception, socket.timeout):
error_msg = f"SSH command timed out: [{command}]"
elif isinstance(exception, paramiko.SSHException):
error_msg = f"SSH command failed to execute [{command}]: {exception}"
else:
error_msg = f"SSH connection failed: {exception}"
self.logger.error(error_msg)
return error_msg

def upload_file(self, src: str, dst: str):
"""Uploads a file to a remote server via sftp protocol
Args:
src: Local file to upload
dst: Remote directory where the file will be uploaded
Returns:
Expand All @@ -358,7 +388,7 @@ def upload_file(self, src: str, dst: str):
result_flag = False
return result_flag

def download_file(self, src: str, dst: str):
def download_file(self, src: str, dst: str) -> bool:
"""Downloads a file from a remote server via sftp protocol
Args:
Expand Down Expand Up @@ -633,7 +663,7 @@ def remove_dir(self, dir_path):
result_flag = False
return result_flag

def test_connection(self):
def test_connection(self) -> bool:
try:
with self.ssh_client():
return True
Expand Down

0 comments on commit c321b55

Please sign in to comment.