From 973251adb9a01c83ff2b7bea70127e1d539b2a97 Mon Sep 17 00:00:00 2001 From: Adrian D'Alessandro Date: Wed, 2 Aug 2023 18:41:15 +0100 Subject: [PATCH 1/2] Use orjson to encode numpy array in DSR GET method --- datahub/data.py | 3 ++- datahub/dsr.py | 27 +++++++++++++++++++++++++-- datahub/main.py | 15 ++++++--------- pyproject.toml | 1 + requirements-dev.txt | 2 ++ requirements.txt | 2 ++ tests/conftest.py | 8 +++----- tests/test_dsr_api.py | 31 +++++++++++++++++++------------ 8 files changed, 60 insertions(+), 29 deletions(-) diff --git a/datahub/data.py b/datahub/data.py index 68fa715..ea0b2b2 100644 --- a/datahub/data.py +++ b/datahub/data.py @@ -1,8 +1,9 @@ """This module defines the data structures for each of the models.""" import pandas as pd +from numpy.typing import NDArray from .opal import create_opal_frame opal_df: pd.DataFrame = create_opal_frame() -dsr_data: list[dict[str, str | list]] = [] # type: ignore[type-arg] +dsr_data: list[dict[str, NDArray | str]] = [] # type: ignore[type-arg] wesim_data: dict[str, dict] = {} # type: ignore[type-arg] diff --git a/datahub/dsr.py b/datahub/dsr.py index e8a615a..eebc691 100644 --- a/datahub/dsr.py +++ b/datahub/dsr.py @@ -1,4 +1,7 @@ """This module defines the data structures for the MEDUSA Demand Simulator model.""" +from typing import BinaryIO + +import h5py # type: ignore import numpy as np from fastapi import HTTPException from numpy.typing import NDArray @@ -35,7 +38,7 @@ class Config: allow_population_by_field_name = True -def validate_dsr_data(data: dict[str, NDArray]) -> None: +def validate_dsr_data(data: dict[str, NDArray | str]) -> None: """Validate the shapes of the arrays in the DSR data. Args: @@ -62,7 +65,7 @@ def validate_dsr_data(data: dict[str, NDArray]) -> None: if field: aliases.append(alias) continue - if field["type"] == "array": + if field["type"] == "array" and not isinstance(array, str): if array.shape != field["shape"] or not np.issubdtype( array.dtype, np.number ): @@ -72,3 +75,23 @@ def validate_dsr_data(data: dict[str, NDArray]) -> None: status_code=422, detail=f"Invalid size for: {', '.join(aliases)}.", ) + + +def read_dsr_file(file: BinaryIO) -> dict[str, NDArray | str]: + """Reads the HDF5 file that contains the DSR data into an in-memory dictionary. + + Args: + file (BinaryIO): A binary file-like object referencing the HDF5 file + + Returns: + The dictionary representation of the DSR Data. + """ + with h5py.File(file, "r") as h5file: + data = { + key: ( + value[...] if key not in ["Name", "Warn"] else str(value.asstr()[...]) + ) + for key, value in h5file.items() + } + + return data diff --git a/datahub/main.py b/datahub/main.py index 0993d37..9318def 100644 --- a/datahub/main.py +++ b/datahub/main.py @@ -1,10 +1,10 @@ """Script for running Datahub API.""" -import h5py # type: ignore from fastapi import FastAPI, HTTPException, UploadFile +from fastapi.responses import ORJSONResponse from . import data as dt from . import log -from .dsr import validate_dsr_data +from .dsr import read_dsr_file, validate_dsr_data from .opal import OpalArrayData, OpalModel from .wesim import get_wesim @@ -142,8 +142,7 @@ def upload_dsr(file: UploadFile) -> dict[str, str | None]: dict[str, str]: dictionary with the filename """ # noqa: D301 log.info("Recieved Opal data.") - with h5py.File(file.file, "r") as h5file: - data = {key: value[...] for key, value in h5file.items()} + data = read_dsr_file(file.file) validate_dsr_data(data) @@ -155,10 +154,8 @@ def upload_dsr(file: UploadFile) -> dict[str, str | None]: return {"filename": file.filename} -@app.get("/dsr") -def get_dsr_data( - start: int = 0, end: int | None = None -) -> dict[str, list]: # type: ignore[type-arg] +@app.get("/dsr", response_class=ORJSONResponse) +def get_dsr_data(start: int = 0, end: int | None = None) -> ORJSONResponse: """GET method function for getting DSR data as JSON. It takes optional query parameters of: @@ -193,7 +190,7 @@ def get_dsr_data( filtered_data = dt.dsr_data[start : end + 1 if end else end] log.debug(f"Filtered DSR data length:\n\n{len(dt.dsr_data)}") - return {"data": filtered_data} + return ORJSONResponse({"data": filtered_data}) @app.get("/wesim") diff --git a/pyproject.toml b/pyproject.toml index 0936a13..16b4a11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "uvicorn", "python-multipart", "h5py", + "orjson", ] [project.optional-dependencies] diff --git a/requirements-dev.txt b/requirements-dev.txt index 7667fb2..ec602ec 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -80,6 +80,8 @@ odfpy==1.4.1 # via pandas openpyxl==3.1.2 # via pandas +orjson==3.9.2 + # via datahub (pyproject.toml) packaging==23.1 # via # black diff --git a/requirements.txt b/requirements.txt index 882fdaa..00e139b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,6 +30,8 @@ odfpy==1.4.1 # via pandas openpyxl==3.1.2 # via pandas +orjson==3.9.2 + # via datahub (pyproject.toml) pandas[excel]==2.0.3 # via datahub (pyproject.toml) pydantic==1.10.10 diff --git a/tests/conftest.py b/tests/conftest.py index 0d41290..7a7ff4b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import pytest from fastapi.testclient import TestClient -from datahub.dsr import DSRModel +from datahub.dsr import DSRModel, read_dsr_file from datahub.main import app from datahub.opal import opal_headers @@ -41,9 +41,7 @@ def opal_data_array(): @pytest.fixture def dsr_data(dsr_data_path): """Pytest Fixture for DSR data as a dictionary.""" - with h5py.File(dsr_data_path, "r") as h5file: - data = {key: value[...] for key, value in h5file.items()} - return data + return read_dsr_file(dsr_data_path) @pytest.fixture @@ -65,7 +63,7 @@ def dsr_data_path(tmp_path): else: h5file[field.alias] = np.random.rand( *field.field_info.extra["shape"] - ).astype("float16") + ).astype("float32") # Return the path to the file return file_path diff --git a/tests/test_dsr_api.py b/tests/test_dsr_api.py index 1b2ca87..2874532 100644 --- a/tests/test_dsr_api.py +++ b/tests/test_dsr_api.py @@ -54,18 +54,25 @@ def test_post_dsr_api_invalid(dsr_data_path): assert len(dt.dsr_data) == 0 -def test_get_dsr_api(): +def test_get_dsr_api(dsr_data): """Tests DSR data GET method.""" - dt.dsr_data = [0, 1, 2, 3, 4, 5] + dt.dsr_data.append(dsr_data) response = client.get("/dsr") - assert response.json()["data"] == dt.dsr_data - - response = client.get("/dsr?start=2") - assert response.json()["data"] == dt.dsr_data[2:] - - response = client.get("/dsr?end=2") - assert response.json()["data"] == dt.dsr_data[:3] - - response = client.get("/dsr?start=1&end=2") - assert response.json()["data"] == dt.dsr_data[1:3] + response_data = response.json()["data"][0] + assert response_data.keys() == dsr_data.keys() + for (key, received), expected in zip( + response.json()["data"][0].items(), dt.dsr_data[0].values() + ): + if key in ["Name", "Warn"]: + assert received == expected + continue + assert np.allclose(received, expected) + + # Add another entry with changed data + new_data = dsr_data.copy() + new_data["Name"] = "A new entry" + dt.dsr_data.append(new_data) + + response = client.get("/dsr?start=1") + assert response.json()["data"][0]["Name"] == dt.dsr_data[1]["Name"] From e6701fe2d715411d48b6d68d8592c3d06635d4b4 Mon Sep 17 00:00:00 2001 From: Adrian D'Alessandro Date: Thu, 3 Aug 2023 14:33:21 +0100 Subject: [PATCH 2/2] Upgrade pip-tools for compatibility with new versions of pip --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index ec602ec..cb94635 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -93,7 +93,7 @@ pandas-stubs==2.0.2.230605 # via datahub (pyproject.toml) pathspec==0.11.2 # via black -pip-tools==6.14.0 +pip-tools==7.2.0 # via datahub (pyproject.toml) platformdirs==3.8.0 # via