From dfcf5e6179dfc1b7c6c11ab22252c472f7f63916 Mon Sep 17 00:00:00 2001 From: MartinBelthle <102529366+MartinBelthle@users.noreply.github.com> Date: Tue, 19 Dec 2023 16:19:01 +0100 Subject: [PATCH] feat(ssh): add retry loop around SSH Exceptions (#68) * feat(ssh): add retry loop around ssh exceptions * refactor(ssh): remove useless stdin * refactor(ssh-connection): use a "retry" decorator --------- Co-authored-by: Laurent LAPORTE (cherry picked from commit 3e0d436891cd635dd9a2c4f6a4ea3df3bebb5f8f) --- .../remote_environnement/ssh_connection.py | 113 ++++++++++++++---- 1 file changed, 92 insertions(+), 21 deletions(-) diff --git a/antareslauncher/remote_environnement/ssh_connection.py b/antareslauncher/remote_environnement/ssh_connection.py index 35c5b80..be4622f 100644 --- a/antareslauncher/remote_environnement/ssh_connection.py +++ b/antareslauncher/remote_environnement/ssh_connection.py @@ -1,5 +1,6 @@ import contextlib import fnmatch +import functools import logging import socket import stat @@ -14,6 +15,48 @@ LocalPath = Path +def retry( + exception: t.Type[Exception], + *exceptions: t.Type[Exception], + delay_sec: float = 5, + max_retry: int = 5, + msg_fmt: str = "Retrying in {delay_sec} seconds...", +): + """ + Decorator to retry a function call if it raises an exception. + + Args: + exception: The exception to catch. + exceptions: Additional exceptions to catch. + delay_sec: The delay (in seconds) between each retry. + max_retry: The maximum number of retries. + msg_fmt: The message to display when retrying, with the following format keys: + - delay_sec: The delay (in seconds) between each retry. + - remaining: The number of remaining retries. + + Returns: + The decorated function. + """ + + def decorator(func): # type: ignore + @functools.wraps(func) + def wrapper(*args, **kwargs): # type: ignore + for attempt in range(max_retry): + try: + return func(*args, **kwargs) + except (exception, *exceptions): + logger = logging.getLogger(__name__) + remaining = max_retry - attempt - 1 + logger.warning(msg_fmt.format(delay_sec=delay_sec, remaining=remaining)) + time.sleep(delay_sec) + # Last attempt + return func(*args, **kwargs) + + return wrapper + + return decorator + + class SshConnectionError(Exception): """ SSH Connection Error @@ -21,7 +64,7 @@ class SshConnectionError(Exception): class InvalidConfigError(SshConnectionError): - def __init__(self, config, msg=""): + def __init__(self, config: t.Mapping[str, t.Any], msg: str = ""): err_msg = f"Invalid configuration error {config}" if msg: err_msg += f": {msg}" @@ -106,7 +149,7 @@ def __str__(self) -> str: return f"{self.msg:<20} ETA: {eta}s [{rate:.0%}]" return f"{self.msg:<20} ETA: ??? [{rate:.0%}]" - def accumulate(self): + def accumulate(self) -> None: """ Accumulates the quantity transferred by the previous transfer and the current transfer. @@ -151,7 +194,7 @@ def __init__(self, config: t.Mapping[str, t.Any]): self.initialize_home_dir() self.logger.info(f"Connection created with host = {self.host} and username = {self.username}") - def _init_public_key(self, key_file_name, key_password): + def _init_public_key(self, key_file_name: str, key_password: str) -> bool: """Initialises self.private_key Args: @@ -247,47 +290,75 @@ def ssh_client(self) -> t.Generator[paramiko.SSHClient, None, None]: finally: client.close() - def execute_command(self, command: str): - """Executes a command on the remote host. Puts stderr and stdout in - self.ssh_error and self.ssh_output respectively + def execute_command(self, command: str) -> t.Tuple[t.Optional[str], str]: + """ + Runs an SSH command with a retry logic. + + If it encounters an SSH Exception, it's going to sleep for 5 seconds. + The command will then be re-executed a maximum of 5 times. + It allows us to wait for the connection to be re-established. + This way, we avoid having a simulation failure due to an SSH error. Args: command: String containing the command that will be executed through the ssh connection Returns: output: The standard output of the command - error: The standard error of the command """ output = None + try: - with self.ssh_client() as client: - # fmt: off - self.logger.info(f"Running SSH command [{command}]...") - _, stdout, stderr = client.exec_command(command, timeout=30) - output = stdout.read().decode("utf-8").strip() - error = stderr.read().decode("utf-8").strip() - self.logger.info(f"SSH command stdout:\n{textwrap.indent(output, 'SSH OUTPUT> ')}") - self.logger.info(f"SSH command stderr:\n{textwrap.indent(error, 'SSH ERROR> ')}") - # fmt: on + output, error = self._exec_command(command) except socket.timeout: error = f"SSH command timed out: [{command}]" - self.logger.error(error) except paramiko.SSHException as e: error = f"SSH command failed to execute [{command}]: {e}" - self.logger.error(error) except ConnectionFailedException as e: error = f"SSH connection failed: {e}" + + if error: self.logger.error(error) return output, error + @retry( + socket.timeout, + paramiko.SSHException, + ConnectionFailedException, + delay_sec=5, + max_retry=5, + msg_fmt=( + "An SSH Error occurred, so the command did not succeed." + " The command will be re-executed {remaining} times until it succeeds." + " Retrying in {delay_sec} seconds..." + ), + ) + def _exec_command(self, command: str) -> t.Tuple[str, str]: + """ + Executes a command on the remote host. + + Args: + command: String containing the command that will be executed through the ssh connection + + Returns: + output: The standard output of the command + error: The standard error of the command + """ + with self.ssh_client() as client: + self.logger.info(f"Running SSH command [{command}]...") + _, stdout, stderr = client.exec_command(command, timeout=30) + output = stdout.read().decode("utf-8").strip() + error = stderr.read().decode("utf-8").strip() + self.logger.info(f"SSH command stdout:\n{textwrap.indent(output, 'SSH OUTPUT> ')}") + self.logger.info(f"SSH command stderr:\n{textwrap.indent(error, 'SSH ERROR> ')}") + return output, error + def upload_file(self, src: str, dst: str): """Uploads a file to a remote server via sftp protocol Args: src: Local file to upload - dst: Remote directory where the file will be uploaded Returns: @@ -311,7 +382,7 @@ def upload_file(self, src: str, dst: str): result_flag = False return result_flag - def download_file(self, src: str, dst: str): + def download_file(self, src: str, dst: str) -> bool: """Downloads a file from a remote server via sftp protocol Args: @@ -586,7 +657,7 @@ def remove_dir(self, dir_path): result_flag = False return result_flag - def test_connection(self): + def test_connection(self) -> bool: try: with self.ssh_client(): return True