diff --git a/datahub/main.py b/datahub/main.py index 2364fa6..c6f263c 100644 --- a/datahub/main.py +++ b/datahub/main.py @@ -1,4 +1,5 @@ """Script for running Datahub API.""" +import numpy as np from fastapi import FastAPI, HTTPException, UploadFile from fastapi.responses import ORJSONResponse @@ -208,19 +209,25 @@ def get_dsr_data( message = "One or more of the specified columns are invalid." log.error(message) raise HTTPException(status_code=400, detail=message) - - log.info("Filtering data by column...") - filtered_data = [] - for frame in filtered_index_data: - filtered_keys = {} - for key in frame.keys(): - if dsr_headers[key.title()] in columns: - filtered_keys[key] = frame[key] - filtered_data.append(filtered_keys) - - return ORJSONResponse({"data": filtered_data}) - - return ORJSONResponse({"data": filtered_index_data}) + else: + columns = list(dsr_headers.values()) + + log.info("Filtering data by column...") + filtered_data = [] + for frame in filtered_index_data: + filtered_keys = {} + for key, value in frame.items(): + if dsr_headers[key.title()] not in columns: + continue + elif not isinstance(value, str) and np.issubdtype( + value.dtype, np.character + ): + filtered_keys[key] = value.astype(str).tolist() + else: + filtered_keys[key] = value + filtered_data.append(filtered_keys) + + return ORJSONResponse({"data": filtered_data}) @app.get("/wesim") diff --git a/tests/conftest.py b/tests/conftest.py index 00f2075..daba1c1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -61,7 +61,8 @@ def dsr_data_path(tmp_path): shape = field.field_info.extra["shape"] if shape[0] is None: shape = (10, shape[1]) - h5file[field.alias] = np.random.rand(*shape).astype("float32") + dtype = "|S13" if field.alias == "Activity Types" else "float32" + h5file[field.alias] = np.random.rand(*shape).astype(dtype) # Return the path to the file return file_path diff --git a/tests/test_dsr_api.py b/tests/test_dsr_api.py index 768f1d4..71a71d1 100644 --- a/tests/test_dsr_api.py +++ b/tests/test_dsr_api.py @@ -69,6 +69,9 @@ def test_get_dsr_api(dsr_data): if key in ["Name", "Warn"]: assert received == expected continue + if key in ["Activity Types"]: + assert received == expected.astype(str).tolist() + continue assert np.allclose(received, expected) # Add another entry with changed data