Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(zip extractor): add support for -z option #67

Merged
merged 10 commits into from
Nov 21, 2023
42 changes: 24 additions & 18 deletions antareslauncher/remote_environnement/ssh_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,20 @@
import stat
import textwrap
import time
from pathlib import Path, PurePosixPath
import typing as t
from pathlib import Path, PurePosixPath

import paramiko

RemotePath = PurePosixPath
LocalPath = Path

PARAMIKO_SSH_ERROR = "Paramiko SSH Exception"
REMOTE_CONNECTION_ERROR = "Failed to connect to remote host"
IO_ERROR = "IO Error"
FILE_NOT_FOUND_ERROR = "File not found error"
DIRECTORY_NOT_FOUND_ERROR = "Directory not found error"


class SshConnectionError(Exception):
"""
Expand Down Expand Up @@ -264,7 +270,7 @@ def execute_command(self, command: str):
with self.ssh_client() as client:
# fmt: off
self.logger.info(f"Running SSH command [{command}]...")
stdin, stdout, stderr = client.exec_command(command, timeout=30)
_, stdout, stderr = client.exec_command(command, timeout=30)
output = stdout.read().decode("utf-8").strip()
error = stderr.read().decode("utf-8").strip()
self.logger.info(f"SSH command stdout:\n{textwrap.indent(output, 'SSH OUTPUT> ')}")
Expand Down Expand Up @@ -301,13 +307,13 @@ def upload_file(self, src: str, dst: str):
sftp_client.put(src, dst)
sftp_client.close()
except paramiko.SSHException:
self.logger.debug("Paramiko SSH Exception", exc_info=True)
self.logger.debug(PARAMIKO_SSH_ERROR, exc_info=True)
result_flag = False
except IOError:
self.logger.debug("IO Error", exc_info=True)
self.logger.debug(IO_ERROR, exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
result_flag = False
return result_flag

Expand All @@ -330,10 +336,10 @@ def download_file(self, src: str, dst: str):
sftp_client.close()
result_flag = True
except paramiko.SSHException:
self.logger.error("Paramiko SSH Exception", exc_info=True)
self.logger.error(PARAMIKO_SSH_ERROR, exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -370,10 +376,10 @@ def download_files(
self.logger.error(f"Timeout: {exc}", exc_info=True)
return []
except paramiko.SSHException:
self.logger.error("Paramiko SSH Exception", exc_info=True)
self.logger.error(PARAMIKO_SSH_ERROR, exc_info=True)
return []
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
return []

def _download_files(
Expand Down Expand Up @@ -449,10 +455,10 @@ def check_remote_dir_exists(self, dir_path):
else:
raise IOError
except FileNotFoundError:
self.logger.debug("FileNotFoundError", exc_info=True)
self.logger.debug(FILE_NOT_FOUND_ERROR, exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -480,10 +486,10 @@ def check_file_not_empty(self, file_path):
else:
raise IOError(f"Not a regular file: '{file_path}'")
except FileNotFoundError:
self.logger.debug("FileNotFoundError", exc_info=True)
self.logger.debug(FILE_NOT_FOUND_ERROR, exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -516,7 +522,7 @@ def make_dir(self, dir_path):
self.logger.debug("Paramiko SSHException", exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -549,7 +555,7 @@ def remove_file(self, file_path):
finally:
sftp_client.close()
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -577,12 +583,12 @@ def remove_dir(self, dir_path):
sftp_client.rmdir(dir_path)
result_flag = True
except FileNotFoundError:
self.logger.debug("DirNotFound nothing to remove", exc_info=True)
self.logger.debug(DIRECTORY_NOT_FOUND_ERROR, exc_info=True)
result_flag = True
finally:
sftp_client.close()
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
result_flag = False
return result_flag

Expand All @@ -591,5 +597,5 @@ def test_connection(self):
with self.ssh_client():
return True
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
return False
77 changes: 44 additions & 33 deletions antareslauncher/use_cases/retrieve/final_zip_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,43 +20,54 @@ def extract_final_zip(self, study: StudyDTO) -> None:
Args:
study: The current study
"""
if study.finished and not study.with_error and study.local_final_zipfile_path and not study.final_zip_extracted:
zip_path = Path(study.local_final_zipfile_path)
try:
with zipfile.ZipFile(zip_path) as zf:
names = zf.namelist()
if len(names) > 1 and os.path.commonpath(names):
# If all files are in the same directory, we can extract the ZIP
# file directly in the target directory.
target_dir = zip_path.parent
else:
# Otherwise, we need to create a directory to store the results.
# This situation occurs when the ZIP file contains
# only the simulation results and not the entire study.
target_dir = zip_path.with_suffix("")
if not study.finished or study.with_error or not study.local_final_zipfile_path or study.final_zip_extracted:
return
zip_path = Path(study.local_final_zipfile_path)
try:
# First, we detect the ZIP layout by looking at the names of the files it contains.
with zipfile.ZipFile(zip_path) as zf:
names = zf.namelist()
file_count = len(names)
has_unique_folder = file_count > 1 and os.path.commonpath(names)

