diff --git a/src/deepsparse/server/helpers.py b/src/deepsparse/server/helpers.py index 0a86748af7..2163365500 100644 --- a/src/deepsparse/server/helpers.py +++ b/src/deepsparse/server/helpers.py @@ -13,7 +13,7 @@ # limitations under the License. from http import HTTPStatus -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional, Union import numpy from pydantic import BaseModel @@ -75,23 +75,35 @@ def server_logger_from_config(config: ServerConfig) -> BaseLogger: ) -def prep_outputs_for_serialization(pipeline_outputs: Any): +def prep_outputs_for_serialization( + pipeline_outputs: Union[BaseModel, numpy.ndarray, list] +) -> Union[BaseModel, list]: """ Prepares a pipeline output for JSON serialization by converting any numpy array field to a list. For large numpy arrays, this operation will take a while to run. - :param pipeline_outputs: output data to clean - :return: cleaned pipeline_outputs + :param pipeline_outputs: output data to that is to be processed before + serialisation. Nested objects are supported. + :return: Pipeline_outputs with potential numpy arrays + converted to lists """ if isinstance(pipeline_outputs, BaseModel): for field_name in pipeline_outputs.__fields__.keys(): field_value = getattr(pipeline_outputs, field_name) - if isinstance(field_value, numpy.ndarray): - # numpy arrays aren't JSON serializable - setattr(pipeline_outputs, field_name, field_value.tolist()) + if isinstance(field_value, (numpy.ndarray, BaseModel, list)): + setattr( + pipeline_outputs, + field_name, + prep_outputs_for_serialization(field_value), + ) + elif isinstance(pipeline_outputs, numpy.ndarray): pipeline_outputs = pipeline_outputs.tolist() + elif isinstance(pipeline_outputs, list): + for i, value in enumerate(pipeline_outputs): + pipeline_outputs[i] = prep_outputs_for_serialization(value) + return pipeline_outputs diff --git a/tests/server/test_helpers.py b/tests/server/test_helpers.py index 460c75fc83..289e8ba5a2 100644 --- a/tests/server/test_helpers.py +++ b/tests/server/test_helpers.py @@ -12,16 +12,67 @@ # # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # # See the License for the specific language governing permissions and # # limitations under the License. +from typing import Any + +import numpy as np import yaml +from pydantic import BaseModel import pytest from deepsparse.loggers import AsyncLogger, MultiLogger, PythonLogger from deepsparse.server.config import ServerConfig -from deepsparse.server.helpers import server_logger_from_config +from deepsparse.server.helpers import ( + prep_outputs_for_serialization, + server_logger_from_config, +) from tests.deepsparse.loggers.helpers import fetch_leaf_logger from tests.helpers import find_free_port +class DummyOutputSchema(BaseModel): + field_1: Any + field_2: Any + field_3: Any + + +@pytest.mark.parametrize( + "unserialized_output, target_serialized_output", + [ + ( + DummyOutputSchema( + field_1=[np.array([[1, 2, 3]])], + field_2={"key_1": np.array([[[1, 2, 3]]])}, + field_3=DummyOutputSchema(field_1=np.array([0])), + ), + DummyOutputSchema( + field_1=[[[1, 2, 3]]], + field_2={"key_1": [[[1, 2, 3]]]}, + field_3=DummyOutputSchema(field_1=[0]), + ), + ) + ], +) +def test_prep_outputs_for_serialization(unserialized_output, target_serialized_output): + def check_dict_equality(dict_1, dict_2): + for key, value in dict_1.items(): + if isinstance(value, BaseModel): + value = value.dict() + check_dict_equality(value, dict_2[key].dict()) + elif isinstance(value, dict): + check_dict_equality(value, dict_2[key]) + elif isinstance(value, list): + equal = value == dict_2[key] + equal = equal if isinstance(equal, bool) else equal.all() + assert equal + else: + assert value == dict_2[key] + + serialized_output = prep_outputs_for_serialization(unserialized_output) + serialized_output = serialized_output.dict() + target_serialized_output = target_serialized_output.dict() + check_dict_equality(target_serialized_output, serialized_output) + + yaml_config_1 = """ loggers: python: