Skip to content

Commit

Permalink
simplify code and add UTs
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBelthle committed Jul 12, 2024
1 parent 0581237 commit 86216b4
Show file tree
Hide file tree
Showing 5 changed files with 2,592 additions and 50 deletions.
4 changes: 4 additions & 0 deletions antarest/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ class ShouldNotHappenException(Exception):
pass


class MustNotModifyOutputException(Exception):
pass


# ============================================================
# Exceptions related to the study configuration (`.ini` files)
# ============================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
from pandas import DataFrame

from antarest.core.exceptions import MustNotModifyOutputException
from antarest.core.model import JSON
from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig
from antarest.study.storage.rawstudy.model.filesystem.context import ContextServer
Expand Down Expand Up @@ -94,29 +95,6 @@ def parse(
matrix = self.parse_dataframe(file_path, tmp_dir)
return cast(JSON, matrix.to_dict(orient="split"))

def _dump_json(self, data: JSON) -> None:
df = pd.DataFrame(**data)

headers = pd.DataFrame(df.columns.values.tolist()).T
matrix = pd.concat([headers, pd.DataFrame(df.values)], axis=0)

time = self.date_serializer.build_date(df.index)
matrix.index = time.index

matrix = pd.concat([time, matrix], axis=1)

head = self.head_writer.build(var=df.columns.size, end=df.index.size)
with self.config.path.open(mode="w", newline="\n") as fd:
fd.write(head)
if not matrix.empty:
matrix.to_csv(
fd,
sep="\t",
header=False,
index=False,
float_format="%.6f",
)

def check_errors(
self,
data: JSON,
Expand Down Expand Up @@ -160,11 +138,7 @@ def load(
) from e

def dump(self, data: Union[bytes, JSON], url: Optional[List[str]] = None) -> None:
if isinstance(data, bytes):
self.config.path.parent.mkdir(exist_ok=True, parents=True)
self.config.path.write_bytes(data)
else:
self._dump_json(data)
raise MustNotModifyOutputException(self.config.path.name)

def normalize(self) -> None:
pass # no external store in this node
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import csv
import typing as t

import pandas as pd

from antarest.core.exceptions import MustNotModifyOutputException
from antarest.core.model import JSON
from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig
from antarest.study.storage.rawstudy.model.filesystem.context import ContextServer
Expand Down Expand Up @@ -31,7 +31,7 @@ def get_lazy_content(
depth: int = -1,
expanded: bool = False,
) -> str:
return f"matrix://{self.config.path.name}"
return f"matrix://{self.config.path.name}" # prefix used by the front to parse the back-end response

def load(
self,
Expand All @@ -47,8 +47,7 @@ def load(
return t.cast(JSON, output)

def dump(self, data: bytes, url: t.Optional[t.List[str]] = None) -> None:
self.config.path.parent.mkdir(exist_ok=True, parents=True)
self.config.path.write_bytes(data)
raise MustNotModifyOutputException(self.config.path.name)

def check_errors(self, data: str, url: t.Optional[t.List[str]] = None, raising: bool = False) -> t.List[str]:
if not self.config.path.exists():
Expand All @@ -59,10 +58,10 @@ def check_errors(self, data: str, url: t.Optional[t.List[str]] = None, raising:
return []

def normalize(self) -> None:
pass # no external store in this node
pass # shouldn't be normalized as it's an output file

def denormalize(self) -> None:
pass # no external store in this node
pass # shouldn't be denormalized as it's an output file


class DigestSynthesis(OutputSynthesis):
Expand All @@ -78,23 +77,11 @@ def load(
) -> JSON:
file_path = self.config.path
with open(file_path, "r") as f:
csv_file = csv.reader(f, delimiter="\t")
longest_row = 0
for row in csv_file:
row_length = len(row)
if row_length > longest_row:
longest_row = row_length
lines = f.read().splitlines()
splitted_rows = [row.split("\t") for row in lines]
longest_row = max(len(row) for row in splitted_rows)
new_rows = [row + [""] * (longest_row - len(row)) for row in splitted_rows]

with open(file_path, "r") as f:
csv_file = csv.reader(f, delimiter="\t")
new_rows = []
for row in csv_file:
new_row = row
n = len(row)
if n < longest_row:
difference = longest_row - n
new_row += [""] * difference
new_rows.append(row)
df = pd.DataFrame(data=new_rows)
output = df.to_dict(orient="split")
del output["index"]
Expand Down
Loading

0 comments on commit 86216b4

Please sign in to comment.