From 9155566ec6b2d52c97c4d9fc5bbefbada20a91bf Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Thu, 14 Sep 2023 18:44:20 +0200 Subject: [PATCH] fix(matrix): create empty file if the matrix is empty --- antarest/matrixstore/repository.py | 29 +++++---- antarest/matrixstore/service.py | 19 ++++-- antarest/study/service.py | 1 + antarest/tools/lib.py | 1 + tests/matrixstore/test_repository.py | 93 +++++++++++++++++----------- 5 files changed, 90 insertions(+), 53 deletions(-) diff --git a/antarest/matrixstore/repository.py b/antarest/matrixstore/repository.py index f96e80de67..6301e39c7f 100644 --- a/antarest/matrixstore/repository.py +++ b/antarest/matrixstore/repository.py @@ -1,7 +1,7 @@ import hashlib import logging +import typing as t from pathlib import Path -from typing import List, Optional, Union import numpy as np from filelock import FileLock @@ -31,19 +31,19 @@ def save(self, matrix_user_metadata: MatrixDataSet) -> MatrixDataSet: logger.debug(f"Matrix dataset {matrix_user_metadata.id} for user {matrix_user_metadata.owner_id} saved") return matrix_user_metadata - def get(self, id: str) -> Optional[MatrixDataSet]: + def get(self, id: str) -> t.Optional[MatrixDataSet]: matrix: MatrixDataSet = db.session.query(MatrixDataSet).get(id) return matrix - def get_all_datasets(self) -> List[MatrixDataSet]: - matrix_datasets: List[MatrixDataSet] = db.session.query(MatrixDataSet).all() + def get_all_datasets(self) -> t.List[MatrixDataSet]: + matrix_datasets: t.List[MatrixDataSet] = db.session.query(MatrixDataSet).all() return matrix_datasets def query( self, - name: Optional[str], - owner: Optional[int] = None, - ) -> List[MatrixDataSet]: + name: t.Optional[str], + owner: t.Optional[int] = None, + ) -> t.List[MatrixDataSet]: """ Query a list of MatrixUserMetadata by searching for each one separately if a set of filter match @@ -59,7 +59,7 @@ def query( query = query.filter(MatrixDataSet.name.ilike(f"%{name}%")) # type: ignore if owner is not None: query = query.filter(MatrixDataSet.owner_id == owner) - datasets: List[MatrixDataSet] = query.distinct().all() + datasets: t.List[MatrixDataSet] = query.distinct().all() return datasets def delete(self, dataset_id: str) -> None: @@ -83,7 +83,7 @@ def save(self, matrix: Matrix) -> Matrix: logger.debug(f"Matrix {matrix.id} saved") return matrix - def get(self, matrix_hash: str) -> Optional[Matrix]: + def get(self, matrix_hash: str) -> t.Optional[Matrix]: matrix: Matrix = db.session.query(Matrix).get(matrix_hash) return matrix @@ -130,6 +130,7 @@ def get(self, matrix_hash: str) -> MatrixContent: matrix_file = self.bucket_dir.joinpath(f"{matrix_hash}.tsv") matrix = np.loadtxt(matrix_file, delimiter="\t", dtype=np.float64, ndmin=2) + matrix = matrix.reshape((1, 0)) if matrix.size == 0 else matrix data = matrix.tolist() index = list(range(matrix.shape[0])) columns = list(range(matrix.shape[1])) @@ -148,7 +149,7 @@ def exists(self, matrix_hash: str) -> bool: matrix_file = self.bucket_dir.joinpath(f"{matrix_hash}.tsv") return matrix_file.exists() - def save(self, content: Union[List[List[MatrixData]], npt.NDArray[np.float64]]) -> str: + def save(self, content: t.Union[t.List[t.List[MatrixData]], npt.NDArray[np.float64]]) -> str: """ Saves the content of a matrix as a TSV file in the bucket directory and returns its SHA256 hash. @@ -188,8 +189,12 @@ def save(self, content: Union[List[List[MatrixData]], npt.NDArray[np.float64]]) # Ensure exclusive access to the matrix file between multiple processes (or threads). lock_file = matrix_file.with_suffix(".tsv.lock") with FileLock(lock_file, timeout=15): - # noinspection PyTypeChecker - np.savetxt(matrix_file, matrix, delimiter="\t", fmt="%.18f") + if matrix.size == 0: + # If the array or dataframe is empty, create an empty file instead of + # traditional saving to avoid unwanted line breaks. + open(matrix_file, mode="wb").close() + else: + np.savetxt(matrix_file, matrix, delimiter="\t", fmt="%.18f") # IMPORTANT: Deleting the lock file under Linux can make locking unreliable. # See https://github.com/tox-dev/py-filelock/issues/31 diff --git a/antarest/matrixstore/service.py b/antarest/matrixstore/service.py index c20b0197dc..d10b92ff50 100644 --- a/antarest/matrixstore/service.py +++ b/antarest/matrixstore/service.py @@ -187,6 +187,7 @@ def _file_importation(self, file: bytes, is_json: bool = False) -> str: return self.create(MatrixContent.parse_raw(file).data) # noinspection PyTypeChecker matrix = np.loadtxt(BytesIO(file), delimiter="\t", dtype=np.float64, ndmin=2) + matrix = matrix.reshape((1, 0)) if matrix.size == 0 else matrix return self.create(matrix) def get_dataset( @@ -380,8 +381,13 @@ def create_matrix_files(self, matrix_ids: Sequence[str], export_path: Path) -> s name = f"matrix-{mtx.id}.txt" filepath = f"{tmpdir}/{name}" array = np.array(mtx.data, dtype=np.float64) - # noinspection PyTypeChecker - np.savetxt(filepath, array, delimiter="\t", fmt="%.18f") + if array.size == 0: + # If the array or dataframe is empty, create an empty file instead of + # traditional saving to avoid unwanted line breaks. + open(filepath, mode="wb").close() + else: + # noinspection PyTypeChecker + np.savetxt(filepath, array, delimiter="\t", fmt="%.18f") zip_dir(Path(tmpdir), export_path) stopwatch.log_elapsed(lambda x: logger.info(f"Matrix dataset exported (zipped mode) in {x}s")) return str(export_path) @@ -467,5 +473,10 @@ def download_matrix( raise UserHasNotPermissionError() if matrix := self.get(matrix_id): array = np.array(matrix.data, dtype=np.float64) - # noinspection PyTypeChecker - np.savetxt(filepath, array, delimiter="\t", fmt="%.18f") + if array.size == 0: + # If the array or dataframe is empty, create an empty file instead of + # traditional saving to avoid unwanted line breaks. + open(filepath, mode="wb").close() + else: + # noinspection PyTypeChecker + np.savetxt(filepath, array, delimiter="\t", fmt="%.18f") diff --git a/antarest/study/service.py b/antarest/study/service.py index 332fc384da..b9ec491c18 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -1378,6 +1378,7 @@ def _create_edit_study_command( if isinstance(data, bytes): # noinspection PyTypeChecker matrix = np.loadtxt(io.BytesIO(data), delimiter="\t", dtype=np.float64, ndmin=2) + matrix = matrix.reshape((1, 0)) if matrix.size == 0 else matrix return ReplaceMatrix( target=url, matrix=matrix.tolist(), diff --git a/antarest/tools/lib.py b/antarest/tools/lib.py index b27e4dcee4..2d2953e3f5 100644 --- a/antarest/tools/lib.py +++ b/antarest/tools/lib.py @@ -77,6 +77,7 @@ def apply_commands( matrix_dataset: List[str] = [] for matrix_file in matrices_dir.iterdir(): matrix = np.loadtxt(matrix_file, delimiter="\t", dtype=np.float64, ndmin=2) + matrix = matrix.reshape((1, 0)) if matrix.size == 0 else matrix matrix_data = matrix.tolist() res = self.session.post(self.build_url("/v1/matrix"), json=matrix_data) res.raise_for_status() diff --git a/tests/matrixstore/test_repository.py b/tests/matrixstore/test_repository.py index 9d1254953a..3973a18d39 100644 --- a/tests/matrixstore/test_repository.py +++ b/tests/matrixstore/test_repository.py @@ -1,9 +1,10 @@ +import typing as t from datetime import datetime from pathlib import Path -from typing import Optional import numpy as np import pytest +from numpy import typing as npt from antarest.core.config import Config, SecurityConfig from antarest.core.utils.fastapi_sqlalchemy import db @@ -12,16 +13,15 @@ from antarest.matrixstore.model import Matrix, MatrixContent, MatrixDataSet, MatrixDataSetRelation from antarest.matrixstore.repository import MatrixContentRepository, MatrixDataSetRepository, MatrixRepository +ArrayData = t.Union[t.List[t.List[float]], npt.NDArray[np.float64]] + class TestMatrixRepository: - def test_db_lifecycle(self): + def test_db_lifecycle(self) -> None: with db(): # sourcery skip: extract-method repo = MatrixRepository() - m = Matrix( - id="hello", - created_at=datetime.now(), - ) + m = Matrix(id="hello", created_at=datetime.now()) repo.save(m) assert m.id assert m == repo.get(m.id) @@ -29,11 +29,11 @@ def test_db_lifecycle(self): repo.delete(m.id) assert repo.get(m.id) is None - def test_bucket_lifecycle(self, tmp_path: Path): + def test_bucket_lifecycle(self, tmp_path: Path) -> None: repo = MatrixContentRepository(tmp_path) - a = [[1, 2], [3, 4]] - b = [[5, 6], [7, 8]] + a: ArrayData = [[1, 2], [3, 4]] + b: ArrayData = [[5, 6], [7, 8]] matrix_content_a = MatrixContent(data=a, index=[0, 1], columns=[0, 1]) matrix_content_b = MatrixContent(data=b, index=[0, 1], columns=[0, 1]) @@ -51,7 +51,7 @@ def test_bucket_lifecycle(self, tmp_path: Path): with pytest.raises(FileNotFoundError): repo.get(aid) - def test_dataset(self): + def test_dataset(self) -> None: with db(): # sourcery skip: extract-duplicate-method, extract-method repo = MatrixRepository() @@ -66,15 +66,9 @@ def test_dataset(self): dataset_repo = MatrixDataSetRepository() - m1 = Matrix( - id="hello", - created_at=datetime.now(), - ) + m1 = Matrix(id="hello", created_at=datetime.now()) repo.save(m1) - m2 = Matrix( - id="world", - created_at=datetime.now(), - ) + m2 = Matrix(id="world", created_at=datetime.now()) repo.save(m2) dataset = MatrixDataSet( @@ -94,7 +88,7 @@ def test_dataset(self): dataset.matrices.append(matrix_relation) dataset = dataset_repo.save(dataset) - dataset_query_result: Optional[MatrixDataSet] = dataset_repo.get(dataset.id) + dataset_query_result = dataset_repo.get(dataset.id) assert dataset_query_result is not None assert dataset_query_result.name == "some name" assert len(dataset_query_result.matrices) == 2 @@ -106,12 +100,12 @@ def test_dataset(self): updated_at=datetime.now(), ) dataset_repo.save(dataset_update) - dataset_query_result: Optional[MatrixDataSet] = dataset_repo.get(dataset.id) + dataset_query_result = dataset_repo.get(dataset.id) assert dataset_query_result is not None assert dataset_query_result.name == "some name change" assert dataset_query_result.owner_id == user.id - def test_datastore_query(self): + def test_datastore_query(self) -> None: # sourcery skip: extract-duplicate-method with db(): user_repo = UserRepository(Config(security=SecurityConfig())) @@ -121,15 +115,9 @@ def test_datastore_query(self): user2 = user_repo.save(User(name="hello", password=Password("world"))) repo = MatrixRepository() - m1 = Matrix( - id="hello", - created_at=datetime.now(), - ) + m1 = Matrix(id="hello", created_at=datetime.now()) repo.save(m1) - m2 = Matrix( - id="world", - created_at=datetime.now(), - ) + m2 = Matrix(id="world", created_at=datetime.now()) repo.save(m2) dataset_repo = MatrixDataSetRepository() @@ -176,14 +164,19 @@ def test_datastore_query(self): assert repo.get(m1.id) is not None assert ( len( - db.session.query(MatrixDataSetRelation).filter(MatrixDataSetRelation.dataset_id == dataset.id).all() + # fmt: off + db.session + .query(MatrixDataSetRelation) + .filter(MatrixDataSetRelation.dataset_id == dataset.id) + .all() + # fmt: on ) == 0 ) class TestMatrixContentRepository: - def test_save(self, matrix_content_repo: MatrixContentRepository): + def test_save(self, matrix_content_repo: MatrixContentRepository) -> None: """ Saves the content of a matrix as a TSV file in the directory and returns its SHA256 hash. @@ -192,6 +185,7 @@ def test_save(self, matrix_content_repo: MatrixContentRepository): bucket_dir = matrix_content_repo.bucket_dir # when the data is saved in the repo + data: ArrayData data = [[1, 2, 3], [4, 5, 6]] matrix_hash = matrix_content_repo.save(data) # then a TSV file is created in the repo directory @@ -224,12 +218,37 @@ def test_save(self, matrix_content_repo: MatrixContentRepository): other_matrix_file = bucket_dir.joinpath(f"{other_matrix_hash}.tsv") assert set(matrix_files) == {matrix_file, other_matrix_file} - def test_get(self, matrix_content_repo): + def test_save_and_retrieve_empty_matrix(self, matrix_content_repo: MatrixContentRepository) -> None: + """ + Test saving and retrieving empty matrices as TSV files. + Il all cases the file must be empty. + """ + bucket_dir = matrix_content_repo.bucket_dir + + # Test with an empty matrix + empty_array: ArrayData = [] + matrix_hash = matrix_content_repo.save(empty_array) + matrix_file = bucket_dir.joinpath(f"{matrix_hash}.tsv") + retrieved_matrix = matrix_content_repo.get(matrix_hash) + + assert not matrix_file.read_bytes() + assert retrieved_matrix.data == [[]] + + # Test with an empty 2D array + empty_2d_array: ArrayData = [[]] + matrix_hash = matrix_content_repo.save(empty_2d_array) + matrix_file = bucket_dir.joinpath(f"{matrix_hash}.tsv") + retrieved_matrix = matrix_content_repo.get(matrix_hash) + + assert not matrix_file.read_bytes() + assert retrieved_matrix.data == [[]] + + def test_get(self, matrix_content_repo: MatrixContentRepository) -> None: """ Retrieves the content of a matrix with a given SHA256 hash. """ # when the data is saved in the repo - data = [[1, 2, 3], [4, 5, 6]] + data: ArrayData = [[1, 2, 3], [4, 5, 6]] matrix_hash = matrix_content_repo.save(data) # then the saved matrix object can be retrieved content = matrix_content_repo.get(matrix_hash) @@ -243,12 +262,12 @@ def test_get(self, matrix_content_repo): missing_hash = "8b1a9953c4611296a827abf8c47804d7e6c49c6b" matrix_content_repo.get(missing_hash) - def test_exists(self, matrix_content_repo): + def test_exists(self, matrix_content_repo: MatrixContentRepository) -> None: """ Checks if a matrix with a given SHA256 hash exists in the directory. """ # when the data is saved in the repo - data = [[1, 2, 3], [4, 5, 6]] + data: ArrayData = [[1, 2, 3], [4, 5, 6]] matrix_hash = matrix_content_repo.save(data) # then the saved matrix object exists assert matrix_content_repo.exists(matrix_hash) @@ -258,12 +277,12 @@ def test_exists(self, matrix_content_repo): missing_hash = "8b1a9953c4611296a827abf8c47804d7e6c49c6b" assert not matrix_content_repo.exists(missing_hash) - def test_delete(self, matrix_content_repo): + def test_delete(self, matrix_content_repo: MatrixContentRepository) -> None: """ Deletes the tsv file containing the content of a matrix with a given SHA256 hash. """ # when the data is saved in the repo - data = [[1, 2, 3], [4, 5, 6]] + data: ArrayData = [[1, 2, 3], [4, 5, 6]] matrix_hash = matrix_content_repo.save(data) # then the saved matrix object can be deleted matrix_content_repo.delete(matrix_hash)