diff --git a/python/kserve/kserve/model_server.py b/python/kserve/kserve/model_server.py index fbbfd5c9a2..4670c1e6a2 100644 --- a/python/kserve/kserve/model_server.py +++ b/python/kserve/kserve/model_server.py @@ -163,7 +163,7 @@ def __init__( workers: int = args.workers, max_threads: int = args.max_threads, max_asyncio_workers: int = args.max_asyncio_workers, - registered_models: ModelRepository = ModelRepository(), + registered_models: ModelRepository = None, enable_grpc: bool = args.enable_grpc, enable_docs_url: bool = args.enable_docs_url, enable_latency_logging: bool = args.enable_latency_logging, @@ -188,7 +188,9 @@ def __init__( (please refer to this Uvicorn [github issue](https://github.com/encode/uvicorn/issues/527) for more info). """ - self.registered_models = registered_models + self.registered_models = ( + ModelRepository() if registered_models is None else registered_models + ) self.http_port = http_port self.grpc_port = grpc_port self.workers = workers @@ -197,7 +199,7 @@ def __init__( self.enable_grpc = enable_grpc self.enable_docs_url = enable_docs_url self.enable_latency_logging = enable_latency_logging - self.dataplane = DataPlane(model_registry=registered_models) + self.dataplane = DataPlane(model_registry=self.registered_models) self.model_repository_extension = ModelRepositoryExtension( model_registry=self.registered_models ) diff --git a/python/kserve/kserve/protocol/grpc/servicer.py b/python/kserve/kserve/protocol/grpc/servicer.py index 25cc64fdb9..c2455803f2 100644 --- a/python/kserve/kserve/protocol/grpc/servicer.py +++ b/python/kserve/kserve/protocol/grpc/servicer.py @@ -22,6 +22,8 @@ from grpc import ServicerContext +from ...errors import InvalidInput + class InferenceServicer(grpc_predict_v2_pb2_grpc.GRPCInferenceServiceServicer): @@ -34,6 +36,21 @@ def __init__( self._data_plane = data_plane self._mode_repository_extension = model_repository_extension + @classmethod + def validate_grpc_request(cls, request: pb.ModelInferRequest): + raw_inputs_length = len(request.raw_input_contents) + if raw_inputs_length != 0 and len(request.inputs) != raw_inputs_length: + raise InvalidInput( + f"the number of inputs ({len(request.inputs)}) does not match the expected number of " + f"raw input contents ({raw_inputs_length}) for model '{request.model_name}'." + ) + if raw_inputs_length != 0: + for input_ in request.inputs: + if input_.HasField("contents"): + raise InvalidInput( + f"contents field must not be specified when using raw_input_contents for input '{input_.name}' for model '{request.model_name}'" + ) + async def ServerMetadata(self, request: pb.ServerMetadataRequest, context): metadata = self._data_plane.metadata() return pb.ServerMetadataResponse( @@ -96,6 +113,7 @@ async def ModelInfer( self, request: pb.ModelInferRequest, context: ServicerContext ) -> pb.ModelInferResponse: headers = to_headers(context) + self.validate_grpc_request(request) infer_request = InferRequest.from_grpc(request) response_body, _ = await self._data_plane.infer( request=infer_request, headers=headers, model_name=request.model_name diff --git a/python/kserve/kserve/protocol/infer_type.py b/python/kserve/kserve/protocol/infer_type.py index 7d716f3839..4fa79216a9 100644 --- a/python/kserve/kserve/protocol/infer_type.py +++ b/python/kserve/kserve/protocol/infer_type.py @@ -15,7 +15,6 @@ import struct from typing import Optional, List, Dict, Union -import numpy import numpy as np import pandas as pd import uuid @@ -33,42 +32,81 @@ from ..utils.numpy_codec import to_np_dtype, from_np_dtype -def serialize_byte_tensor(input_tensor: numpy.ndarray): +def serialize_byte_tensor(input_tensor: np.ndarray) -> np.ndarray: """ Serializes a bytes tensor into a flat numpy array of length prepended - bytes. The numpy array should use dtype of np.object_. For np.bytes_, + bytes. The numpy array should use dtype of np.object. For np.bytes, numpy will remove trailing zeros at the end of byte sequence and because of this it should be avoided. + Args: - input_tensor : np.array of the bytes tensor to serialize. + input_tensor : np.array + The bytes tensor to serialize. Returns: - serialized_bytes_tensor : The 1-D numpy array of type uint8 containing the serialized bytes in 'C' order. + serialized_bytes_tensor : np.array + The 1-D numpy array of type uint8 containing the serialized bytes in row-major form. + Raises: + InferenceError If unable to serialize the given tensor. """ if input_tensor.size == 0: - return () - - # If the input is a tensor of string/bytes objects, then must flatten those - # into a 1-dimensional array containing the 4-byte byte size followed by the - # actual element bytes. All elements are concatenated together in "C" order. - if (input_tensor.dtype == np.object_) or (input_tensor.dtype.type == np.bytes_): - flattened_ls = [] - for obj in np.nditer(input_tensor, flags=["refs_ok"], order="C"): - # If directly passing bytes to BYTES type, - # don't convert it to str as Python will encode the - # bytes which may distort the meaning - if input_tensor.dtype == np.object_: - if type(obj.item()) == bytes: - s = obj.item() - else: - s = str(obj.item()).encode("utf-8") - else: + return np.empty([0], dtype=np.object_) + + # If the input is a tensor of string/bytes objects, then must flatten those into + # a 1-dimensional array containing the 4-byte byte size followed by the + # actual element bytes. All elements are concatenated together in row-major + # order. + + if (input_tensor.dtype != np.object_) and (input_tensor.dtype.type != np.bytes_): + raise InferenceError("cannot serialize bytes tensor: invalid datatype") + + flattened_ls = [] + # 'C' order is row-major. + for obj in np.nditer(input_tensor, flags=["refs_ok"], order="C"): + # If directly passing bytes to BYTES type, + # don't convert it to str as Python will encode the + # bytes which may distort the meaning + if input_tensor.dtype == np.object_: + if type(obj.item()) == bytes: s = obj.item() - flattened_ls.append(struct.pack(" np.ndarray: + """ + Deserializes an encoded bytes tensor into a + numpy array of dtype of python objects + + Args: + encoded_tensor : bytes + The encoded bytes tensor where each element + has its length in first 4 bytes followed by + the content + Returns: + string_tensor : np.array + The 1-D numpy array of type object containing the + deserialized bytes in row-major form. + """ + strs = list() + offset = 0 + val_buf = encoded_tensor + while offset < len(val_buf): + length = struct.unpack_from(" np.ndarray: if dtype is None: raise InvalidInput(f"invalid datatype {dtype} in the input") if self._raw_data is not None: - np_array = np.frombuffer(self._raw_data, dtype=dtype) + if self.datatype == "BYTES": + # String results contain a 4-byte string length + # followed by the actual string characters. Hence, + # need to decode the raw bytes to convert into + # array elements. + np_array = deserialize_bytes_tensor(self._raw_data) + else: + np_array = np.frombuffer(self._raw_data, dtype=dtype) return np_array.reshape(self._shape) else: np_array = np.array(self._data, dtype=dtype) @@ -286,6 +331,9 @@ def get_content(datatype: str, data: InferTensorContents): return list(data.int_contents) elif datatype == "INT64": return list(data.int64_contents) + elif datatype == "FP16": + # FP16 data should be present in raw_input_content, so return an empty list. + return list() elif datatype == "FP32": return list(data.fp32_contents) elif datatype == "FP64": @@ -328,10 +376,20 @@ def __init__( self.inputs = infer_inputs self.parameters = parameters self.from_grpc = from_grpc + self._use_raw_outputs = False if raw_inputs: + self._use_raw_outputs = True for i, raw_input in enumerate(raw_inputs): self.inputs[i]._raw_data = raw_input + @property + def use_binary_outputs(self) -> bool: + """Whether to use binary raw outputs + Returns: + a boolean indicating whether to use binary raw outputs + """ + return self._use_raw_outputs + @classmethod def from_grpc(cls, request: ModelInferRequest): """The class method to construct the InferRequest from a ModelInferRequest""" @@ -363,7 +421,7 @@ def to_rest(self) -> Dict: infer_inputs = [] for infer_input in self.inputs: datatype = infer_input.datatype - if isinstance(infer_input.datatype, numpy.dtype): + if isinstance(infer_input.datatype, np.dtype): datatype = from_np_dtype(infer_input.datatype) infer_input_dict = { "name": infer_input.name, @@ -374,7 +432,7 @@ def to_rest(self) -> Dict: infer_input_dict["parameters"] = to_http_parameters( infer_input.parameters ) - if isinstance(infer_input.data, numpy.ndarray): + if isinstance(infer_input.data, np.ndarray): infer_input.set_data_from_numpy(infer_input.data, binary_data=False) infer_input_dict["data"] = infer_input.data else: @@ -397,7 +455,7 @@ def to_grpc(self) -> ModelInferRequest: infer_inputs = [] raw_input_contents = [] for infer_input in self.inputs: - if isinstance(infer_input.data, numpy.ndarray): + if isinstance(infer_input.data, np.ndarray): infer_input.set_data_from_numpy(infer_input.data, binary_data=True) infer_input_dict = { "name": infer_input.name, @@ -452,6 +510,21 @@ def as_dataframe(self) -> pd.DataFrame: dfs.append(pd.DataFrame(input_data, columns=[input.name])) return pd.concat(dfs, axis=1) + def get_input_by_name(self, name: str) -> Optional[InferInput]: + """Find an input Tensor in the InferenceRequest that has the given name + Args: + name : str + name of the input Tensor object + Returns: + InferInput + The InferInput with the specified name, or None if no + input with this name exists + """ + for infer_input in self.inputs: + if name == infer_input.name: + return infer_input + return None + def __eq__(self, other): if not isinstance(other, InferRequest): return False @@ -522,6 +595,15 @@ def data(self) -> Union[List, np.ndarray, InferTensorContents]: """ return self._data + @data.setter + def data(self, data: Union[List, np.ndarray, InferTensorContents]): + """Set the data of inference output associated with this object. + + Args: + data: inference output data. + """ + self._data = data + @property def shape(self) -> List[int]: """Get the shape of inference output associated with this object. @@ -552,7 +634,7 @@ def set_shape(self, shape: List[int]): """ self._shape = shape - def as_numpy(self) -> numpy.ndarray: + def as_numpy(self) -> np.ndarray: """Decode the tensor output data as numpy array. Returns: @@ -562,7 +644,14 @@ def as_numpy(self) -> numpy.ndarray: if dtype is None: raise InvalidInput("invalid datatype in the input") if self._raw_data is not None: - np_array = np.frombuffer(self._raw_data, dtype=dtype) + if self.datatype == "BYTES": + # String results contain a 4-byte string length + # followed by the actual string characters. Hence, + # need to decode the raw bytes to convert into + # array elements. + np_array = deserialize_bytes_tensor(self._raw_data) + else: + np_array = np.frombuffer(self._raw_data, dtype=dtype) return np_array.reshape(self._shape) else: np_array = np.array(self._data, dtype=dtype) @@ -765,7 +854,7 @@ def to_rest(self) -> Dict: infer_output_dict["parameters"] = to_http_parameters( infer_output.parameters ) - if isinstance(infer_output.data, numpy.ndarray): + if isinstance(infer_output.data, np.ndarray): infer_output.set_data_from_numpy(infer_output.data, binary_data=False) infer_output_dict["data"] = infer_output.data elif isinstance(infer_output._raw_data, bytes): @@ -791,8 +880,18 @@ def to_grpc(self) -> ModelInferResponse: """ infer_outputs = [] raw_output_contents = [] + use_raw_outputs = False + # If FP16 datatype is present in the outputs use raw outputs. + if _contains_fp16_datatype(self): + use_raw_outputs = True for infer_output in self.outputs: - if isinstance(infer_output.data, numpy.ndarray): + if ( + use_raw_outputs + and infer_output.data + and isinstance(infer_output.data, list) + ): + infer_output.data = infer_output.as_numpy() + if isinstance(infer_output.data, np.ndarray): infer_output.set_data_from_numpy(infer_output.data, binary_data=True) infer_output_dict = { "name": infer_output.name, @@ -831,6 +930,22 @@ def to_grpc(self) -> ModelInferResponse: parameters=to_grpc_parameters(self.parameters) if self.parameters else None, ) + def get_output_by_name(self, name: str) -> Optional[InferOutput]: + """Find an output Tensor in the InferResponse that has the given name + + Args: + name : str + name of the output Tensor object + Returns: + InferOutput + The InferOutput with the specified name, or None if no + output with this name exists + """ + for infer_output in self.outputs: + if name == infer_output.name: + return infer_output + return None + def __eq__(self, other): if not isinstance(other, InferResponse): return False @@ -895,3 +1010,16 @@ def to_http_parameters( else: http_params[key] = val return http_params + + +def _contains_fp16_datatype(infer_response: InferResponse) -> bool: + """ + Checks whether the InferResponse outputs contains FP16 datatype. + + :param infer_response: An InferResponse object containing model inference results. + :return: A boolean indicating whether any output in the InferResponse uses the FP16 datatype. + """ + for infer_output in infer_response.outputs: + if infer_output.datatype == "FP16": + return True + return False diff --git a/python/kserve/kserve/utils/utils.py b/python/kserve/kserve/utils/utils.py index e22ae2cb17..f8090b1336 100644 --- a/python/kserve/kserve/utils/utils.py +++ b/python/kserve/kserve/utils/utils.py @@ -212,7 +212,9 @@ def get_predict_response( name=col, shape=list(result[col].shape), datatype=from_np_dtype(result[col].dtype), - data=result[col].tolist(), + ) + infer_output.set_data_from_numpy( + result[col].to_numpy(), binary_data=payload.use_binary_outputs ) infer_outputs.append(infer_output) elif ( @@ -222,7 +224,10 @@ def get_predict_response( name="output-0", shape=[len(result)], datatype="BYTES", - data=result, + ) + infer_output.set_data_from_numpy( + np.array(result, dtype=np.object_), + binary_data=payload.use_binary_outputs, ) infer_outputs.append(infer_output) else: @@ -232,7 +237,9 @@ def get_predict_response( name="output-0", shape=list(result.shape), datatype=from_np_dtype(result.dtype), - data=result.flatten().tolist(), + ) + infer_output.set_data_from_numpy( + result, binary_data=payload.use_binary_outputs ) infer_outputs.append(infer_output) return InferResponse( diff --git a/python/kserve/poetry.lock b/python/kserve/poetry.lock index 31b937cd55..ceaeb9fd96 100644 --- a/python/kserve/poetry.lock +++ b/python/kserve/poetry.lock @@ -1202,6 +1202,21 @@ files = [ [package.extras] protobuf = ["grpcio-tools (>=1.62.2)"] +[[package]] +name = "grpcio-testing" +version = "1.48.2" +description = "Testing utilities for gRPC Python" +optional = false +python-versions = "*" +files = [ + {file = "grpcio-testing-1.48.2.tar.gz", hash = "sha256:ae60f49f1a92a149edcde6cc73ce30ed4c2ed972e9c7afb9780348b19cd767b4"}, + {file = "grpcio_testing-1.48.2-py3-none-any.whl", hash = "sha256:5b41abcb48937ba400fdf63a2a0ec712a3c284aea264effb1c4d21570dea2fb2"}, +] + +[package.dependencies] +grpcio = ">=1.48.2" +protobuf = ">=3.12.0" + [[package]] name = "h11" version = "0.14.0" @@ -3471,4 +3486,4 @@ storage = ["azure-identity", "azure-storage-blob", "azure-storage-file-share", " [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.12" -content-hash = "e3551b23a1833dc3b733f3882ce2ec0e4589d825f9ac74e19abc1e19b54a3395" +content-hash = "dd3b41f53321123e4769eec48fc1223c9b23b92125c98a50537697f71775b50a" diff --git a/python/kserve/pyproject.toml b/python/kserve/pyproject.toml index 5b92e5439a..9cb606b2ff 100644 --- a/python/kserve/pyproject.toml +++ b/python/kserve/pyproject.toml @@ -91,7 +91,7 @@ portforward = "^0.4.3" avro = "^1.11.0" tomlkit = "^0.11.6" jinja2 = "^3.1.2" - +grpcio-testing = "^1.45.0" [tool.poetry.group.dev] optional = true diff --git a/python/kserve/test/test_grpc_server.py b/python/kserve/test/test_grpc_server.py new file mode 100644 index 0000000000..989f996e69 --- /dev/null +++ b/python/kserve/test/test_grpc_server.py @@ -0,0 +1,680 @@ +# Copyright 2024 The KServe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import grpc +import grpc_testing +import numpy as np +import pandas as pd +import pytest +from google.protobuf.json_format import MessageToDict +from unittest.mock import patch + +from kserve import Model, ModelServer +from kserve.errors import InvalidInput +from kserve.protocol.grpc import grpc_predict_v2_pb2, servicer +from kserve.protocol.infer_type import serialize_byte_tensor, InferResponse +from kserve.utils.utils import get_predict_response + + +class DummyModel(Model): + def __init__(self, name): + super().__init__(name) + self.name = name + self.ready = False + + def load(self): + self.ready = True + + async def predict(self, request, headers=None): + outputs = pd.DataFrame( + { + "fp32_output": request.get_input_by_name("fp32_input") + .as_numpy() + .flatten(), + "int32_output": request.get_input_by_name("int32_input") + .as_numpy() + .flatten(), + "string_output": request.get_input_by_name("string_input") + .as_numpy() + .flatten(), + "uint8_output": request.get_input_by_name("uint8_input") + .as_numpy() + .flatten(), + "bool_input": request.get_input_by_name("bool_input") + .as_numpy() + .flatten(), + } + ) + # Fixme: Gets only the 1st element of the input + # inputs = get_predict_input(request) + infer_response = get_predict_response(request, outputs, self.name) + if request.parameters: + infer_response.parameters = request.parameters + if request.inputs[0].parameters: + infer_response.outputs[0].parameters = request.inputs[0].parameters + return infer_response + + +class DummyFP16OutputModel(Model): + def __init__(self, name): + super().__init__(name) + self.name = name + self.ready = False + + def load(self): + self.ready = True + + async def predict(self, request, headers=None): + outputs = pd.DataFrame( + { + "fp16_output": request.get_input_by_name("fp32_input") + .as_numpy() + .astype(np.float16) + .flatten(), + "fp32_output": request.get_input_by_name("fp32_input") + .as_numpy() + .flatten(), + } + ) + # Fixme: Gets only the 1st element of the input + # inputs = get_predict_input(request) + infer_response = get_predict_response(request, outputs, self.name) + if request.parameters: + infer_response.parameters = request.parameters + if request.inputs[0].parameters: + infer_response.outputs[0].parameters = request.inputs[0].parameters + return infer_response + + +class DummyFP16InputModel(Model): + def __init__(self, name): + super().__init__(name) + self.name = name + self.ready = False + + def load(self): + self.ready = True + + async def predict(self, request, headers=None): + outputs = pd.DataFrame( + { + "int32_output": np.array([1, 2, 3, 4, 5, 6, 7, 8]), + "fp16_output": request.get_input_by_name("fp16_input") + .as_numpy() + .flatten(), + } + ) + # Fixme: Gets only the 1st element of the input + # inputs = get_predict_input(request) + infer_response = get_predict_response(request, outputs, self.name) + if request.parameters: + infer_response.parameters = request.parameters + if request.inputs[0].parameters: + infer_response.outputs[0].parameters = request.inputs[0].parameters + return infer_response + + +@pytest.fixture(scope="class") +def server(): + server = ModelServer() + model = DummyModel("TestModel") + model.load() + server.register_model(model) + fp16_output_model = DummyFP16OutputModel("FP16OutputModel") + fp16_output_model.load() + server.register_model(fp16_output_model) + fp16_input_model = DummyFP16InputModel("FP16InputModel") + fp16_input_model.load() + server.register_model(fp16_input_model) + servicers = { + grpc_predict_v2_pb2.DESCRIPTOR.services_by_name[ + "GRPCInferenceService" + ]: servicer.InferenceServicer( + server.dataplane, server.model_repository_extension + ) + } + test_server = grpc_testing.server_from_dictionary( + servicers, + grpc_testing.strict_real_time(), + ) + return test_server + + +@pytest.mark.asyncio +@patch( + "kserve.protocol.grpc.servicer.to_headers", return_value=[] +) # To avoid NotImplementedError from trailing_metadata function +async def test_grpc_inputs(mock_to_headers, server): + request = grpc_predict_v2_pb2.ModelInferRequest( + model_name="TestModel", + id="123", + inputs=[ + { + "name": "fp32_input", + "shape": [2, 4], + "datatype": "FP32", + "contents": {"fp32_contents": [6.8, 2.8, 4.8, 1.4, 6.0, 3.4, 4.5, 1.6]}, + }, + { + "name": "int32_input", + "shape": [2, 4], + "datatype": "INT32", + "contents": {"int_contents": [6, 2, 4, 1, 6, 3, 4, 1]}, + }, + { + "name": "string_input", + "shape": [8], + "datatype": "BYTES", + "contents": { + "bytes_contents": [ + b"Cat", + b"Dog", + b"Wolf", + b"Cat", + b"Dog", + b"Wolf", + b"Dog", + b"Wolf", + ] + }, + }, + { + "name": "uint8_input", + "shape": [2, 4], + "datatype": "UINT8", + "contents": {"uint_contents": [6, 2, 4, 1, 6, 3, 4, 1]}, + }, + { + "name": "bool_input", + "shape": [8], + "datatype": "BOOL", + "contents": { + "bool_contents": [ + True, + False, + True, + False, + True, + False, + True, + False, + ] + }, + }, + ], + ) + + model_infer_method = server.invoke_unary_unary( + method_descriptor=( + grpc_predict_v2_pb2.DESCRIPTOR.services_by_name[ + "GRPCInferenceService" + ].methods_by_name["ModelInfer"] + ), + invocation_metadata={}, + request=request, + timeout=20, + ) + + response, _, code, _ = model_infer_method.termination() + response = await response + response_dict = MessageToDict( + response, + preserving_proto_field_name=True, + ) + assert code == grpc.StatusCode.OK + assert response_dict == { + "model_name": "TestModel", + "id": "123", + "outputs": [ + { + "name": "fp32_output", + "datatype": "FP32", + "shape": ["8"], + "contents": {"fp32_contents": [6.8, 2.8, 4.8, 1.4, 6.0, 3.4, 4.5, 1.6]}, + }, + { + "name": "int32_output", + "datatype": "INT32", + "shape": ["8"], + "contents": {"int_contents": [6, 2, 4, 1, 6, 3, 4, 1]}, + }, + { + "name": "string_output", + "datatype": "BYTES", + "shape": ["8"], + "contents": { + "bytes_contents": [ + "Q2F0", + "RG9n", + "V29sZg==", + "Q2F0", + "RG9n", + "V29sZg==", + "RG9n", + "V29sZg==", + ] + }, + }, + { + "name": "uint8_output", + "datatype": "UINT8", + "shape": ["8"], + "contents": {"uint_contents": [6, 2, 4, 1, 6, 3, 4, 1]}, + }, + { + "name": "bool_input", + "datatype": "BOOL", + "shape": ["8"], + "contents": { + "bool_contents": [ + True, + False, + True, + False, + True, + False, + True, + False, + ] + }, + }, + ], + } + + +@pytest.mark.asyncio +@patch( + "kserve.protocol.grpc.servicer.to_headers", return_value=[] +) # To avoid NotImplementedError from trailing_metadata function +async def test_grpc_raw_inputs(mock_to_headers, server): + """ + If we receive raw inputs then, the response also should be in raw output format. + """ + fp32_data = np.array([6.8, 2.8, 4.8, 1.4, 6.0, 3.4, 4.5, 1.6], dtype=np.float32) + int32_data = np.array([6, 2, 4, 1, 6, 3, 4, 1], dtype=np.int32) + str_data = np.array( + [b"Cat", b"Dog", b"Wolf", b"Cat", b"Dog", b"Wolf", b"Dog", b"Wolf"], + dtype=np.object_, + ) + uint8_data = np.array([6, 2, 4, 1, 6, 3, 4, 1], dtype=np.uint8) + bool_data = np.array( + [True, False, True, False, True, False, True, False], dtype=np.bool_ + ) + raw_input_contents = [ + fp32_data.tobytes(), + int32_data.tobytes(), + serialize_byte_tensor(str_data).item(), + uint8_data.tobytes(), + bool_data.tobytes(), + ] + request = grpc_predict_v2_pb2.ModelInferRequest( + model_name="TestModel", + id="123", + inputs=[ + { + "name": "fp32_input", + "shape": [2, 4], + "datatype": "FP32", + }, + { + "name": "int32_input", + "shape": [2, 4], + "datatype": "INT32", + }, + { + "name": "string_input", + "shape": [8], + "datatype": "BYTES", + }, + { + "name": "uint8_input", + "shape": [2, 4], + "datatype": "UINT8", + }, + { + "name": "bool_input", + "shape": [8], + "datatype": "BOOL", + }, + ], + raw_input_contents=raw_input_contents, + ) + + model_infer_method = server.invoke_unary_unary( + method_descriptor=( + grpc_predict_v2_pb2.DESCRIPTOR.services_by_name[ + "GRPCInferenceService" + ].methods_by_name["ModelInfer"] + ), + invocation_metadata={}, + request=request, + timeout=20, + ) + + response, _, code, _ = model_infer_method.termination() + response = await response + response_dict = MessageToDict( + response, + preserving_proto_field_name=True, + ) + assert code == grpc.StatusCode.OK + assert response_dict == { + "model_name": "TestModel", + "id": "123", + "outputs": [ + { + "name": "fp32_output", + "datatype": "FP32", + "shape": ["8"], + "parameters": {"binary_data_size": {"int64_param": "32"}}, + }, + { + "name": "int32_output", + "datatype": "INT32", + "shape": ["8"], + "parameters": {"binary_data_size": {"int64_param": "32"}}, + }, + { + "name": "string_output", + "datatype": "BYTES", + "shape": ["8"], + "parameters": {"binary_data_size": {"int64_param": "59"}}, + }, + { + "name": "uint8_output", + "datatype": "UINT8", + "shape": ["8"], + "parameters": {"binary_data_size": {"int64_param": "8"}}, + }, + { + "name": "bool_input", + "datatype": "BOOL", + "shape": ["8"], + "parameters": {"binary_data_size": {"int64_param": "8"}}, + }, + ], + "raw_output_contents": [ + "mpnZQDMzM0CamZlAMzOzPwAAwECamVlAAACQQM3MzD8=", + "BgAAAAIAAAAEAAAAAQAAAAYAAAADAAAABAAAAAEAAAA=", + "AwAAAENhdAMAAABEb2cEAAAAV29sZgMAAABDYXQDAAAARG9nBAAAAFdvbGYDAAAARG9nBAAAAFdvbGY=", + "BgIEAQYDBAE=", + "AQABAAEAAQA=", + ], + } + infer_response = InferResponse.from_grpc(response) + assert np.array_equal(infer_response.outputs[0].as_numpy(), fp32_data) + assert np.array_equal(infer_response.outputs[1].as_numpy(), int32_data) + assert np.array_equal(infer_response.outputs[2].as_numpy(), str_data) + assert np.array_equal(infer_response.outputs[3].as_numpy(), uint8_data) + assert np.array_equal(infer_response.outputs[4].as_numpy(), bool_data) + + +@pytest.mark.asyncio +@patch( + "kserve.protocol.grpc.servicer.to_headers", return_value=[] +) # To avoid NotImplementedError from trailing_metadata function +async def test_grpc_fp16_output(mock_to_headers, server): + """ + If the output contains FP16 datatype, then the outputs should be returned as raw outputs. + """ + fp32_data = [6.8, 2.8, 4.8, 1.4, 6.0, 3.4, 4.5, 1.6] + request = grpc_predict_v2_pb2.ModelInferRequest( + model_name="FP16OutputModel", + id="123", + inputs=[ + { + "name": "fp32_input", + "shape": [2, 4], + "datatype": "FP32", + "contents": {"fp32_contents": fp32_data}, + }, + ], + ) + + model_infer_method = server.invoke_unary_unary( + method_descriptor=( + grpc_predict_v2_pb2.DESCRIPTOR.services_by_name[ + "GRPCInferenceService" + ].methods_by_name["ModelInfer"] + ), + invocation_metadata={}, + request=request, + timeout=20, + ) + + response, _, code, _ = model_infer_method.termination() + response = await response + response_dict = MessageToDict( + response, + preserving_proto_field_name=True, + ) + assert code == grpc.StatusCode.OK + assert response_dict == { + "model_name": "FP16OutputModel", + "id": "123", + "outputs": [ + { + "name": "fp16_output", + "datatype": "FP16", + "shape": ["8"], + "parameters": {"binary_data_size": {"int64_param": "16"}}, + }, + { + "name": "fp32_output", + "datatype": "FP32", + "shape": ["8"], + "parameters": {"binary_data_size": {"int64_param": "32"}}, + }, + ], + "raw_output_contents": [ + "zUaaQc1Emj0ARs1CgERmPg==", + "mpnZQDMzM0CamZlAMzOzPwAAwECamVlAAACQQM3MzD8=", + ], + } + infer_response = InferResponse.from_grpc(response) + assert np.array_equal( + infer_response.outputs[0].as_numpy(), np.array(fp32_data, dtype=np.float16) + ) + assert np.array_equal( + infer_response.outputs[1].as_numpy(), np.array(fp32_data, dtype=np.float32) + ) + + +@pytest.mark.asyncio +@patch( + "kserve.protocol.grpc.servicer.to_headers", return_value=[] +) # To avoid NotImplementedError from trailing_metadata function +async def test_grpc_fp16_input(mock_to_headers, server): + fp16_data = np.array([6.8, 2.8, 4.8, 1.4, 6.0, 3.4, 4.5, 1.6], dtype=np.float16) + raw_input_contents = [fp16_data.tobytes()] + request = grpc_predict_v2_pb2.ModelInferRequest( + model_name="FP16InputModel", + id="123", + inputs=[ + { + "name": "fp16_input", + "shape": [2, 4], + "datatype": "FP16", + }, + ], + raw_input_contents=raw_input_contents, + ) + + model_infer_method = server.invoke_unary_unary( + method_descriptor=( + grpc_predict_v2_pb2.DESCRIPTOR.services_by_name[ + "GRPCInferenceService" + ].methods_by_name["ModelInfer"] + ), + invocation_metadata={}, + request=request, + timeout=20, + ) + + response, _, code, _ = model_infer_method.termination() + response = await response + response_dict = MessageToDict( + response, + preserving_proto_field_name=True, + ) + assert code == grpc.StatusCode.OK + assert response_dict == { + "model_name": "FP16InputModel", + "id": "123", + "outputs": [ + { + "name": "int32_output", + "datatype": "INT64", + "shape": ["8"], + "parameters": {"binary_data_size": {"int64_param": "64"}}, + }, + { + "name": "fp16_output", + "datatype": "FP16", + "shape": ["8"], + "parameters": {"binary_data_size": {"int64_param": "16"}}, + }, + ], + "raw_output_contents": [ + "AQAAAAAAAAACAAAAAAAAAAMAAAAAAAAABAAAAAAAAAAFAAAAAAAAAAYAAAAAAAAABwAAAAAAAAAIAAAAAAAAAA==", + "zUaaQc1Emj0ARs1CgERmPg==", + ], + } + infer_response = InferResponse.from_grpc(response) + assert np.array_equal(infer_response.outputs[1].as_numpy(), fp16_data) + + +@pytest.mark.asyncio +@patch( + "kserve.protocol.grpc.servicer.to_headers", return_value=[] +) # To avoid NotImplementedError from trailing_metadata function +async def test_grpc_raw_inputs_with_missing_input_data(mock_to_headers, server): + """ + Server should raise InvalidInput if raw_input_contents missing some input data. + """ + raw_input_contents = [ + np.array([6.8, 2.8, 4.8, 1.4, 6.0, 3.4, 4.5, 1.6], dtype=np.float32).tobytes(), + np.array([6, 2, 4, 1, 6, 3, 4, 1], dtype=np.int32).tobytes(), + ] + request = grpc_predict_v2_pb2.ModelInferRequest( + model_name="TestModel", + id="123", + inputs=[ + { + "name": "fp32_input", + "shape": [2, 4], + "datatype": "FP32", + }, + { + "name": "int32_input", + "shape": [2, 4], + "datatype": "INT32", + }, + { + "name": "string_input", + "shape": [8], + "datatype": "BYTES", + }, + ], + raw_input_contents=raw_input_contents, + ) + + model_infer_method = server.invoke_unary_unary( + method_descriptor=( + grpc_predict_v2_pb2.DESCRIPTOR.services_by_name[ + "GRPCInferenceService" + ].methods_by_name["ModelInfer"] + ), + invocation_metadata={}, + request=request, + timeout=20, + ) + + with pytest.raises(InvalidInput): + response, _, _, _ = model_infer_method.termination() + _ = await response + + +@pytest.mark.asyncio +@patch( + "kserve.protocol.grpc.servicer.to_headers", return_value=[] +) # To avoid NotImplementedError from trailing_metadata function +async def test_grpc_raw_inputs_with_contents_specified(mock_to_headers, server): + """ + Server should raise InvalidInput if both contents and raw_input_contents specified. + """ + raw_input_contents = [ + np.array([6.8, 2.8, 4.8, 1.4, 6.0, 3.4, 4.5, 1.6], dtype=np.float32).tobytes(), + np.array([6, 2, 4, 1, 6, 3, 4, 1], dtype=np.int32).tobytes(), + serialize_byte_tensor( + np.array( + [b"Cat", b"Dog", b"Wolf", b"Cat", b"Dog", b"Wolf", b"Dog", b"Wolf"], + dtype=np.object_, + ) + ).item(), + np.array([6, 2, 4, 1, 6, 3, 4, 1], dtype=np.uint8).tobytes(), + np.array( + [True, False, True, False, True, False, True, False], dtype=np.bool_ + ).tobytes(), + ] + request = grpc_predict_v2_pb2.ModelInferRequest( + model_name="TestModel", + id="123", + inputs=[ + { + "name": "fp32_input", + "shape": [2, 4], + "datatype": "FP32", + }, + { + "name": "int32_input", + "shape": [2, 4], + "datatype": "INT32", + }, + { + "name": "string_input", + "shape": [8], + "datatype": "BYTES", + }, + { + "name": "uint8_input", + "shape": [2, 4], + "datatype": "UINT8", + "contents": { + "uint_contents": [6, 2, 4, 1, 6, 3, 4, 1], + }, + }, + { + "name": "bool_input", + "shape": [8], + "datatype": "BOOL", + }, + ], + raw_input_contents=raw_input_contents, + ) + + model_infer_method = server.invoke_unary_unary( + method_descriptor=( + grpc_predict_v2_pb2.DESCRIPTOR.services_by_name[ + "GRPCInferenceService" + ].methods_by_name["ModelInfer"] + ), + invocation_metadata={}, + request=request, + timeout=20, + ) + + with pytest.raises(InvalidInput): + response, _, _, _ = model_infer_method.termination() + _ = await response