Skip to content

Commit

Permalink
fix(matrix): create empty file if the matrix is empty
Browse files Browse the repository at this point in the history
  • Loading branch information
laurent-laporte-pro committed Sep 28, 2023
1 parent 36410b5 commit 9155566
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 53 deletions.
29 changes: 17 additions & 12 deletions antarest/matrixstore/repository.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import hashlib
import logging
import typing as t
from pathlib import Path
from typing import List, Optional, Union

import numpy as np
from filelock import FileLock
Expand Down Expand Up @@ -31,19 +31,19 @@ def save(self, matrix_user_metadata: MatrixDataSet) -> MatrixDataSet:
logger.debug(f"Matrix dataset {matrix_user_metadata.id} for user {matrix_user_metadata.owner_id} saved")
return matrix_user_metadata

def get(self, id: str) -> Optional[MatrixDataSet]:
def get(self, id: str) -> t.Optional[MatrixDataSet]:
matrix: MatrixDataSet = db.session.query(MatrixDataSet).get(id)
return matrix

def get_all_datasets(self) -> List[MatrixDataSet]:
matrix_datasets: List[MatrixDataSet] = db.session.query(MatrixDataSet).all()
def get_all_datasets(self) -> t.List[MatrixDataSet]:
matrix_datasets: t.List[MatrixDataSet] = db.session.query(MatrixDataSet).all()
return matrix_datasets

def query(
self,
name: Optional[str],
owner: Optional[int] = None,
) -> List[MatrixDataSet]:
name: t.Optional[str],
owner: t.Optional[int] = None,
) -> t.List[MatrixDataSet]:
"""
Query a list of MatrixUserMetadata by searching for each one separately if a set of filter match
Expand All @@ -59,7 +59,7 @@ def query(
query = query.filter(MatrixDataSet.name.ilike(f"%{name}%")) # type: ignore
if owner is not None:
query = query.filter(MatrixDataSet.owner_id == owner)
datasets: List[MatrixDataSet] = query.distinct().all()
datasets: t.List[MatrixDataSet] = query.distinct().all()
return datasets

def delete(self, dataset_id: str) -> None:
Expand All @@ -83,7 +83,7 @@ def save(self, matrix: Matrix) -> Matrix:
logger.debug(f"Matrix {matrix.id} saved")
return matrix

def get(self, matrix_hash: str) -> Optional[Matrix]:
def get(self, matrix_hash: str) -> t.Optional[Matrix]:
matrix: Matrix = db.session.query(Matrix).get(matrix_hash)
return matrix

Expand Down Expand Up @@ -130,6 +130,7 @@ def get(self, matrix_hash: str) -> MatrixContent:

matrix_file = self.bucket_dir.joinpath(f"{matrix_hash}.tsv")
matrix = np.loadtxt(matrix_file, delimiter="\t", dtype=np.float64, ndmin=2)
matrix = matrix.reshape((1, 0)) if matrix.size == 0 else matrix
data = matrix.tolist()
index = list(range(matrix.shape[0]))
columns = list(range(matrix.shape[1]))
Expand All @@ -148,7 +149,7 @@ def exists(self, matrix_hash: str) -> bool:
matrix_file = self.bucket_dir.joinpath(f"{matrix_hash}.tsv")
return matrix_file.exists()

def save(self, content: Union[List[List[MatrixData]], npt.NDArray[np.float64]]) -> str:
def save(self, content: t.Union[t.List[t.List[MatrixData]], npt.NDArray[np.float64]]) -> str:
"""
Saves the content of a matrix as a TSV file in the bucket directory
and returns its SHA256 hash.
Expand Down Expand Up @@ -188,8 +189,12 @@ def save(self, content: Union[List[List[MatrixData]], npt.NDArray[np.float64]])
# Ensure exclusive access to the matrix file between multiple processes (or threads).
lock_file = matrix_file.with_suffix(".tsv.lock")
with FileLock(lock_file, timeout=15):
# noinspection PyTypeChecker
np.savetxt(matrix_file, matrix, delimiter="\t", fmt="%.18f")
if matrix.size == 0:
# If the array or dataframe is empty, create an empty file instead of
# traditional saving to avoid unwanted line breaks.
open(matrix_file, mode="wb").close()
else:
np.savetxt(matrix_file, matrix, delimiter="\t", fmt="%.18f")

