Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v1.3.2 #70

Merged
merged 5 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@ npx auto-changelog -l false --hide-empty-releases -v v1.3.1 -o CHANGES.out.md
```
-->

## [1.3.2] - 2024-04-11

### Build

- build: add a script to bump the version

### Changed

- feat(ssh): add retry loop around SSH Exceptions [`#68`](https://github.com/AntaresSimulatorTeam/antares-launcher/pull/68)

### Fixes

- fix(retriever): avoid infinite loop if sbatch command fails [`#69`](https://github.com/AntaresSimulatorTeam/antares-launcher/pull/69)


## [1.3.1] - 2023-09-26

### Changed
Expand Down
4 changes: 2 additions & 2 deletions antareslauncher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

# Standard project metadata

__version__ = "1.3.1"
__version__ = "1.3.2"
__author__ = "RTE, Antares Web Team"
__date__ = "2023-09-26"
__date__ = "2024-04-11"
# noinspection SpellCheckingInspection
__credits__ = "(c) Réseau de Transport de l’Électricité (RTE)"

Expand Down
197 changes: 134 additions & 63 deletions antareslauncher/remote_environnement/ssh_connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import fnmatch
import functools
import logging
import socket
import stat
Expand All @@ -14,14 +15,56 @@
LocalPath = Path


def retry(
exception: t.Type[Exception],
*exceptions: t.Type[Exception],
delay_sec: float = 5,
max_retry: int = 5,
msg_fmt: str = "Retrying in {delay_sec} seconds...",
):
"""
Decorator to retry a function call if it raises an exception.

Args:
exception: The exception to catch.
exceptions: Additional exceptions to catch.
delay_sec: The delay (in seconds) between each retry.
max_retry: The maximum number of retries.
msg_fmt: The message to display when retrying, with the following format keys:
- delay_sec: The delay (in seconds) between each retry.
- remaining: The number of remaining retries.

Returns:
The decorated function.
"""

def decorator(func): # type: ignore
@functools.wraps(func)
def wrapper(*args, **kwargs): # type: ignore
for attempt in range(max_retry):
try:
return func(*args, **kwargs)
except (exception, *exceptions):
logger = logging.getLogger(__name__)
remaining = max_retry - attempt - 1
logger.warning(msg_fmt.format(delay_sec=delay_sec, remaining=remaining))
time.sleep(delay_sec)
# Last attempt
return func(*args, **kwargs)

return wrapper

return decorator


class SshConnectionError(Exception):
"""
SSH Connection Error
"""


class InvalidConfigError(SshConnectionError):
def __init__(self, config, msg=""):
def __init__(self, config: t.Mapping[str, t.Any], msg: str = ""):
err_msg = f"Invalid configuration error {config}"
if msg:
err_msg += f": {msg}"
Expand Down Expand Up @@ -106,7 +149,7 @@ def __str__(self) -> str:
return f"{self.msg:<20} ETA: {eta}s [{rate:.0%}]"
return f"{self.msg:<20} ETA: ??? [{rate:.0%}]"

def accumulate(self):
def accumulate(self) -> None:
"""
Accumulates the quantity transferred by the previous transfer and
the current transfer.
Expand Down Expand Up @@ -151,7 +194,7 @@ def __init__(self, config: t.Mapping[str, t.Any]):
self.initialize_home_dir()
self.logger.info(f"Connection created with host = {self.host} and username = {self.username}")

def _init_public_key(self, key_file_name, key_password):
def _init_public_key(self, key_file_name: str, key_password: str) -> bool:
"""Initialises self.private_key

Args:
Expand Down Expand Up @@ -234,7 +277,7 @@ def ssh_client(self) -> t.Generator[paramiko.SSHClient, None, None]:
self.logger.exception(f"paramiko.AuthenticationException: {paramiko.AuthenticationException}")
raise ConnectionFailedException(self.host, self.port, self.username) from e
except paramiko.SSHException as e:
self.logger.exception(f"paramiko.SSHException: {paramiko.SSHException}")
self.logger.exception(f"Paramiko SSH Exception: {e!r}")
raise ConnectionFailedException(self.host, self.port, self.username) from e
except socket.timeout as e:
self.logger.exception(f"socket.timeout: {socket.timeout}")
Expand All @@ -247,47 +290,75 @@ def ssh_client(self) -> t.Generator[paramiko.SSHClient, None, None]:
finally:
client.close()

def execute_command(self, command: str):
"""Executes a command on the remote host. Puts stderr and stdout in
self.ssh_error and self.ssh_output respectively
def execute_command(self, command: str) -> t.Tuple[t.Optional[str], str]:
"""
Runs an SSH command with a retry logic.

If it encounters an SSH Exception, it's going to sleep for 5 seconds.
The command will then be re-executed a maximum of 5 times.
It allows us to wait for the connection to be re-established.
This way, we avoid having a simulation failure due to an SSH error.

Args:
command: String containing the command that will be executed through the ssh connection

Returns:
output: The standard output of the command

error: The standard error of the command
"""
output = None

try:
with self.ssh_client() as client:
# fmt: off
self.logger.info(f"Running SSH command [{command}]...")
stdin, stdout, stderr = client.exec_command(command, timeout=30)
output = stdout.read().decode("utf-8").strip()
error = stderr.read().decode("utf-8").strip()
self.logger.info(f"SSH command stdout:\n{textwrap.indent(output, 'SSH OUTPUT> ')}")
self.logger.info(f"SSH command stderr:\n{textwrap.indent(error, 'SSH ERROR> ')}")
# fmt: on
output, error = self._exec_command(command)
except socket.timeout:
error = f"SSH command timed out: [{command}]"
self.logger.error(error)
except paramiko.SSHException as e:
error = f"SSH command failed to execute [{command}]: {e}"
self.logger.error(error)
except ConnectionFailedException as e:
error = f"SSH connection failed: {e}"

if error:
self.logger.error(error)

return output, error

@retry(
socket.timeout,
paramiko.SSHException,
ConnectionFailedException,
delay_sec=5,
max_retry=5,
msg_fmt=(
"An SSH Error occurred, so the command did not succeed."
" The command will be re-executed {remaining} times until it succeeds."
" Retrying in {delay_sec} seconds..."
),
)
def _exec_command(self, command: str) -> t.Tuple[str, str]:
"""
Executes a command on the remote host.

Args:
command: String containing the command that will be executed through the ssh connection

Returns:
output: The standard output of the command
error: The standard error of the command
"""
with self.ssh_client() as client:
self.logger.info(f"Running SSH command [{command}]...")
_, stdout, stderr = client.exec_command(command, timeout=30)
output = stdout.read().decode("utf-8").strip()
error = stderr.read().decode("utf-8").strip()
self.logger.info(f"SSH command stdout:\n{textwrap.indent(output, 'SSH OUTPUT> ')}")
self.logger.info(f"SSH command stderr:\n{textwrap.indent(error, 'SSH ERROR> ')}")
return output, error

def upload_file(self, src: str, dst: str):
"""Uploads a file to a remote server via sftp protocol

Args:
src: Local file to upload

dst: Remote directory where the file will be uploaded

Returns:
Expand All @@ -300,18 +371,18 @@ def upload_file(self, src: str, dst: str):
sftp_client = client.open_sftp()
sftp_client.put(src, dst)
sftp_client.close()
except paramiko.SSHException:
self.logger.debug("Paramiko SSH Exception", exc_info=True)
except paramiko.SSHException as e:
self.logger.debug(f"Paramiko SSH Exception: {e!r}", exc_info=True)
result_flag = False
except IOError:
self.logger.debug("IO Error", exc_info=True)
except IOError as e:
self.logger.debug(f"IO Error: {e!r}", exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
result_flag = False
return result_flag

def download_file(self, src: str, dst: str):
def download_file(self, src: str, dst: str) -> bool:
"""Downloads a file from a remote server via sftp protocol

Args:
Expand All @@ -329,21 +400,21 @@ def download_file(self, src: str, dst: str):
sftp_client.get(src, dst)
sftp_client.close()
result_flag = True
except paramiko.SSHException:
self.logger.error("Paramiko SSH Exception", exc_info=True)
except paramiko.SSHException as e:
self.logger.error(f"Paramiko SSH Exception: {e!r}", exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
result_flag = False
return result_flag

def download_files(
self,
src_dir: RemotePath,
dst_dir: LocalPath,
pattern: str,
*patterns: str,
remove: bool = True,
self,
src_dir: RemotePath,
dst_dir: LocalPath,
pattern: str,
*patterns: str,
remove: bool = True,
) -> t.Sequence[LocalPath]:
"""
Download files matching the specified patterns from the remote
Expand All @@ -369,20 +440,20 @@ def download_files(
except TimeoutError as exc:
self.logger.error(f"Timeout: {exc}", exc_info=True)
return []
except paramiko.SSHException:
self.logger.error("Paramiko SSH Exception", exc_info=True)
except paramiko.SSHException as e:
self.logger.error(f"Paramiko SSH Exception: {e!r}", exc_info=True)
return []
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
return []

def _download_files(
self,
src_dir: RemotePath,
dst_dir: LocalPath,
patterns: t.Tuple[str, ...],
*,
remove: bool = True,
self,
src_dir: RemotePath,
dst_dir: LocalPath,
patterns: t.Tuple[str, ...],
*,
remove: bool = True,
) -> t.Sequence[LocalPath]:
"""
Download files matching the specified patterns from the remote
Expand Down Expand Up @@ -447,12 +518,12 @@ def check_remote_dir_exists(self, dir_path):
if stat.S_ISDIR(sftp_stat.st_mode):
result_flag = True
else:
raise IOError
raise IOError(f"Not a directory: '{dir_path}'")
except FileNotFoundError:
self.logger.debug("FileNotFoundError", exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -482,8 +553,8 @@ def check_file_not_empty(self, file_path):
except FileNotFoundError:
self.logger.debug("FileNotFoundError", exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -512,11 +583,11 @@ def make_dir(self, dir_path):
result_flag = True
finally:
sftp_client.close()
except paramiko.SSHException:
self.logger.debug("Paramiko SSHException", exc_info=True)
except paramiko.SSHException as e:
self.logger.debug(f"Paramiko SSH Exception: {e!r}", exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -548,8 +619,8 @@ def remove_file(self, file_path):
result_flag = True
finally:
sftp_client.close()
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -581,15 +652,15 @@ def remove_dir(self, dir_path):
result_flag = True
finally:
sftp_client.close()
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
result_flag = False
return result_flag

def test_connection(self):
def test_connection(self) -> bool:
try:
with self.ssh_client():
return True
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
except ConnectionFailedException as e:
self.logger.error(f"Failed to connect to remote host: {e!r}", exc_info=True)
return False
2 changes: 1 addition & 1 deletion antareslauncher/use_cases/retrieve/state_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def run(self, study: StudyDTO) -> None:
Args:
study: The study data transfer object
"""
if not study.done:
if not study.done and not study.with_error:
# set current study job state flags
if study.job_id:
s, f, e = self._env.get_job_state_flags(study)
Expand Down
Loading
Loading