diff --git a/antarest/matrixstore/service.py b/antarest/matrixstore/service.py index 4f44a0a471..639084b587 100644 --- a/antarest/matrixstore/service.py +++ b/antarest/matrixstore/service.py @@ -1,13 +1,13 @@ import contextlib +import io import json import logging import tempfile +import zipfile from abc import ABC, abstractmethod from datetime import datetime, timezone -from io import BytesIO from pathlib import Path from typing import List, Optional, Sequence, Tuple, Union -from zipfile import ZipFile import numpy as np from fastapi import UploadFile @@ -37,6 +37,18 @@ ) from antarest.matrixstore.repository import MatrixContentRepository, MatrixDataSetRepository, MatrixRepository +# List of files to exclude from ZIP archives +EXCLUDED_FILES = { + "__MACOSX", + ".DS_Store", + "._.DS_Store", + "Thumbs.db", + "desktop.ini", + "$RECYCLE.BIN", + "System Volume Information", + "RECYCLER", +} + logger = logging.getLogger(__name__) @@ -151,29 +163,42 @@ def create(self, data: Union[List[List[MatrixData]], npt.NDArray[np.float64]]) - self.repo.save(matrix) return matrix_id - def create_by_importation(self, file: UploadFile, json: bool = False) -> List[MatrixInfoDTO]: + def create_by_importation(self, file: UploadFile, is_json: bool = False) -> List[MatrixInfoDTO]: + """ + Imports a matrix from a TSV or JSON file or a collection of matrices from a ZIP file. + + TSV-formatted files are expected to contain only matrix data without any header. + + JSON-formatted files are expected to contain the following attributes: + + - `index`: The list of row labels. + - `columns`: The list of column labels. + - `data`: The matrix data as a nested list of floats. + + Args: + file: The file to import (TSV, JSON or ZIP). + is_json: Flag indicating if the file is JSON-encoded. + + Returns: + A list of `MatrixInfoDTO` objects containing the SHA256 hash of the imported matrices. + """ with file.file as f: if file.content_type == "application/zip": - input_zip = ZipFile(BytesIO(f.read())) - files = { - info.filename: input_zip.read(info.filename) for info in input_zip.infolist() if not info.is_dir() - } + with contextlib.closing(f): + buffer = io.BytesIO(f.read()) matrix_info: List[MatrixInfoDTO] = [] - for name in files: - if all( - [ - not name.startswith("__MACOSX/"), - not name.startswith(".DS_Store"), - ] - ): - matrix_id = self._file_importation(files[name], json) - matrix_info.append(MatrixInfoDTO(id=matrix_id, name=name)) + with zipfile.ZipFile(buffer) as zf: + for info in zf.infolist(): + if info.is_dir() or info.filename in EXCLUDED_FILES: + continue + matrix_id = self._file_importation(zf.read(info.filename), is_json=is_json) + matrix_info.append(MatrixInfoDTO(id=matrix_id, name=info.filename)) return matrix_info else: - matrix_id = self._file_importation(f.read(), json) + matrix_id = self._file_importation(f.read(), is_json=is_json) return [MatrixInfoDTO(id=matrix_id, name=file.filename)] - def _file_importation(self, file: bytes, is_json: bool = False) -> str: + def _file_importation(self, file: bytes, *, is_json: bool = False) -> str: """ Imports a matrix from a TSV or JSON file in bytes format. @@ -189,7 +214,7 @@ def _file_importation(self, file: bytes, is_json: bool = False) -> str: content = MatrixContent(**obj) return self.create(content.data) # noinspection PyTypeChecker - matrix = np.loadtxt(BytesIO(file), delimiter="\t", dtype=np.float64, ndmin=2) + matrix = np.loadtxt(io.BytesIO(file), delimiter="\t", dtype=np.float64, ndmin=2) matrix = matrix.reshape((1, 0)) if matrix.size == 0 else matrix return self.create(matrix) diff --git a/antarest/matrixstore/web.py b/antarest/matrixstore/web.py index a97ae7d45b..4b47135b52 100644 --- a/antarest/matrixstore/web.py +++ b/antarest/matrixstore/web.py @@ -55,7 +55,7 @@ def create_by_importation( ) -> Any: logger.info("Importing new matrix dataset", extra={"user": current_user.id}) if current_user.id is not None: - return service.create_by_importation(file, json) + return service.create_by_importation(file, is_json=json) raise UserHasNotPermissionError() @bp.get("/matrix/{id}", tags=[APITag.matrix], response_model=MatrixDTO) diff --git a/tests/matrixstore/test_service.py b/tests/matrixstore/test_service.py index bcef2c96b6..db26e6403a 100644 --- a/tests/matrixstore/test_service.py +++ b/tests/matrixstore/test_service.py @@ -3,8 +3,8 @@ import json import time import typing as t -from unittest.mock import ANY, Mock import zipfile +from unittest.mock import ANY, Mock import numpy as np import pytest @@ -186,7 +186,7 @@ def test_create_by_importation__nominal_case( ) -> None: """ Create a new matrix by importing a file. - The file is either a JSON file or a CSV file. + The file is either a JSON file or a TSV file. """ # Prepare the matrix data to import matrix = np.array(data, dtype=np.float64) @@ -199,7 +199,7 @@ def test_create_by_importation__nominal_case( filename = "matrix.json" json_format = True else: - # CSV format of the array (without header) + # TSV format of the array (without header) buffer = io.BytesIO() np.savetxt(buffer, matrix, delimiter="\t") buffer.seek(0) @@ -210,7 +210,7 @@ def test_create_by_importation__nominal_case( upload_file = _create_upload_file(filename=filename, file=buffer, content_type=content_type) # when a matrix is created (inserted) in the service - info_list: t.Sequence[MatrixInfoDTO] = matrix_service.create_by_importation(upload_file, json=json_format) + info_list: t.Sequence[MatrixInfoDTO] = matrix_service.create_by_importation(upload_file, is_json=json_format) # Then, check the list of created matrices assert len(info_list) == 1 @@ -237,7 +237,7 @@ def test_create_by_importation__nominal_case( @pytest.mark.parametrize("content_type", ["application/json", "text/plain"]) def test_create_by_importation__zip_file(self, matrix_service: MatrixService, content_type: str) -> None: """ - Create a ZIP file with several matrices, using either a JSON format or a CSV format. + Create a ZIP file with several matrices, using either a JSON format or a TSV format. All matrices of the ZIP file use the same format. Check that the matrices are correctly imported. """ @@ -259,7 +259,7 @@ def test_create_by_importation__zip_file(self, matrix_service: MatrixService, co ] json_format = True else: - # CSV format of the array (without header) + # TSV format of the array (without header) content_list = [] for matrix in matrix_list: buffer = io.BytesIO() @@ -278,7 +278,7 @@ def test_create_by_importation__zip_file(self, matrix_service: MatrixService, co upload_file = _create_upload_file(filename="matrices.zip", file=buffer, content_type="application/zip") # When matrices are created (inserted) in the service - info_list: t.Sequence[MatrixInfoDTO] = matrix_service.create_by_importation(upload_file, json=json_format) + info_list: t.Sequence[MatrixInfoDTO] = matrix_service.create_by_importation(upload_file, is_json=json_format) # Then, check the list of created matrices assert len(info_list) == len(data_list)