# IMPORTANT: Deleting the lock file under Linux can make locking unreliable.
# See https://github.com/tox-dev/py-filelock/issues/31
Expand Down
19 changes: 15 additions & 4 deletions antarest/matrixstore/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def _file_importation(self, file: bytes, is_json: bool = False) -> str:
return self.create(MatrixContent.parse_raw(file).data)
# noinspection PyTypeChecker
matrix = np.loadtxt(BytesIO(file), delimiter="\t", dtype=np.float64, ndmin=2)
matrix = matrix.reshape((1, 0)) if matrix.size == 0 else matrix
return self.create(matrix)

def get_dataset(
Expand Down Expand Up @@ -380,8 +381,13 @@ def create_matrix_files(self, matrix_ids: Sequence[str], export_path: Path) -> s
name = f"matrix-{mtx.id}.txt"
filepath = f"{tmpdir}/{name}"
array = np.array(mtx.data, dtype=np.float64)
# noinspection PyTypeChecker
np.savetxt(filepath, array, delimiter="\t", fmt="%.18f")
if array.size == 0:
# If the array or dataframe is empty, create an empty file instead of
# traditional saving to avoid unwanted line breaks.
open(filepath, mode="wb").close()
else:
# noinspection PyTypeChecker
np.savetxt(filepath, array, delimiter="\t", fmt="%.18f")
zip_dir(Path(tmpdir), export_path)
stopwatch.log_elapsed(lambda x: logger.info(f"Matrix dataset exported (zipped mode) in {x}s"))
return str(export_path)
Expand Down Expand Up @@ -467,5 +473,10 @@ def download_matrix(
raise UserHasNotPermissionError()
if matrix := self.get(matrix_id):
array = np.array(matrix.data, dtype=np.float64)
# noinspection PyTypeChecker
np.savetxt(filepath, array, delimiter="\t", fmt="%.18f")
if array.size == 0:
# If the array or dataframe is empty, create an empty file instead of
# traditional saving to avoid unwanted line breaks.
open(filepath, mode="wb").close()
else:
# noinspection PyTypeChecker
np.savetxt(filepath, array, delimiter="\t", fmt="%.18f")
1 change: 1 addition & 0 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,7 @@ def _create_edit_study_command(
if isinstance(data, bytes):
# noinspection PyTypeChecker
matrix = np.loadtxt(io.BytesIO(data), delimiter="\t", dtype=np.float64, ndmin=2)
matrix = matrix.reshape((1, 0)) if matrix.size == 0 else matrix
return ReplaceMatrix(
target=url,
matrix=matrix.tolist(),
Expand Down
1 change: 1 addition & 0 deletions antarest/tools/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def apply_commands(
matrix_dataset: List[str] = []
for matrix_file in matrices_dir.iterdir():
matrix = np.loadtxt(matrix_file, delimiter="\t", dtype=np.float64, ndmin=2)
matrix = matrix.reshape((1, 0)) if matrix.size == 0 else matrix
matrix_data = matrix.tolist()
res = self.session.post(self.build_url("/v1/matrix"), json=matrix_data)
res.raise_for_status()
Expand Down
93 changes: 56 additions & 37 deletions tests/matrixstore/test_repository.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import typing as t
from datetime import datetime
from pathlib import Path
from typing import Optional

import numpy as np
import pytest
from numpy import typing as npt

from antarest.core.config import Config, SecurityConfig
from antarest.core.utils.fastapi_sqlalchemy import db
Expand All @@ -12,28 +13,27 @@
from antarest.matrixstore.model import Matrix, MatrixContent, MatrixDataSet, MatrixDataSetRelation
from antarest.matrixstore.repository import MatrixContentRepository, MatrixDataSetRepository, MatrixRepository

ArrayData = t.Union[t.List[t.List[float]], npt.NDArray[np.float64]]


class TestMatrixRepository:
def test_db_lifecycle(self):
def test_db_lifecycle(self) -> None:
with db():
# sourcery skip: extract-method
repo = MatrixRepository()
m = Matrix(
id="hello",
created_at=datetime.now(),
)
m = Matrix(id="hello", created_at=datetime.now())
repo.save(m)
assert m.id
assert m == repo.get(m.id)
assert repo.exists(m.id)
repo.delete(m.id)
assert repo.get(m.id) is None

def test_bucket_lifecycle(self, tmp_path: Path):
def test_bucket_lifecycle(self, tmp_path: Path) -> None:
repo = MatrixContentRepository(tmp_path)

a = [[1, 2], [3, 4]]
b = [[5, 6], [7, 8]]
a: ArrayData = [[1, 2], [3, 4]]
b: ArrayData = [[5, 6], [7, 8]]

matrix_content_a = MatrixContent(data=a, index=[0, 1], columns=[0, 1])
matrix_content_b = MatrixContent(data=b, index=[0, 1], columns=[0, 1])
Expand All @@ -51,7 +51,7 @@ def test_bucket_lifecycle(self, tmp_path: Path):
with pytest.raises(FileNotFoundError):
repo.get(aid)

def test_dataset(self):
def test_dataset(self) -> None:
with db():
# sourcery skip: extract-duplicate-method, extract-method
repo = MatrixRepository()
Expand All @@ -66,15 +66,9 @@ def test_dataset(self):

dataset_repo = MatrixDataSetRepository()

m1 = Matrix(
id="hello",
created_at=datetime.now(),
)
m1 = Matrix(id="hello", created_at=datetime.now())
repo.save(m1)
m2 = Matrix(
id="world",
created_at=datetime.now(),
)
m2 = Matrix(id="world", created_at=datetime.now())
repo.save(m2)

dataset = MatrixDataSet(
Expand All @@ -94,7 +88,7 @@ def test_dataset(self):
dataset.matrices.append(matrix_relation)

dataset = dataset_repo.save(dataset)
dataset_query_result: Optional[MatrixDataSet] = dataset_repo.get(dataset.id)
dataset_query_result = dataset_repo.get(dataset.id)
assert dataset_query_result is not None
assert dataset_query_result.name == "some name"
assert len(dataset_query_result.matrices) == 2
Expand All @@ -106,12 +100,12 @@ def test_dataset(self):
updated_at=datetime.now(),
)
dataset_repo.save(dataset_update)
dataset_query_result: Optional[MatrixDataSet] = dataset_repo.get(dataset.id)
dataset_query_result = dataset_repo.get(dataset.id)
assert dataset_query_result is not None
assert dataset_query_result.name == "some name change"
assert dataset_query_result.owner_id == user.id

def test_datastore_query(self):
def test_datastore_query(self) -> None:
# sourcery skip: extract-duplicate-method
with db():
user_repo = UserRepository(Config(security=SecurityConfig()))
Expand All @@ -121,15 +115,9 @@ def test_datastore_query(self):
user2 = user_repo.save(User(name="hello", password=Password("world")))

repo = MatrixRepository()
m1 = Matrix(
id="hello",
created_at=datetime.now(),
)
m1 = Matrix(id="hello", created_at=datetime.now())
repo.save(m1)
m2 = Matrix(
id="world",
created_at=datetime.now(),
)
m2 = Matrix(id="world", created_at=datetime.now())
repo.save(m2)

dataset_repo = MatrixDataSetRepository()
Expand Down Expand Up @@ -176,14 +164,19 @@ def test_datastore_query(self):
assert repo.get(m1.id) is not None
assert (
len(
db.session.query(MatrixDataSetRelation).filter(MatrixDataSetRelation.dataset_id == dataset.id).all()
# fmt: off
db.session
.query(MatrixDataSetRelation)
.filter(MatrixDataSetRelation.dataset_id == dataset.id)
.all()
# fmt: on
)
== 0
)


class TestMatrixContentRepository:
def test_save(self, matrix_content_repo: MatrixContentRepository):
def test_save(self, matrix_content_repo: MatrixContentRepository) -> None:
"""
Saves the content of a matrix as a TSV file in the directory
and returns its SHA256 hash.
Expand All @@ -192,6 +185,7 @@ def test_save(self, matrix_content_repo: MatrixContentRepository):
bucket_dir = matrix_content_repo.bucket_dir

# when the data is saved in the repo
data: ArrayData
data = [[1, 2, 3], [4, 5, 6]]
matrix_hash = matrix_content_repo.save(data)
# then a TSV file is created in the repo directory
Expand Down Expand Up @@ -224,12 +218,37 @@ def test_save(self, matrix_content_repo: MatrixContentRepository):
other_matrix_file = bucket_dir.joinpath(f"{other_matrix_hash}.tsv")
assert set(matrix_files) == {matrix_file, other_matrix_file}

def test_get(self, matrix_content_repo):
def test_save_and_retrieve_empty_matrix(self, matrix_content_repo: MatrixContentRepository) -> None:
"""
Test saving and retrieving empty matrices as TSV files.
Il all cases the file must be empty.
"""
bucket_dir = matrix_content_repo.bucket_dir

# Test with an empty matrix
empty_array: ArrayData = []
matrix_hash = matrix_content_repo.save(empty_array)
matrix_file = bucket_dir.joinpath(f"{matrix_hash}.tsv")
retrieved_matrix = matrix_content_repo.get(matrix_hash)

assert not matrix_file.read_bytes()
assert retrieved_matrix.data == [[]]

# Test with an empty 2D array
empty_2d_array: ArrayData = [[]]
matrix_hash = matrix_content_repo.save(empty_2d_array)
matrix_file = bucket_dir.joinpath(f"{matrix_hash}.tsv")
retrieved_matrix = matrix_content_repo.get(matrix_hash)

assert not matrix_file.read_bytes()
assert retrieved_matrix.data == [[]]

def test_get(self, matrix_content_repo: MatrixContentRepository) -> None:
"""
Retrieves the content of a matrix with a given SHA256 hash.
"""
# when the data is saved in the repo
data = [[1, 2, 3], [4, 5, 6]]
data: ArrayData = [[1, 2, 3], [4, 5, 6]]
matrix_hash = matrix_content_repo.save(data)
# then the saved matrix object can be retrieved
content = matrix_content_repo.get(matrix_hash)
Expand All @@ -243,12 +262,12 @@ def test_get(self, matrix_content_repo):
missing_hash = "8b1a9953c4611296a827abf8c47804d7e6c49c6b"
matrix_content_repo.get(missing_hash)

def test_exists(self, matrix_content_repo):
def test_exists(self, matrix_content_repo: MatrixContentRepository) -> None:
"""
Checks if a matrix with a given SHA256 hash exists in the directory.
"""
# when the data is saved in the repo
data = [[1, 2, 3], [4, 5, 6]]
data: ArrayData = [[1, 2, 3], [4, 5, 6]]
matrix_hash = matrix_content_repo.save(data)
# then the saved matrix object exists
assert matrix_content_repo.exists(matrix_hash)
Expand All @@ -258,12 +277,12 @@ def test_exists(self, matrix_content_repo):
missing_hash = "8b1a9953c4611296a827abf8c47804d7e6c49c6b"
assert not matrix_content_repo.exists(missing_hash)

def test_delete(self, matrix_content_repo):
def test_delete(self, matrix_content_repo: MatrixContentRepository) -> None:
"""
Deletes the tsv file containing the content of a matrix with a given SHA256 hash.
"""
# when the data is saved in the repo
data = [[1, 2, 3], [4, 5, 6]]
data: ArrayData = [[1, 2, 3], [4, 5, 6]]
matrix_hash = matrix_content_repo.save(data)
# then the saved matrix object can be deleted
matrix_content_repo.delete(matrix_hash)
Expand Down

0 comments on commit 9155566

Please sign in to comment.