Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Oct 31, 2023
1 parent 8a0c099 commit 78dc97f
Showing 1 changed file with 52 additions and 1 deletion.
53 changes: 52 additions & 1 deletion tests/server/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,67 @@
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# # See the License for the specific language governing permissions and
# # limitations under the License.
from typing import Any

import numpy as np
import yaml
from pydantic import BaseModel

import pytest
from deepsparse.loggers import AsyncLogger, MultiLogger, PythonLogger
from deepsparse.server.config import ServerConfig
from deepsparse.server.helpers import server_logger_from_config
from deepsparse.server.helpers import (
prep_outputs_for_serialization,
server_logger_from_config,
)
from tests.deepsparse.loggers.helpers import fetch_leaf_logger
from tests.helpers import find_free_port


class DummyOutputSchema(BaseModel):
field_1: Any
field_2: Any
field_3: Any


@pytest.mark.parametrize(
"unserialized_output, target_serialized_output",
[
(
DummyOutputSchema(
field_1=[np.array([[1, 2, 3]])],
field_2={"key_1": np.array([[[1, 2, 3]]])},
field_3=DummyOutputSchema(field_1=np.array([0])),
),
DummyOutputSchema(
field_1=[[[1, 2, 3]]],
field_2={"key_1": [[[1, 2, 3]]]},
field_3=DummyOutputSchema(field_1=[0]),
),
)
],
)
def test_prep_outputs_for_serialization(unserialized_output, target_serialized_output):
def check_dict_equality(dict_1, dict_2):
for key, value in dict_1.items():
if isinstance(value, BaseModel):
value = value.dict()
check_dict_equality(value, dict_2[key].dict())
elif isinstance(value, dict):
check_dict_equality(value, dict_2[key])
elif isinstance(value, list):
equal = value == dict_2[key]
equal = equal if isinstance(equal, bool) else equal.all()
assert equal
else:
assert value == dict_2[key]

serialized_output = prep_outputs_for_serialization(unserialized_output)
serialized_output = serialized_output.dict()
target_serialized_output = target_serialized_output.dict()
check_dict_equality(target_serialized_output, serialized_output)


yaml_config_1 = """
loggers:
python:
Expand Down

0 comments on commit 78dc97f

Please sign in to comment.