Skip to content

Commit

Permalink
style: simplify implementation of RawStudyService and `VariantStudy…
Browse files Browse the repository at this point in the history
…Service`
  • Loading branch information
laurent-laporte-pro committed Feb 23, 2024
1 parent 8899fd7 commit 5ab7bfc
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 50 deletions.
20 changes: 10 additions & 10 deletions antarest/study/storage/rawstudy/raw_study_service.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
import shutil
import time
import typing as t
from datetime import datetime
from pathlib import Path
from threading import Thread
from typing import BinaryIO, List, Optional, Sequence
from uuid import uuid4
from zipfile import ZipFile

Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(
)
self.cleanup_thread.start()

def update_from_raw_meta(self, metadata: RawStudy, fallback_on_default: Optional[bool] = False) -> None:
def update_from_raw_meta(self, metadata: RawStudy, fallback_on_default: t.Optional[bool] = False) -> None:
"""
Update metadata from study raw metadata
Args:
Expand Down Expand Up @@ -90,7 +90,7 @@ def update_from_raw_meta(self, metadata: RawStudy, fallback_on_default: Optional
metadata.version = metadata.version or 0
metadata.created_at = metadata.created_at or datetime.utcnow()
metadata.updated_at = metadata.updated_at or datetime.utcnow()
if not metadata.additional_data:
if metadata.additional_data is None:
metadata.additional_data = StudyAdditionalData()
metadata.additional_data.patch = metadata.additional_data.patch or Patch().json()
metadata.additional_data.author = metadata.additional_data.author or "Unknown"
Expand Down Expand Up @@ -148,7 +148,7 @@ def get_raw(
self,
metadata: RawStudy,
use_cache: bool = True,
output_dir: Optional[Path] = None,
output_dir: t.Optional[Path] = None,
) -> FileStudy:
"""
Fetch a study object and its config
Expand All @@ -163,7 +163,7 @@ def get_raw(
study_path = self.get_study_path(metadata)
return self.study_factory.create_from_fs(study_path, metadata.id, output_dir, use_cache=use_cache)

def get_synthesis(self, metadata: RawStudy, params: Optional[RequestParameters] = None) -> FileStudyTreeConfigDTO:
def get_synthesis(self, metadata: RawStudy, params: t.Optional[RequestParameters] = None) -> FileStudyTreeConfigDTO:
self._check_study_exists(metadata)
study_path = self.get_study_path(metadata)
study = self.study_factory.create_from_fs(study_path, metadata.id)
Expand Down Expand Up @@ -206,7 +206,7 @@ def copy(
self,
src_meta: RawStudy,
dest_name: str,
groups: Sequence[str],
groups: t.Sequence[str],
with_outputs: bool = False,
) -> RawStudy:
"""
Expand All @@ -223,7 +223,7 @@ def copy(
"""
self._check_study_exists(src_meta)

if not src_meta.additional_data:
if src_meta.additional_data is None:
additional_data = StudyAdditionalData()
else:
additional_data = StudyAdditionalData(
Expand Down Expand Up @@ -295,7 +295,7 @@ def delete_output(self, metadata: RawStudy, output_name: str) -> None:
output_path.unlink(missing_ok=True)
remove_from_cache(self.cache, metadata.id)

def import_study(self, metadata: RawStudy, stream: BinaryIO) -> Study:
def import_study(self, metadata: RawStudy, stream: t.BinaryIO) -> Study:
"""
Import study in the directory of the study.
Expand Down Expand Up @@ -329,7 +329,7 @@ def export_study_flat(
metadata: RawStudy,
dst_path: Path,
outputs: bool = True,
output_list_filter: Optional[List[str]] = None,
output_list_filter: t.Optional[t.List[str]] = None,
denormalize: bool = True,
) -> None:
try:
Expand All @@ -352,7 +352,7 @@ def export_study_flat(
def check_errors(
self,
metadata: RawStudy,
) -> List[str]:
) -> t.List[str]:
"""
Check study antares data integrity
Args:
Expand Down
77 changes: 37 additions & 40 deletions antarest/study/storage/variantstudy/variant_study_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import logging
import re
import shutil
import typing as t
from datetime import datetime
from functools import reduce
from pathlib import Path
from typing import Callable, List, Optional, Sequence, Tuple, cast
from uuid import uuid4

from fastapi import HTTPException
Expand Down Expand Up @@ -101,11 +101,11 @@ def get_command(self, study_id: str, command_id: str, params: RequestParameters)

try:
index = [command.id for command in study.commands].index(command_id) # Maybe add Try catch for this
return cast(CommandDTO, study.commands[index].to_dto())
return t.cast(CommandDTO, study.commands[index].to_dto())
except ValueError:
raise CommandNotFoundError(f"Command with id {command_id} not found") from None

def get_commands(self, study_id: str, params: RequestParameters) -> List[CommandDTO]:
def get_commands(self, study_id: str, params: RequestParameters) -> t.List[CommandDTO]:
"""
Get command lists
Args:
Expand All @@ -116,8 +116,8 @@ def get_commands(self, study_id: str, params: RequestParameters) -> List[Command
study = self._get_variant_study(study_id, params)
return [command.to_dto() for command in study.commands]

def _check_commands_validity(self, study_id: str, commands: List[CommandDTO]) -> List[ICommand]:
command_objects: List[ICommand] = []
def _check_commands_validity(self, study_id: str, commands: t.List[CommandDTO]) -> t.List[ICommand]:
command_objects: t.List[ICommand] = []
for i, command in enumerate(commands):
try:
command_objects.extend(self.command_factory.to_command(command))
Expand Down Expand Up @@ -157,9 +157,9 @@ def append_command(self, study_id: str, command: CommandDTO, params: RequestPara
def append_commands(
self,
study_id: str,
commands: List[CommandDTO],
commands: t.List[CommandDTO],
params: RequestParameters,
) -> List[str]:
) -> t.List[str]:
"""
Add command to list of commands (at the end)
Args:
Expand Down Expand Up @@ -196,7 +196,7 @@ def append_commands(
def replace_commands(
self,
study_id: str,
commands: List[CommandDTO],
commands: t.List[CommandDTO],
params: RequestParameters,
) -> str:
"""
Expand Down Expand Up @@ -320,13 +320,13 @@ def export_commands_matrices(self, study_id: str, params: RequestParameters) ->
lambda: reduce(
lambda m, c: m + c.get_inner_matrices(),
self.command_factory.to_command(command.to_dto()),
cast(List[str], []),
t.cast(t.List[str], []),
),
lambda e: logger.warning(f"Failed to parse command {command}", exc_info=e),
)
or []
}
return cast(MatrixService, self.command_factory.command_context.matrix_service).download_matrix_list(
return t.cast(MatrixService, self.command_factory.command_context.matrix_service).download_matrix_list(
list(matrices), f"{study.name}_{study.id}_matrices", params
)

Expand Down Expand Up @@ -410,7 +410,7 @@ def get_all_variants_children(
def walk_children(
self,
parent_id: str,
fun: Callable[[VariantStudy], None],
fun: t.Callable[[VariantStudy], None],
bottom_first: bool,
) -> None:
study = self._get_variant_study(
Expand All @@ -426,13 +426,13 @@ def walk_children(
if bottom_first:
fun(study)

def get_variants_parents(self, id: str, params: RequestParameters) -> List[StudyMetadataDTO]:
output_list: List[StudyMetadataDTO] = self._get_variants_parents(id, params)
def get_variants_parents(self, id: str, params: RequestParameters) -> t.List[StudyMetadataDTO]:
output_list: t.List[StudyMetadataDTO] = self._get_variants_parents(id, params)
if output_list:
output_list = output_list[1:]
return output_list

def get_direct_parent(self, id: str, params: RequestParameters) -> Optional[StudyMetadataDTO]:
def get_direct_parent(self, id: str, params: RequestParameters) -> t.Optional[StudyMetadataDTO]:
study = self._get_variant_study(id, params, raw_study_accepted=True)
if study.parent_id is not None:
parent = self._get_variant_study(study.parent_id, params, raw_study_accepted=True)
Expand All @@ -447,7 +447,7 @@ def get_direct_parent(self, id: str, params: RequestParameters) -> Optional[Stud
)
return None

def _get_variants_parents(self, id: str, params: RequestParameters) -> List[StudyMetadataDTO]:
def _get_variants_parents(self, id: str, params: RequestParameters) -> t.List[StudyMetadataDTO]:
study = self._get_variant_study(id, params, raw_study_accepted=True)
metadata = (
self.get_study_information(
Expand All @@ -458,7 +458,7 @@ def _get_variants_parents(self, id: str, params: RequestParameters) -> List[Stud
study,
)
)
output_list: List[StudyMetadataDTO] = [metadata]
output_list: t.List[StudyMetadataDTO] = [metadata]
if study.parent_id is not None:
output_list.extend(
self._get_variants_parents(
Expand Down Expand Up @@ -530,16 +530,15 @@ def create_variant_study(self, uuid: str, name: str, params: RequestParameters)
assert_permission(params.user, study, StudyPermissionType.READ)
new_id = str(uuid4())
study_path = str(self.config.get_workspace_path() / new_id)
if study.additional_data:
# noinspection PyArgumentList
if study.additional_data is None:
additional_data = StudyAdditionalData()
else:
additional_data = StudyAdditionalData(
horizon=study.additional_data.horizon,
author=study.additional_data.author,
patch=study.additional_data.patch,
)
else:
additional_data = StudyAdditionalData()
# noinspection PyArgumentList

variant_study = VariantStudy(
id=new_id,
name=name,
Expand Down Expand Up @@ -653,7 +652,7 @@ def generate_study_config(
self,
variant_study_id: str,
params: RequestParameters,
) -> Tuple[GenerationResultInfoDTO, FileStudyTreeConfig]:
) -> t.Tuple[GenerationResultInfoDTO, FileStudyTreeConfig]:
# Get variant study
variant_study = self._get_variant_study(variant_study_id, params)

Expand All @@ -667,8 +666,8 @@ def _generate_study_config(
self,
original_study: VariantStudy,
metadata: VariantStudy,
config: Optional[FileStudyTreeConfig],
) -> Tuple[GenerationResultInfoDTO, FileStudyTreeConfig]:
config: t.Optional[FileStudyTreeConfig],
) -> t.Tuple[GenerationResultInfoDTO, FileStudyTreeConfig]:
parent_study = self.repository.get(metadata.parent_id)
if parent_study is None:
raise StudyNotFoundError(metadata.parent_id)
Expand Down Expand Up @@ -698,9 +697,9 @@ def _get_commands_and_notifier(
variant_study: VariantStudy,
notifier: TaskUpdateNotifier,
from_index: int = 0,
) -> Tuple[List[List[ICommand]], Callable[[int, bool, str], None]]:
) -> t.Tuple[t.List[t.List[ICommand]], t.Callable[[int, bool, str], None]]:
# Generate
commands: List[List[ICommand]] = self._to_commands(variant_study, from_index)
commands: t.List[t.List[ICommand]] = self._to_commands(variant_study, from_index)

def notify(command_index: int, command_result: bool, command_message: str) -> None:
try:
Expand All @@ -727,8 +726,8 @@ def notify(command_index: int, command_result: bool, command_message: str) -> No

return commands, notify

def _to_commands(self, metadata: VariantStudy, from_index: int = 0) -> List[List[ICommand]]:
commands: List[List[ICommand]] = [
def _to_commands(self, metadata: VariantStudy, from_index: int = 0) -> t.List[t.List[ICommand]]:
commands: t.List[t.List[ICommand]] = [
self.command_factory.to_command(command_block.to_dto())
for index, command_block in enumerate(metadata.commands)
if from_index <= index
Expand All @@ -740,7 +739,7 @@ def _generate_config(
variant_study: VariantStudy,
config: FileStudyTreeConfig,
notifier: TaskUpdateNotifier = noop_notifier,
) -> Tuple[GenerationResultInfoDTO, FileStudyTreeConfig]:
) -> t.Tuple[GenerationResultInfoDTO, FileStudyTreeConfig]:
commands, notify = self._get_commands_and_notifier(variant_study=variant_study, notifier=notifier)
return self.generator.generate_config(commands, config, variant_study, notifier=notify)

Expand Down Expand Up @@ -809,7 +808,7 @@ def copy(
self,
src_meta: VariantStudy,
dest_name: str,
groups: Sequence[str],
groups: t.Sequence[str],
with_outputs: bool = False,
) -> VariantStudy:
"""
Expand All @@ -826,16 +825,14 @@ def copy(
"""
new_id = str(uuid4())
study_path = str(self.config.get_workspace_path() / new_id)
if src_meta.additional_data:
# noinspection PyArgumentList
if src_meta.additional_data is None:
additional_data = StudyAdditionalData()
else:
additional_data = StudyAdditionalData(
horizon=src_meta.additional_data.horizon,
author=src_meta.additional_data.author,
patch=src_meta.additional_data.patch,
)
else:
additional_data = StudyAdditionalData()
# noinspection PyArgumentList
dst_meta = VariantStudy(
id=new_id,
name=dest_name,
Expand Down Expand Up @@ -893,7 +890,7 @@ def _safe_generation(self, metadata: VariantStudy, timeout: int = DEFAULT_AWAIT_
@staticmethod
def _get_snapshot_last_executed_command_index(
study: VariantStudy,
) -> Optional[int]:
) -> t.Optional[int]:
if study.snapshot and study.snapshot.last_executed_command:
last_executed_command_index = [command.id for command in study.commands].index(
study.snapshot.last_executed_command
Expand All @@ -905,7 +902,7 @@ def get_raw(
self,
metadata: VariantStudy,
use_cache: bool = True,
output_dir: Optional[Path] = None,
output_dir: t.Optional[Path] = None,
) -> FileStudy:
"""
Fetch a study raw tree object and its config
Expand All @@ -925,7 +922,7 @@ def get_raw(
use_cache=use_cache,
)

def get_study_sim_result(self, study: VariantStudy) -> List[StudySimResultDTO]:
def get_study_sim_result(self, study: VariantStudy) -> t.List[StudySimResultDTO]:
"""
Get global result information
Args:
Expand Down Expand Up @@ -988,7 +985,7 @@ def export_study_flat(
metadata: VariantStudy,
dst_path: Path,
outputs: bool = True,
output_list_filter: Optional[List[str]] = None,
output_list_filter: t.Optional[t.List[str]] = None,
denormalize: bool = True,
) -> None:
self._safe_generation(metadata)
Expand All @@ -1009,7 +1006,7 @@ def export_study_flat(
def get_synthesis(
self,
metadata: VariantStudy,
params: Optional[RequestParameters] = None,
params: t.Optional[RequestParameters] = None,
) -> FileStudyTreeConfigDTO:
"""
Return study synthesis
Expand Down

0 comments on commit 5ab7bfc

Please sign in to comment.