Skip to content

Commit

Permalink
refactor(matrix-service): improve implementation of `create_by_import…
Browse files Browse the repository at this point in the history
…ation`
  • Loading branch information
laurent-laporte-pro committed Sep 28, 2023
1 parent c156c3b commit b054b8b
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 27 deletions.
63 changes: 44 additions & 19 deletions antarest/matrixstore/service.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import contextlib
import io
import json
import logging
import tempfile
import zipfile
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from io import BytesIO
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union
from zipfile import ZipFile

import numpy as np
from fastapi import UploadFile
Expand Down Expand Up @@ -37,6 +37,18 @@
)
from antarest.matrixstore.repository import MatrixContentRepository, MatrixDataSetRepository, MatrixRepository

# List of files to exclude from ZIP archives
EXCLUDED_FILES = {
"__MACOSX",
".DS_Store",
"._.DS_Store",
"Thumbs.db",
"desktop.ini",
"$RECYCLE.BIN",
"System Volume Information",
"RECYCLER",
}

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -151,29 +163,42 @@ def create(self, data: Union[List[List[MatrixData]], npt.NDArray[np.float64]]) -
self.repo.save(matrix)
return matrix_id

def create_by_importation(self, file: UploadFile, json: bool = False) -> List[MatrixInfoDTO]:
def create_by_importation(self, file: UploadFile, is_json: bool = False) -> List[MatrixInfoDTO]:
"""
Imports a matrix from a TSV or JSON file or a collection of matrices from a ZIP file.
TSV-formatted files are expected to contain only matrix data without any header.
JSON-formatted files are expected to contain the following attributes:
- `index`: The list of row labels.
- `columns`: The list of column labels.
- `data`: The matrix data as a nested list of floats.
Args:
file: The file to import (TSV, JSON or ZIP).
is_json: Flag indicating if the file is JSON-encoded.
Returns:
A list of `MatrixInfoDTO` objects containing the SHA256 hash of the imported matrices.
"""
with file.file as f:
if file.content_type == "application/zip":
input_zip = ZipFile(BytesIO(f.read()))
files = {
info.filename: input_zip.read(info.filename) for info in input_zip.infolist() if not info.is_dir()
}
with contextlib.closing(f):
buffer = io.BytesIO(f.read())
matrix_info: List[MatrixInfoDTO] = []
for name in files:
if all(
[
not name.startswith("__MACOSX/"),
not name.startswith(".DS_Store"),
]
):
matrix_id = self._file_importation(files[name], json)
matrix_info.append(MatrixInfoDTO(id=matrix_id, name=name))
with zipfile.ZipFile(buffer) as zf:
for info in zf.infolist():
if info.is_dir() or info.filename in EXCLUDED_FILES:
continue
matrix_id = self._file_importation(zf.read(info.filename), is_json=is_json)
matrix_info.append(MatrixInfoDTO(id=matrix_id, name=info.filename))
return matrix_info
else:
matrix_id = self._file_importation(f.read(), json)
matrix_id = self._file_importation(f.read(), is_json=is_json)
return [MatrixInfoDTO(id=matrix_id, name=file.filename)]

def _file_importation(self, file: bytes, is_json: bool = False) -> str:
def _file_importation(self, file: bytes, *, is_json: bool = False) -> str:
"""
Imports a matrix from a TSV or JSON file in bytes format.
Expand All @@ -189,7 +214,7 @@ def _file_importation(self, file: bytes, is_json: bool = False) -> str:
content = MatrixContent(**obj)
return self.create(content.data)
# noinspection PyTypeChecker
matrix = np.loadtxt(BytesIO(file), delimiter="\t", dtype=np.float64, ndmin=2)
matrix = np.loadtxt(io.BytesIO(file), delimiter="\t", dtype=np.float64, ndmin=2)
matrix = matrix.reshape((1, 0)) if matrix.size == 0 else matrix
return self.create(matrix)

Expand Down
2 changes: 1 addition & 1 deletion antarest/matrixstore/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def create_by_importation(
) -> Any:
logger.info("Importing new matrix dataset", extra={"user": current_user.id})
if current_user.id is not None:
return service.create_by_importation(file, json)
return service.create_by_importation(file, is_json=json)
raise UserHasNotPermissionError()

@bp.get("/matrix/{id}", tags=[APITag.matrix], response_model=MatrixDTO)
Expand Down
14 changes: 7 additions & 7 deletions tests/matrixstore/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import json
import time
import typing as t
from unittest.mock import ANY, Mock
import zipfile
from unittest.mock import ANY, Mock

import numpy as np
import pytest
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_create_by_importation__nominal_case(
) -> None:
"""
Create a new matrix by importing a file.
The file is either a JSON file or a CSV file.
The file is either a JSON file or a TSV file.
"""
# Prepare the matrix data to import
matrix = np.array(data, dtype=np.float64)
Expand All @@ -199,7 +199,7 @@ def test_create_by_importation__nominal_case(
filename = "matrix.json"
json_format = True
else:
# CSV format of the array (without header)
# TSV format of the array (without header)
buffer = io.BytesIO()
np.savetxt(buffer, matrix, delimiter="\t")
buffer.seek(0)
Expand All @@ -210,7 +210,7 @@ def test_create_by_importation__nominal_case(
upload_file = _create_upload_file(filename=filename, file=buffer, content_type=content_type)

# when a matrix is created (inserted) in the service
info_list: t.Sequence[MatrixInfoDTO] = matrix_service.create_by_importation(upload_file, json=json_format)
info_list: t.Sequence[MatrixInfoDTO] = matrix_service.create_by_importation(upload_file, is_json=json_format)

# Then, check the list of created matrices
assert len(info_list) == 1
Expand All @@ -237,7 +237,7 @@ def test_create_by_importation__nominal_case(
@pytest.mark.parametrize("content_type", ["application/json", "text/plain"])
def test_create_by_importation__zip_file(self, matrix_service: MatrixService, content_type: str) -> None:
"""
Create a ZIP file with several matrices, using either a JSON format or a CSV format.
Create a ZIP file with several matrices, using either a JSON format or a TSV format.
All matrices of the ZIP file use the same format.
Check that the matrices are correctly imported.
"""
Expand All @@ -259,7 +259,7 @@ def test_create_by_importation__zip_file(self, matrix_service: MatrixService, co
]
json_format = True
else:
# CSV format of the array (without header)
# TSV format of the array (without header)
content_list = []
for matrix in matrix_list:
buffer = io.BytesIO()
Expand All @@ -278,7 +278,7 @@ def test_create_by_importation__zip_file(self, matrix_service: MatrixService, co
upload_file = _create_upload_file(filename="matrices.zip", file=buffer, content_type="application/zip")

# When matrices are created (inserted) in the service
info_list: t.Sequence[MatrixInfoDTO] = matrix_service.create_by_importation(upload_file, json=json_format)
info_list: t.Sequence[MatrixInfoDTO] = matrix_service.create_by_importation(upload_file, is_json=json_format)

# Then, check the list of created matrices
assert len(info_list) == len(data_list)
Expand Down

0 comments on commit b054b8b

Please sign in to comment.