if has_unique_folder:
# If the ZIP file contains a unique folder, it contains the whole study.
# We can extract it directly in the target directory.
with zipfile.ZipFile(zip_path) as zf:
target_dir = zip_path.parent
progress_bar = self._display.generate_progress_bar(
names, desc="Extracting archive:", total=len(names)
names, desc="Extracting archive:", total=file_count
)
for file in progress_bar:
zf.extract(member=file, path=target_dir)

except (OSError, zipfile.BadZipFile) as exc:
# If we cannot extract the final ZIP file, either because the file
# doesn't exist or the ZIP file is corrupted, we find ourselves
# in a situation where the results are unusable.
# In such cases, it's best to consider the simulation as failed,
# enabling the user to restart its simulation.
study.final_zip_extracted = False
study.with_error = True
self._display.show_error(
f'"{study.name}": Final zip not extracted: {exc}',
LOG_NAME,
)

else:
study.final_zip_extracted = True
self._display.show_message(
f'"{study.name}": Final zip extracted',
LOG_NAME,
)
# The directory is already an output and does not need to be unzipped.
# All we have to do is rename it by removing the prefix "finished_"
# and the suffix "_{job_id}" that lies before the ".zip".
# e.g.: "finished_Foo-Study_123456.zip" -> "Foo-Study.zip".
# or: "finished_XPANSION_Foo-Study_123456.zip" -> "Foo-Study_123456.zip".
laurent-laporte-pro marked this conversation as resolved.
Show resolved Hide resolved
new_name = zip_path.name.lstrip("finished_")
new_name = new_name.lstrip("XPANSION_")
new_name = new_name.split("_", 1)[0] + ".zip"
zip_path.rename(zip_path.parent / new_name)

except (OSError, zipfile.BadZipFile) as exc:
# If we cannot extract the final ZIP file, either because the file
# doesn't exist or the ZIP file is corrupted, we find ourselves
# in a situation where the results are unusable.
# In such cases, it's best to consider the simulation as failed,
# enabling the user to restart its simulation.
study.final_zip_extracted = False
study.with_error = True
self._display.show_error(
f'"{study.name}": Final zip not extracted: {exc}',
LOG_NAME,
)

