Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/1646 refactoring request for study function #1669

Merged
merged 13 commits into from
Jul 26, 2023
Merged
24 changes: 3 additions & 21 deletions antarest/study/common/studystorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,27 +244,6 @@ def export_output(self, metadata: T, output_id: str, target: Path) -> None:
"""
raise NotImplementedError()

@abstractmethod
def export_study_flat(
self,
metadata: T,
dst_path: Path,
outputs: bool = True,
output_list_filter: Optional[List[str]] = None,
denormalize: bool = True,
) -> None:
"""
Export study to destination

Args:
metadata: study.
dst_path: destination path.
outputs: list of outputs to keep.
output_list_filter: list of outputs to keep (None indicate all outputs).
denormalize: denormalize the study (replace matrix links by real matrices).
"""
raise NotImplementedError()

@abstractmethod
def get_synthesis(
self, metadata: T, params: Optional[RequestParameters] = None
Expand Down Expand Up @@ -292,3 +271,6 @@ def unarchive_study_output(
self, study: T, output_id: str, keep_src_zip: bool
) -> bool:
raise NotImplementedError()

def unarchive(self, study: T) -> None:
raise NotImplementedError()
47 changes: 41 additions & 6 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import io
import json
import logging
import shutil
import os
from datetime import datetime, timedelta
from http import HTTPStatus
Expand Down Expand Up @@ -165,6 +166,7 @@
remove_from_cache,
study_matcher,
)
from antarest.study.storage.abstract_storage_service import export_study_flat
from antarest.study.storage.variantstudy.model.command.icommand import ICommand
from antarest.study.storage.variantstudy.model.command.replace_matrix import (
ReplaceMatrix,
Expand Down Expand Up @@ -1089,9 +1091,23 @@ def export_study(
def export_task(notifier: TaskUpdateNotifier) -> TaskResult:
try:
target_study = self.get_study(uuid)
self.storage_service.get_storage(target_study).export_study(
target_study, export_path, outputs
)
if isinstance(target_study, RawStudy):
if target_study.archived:
self.storage_service.get_storage(
target_study
).unarchive(target_study)
try:
self.storage_service.get_storage(
target_study
).export_study(target_study, export_path, outputs)
finally:
if target_study.archived:
shutil.rmtree(target_study.path)
else:
self.storage_service.get_storage(
target_study
).export_study(target_study, export_path, outputs)

self.file_transfer_manager.set_ready(export_id)
return TaskResult(
success=True, message=f"Study {uuid} successfully exported"
Expand Down Expand Up @@ -1201,9 +1217,28 @@ def export_study_flat(
study = self.get_study(uuid)
assert_permission(params.user, study, StudyPermissionType.READ)
self._assert_study_unarchived(study)

return self.storage_service.get_storage(study).export_study_flat(
study, dest, len(output_list or []) > 0, output_list
path_study = Path(study.path)
if isinstance(study, RawStudy):
if study.archived:
self.storage_service.get_storage(study).unarchive(study)
try:
return export_study_flat(
path_study=path_study,
dest=dest,
outputs=len(output_list or []) > 0,
output_list_filter=output_list,
)
finally:
if study.archived:
shutil.rmtree(study.path)
snapshot_path = path_study / "snapshot"
output_src_path = path_study / "output"
export_study_flat(
path_study=snapshot_path,
dest=dest,
outputs=len(output_list or []) > 0,
output_list_filter=output_list,
output_src_path=output_src_path,
)

def delete_study(
Expand Down
70 changes: 69 additions & 1 deletion antarest/study/storage/abstract_storage_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from pathlib import Path
from typing import List, Union, Optional, IO
from uuid import uuid4
import time
from zipfile import ZipFile
import os


from antarest.core.config import Config
from antarest.core.exceptions import BadOutputError, StudyOutputNotFoundError
Expand Down Expand Up @@ -39,6 +43,7 @@
StudyFactory,
FileStudy,
)
from antarest.study.model import RawStudy
from antarest.study.storage.rawstudy.model.helpers import FileStudyHelpers
from antarest.study.storage.utils import (
fix_study_root,
Expand All @@ -49,6 +54,60 @@
logger = logging.getLogger(__name__)


def export_study_flat(
path_study: Path,
dest: Path,
outputs: bool = True,
output_list_filter: Optional[List[str]] = None,
output_src_path: Optional[Path] = None,
) -> None:
"""
Export study to destination

Args:
path_study: Study source path
dest: Destination path.
outputs: List of outputs to keep.
output_list_filter: List of outputs to keep (None indicate all outputs).
output_src_path: Denormalize the study (replace matrix links by real matrices).

