Skip to content

Commit

Permalink
refactor(matrix): improve implementation of dataframe saving
Browse files Browse the repository at this point in the history
  • Loading branch information
laurent-laporte-pro committed Sep 28, 2023
1 parent 9155566 commit c156c3b
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 92 deletions.
5 changes: 4 additions & 1 deletion antarest/matrixstore/service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import json
import logging
import tempfile
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -184,7 +185,9 @@ def _file_importation(self, file: bytes, is_json: bool = False) -> str:
A SHA256 hash that identifies the imported matrix.
"""
if is_json:
return self.create(MatrixContent.parse_raw(file).data)
obj = json.loads(file)
content = MatrixContent(**obj)
return self.create(content.data)
# noinspection PyTypeChecker
matrix = np.loadtxt(BytesIO(file), delimiter="\t", dtype=np.float64, ndmin=2)
matrix = matrix.reshape((1, 0)) if matrix.size == 0 else matrix
Expand Down
24 changes: 11 additions & 13 deletions antarest/matrixstore/uri_resolver_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class UriResolverService:
def __init__(self, matrix_service: ISimpleMatrixService):
self.matrix_service = matrix_service

def resolve(self, uri: str, formatted: bool = True) -> Optional[SUB_JSON]:
def resolve(self, uri: str, formatted: bool = True) -> SUB_JSON:
res = UriResolverService._extract_uri_components(uri)
if res:
protocol, uuid = res
Expand Down Expand Up @@ -52,19 +52,17 @@ def _resolve_matrix(self, id: str, formatted: bool = True) -> SUB_JSON:
index=data.index,
columns=data.columns,
)
if not df.empty:
return (
df.to_csv(
None,
sep="\t",
header=False,
index=False,
float_format="%.6f",
)
or ""
)
else:
if df.empty:
return ""
else:
csv = df.to_csv(
None,
sep="\t",
header=False,
index=False,
float_format="%.6f",
)
return csv or ""
raise ValueError(f"id matrix {id} not found")

def build_matrix_uri(self, id: str) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,16 @@ def _dump_json(self, data: JSON) -> None:
matrix = pd.concat([time, matrix], axis=1)

head = self.head_writer.build(var=df.columns.size, end=df.index.size)
self.config.path.write_text(head)

matrix.to_csv(
open(self.config.path, "a", newline="\n"),
sep="\t",
index=False,
header=False,
line_terminator="\n",
)
with self.config.path.open(mode="w", newline="\n") as fd:
fd.write(head)
if not matrix.empty:
matrix.to_csv(
fd,
sep="\t",
header=False,
index=False,
float_format="%.6f",
)

def check_errors(
self,
Expand Down
215 changes: 149 additions & 66 deletions tests/matrixstore/test_service.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import datetime
import io
import json
import time
import typing as t
from unittest.mock import ANY, Mock
from zipfile import ZIP_DEFLATED, ZipFile
import zipfile

import numpy as np
import pytest
Expand All @@ -26,12 +27,14 @@
)
from antarest.matrixstore.service import MatrixService

MatrixType = t.List[t.List[float]]


class TestMatrixService:
def test_create__nominal_case(self, matrix_service: MatrixService):
def test_create__nominal_case(self, matrix_service: MatrixService) -> None:
"""Creates a new matrix object with the specified data."""
# when a matrix is created (inserted) in the service
data = [[1, 2, 3], [4, 5, 6]]
data: MatrixType = [[1, 2, 3], [4, 5, 6]]
matrix_id = matrix_service.create(data)

# A "real" hash value is calculated
Expand All @@ -52,7 +55,7 @@ def test_create__nominal_case(self, matrix_service: MatrixService):
now = datetime.datetime.utcnow()
assert now - datetime.timedelta(seconds=1) <= obj.created_at <= now

def test_create__from_numpy_array(self, matrix_service: MatrixService):
def test_create__from_numpy_array(self, matrix_service: MatrixService) -> None:
"""Creates a new matrix object with the specified data."""
# when a matrix is created (inserted) in the service
data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
Expand All @@ -76,13 +79,13 @@ def test_create__from_numpy_array(self, matrix_service: MatrixService):
now = datetime.datetime.utcnow()
assert now - datetime.timedelta(seconds=1) <= obj.created_at <= now

def test_create__side_effect(self, matrix_service: MatrixService):
def test_create__side_effect(self, matrix_service: MatrixService) -> None:
"""Creates a new matrix object with the specified data, but fail during saving."""
# if the matrix can't be created in the service
matrix_repo = matrix_service.repo
matrix_repo.save = Mock(side_effect=Exception("database error"))
with pytest.raises(Exception, match="database error"):
data = [[1, 2, 3], [4, 5, 6]]
data: MatrixType = [[1, 2, 3], [4, 5, 6]]
matrix_service.create(data)

# the associated matrix file must not be deleted
Expand All @@ -94,10 +97,10 @@ def test_create__side_effect(self, matrix_service: MatrixService):
with db():
assert not db.session.query(Matrix).count()

def test_get(self, matrix_service):
def test_get(self, matrix_service: MatrixService) -> None:
"""Get a matrix object from the database and the matrix content repository."""
# when a matrix is created (inserted) in the service
data = [[1, 2, 3], [4, 5, 6]]
data: MatrixType = [[1, 2, 3], [4, 5, 6]]
matrix_id = matrix_service.create(data)

# nominal_case: we can retrieve the matrix and its content
Expand All @@ -120,10 +123,10 @@ def test_get(self, matrix_service):
obj = matrix_service.get(missing_hash)
assert obj is None

def test_exists(self, matrix_service):
def test_exists(self, matrix_service: MatrixService) -> None:
"""Test the exists method."""
# when a matrix is created (inserted) in the service
data = [[1, 2, 3], [4, 5, 6]]
data: MatrixType = [[1, 2, 3], [4, 5, 6]]
matrix_id = matrix_service.create(data)

# nominal_case: we can retrieve the matrix and its content
Expand All @@ -132,10 +135,10 @@ def test_exists(self, matrix_service):
missing_hash = "8b1a9953c4611296a827abf8c47804d7e6c49c6b"
assert not matrix_service.exists(missing_hash)

def test_delete__nominal_case(self, matrix_service: MatrixService):
def test_delete__nominal_case(self, matrix_service: MatrixService) -> None:
"""Delete a matrix object from the matrix content repository and the database."""
# when a matrix is created (inserted) in the service
data = [[1, 2, 3], [4, 5, 6]]
data: MatrixType = [[1, 2, 3], [4, 5, 6]]
matrix_id = matrix_service.create(data)

# When the matrix id deleted
Expand All @@ -151,7 +154,7 @@ def test_delete__nominal_case(self, matrix_service: MatrixService):
with db():
assert not db.session.query(Matrix).count()

def test_delete__missing(self, matrix_service: MatrixService):
def test_delete__missing(self, matrix_service: MatrixService) -> None:
"""Delete a matrix object from the matrix content repository and the database."""
# When the matrix id deleted
with db():
Expand All @@ -167,8 +170,139 @@ def test_delete__missing(self, matrix_service: MatrixService):
with db():
assert not db.session.query(Matrix).count()

@pytest.mark.parametrize(
"data",
[
pytest.param([[1, 2, 3], [4, 5, 6]], id="classic-array"),
pytest.param([[]], id="2D-empty-array"),
],
)
@pytest.mark.parametrize("content_type", ["application/json", "text/plain"])
def test_create_by_importation__nominal_case(
self,
matrix_service: MatrixService,
data: MatrixType,
content_type: str,
) -> None:
"""
Create a new matrix by importing a file.
The file is either a JSON file or a CSV file.
"""
# Prepare the matrix data to import
matrix = np.array(data, dtype=np.float64)
if content_type == "application/json":
# JSON format of the array using the dataframe format
index = list(range(matrix.shape[0]))
columns = list(range(matrix.shape[1]))
content = json.dumps({"index": index, "columns": columns, "data": matrix.tolist()})
buffer = io.BytesIO(content.encode("utf-8"))
filename = "matrix.json"
json_format = True
else:
# CSV format of the array (without header)
buffer = io.BytesIO()
np.savetxt(buffer, matrix, delimiter="\t")
buffer.seek(0)
filename = "matrix.txt"
json_format = False

# Prepare a UploadFile object using the buffer
upload_file = _create_upload_file(filename=filename, file=buffer, content_type=content_type)

# when a matrix is created (inserted) in the service
info_list: t.Sequence[MatrixInfoDTO] = matrix_service.create_by_importation(upload_file, json=json_format)

# Then, check the list of created matrices
assert len(info_list) == 1
info = info_list[0]

def test_dataset_lifecycle():
# A "real" hash value is calculated
assert info.id, "ID can't be empty"

# The matrix is saved in the content repository as a TSV file
bucket_dir = matrix_service.matrix_content_repository.bucket_dir
content_path = bucket_dir.joinpath(f"{info.id}.tsv")
actual = np.loadtxt(content_path)
assert actual.all() == matrix.all()

# A matrix object is stored in the database
with db():
obj = matrix_service.repo.get(info.id)
assert obj is not None, f"Missing Matrix object {info.id}"
assert obj.width == matrix.shape[1]
assert obj.height == matrix.shape[0]
now = datetime.datetime.utcnow()
assert now - datetime.timedelta(seconds=1) <= obj.created_at <= now

@pytest.mark.parametrize("content_type", ["application/json", "text/plain"])
def test_create_by_importation__zip_file(self, matrix_service: MatrixService, content_type: str) -> None:
"""
Create a ZIP file with several matrices, using either a JSON format or a CSV format.
All matrices of the ZIP file use the same format.
Check that the matrices are correctly imported.
"""
# Prepare the matrix data to import
data_list: t.List[MatrixType] = [
[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9, 10, 11], [17, 18, 19, 20, 21], [27, 28, 29, 30, 31]],
[[]],
]
matrix_list: t.List[np.ndarray] = [np.array(data, dtype=np.float64) for data in data_list]
if content_type == "application/json":
# JSON format of the array using the dataframe format
index_list = [list(range(matrix.shape[0])) for matrix in matrix_list]
columns_list = [list(range(matrix.shape[1])) for matrix in matrix_list]
data_list = [matrix.tolist() for matrix in matrix_list]
content_list = [
json.dumps({"index": index, "columns": columns, "data": data}).encode("utf-8")
for index, columns, data in zip(index_list, columns_list, data_list)
]
json_format = True
else:
# CSV format of the array (without header)
content_list = []
for matrix in matrix_list:
buffer = io.BytesIO()
np.savetxt(buffer, matrix, delimiter="\t")
content_list.append(buffer.getvalue())
json_format = False

buffer = io.BytesIO()
with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
for i, content in enumerate(content_list):
suffix = {True: "json", False: "txt"}[json_format]
zf.writestr(f"matrix-{i:1d}.{suffix}", content)
buffer.seek(0)

# Prepare a UploadFile object using the buffer
upload_file = _create_upload_file(filename="matrices.zip", file=buffer, content_type="application/zip")

# When matrices are created (inserted) in the service
info_list: t.Sequence[MatrixInfoDTO] = matrix_service.create_by_importation(upload_file, json=json_format)

# Then, check the list of created matrices
assert len(info_list) == len(data_list)
for info, matrix in zip(info_list, matrix_list):
# A "real" hash value is calculated
assert info.id, "ID can't be empty"

# The matrix is saved in the content repository as a TSV file
bucket_dir = matrix_service.matrix_content_repository.bucket_dir
content_path = bucket_dir.joinpath(f"{info.id}.tsv")
actual = np.loadtxt(content_path)
assert actual.all() == matrix.all()

# A matrix object is stored in the database
with db():
obj = matrix_service.repo.get(info.id)
assert obj is not None, f"Missing Matrix object {info.id}"
assert obj.width == (matrix.shape[1] if matrix.size else 0)
assert obj.height == matrix.shape[0]
now = datetime.datetime.utcnow()
assert now - datetime.timedelta(seconds=1) <= obj.created_at <= now


def test_dataset_lifecycle() -> None:
content = Mock()
repo = Mock()
dataset_repo = Mock()
Expand Down Expand Up @@ -347,7 +481,7 @@ def test_dataset_lifecycle():
dataset_repo.delete.assert_called_once()


def _create_upload_file(filename: str, file: t.IO = None, content_type: str = "") -> UploadFile:
def _create_upload_file(filename: str, file: io.BytesIO, content_type: str = "") -> UploadFile:
if hasattr(UploadFile, "content_type"):
# `content_type` attribute was replace by a read-ony property in starlette-v0.24.
headers = Headers(headers={"content-type": content_type})
Expand All @@ -356,54 +490,3 @@ def _create_upload_file(filename: str, file: t.IO = None, content_type: str = ""
else:
# noinspection PyTypeChecker,PyArgumentList
return UploadFile(filename=filename, file=file, content_type=content_type)


def test_import():
# Init Mock
repo_content = Mock()
repo = Mock()

file_str = "1\t2\t3\t4\t5\n6\t7\t8\t9\t10"
matrix_content = str.encode(file_str)

# Expected
matrix_id = "123"
exp_matrix_info = [MatrixInfoDTO(id=matrix_id, name="matrix.txt")]
exp_matrix = Matrix(id=matrix_id, width=5, height=2)
# Test
service = MatrixService(
repo=repo,
repo_dataset=Mock(),
matrix_content_repository=repo_content,
file_transfer_manager=Mock(),
task_service=Mock(),
config=Mock(),
user_service=Mock(),
)
service.repo.get.return_value = None
service.matrix_content_repository.save.return_value = matrix_id
service.repo.save.return_value = exp_matrix

# CSV importation
matrix_file = _create_upload_file(
filename="matrix.txt",
file=io.BytesIO(matrix_content),
content_type="test/plain",
)
matrix = service.create_by_importation(matrix_file)
assert matrix[0].name == exp_matrix_info[0].name
assert matrix[0].id is not None

# Zip importation
zip_content = io.BytesIO()
with ZipFile(zip_content, "w", ZIP_DEFLATED) as output_data:
output_data.writestr("matrix.txt", file_str)

zip_content.seek(0)
zip_file = _create_upload_file(
filename="Matrix.zip",
file=zip_content,
content_type="application/zip",
)
matrix = service.create_by_importation(zip_file)
assert matrix == exp_matrix_info
Loading

0 comments on commit c156c3b

Please sign in to comment.