else:
study.final_zip_extracted = True
self._display.show_message(
f'"{study.name}": Final zip extracted',
LOG_NAME,
)
34 changes: 23 additions & 11 deletions tests/unit/retriever/test_final_zip_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@ def create_final_zip(study: StudyDTO, *, scenario: str = "nominal_study") -> str
"""Prepare a final ZIP."""
dst_dir = Path(study.output_dir) # must exist
dst_dir.mkdir(parents=True, exist_ok=True)
out_path = dst_dir.joinpath(f"finished_{study.name}_{study.job_id}.zip")
if scenario == "nominal_study":
if "xpansion" in scenario:
out_path = dst_dir.joinpath(f"finished_XPANSION_{study.name}_{study.job_id}.zip")
else:
out_path = dst_dir.joinpath(f"finished_{study.name}_{study.job_id}.zip")
if scenario in {"nominal_study", "xpansion_study"}:
# Case where the ZIP contains all the study files.
with zipfile.ZipFile(
out_path,
mode="w",
Expand All @@ -29,16 +33,19 @@ def create_final_zip(study: StudyDTO, *, scenario: str = "nominal_study") -> str
f"{study.name}/output/20230922-1601eco/simulation.log",
data=b"Simulation OK",
)
elif scenario == "nominal_results":
elif scenario in {"nominal_results", "xpansion_results"}:
# Case where the ZIP contains only the results.
with zipfile.ZipFile(
out_path,
mode="w",
compression=zipfile.ZIP_DEFLATED,
) as zf:
zf.writestr("simulation.log", data=b"Simulation OK")
elif scenario == "corrupted":
# Case where the ZIP is corrupted.
out_path.write_bytes(b"PK corrupted content")
elif scenario == "missing":
# Case where the ZIP is missing.
pass
else:
raise NotImplementedError(scenario)
Expand Down Expand Up @@ -86,12 +93,13 @@ def test_extract_final_zip__finished_study__no_output(self, finished_study: Stud
assert not finished_study.final_zip_extracted

@pytest.mark.unit_test
def test_extract_final_zip__finished_study__nominal_study(self, finished_study: StudyDTO) -> None:
@pytest.mark.parametrize("scenario", ["nominal_study", "xpansion_study"])
def test_extract_final_zip__finished_study__nominal_study(self, finished_study: StudyDTO, scenario: str) -> None:
display = mock.Mock(spec=DisplayTerminal)
display.generate_progress_bar = lambda names, *args, **kwargs: names

# Prepare a valid final ZIP
finished_study.local_final_zipfile_path = create_final_zip(finished_study, scenario="nominal_study")
finished_study.local_final_zipfile_path = create_final_zip(finished_study, scenario=scenario)

# Initialize and execute the ZIP extraction
extractor = FinalZipExtractor(display=display)
Expand All @@ -113,12 +121,13 @@ def test_extract_final_zip__finished_study__nominal_study(self, finished_study:
assert result_dir.joinpath(file).is_file()

@pytest.mark.unit_test
def test_extract_final_zip__finished_study__nominal_results(self, finished_study: StudyDTO) -> None:
@pytest.mark.parametrize("scenario", ["nominal_results", "xpansion_results"])
def test_extract_final_zip__finished_study__nominal_results(self, finished_study: StudyDTO, scenario: str) -> None:
display = mock.Mock(spec=DisplayTerminal)
display.generate_progress_bar = lambda names, *args, **kwargs: names

# Prepare a valid final ZIP
finished_study.local_final_zipfile_path = create_final_zip(finished_study, scenario="nominal_results")
finished_study.local_final_zipfile_path = create_final_zip(finished_study, scenario=scenario)

# Initialize and execute the ZIP extraction
extractor = FinalZipExtractor(display=display)
Expand All @@ -131,16 +140,19 @@ def test_extract_final_zip__finished_study__nominal_results(self, finished_study
assert finished_study.final_zip_extracted
assert not finished_study.with_error

result_dir = Path(finished_study.local_final_zipfile_path).with_suffix("")
assert result_dir.joinpath("simulation.log").is_file()
result_dir = (Path(finished_study.local_final_zipfile_path).parent / finished_study.name).with_suffix(".zip")
assert result_dir.exists()
with zipfile.ZipFile(result_dir, "r") as zf:
assert zf.namelist() == ["simulation.log"]

@pytest.mark.unit_test
def test_extract_final_zip__finished_study__reentrancy(self, finished_study: StudyDTO) -> None:
@pytest.mark.parametrize("scenario", ["nominal_study", "xpansion_study"])
def test_extract_final_zip__finished_study__reentrancy(self, finished_study: StudyDTO, scenario: str) -> None:
display = mock.Mock(spec=DisplayTerminal)
display.generate_progress_bar = lambda names, *args, **kwargs: names

# Prepare a valid final ZIP
finished_study.local_final_zipfile_path = create_final_zip(finished_study)
finished_study.local_final_zipfile_path = create_final_zip(finished_study, scenario=scenario)

# Initialize and execute the ZIP extraction twice
extractor = FinalZipExtractor(display=display)
Expand Down