"""
start_time = time.time()
output_src_path = output_src_path or path_study / "output"
output_dest_path = dest / "output"
ignore_patterns = (
lambda directory, contents: ["output"]
if str(directory) == str(path_study)
else []
)

shutil.copytree(src=path_study, dst=dest, ignore=ignore_patterns)
if outputs and output_src_path.is_dir():
if output_dest_path.exists():
shutil.rmtree(output_dest_path)
if output_list_filter is not None:
os.mkdir(output_dest_path)
for output in output_list_filter:
zip_path = output_src_path / f"{output}.zip"
if zip_path.exists():
with ZipFile(zip_path) as zf:
zf.extractall(output_dest_path / output)
else:
shutil.copytree(
src=output_src_path / output,
dst=output_dest_path / output,
)
else:
shutil.copytree(
src=output_src_path,
dst=output_dest_path,
)

stop_time = time.time()
duration = "{:.3f}".format(stop_time - start_time)
logger.info(f"Study {path_study} exported (flat mode) in {duration}s")


class AbstractStorageService(IStudyStorageService[T], ABC):
def __init__(
self,
Expand Down Expand Up @@ -272,7 +331,16 @@ def export_study(
logger.info(f"Exporting study {metadata.id} to tmp path {tmpdir}")
assert_this(target.name.endswith(".zip"))
tmp_study_path = Path(tmpdir) / "tmp_copy"
self.export_study_flat(metadata, tmp_study_path, outputs)
if not isinstance(metadata, RawStudy):
snapshot_path = path_study / "snapshot"
output_src_path = path_study / "output"
export_study_flat(
path_study=snapshot_path,
dest=tmp_study_path,
outputs=outputs,
output_src_path=output_src_path,
)
export_study_flat(path_study, tmp_study_path, outputs)
stopwatch = StopWatch()
zip_dir(tmp_study_path, target)
stopwatch.log_elapsed(
Expand Down
26 changes: 0 additions & 26 deletions antarest/study/storage/rawstudy/raw_study_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
is_managed,
remove_from_cache,
create_new_empty_study,
export_study_flat,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -362,31 +361,6 @@ def import_study(self, metadata: RawStudy, stream: IO[bytes]) -> Study:
metadata.path = str(path_study)
return metadata

def export_study_flat(
self,
metadata: RawStudy,
dst_path: Path,
outputs: bool = True,
output_list_filter: Optional[List[str]] = None,
denormalize: bool = True,
) -> None:
path_study = Path(metadata.path)

if metadata.archived:
self.unarchive(metadata)
try:
export_study_flat(
path_study,
dst_path,
self.study_factory,
outputs,
output_list_filter,
denormalize,
)
finally:
if metadata.archived:
shutil.rmtree(metadata.path)

def check_errors(
self,
metadata: RawStudy,
Expand Down
52 changes: 0 additions & 52 deletions antarest/study/storage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,55 +367,3 @@ def get_start_date(
first_week_size=first_week_size,
level=level,
)


def export_study_flat(
path_study: Path,
dest: Path,
study_factory: StudyFactory,
outputs: bool = True,
output_list_filter: Optional[List[str]] = None,
denormalize: bool = True,
output_src_path: Optional[Path] = None,
) -> None:
start_time = time.time()

output_src_path = output_src_path or path_study / "output"
output_dest_path = dest / "output"
ignore_patterns = (
lambda directory, contents: ["output"]
if str(directory) == str(path_study)
else []
)

shutil.copytree(src=path_study, dst=dest, ignore=ignore_patterns)

if outputs and output_src_path.is_dir():
if output_dest_path.is_dir():
shutil.rmtree(output_dest_path)
if output_list_filter is not None:
os.mkdir(output_dest_path)
for output in output_list_filter:
zip_path = output_src_path / f"{output}.zip"
if zip_path.exists():
with ZipFile(zip_path) as zf:
zf.extractall(output_dest_path / output)
else:
shutil.copytree(
src=output_src_path / output,
dst=output_dest_path / output,
)
else:
shutil.copytree(
src=output_src_path,
dst=output_dest_path,
)

stop_time = time.time()
duration = "{:.3f}".format(stop_time - start_time)
logger.info(f"Study {path_study} exported (flat mode) in {duration}s")
study = study_factory.create_from_fs(dest, "", use_cache=False)
if denormalize:
study.tree.denormalize()
duration = "{:.3f}".format(time.time() - stop_time)
logger.info(f"Study {path_study} denormalized in {duration}s")
Loading