Skip to content

Commit

Permalink
Merge pull request #118 from ImperialCollegeLondon/fix-dsr-get
Browse files Browse the repository at this point in the history
Use orjson to encode numpy array in DSR GET method
  • Loading branch information
AdrianDAlessandro authored Aug 4, 2023
2 parents 5fce417 + e6701fe commit 81b2724
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 30 deletions.
3 changes: 2 additions & 1 deletion datahub/data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""This module defines the data structures for each of the models."""
import pandas as pd
from numpy.typing import NDArray

from .opal import create_opal_frame

opal_df: pd.DataFrame = create_opal_frame()
dsr_data: list[dict[str, str | list]] = [] # type: ignore[type-arg]
dsr_data: list[dict[str, NDArray | str]] = [] # type: ignore[type-arg]
wesim_data: dict[str, dict] = {} # type: ignore[type-arg]
27 changes: 25 additions & 2 deletions datahub/dsr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""This module defines the data structures for the MEDUSA Demand Simulator model."""
from typing import BinaryIO

import h5py # type: ignore
import numpy as np
from fastapi import HTTPException
from numpy.typing import NDArray
Expand Down Expand Up @@ -35,7 +38,7 @@ class Config:
allow_population_by_field_name = True


def validate_dsr_data(data: dict[str, NDArray]) -> None:
def validate_dsr_data(data: dict[str, NDArray | str]) -> None:
"""Validate the shapes of the arrays in the DSR data.
Args:
Expand All @@ -62,7 +65,7 @@ def validate_dsr_data(data: dict[str, NDArray]) -> None:
if field:
aliases.append(alias)
continue
if field["type"] == "array":
if field["type"] == "array" and not isinstance(array, str):
if array.shape != field["shape"] or not np.issubdtype(
array.dtype, np.number
):
Expand All @@ -72,3 +75,23 @@ def validate_dsr_data(data: dict[str, NDArray]) -> None:
status_code=422,
detail=f"Invalid size for: {', '.join(aliases)}.",
)


def read_dsr_file(file: BinaryIO) -> dict[str, NDArray | str]:
"""Reads the HDF5 file that contains the DSR data into an in-memory dictionary.
Args:
file (BinaryIO): A binary file-like object referencing the HDF5 file
Returns:
The dictionary representation of the DSR Data.
"""
with h5py.File(file, "r") as h5file:
data = {
key: (
value[...] if key not in ["Name", "Warn"] else str(value.asstr()[...])
)
for key, value in h5file.items()
}

return data
15 changes: 6 additions & 9 deletions datahub/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Script for running Datahub API."""
import h5py # type: ignore
from fastapi import FastAPI, HTTPException, UploadFile
from fastapi.responses import ORJSONResponse

from . import data as dt
from . import log
from .dsr import validate_dsr_data
from .dsr import read_dsr_file, validate_dsr_data
from .opal import OpalArrayData, OpalModel
from .wesim import get_wesim

Expand Down Expand Up @@ -142,8 +142,7 @@ def upload_dsr(file: UploadFile) -> dict[str, str | None]:
dict[str, str]: dictionary with the filename
""" # noqa: D301
log.info("Recieved Opal data.")
with h5py.File(file.file, "r") as h5file:
data = {key: value[...] for key, value in h5file.items()}
data = read_dsr_file(file.file)

validate_dsr_data(data)

Expand All @@ -155,10 +154,8 @@ def upload_dsr(file: UploadFile) -> dict[str, str | None]:
return {"filename": file.filename}


@app.get("/dsr")
def get_dsr_data(
start: int = 0, end: int | None = None
) -> dict[str, list]: # type: ignore[type-arg]
@app.get("/dsr", response_class=ORJSONResponse)
def get_dsr_data(start: int = 0, end: int | None = None) -> ORJSONResponse:
"""GET method function for getting DSR data as JSON.
It takes optional query parameters of:
Expand Down Expand Up @@ -193,7 +190,7 @@ def get_dsr_data(
filtered_data = dt.dsr_data[start : end + 1 if end else end]
log.debug(f"Filtered DSR data length:\n\n{len(dt.dsr_data)}")

return {"data": filtered_data}
return ORJSONResponse({"data": filtered_data})


@app.get("/wesim")
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"uvicorn",
"python-multipart",
"h5py",
"orjson",
]

[project.optional-dependencies]
Expand Down
4 changes: 3 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ odfpy==1.4.1
# via pandas
openpyxl==3.1.2
# via pandas
orjson==3.9.2
# via datahub (pyproject.toml)
packaging==23.1
# via
# black
Expand All @@ -91,7 +93,7 @@ pandas-stubs==2.0.2.230605
# via datahub (pyproject.toml)
pathspec==0.11.2
# via black
pip-tools==6.14.0
pip-tools==7.2.0
# via datahub (pyproject.toml)
platformdirs==3.8.0
# via
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ odfpy==1.4.1
# via pandas
openpyxl==3.1.2
# via pandas
orjson==3.9.2
# via datahub (pyproject.toml)
pandas[excel]==2.0.3
# via datahub (pyproject.toml)
pydantic==1.10.10
Expand Down
8 changes: 3 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from fastapi.testclient import TestClient

from datahub.dsr import DSRModel
from datahub.dsr import DSRModel, read_dsr_file
from datahub.main import app
from datahub.opal import opal_headers

Expand Down Expand Up @@ -41,9 +41,7 @@ def opal_data_array():
@pytest.fixture
def dsr_data(dsr_data_path):
"""Pytest Fixture for DSR data as a dictionary."""
with h5py.File(dsr_data_path, "r") as h5file:
data = {key: value[...] for key, value in h5file.items()}
return data
return read_dsr_file(dsr_data_path)


@pytest.fixture
Expand All @@ -65,7 +63,7 @@ def dsr_data_path(tmp_path):
else:
h5file[field.alias] = np.random.rand(
*field.field_info.extra["shape"]
).astype("float16")
).astype("float32")

# Return the path to the file
return file_path
Expand Down
31 changes: 19 additions & 12 deletions tests/test_dsr_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,25 @@ def test_post_dsr_api_invalid(dsr_data_path):
assert len(dt.dsr_data) == 0


def test_get_dsr_api():
def test_get_dsr_api(dsr_data):
"""Tests DSR data GET method."""
dt.dsr_data = [0, 1, 2, 3, 4, 5]
dt.dsr_data.append(dsr_data)

response = client.get("/dsr")
assert response.json()["data"] == dt.dsr_data

response = client.get("/dsr?start=2")
assert response.json()["data"] == dt.dsr_data[2:]

response = client.get("/dsr?end=2")
assert response.json()["data"] == dt.dsr_data[:3]

response = client.get("/dsr?start=1&end=2")
assert response.json()["data"] == dt.dsr_data[1:3]
response_data = response.json()["data"][0]
assert response_data.keys() == dsr_data.keys()
for (key, received), expected in zip(
response.json()["data"][0].items(), dt.dsr_data[0].values()
):
if key in ["Name", "Warn"]:
assert received == expected
continue
assert np.allclose(received, expected)

# Add another entry with changed data
new_data = dsr_data.copy()
new_data["Name"] = "A new entry"
dt.dsr_data.append(new_data)

response = client.get("/dsr?start=1")
assert response.json()["data"][0]["Name"] == dt.dsr_data[1]["Name"]

0 comments on commit 81b2724

Please sign in to comment.