Skip to content

Commit

Permalink
chore(study-filesystem): improve implementation of BucketNode
Browse files Browse the repository at this point in the history
  • Loading branch information
laurent-laporte-pro authored and MartinBelthle committed Mar 5, 2024
1 parent 16d6669 commit 6de7e70
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 31 deletions.
46 changes: 20 additions & 26 deletions antarest/study/storage/rawstudy/model/filesystem/bucket_node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional
import typing as t

from antarest.core.model import JSON, SUB_JSON
from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig
Expand All @@ -12,7 +12,7 @@ class RegisteredFile:
def __init__(
self,
key: str,
node: Optional[Callable[[ContextServer, FileStudyTreeConfig], INode[Any, Any, Any]]],
node: t.Optional[t.Callable[[ContextServer, FileStudyTreeConfig], INode[t.Any, t.Any, t.Any]]],
filename: str = "",
):
self.key = key
Expand All @@ -29,42 +29,36 @@ def __init__(
self,
context: ContextServer,
config: FileStudyTreeConfig,
registered_files: Optional[List[RegisteredFile]] = None,
default_file_node: Callable[..., INode[Any, Any, Any]] = RawFileNode,
registered_files: t.Optional[t.List[RegisteredFile]] = None,
default_file_node: t.Callable[..., INode[t.Any, t.Any, t.Any]] = RawFileNode,
):
super().__init__(context, config)
self.registered_files: List[RegisteredFile] = registered_files or []
self.default_file_node: Callable[..., INode[Any, Any, Any]] = default_file_node
self.registered_files: t.List[RegisteredFile] = registered_files or []
self.default_file_node: t.Callable[..., INode[t.Any, t.Any, t.Any]] = default_file_node

def _get_registered_file(self, key: str) -> Optional[RegisteredFile]:
for registered_file in self.registered_files:
if registered_file.key == key:
return registered_file
return None
def _get_registered_file_by_key(self, key: str) -> t.Optional[RegisteredFile]:
return next((rf for rf in self.registered_files if rf.key == key), None)

def _get_registered_file_from_filename(self, filename: str) -> Optional[RegisteredFile]:
for registered_file in self.registered_files:
if registered_file.filename == filename:
return registered_file
return None
def _get_registered_file_by_filename(self, filename: str) -> t.Optional[RegisteredFile]:
return next((rf for rf in self.registered_files if rf.filename == filename), None)

def save(
self,
data: SUB_JSON,
url: Optional[List[str]] = None,
url: t.Optional[t.List[str]] = None,
) -> None:
self._assert_not_in_zipped_file()
if not self.config.path.exists():
self.config.path.mkdir()

if url is None or len(url) == 0:
assert isinstance(data, Dict)
if not url:
assert isinstance(data, dict)
for key, value in data.items():
self._save(value, key)
else:
key = url[0]
if len(url) > 1:
registered_file = self._get_registered_file(key)
registered_file = self._get_registered_file_by_key(key)
if registered_file:
node = registered_file.node or self.default_file_node
node(self.context, self.config.next_file(key)).save(data, url[1:])
Expand All @@ -74,7 +68,7 @@ def save(
self._save(data, key)

def _save(self, data: SUB_JSON, key: str) -> None:
registered_file = self._get_registered_file(key)
registered_file = self._get_registered_file_by_key(key)
if registered_file:
node, filename = (
registered_file.node or self.default_file_node,
Expand All @@ -88,12 +82,12 @@ def _save(self, data: SUB_JSON, key: str) -> None:
BucketNode(self.context, self.config.next_file(key)).save(data)

def build(self) -> TREE:
if not self.config.path.exists():
return dict()
if not self.config.path.is_dir():
return {}

children: TREE = {}
for item in sorted(self.config.path.iterdir()):
registered_file = self._get_registered_file_from_filename(item.name)
registered_file = self._get_registered_file_by_filename(item.name)
if registered_file:
node = registered_file.node or self.default_file_node
children[registered_file.key] = node(self.context, self.config.next_file(item.name))
Expand All @@ -107,7 +101,7 @@ def build(self) -> TREE:
def check_errors(
self,
data: JSON,
url: Optional[List[str]] = None,
url: t.Optional[t.List[str]] = None,
raising: bool = False,
) -> List[str]:
) -> t.List[str]:
return []
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@

class Expansion(BucketNode):
registered_files = [
RegisteredFile(
key="candidates",
node=ExpansionCandidates,
filename="candidates.ini",
),
RegisteredFile(key="candidates", node=ExpansionCandidates, filename="candidates.ini"),
RegisteredFile(key="settings", node=ExpansionSettings, filename="settings.ini"),
RegisteredFile(key="capa", node=ExpansionMatrixResources),
RegisteredFile(key="weights", node=ExpansionMatrixResources),
Expand Down

0 comments on commit 6de7e70

Please sign in to comment.