Skip to content

Commit

Permalink
feat(zip extractor): add support for -z option (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBelthle authored Nov 21, 2023
1 parent 516f94d commit 6a22daa
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 62 deletions.
42 changes: 24 additions & 18 deletions antareslauncher/remote_environnement/ssh_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,20 @@
import stat
import textwrap
import time
from pathlib import Path, PurePosixPath
import typing as t
from pathlib import Path, PurePosixPath

import paramiko

RemotePath = PurePosixPath
LocalPath = Path

PARAMIKO_SSH_ERROR = "Paramiko SSH Exception"
REMOTE_CONNECTION_ERROR = "Failed to connect to remote host"
IO_ERROR = "IO Error"
FILE_NOT_FOUND_ERROR = "File not found error"
DIRECTORY_NOT_FOUND_ERROR = "Directory not found error"


class SshConnectionError(Exception):
"""
Expand Down Expand Up @@ -264,7 +270,7 @@ def execute_command(self, command: str):
with self.ssh_client() as client:
# fmt: off
self.logger.info(f"Running SSH command [{command}]...")
stdin, stdout, stderr = client.exec_command(command, timeout=30)
_, stdout, stderr = client.exec_command(command, timeout=30)
output = stdout.read().decode("utf-8").strip()
error = stderr.read().decode("utf-8").strip()
self.logger.info(f"SSH command stdout:\n{textwrap.indent(output, 'SSH OUTPUT> ')}")
Expand Down Expand Up @@ -301,13 +307,13 @@ def upload_file(self, src: str, dst: str):
sftp_client.put(src, dst)
sftp_client.close()
except paramiko.SSHException:
self.logger.debug("Paramiko SSH Exception", exc_info=True)
self.logger.debug(PARAMIKO_SSH_ERROR, exc_info=True)
result_flag = False
except IOError:
self.logger.debug("IO Error", exc_info=True)
self.logger.debug(IO_ERROR, exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
result_flag = False
return result_flag

Expand All @@ -330,10 +336,10 @@ def download_file(self, src: str, dst: str):
sftp_client.close()
result_flag = True
except paramiko.SSHException:
self.logger.error("Paramiko SSH Exception", exc_info=True)
self.logger.error(PARAMIKO_SSH_ERROR, exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -370,10 +376,10 @@ def download_files(
self.logger.error(f"Timeout: {exc}", exc_info=True)
return []
except paramiko.SSHException:
self.logger.error("Paramiko SSH Exception", exc_info=True)
self.logger.error(PARAMIKO_SSH_ERROR, exc_info=True)
return []
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
return []

def _download_files(
Expand Down Expand Up @@ -449,10 +455,10 @@ def check_remote_dir_exists(self, dir_path):
else:
raise IOError
except FileNotFoundError:
self.logger.debug("FileNotFoundError", exc_info=True)
self.logger.debug(FILE_NOT_FOUND_ERROR, exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -480,10 +486,10 @@ def check_file_not_empty(self, file_path):
else:
raise IOError(f"Not a regular file: '{file_path}'")
except FileNotFoundError:
self.logger.debug("FileNotFoundError", exc_info=True)
self.logger.debug(FILE_NOT_FOUND_ERROR, exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -516,7 +522,7 @@ def make_dir(self, dir_path):
self.logger.debug("Paramiko SSHException", exc_info=True)
result_flag = False
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -549,7 +555,7 @@ def remove_file(self, file_path):
finally:
sftp_client.close()
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
result_flag = False
return result_flag

Expand Down Expand Up @@ -577,12 +583,12 @@ def remove_dir(self, dir_path):
sftp_client.rmdir(dir_path)
result_flag = True
except FileNotFoundError:
self.logger.debug("DirNotFound nothing to remove", exc_info=True)
self.logger.debug(DIRECTORY_NOT_FOUND_ERROR, exc_info=True)
result_flag = True
finally:
sftp_client.close()
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
result_flag = False
return result_flag

Expand All @@ -591,5 +597,5 @@ def test_connection(self):
with self.ssh_client():
return True
except ConnectionFailedException:
self.logger.error("Failed to connect to remote host", exc_info=True)
self.logger.error(REMOTE_CONNECTION_ERROR, exc_info=True)
return False
77 changes: 44 additions & 33 deletions antareslauncher/use_cases/retrieve/final_zip_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,43 +20,54 @@ def extract_final_zip(self, study: StudyDTO) -> None:
Args:
study: The current study
"""
if study.finished and not study.with_error and study.local_final_zipfile_path and not study.final_zip_extracted:
zip_path = Path(study.local_final_zipfile_path)
try:
with zipfile.ZipFile(zip_path) as zf:
names = zf.namelist()
if len(names) > 1 and os.path.commonpath(names):
# If all files are in the same directory, we can extract the ZIP
# file directly in the target directory.
target_dir = zip_path.parent
else:
# Otherwise, we need to create a directory to store the results.
# This situation occurs when the ZIP file contains
# only the simulation results and not the entire study.
target_dir = zip_path.with_suffix("")
if not study.finished or study.with_error or not study.local_final_zipfile_path or study.final_zip_extracted:
return
zip_path = Path(study.local_final_zipfile_path)
try:
# First, we detect the ZIP layout by looking at the names of the files it contains.
with zipfile.ZipFile(zip_path) as zf:
names = zf.namelist()
file_count = len(names)
has_unique_folder = file_count > 1 and os.path.commonpath(names)

if has_unique_folder:
# If the ZIP file contains a unique folder, it contains the whole study.
# We can extract it directly in the target directory.
with zipfile.ZipFile(zip_path) as zf:
target_dir = zip_path.parent
progress_bar = self._display.generate_progress_bar(
names, desc="Extracting archive:", total=len(names)
names, desc="Extracting archive:", total=file_count
)
for file in progress_bar:
zf.extract(member=file, path=target_dir)

except (OSError, zipfile.BadZipFile) as exc:
# If we cannot extract the final ZIP file, either because the file
# doesn't exist or the ZIP file is corrupted, we find ourselves
# in a situation where the results are unusable.
# In such cases, it's best to consider the simulation as failed,
# enabling the user to restart its simulation.
study.final_zip_extracted = False
study.with_error = True
self._display.show_error(
f'"{study.name}": Final zip not extracted: {exc}',
LOG_NAME,
)

else:
study.final_zip_extracted = True
self._display.show_message(
f'"{study.name}": Final zip extracted',
LOG_NAME,
)
# The directory is already an output and does not need to be unzipped.
# All we have to do is rename it by removing the prefix "finished_"
# and the suffix "_{job_id}" that lies before the ".zip".
# e.g.: "finished_Foo-Study_123456.zip" -> "Foo-Study.zip".
# or: "finished_XPANSION_Foo-Study_123456.zip" -> "Foo-Study_123456.zip".
new_name = zip_path.name.lstrip("finished_")
new_name = new_name.lstrip("XPANSION_")
new_name = new_name.split("_", 1)[0] + ".zip"
zip_path.rename(zip_path.parent / new_name)

except (OSError, zipfile.BadZipFile) as exc:
# If we cannot extract the final ZIP file, either because the file
# doesn't exist or the ZIP file is corrupted, we find ourselves
# in a situation where the results are unusable.
# In such cases, it's best to consider the simulation as failed,
# enabling the user to restart its simulation.
study.final_zip_extracted = False
study.with_error = True
self._display.show_error(
f'"{study.name}": Final zip not extracted: {exc}',
LOG_NAME,
)

else:
study.final_zip_extracted = True
self._display.show_message(
f'"{study.name}": Final zip extracted',
LOG_NAME,
)
34 changes: 23 additions & 11 deletions tests/unit/retriever/test_final_zip_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@ def create_final_zip(study: StudyDTO, *, scenario: str = "nominal_study") -> str
"""Prepare a final ZIP."""
dst_dir = Path(study.output_dir) # must exist
dst_dir.mkdir(parents=True, exist_ok=True)
out_path = dst_dir.joinpath(f"finished_{study.name}_{study.job_id}.zip")
if scenario == "nominal_study":
if "xpansion" in scenario:
out_path = dst_dir.joinpath(f"finished_XPANSION_{study.name}_{study.job_id}.zip")
else:
out_path = dst_dir.joinpath(f"finished_{study.name}_{study.job_id}.zip")
if scenario in {"nominal_study", "xpansion_study"}:
# Case where the ZIP contains all the study files.
with zipfile.ZipFile(
out_path,
mode="w",
Expand All @@ -29,16 +33,19 @@ def create_final_zip(study: StudyDTO, *, scenario: str = "nominal_study") -> str
f"{study.name}/output/20230922-1601eco/simulation.log",
data=b"Simulation OK",
)
elif scenario == "nominal_results":
elif scenario in {"nominal_results", "xpansion_results"}:
# Case where the ZIP contains only the results.
with zipfile.ZipFile(
out_path,
mode="w",
compression=zipfile.ZIP_DEFLATED,
) as zf:
zf.writestr("simulation.log", data=b"Simulation OK")
elif scenario == "corrupted":
# Case where the ZIP is corrupted.
out_path.write_bytes(b"PK corrupted content")
elif scenario == "missing":
# Case where the ZIP is missing.
pass
else:
raise NotImplementedError(scenario)
Expand Down Expand Up @@ -86,12 +93,13 @@ def test_extract_final_zip__finished_study__no_output(self, finished_study: Stud
assert not finished_study.final_zip_extracted

@pytest.mark.unit_test
def test_extract_final_zip__finished_study__nominal_study(self, finished_study: StudyDTO) -> None:
@pytest.mark.parametrize("scenario", ["nominal_study", "xpansion_study"])
def test_extract_final_zip__finished_study__nominal_study(self, finished_study: StudyDTO, scenario: str) -> None:
display = mock.Mock(spec=DisplayTerminal)
display.generate_progress_bar = lambda names, *args, **kwargs: names

# Prepare a valid final ZIP
finished_study.local_final_zipfile_path = create_final_zip(finished_study, scenario="nominal_study")
finished_study.local_final_zipfile_path = create_final_zip(finished_study, scenario=scenario)

# Initialize and execute the ZIP extraction
extractor = FinalZipExtractor(display=display)
Expand All @@ -113,12 +121,13 @@ def test_extract_final_zip__finished_study__nominal_study(self, finished_study:
assert result_dir.joinpath(file).is_file()

@pytest.mark.unit_test
def test_extract_final_zip__finished_study__nominal_results(self, finished_study: StudyDTO) -> None:
@pytest.mark.parametrize("scenario", ["nominal_results", "xpansion_results"])
def test_extract_final_zip__finished_study__nominal_results(self, finished_study: StudyDTO, scenario: str) -> None:
display = mock.Mock(spec=DisplayTerminal)
display.generate_progress_bar = lambda names, *args, **kwargs: names

# Prepare a valid final ZIP
finished_study.local_final_zipfile_path = create_final_zip(finished_study, scenario="nominal_results")
finished_study.local_final_zipfile_path = create_final_zip(finished_study, scenario=scenario)

# Initialize and execute the ZIP extraction
extractor = FinalZipExtractor(display=display)
Expand All @@ -131,16 +140,19 @@ def test_extract_final_zip__finished_study__nominal_results(self, finished_study
assert finished_study.final_zip_extracted
assert not finished_study.with_error

result_dir = Path(finished_study.local_final_zipfile_path).with_suffix("")
assert result_dir.joinpath("simulation.log").is_file()
result_dir = (Path(finished_study.local_final_zipfile_path).parent / finished_study.name).with_suffix(".zip")
assert result_dir.exists()
with zipfile.ZipFile(result_dir, "r") as zf:
assert zf.namelist() == ["simulation.log"]

@pytest.mark.unit_test
def test_extract_final_zip__finished_study__reentrancy(self, finished_study: StudyDTO) -> None:
@pytest.mark.parametrize("scenario", ["nominal_study", "xpansion_study"])
def test_extract_final_zip__finished_study__reentrancy(self, finished_study: StudyDTO, scenario: str) -> None:
display = mock.Mock(spec=DisplayTerminal)
display.generate_progress_bar = lambda names, *args, **kwargs: names

# Prepare a valid final ZIP
finished_study.local_final_zipfile_path = create_final_zip(finished_study)
finished_study.local_final_zipfile_path = create_final_zip(finished_study, scenario=scenario)

# Initialize and execute the ZIP extraction twice
extractor = FinalZipExtractor(display=display)
Expand Down

0 comments on commit 6a22daa

Please sign in to comment.