Skip to content

Commit

Permalink
Add FP16 datatype support for OIP grpc (kserve#3695)
Browse files Browse the repository at this point in the history
* Add FP16 datatype support for OIP grpc
Add grpc server tests

Signed-off-by: Sivanantham Chinnaiyan <[email protected]>

* Add grpcio-testing as test dependency

Signed-off-by: Sivanantham Chinnaiyan <[email protected]>

* Fix model repository initialization default value

Signed-off-by: Sivanantham Chinnaiyan <[email protected]>

* Remove fp16 global map

Signed-off-by: Sivanantham Chinnaiyan <[email protected]>

* Resolve comments

Signed-off-by: Sivanantham Chinnaiyan <[email protected]>

---------

Signed-off-by: Sivanantham Chinnaiyan <[email protected]>
  • Loading branch information
sivanantha321 authored May 28, 2024
1 parent 04c41c2 commit c660972
Show file tree
Hide file tree
Showing 7 changed files with 893 additions and 43 deletions.
8 changes: 5 additions & 3 deletions python/kserve/kserve/model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def __init__(
workers: int = args.workers,
max_threads: int = args.max_threads,
max_asyncio_workers: int = args.max_asyncio_workers,
registered_models: ModelRepository = ModelRepository(),
registered_models: ModelRepository = None,
enable_grpc: bool = args.enable_grpc,
enable_docs_url: bool = args.enable_docs_url,
enable_latency_logging: bool = args.enable_latency_logging,
Expand All @@ -188,7 +188,9 @@ def __init__(
(please refer to this Uvicorn
[github issue](https://github.com/encode/uvicorn/issues/527) for more info).
"""
self.registered_models = registered_models
self.registered_models = (
ModelRepository() if registered_models is None else registered_models
)
self.http_port = http_port
self.grpc_port = grpc_port
self.workers = workers
Expand All @@ -197,7 +199,7 @@ def __init__(
self.enable_grpc = enable_grpc
self.enable_docs_url = enable_docs_url
self.enable_latency_logging = enable_latency_logging
self.dataplane = DataPlane(model_registry=registered_models)
self.dataplane = DataPlane(model_registry=self.registered_models)
self.model_repository_extension = ModelRepositoryExtension(
model_registry=self.registered_models
)
Expand Down
18 changes: 18 additions & 0 deletions python/kserve/kserve/protocol/grpc/servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from grpc import ServicerContext

from ...errors import InvalidInput


class InferenceServicer(grpc_predict_v2_pb2_grpc.GRPCInferenceServiceServicer):

Expand All @@ -34,6 +36,21 @@ def __init__(
self._data_plane = data_plane
self._mode_repository_extension = model_repository_extension

@classmethod
def validate_grpc_request(cls, request: pb.ModelInferRequest):
raw_inputs_length = len(request.raw_input_contents)
if raw_inputs_length != 0 and len(request.inputs) != raw_inputs_length:
raise InvalidInput(
f"the number of inputs ({len(request.inputs)}) does not match the expected number of "
f"raw input contents ({raw_inputs_length}) for model '{request.model_name}'."
)
if raw_inputs_length != 0:
for input_ in request.inputs:
if input_.HasField("contents"):
raise InvalidInput(
f"contents field must not be specified when using raw_input_contents for input '{input_.name}' for model '{request.model_name}'"
)

async def ServerMetadata(self, request: pb.ServerMetadataRequest, context):
metadata = self._data_plane.metadata()
return pb.ServerMetadataResponse(
Expand Down Expand Up @@ -96,6 +113,7 @@ async def ModelInfer(
self, request: pb.ModelInferRequest, context: ServicerContext
) -> pb.ModelInferResponse:
headers = to_headers(context)
self.validate_grpc_request(request)
infer_request = InferRequest.from_grpc(request)
response_body, _ = await self._data_plane.infer(
request=infer_request, headers=headers, model_name=request.model_name
Expand Down
198 changes: 163 additions & 35 deletions python/kserve/kserve/protocol/infer_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import struct
from typing import Optional, List, Dict, Union

import numpy
import numpy as np
import pandas as pd
import uuid
Expand All @@ -33,42 +32,81 @@
from ..utils.numpy_codec import to_np_dtype, from_np_dtype


def serialize_byte_tensor(input_tensor: numpy.ndarray):
def serialize_byte_tensor(input_tensor: np.ndarray) -> np.ndarray:
"""
Serializes a bytes tensor into a flat numpy array of length prepended
bytes. The numpy array should use dtype of np.object_. For np.bytes_,
bytes. The numpy array should use dtype of np.object. For np.bytes,
numpy will remove trailing zeros at the end of byte sequence and because
of this it should be avoided.
Args:
input_tensor : np.array of the bytes tensor to serialize.
input_tensor : np.array
The bytes tensor to serialize.
Returns:
serialized_bytes_tensor : The 1-D numpy array of type uint8 containing the serialized bytes in 'C' order.
serialized_bytes_tensor : np.array
The 1-D numpy array of type uint8 containing the serialized bytes in row-major form.
Raises:
InferenceError If unable to serialize the given tensor.
"""

if input_tensor.size == 0:
return ()

# If the input is a tensor of string/bytes objects, then must flatten those
# into a 1-dimensional array containing the 4-byte byte size followed by the
# actual element bytes. All elements are concatenated together in "C" order.
if (input_tensor.dtype == np.object_) or (input_tensor.dtype.type == np.bytes_):
flattened_ls = []
for obj in np.nditer(input_tensor, flags=["refs_ok"], order="C"):
# If directly passing bytes to BYTES type,
# don't convert it to str as Python will encode the
# bytes which may distort the meaning
if input_tensor.dtype == np.object_:
if type(obj.item()) == bytes:
s = obj.item()
else:
s = str(obj.item()).encode("utf-8")
else:
return np.empty([0], dtype=np.object_)

# If the input is a tensor of string/bytes objects, then must flatten those into
# a 1-dimensional array containing the 4-byte byte size followed by the
# actual element bytes. All elements are concatenated together in row-major
# order.

if (input_tensor.dtype != np.object_) and (input_tensor.dtype.type != np.bytes_):
raise InferenceError("cannot serialize bytes tensor: invalid datatype")

flattened_ls = []
# 'C' order is row-major.
for obj in np.nditer(input_tensor, flags=["refs_ok"], order="C"):
# If directly passing bytes to BYTES type,
# don't convert it to str as Python will encode the
# bytes which may distort the meaning
if input_tensor.dtype == np.object_:
if type(obj.item()) == bytes:
s = obj.item()
flattened_ls.append(struct.pack("<I", len(s)))
flattened_ls.append(s)
flattened = b"".join(flattened_ls)
return flattened
return None
else:
s = str(obj.item()).encode("utf-8")
else:
s = obj.item()
flattened_ls.append(struct.pack("<I", len(s)))
flattened_ls.append(s)
flattened = b"".join(flattened_ls)
flattened_array = np.asarray(flattened, dtype=np.object_)
if not flattened_array.flags["C_CONTIGUOUS"]:
flattened_array = np.ascontiguousarray(flattened_array, dtype=np.object_)
return flattened_array


def deserialize_bytes_tensor(encoded_tensor: bytes) -> np.ndarray:
"""
Deserializes an encoded bytes tensor into a
numpy array of dtype of python objects
Args:
encoded_tensor : bytes
The encoded bytes tensor where each element
has its length in first 4 bytes followed by
the content
Returns:
string_tensor : np.array
The 1-D numpy array of type object containing the
deserialized bytes in row-major form.
"""
strs = list()
offset = 0
val_buf = encoded_tensor
while offset < len(val_buf):
length = struct.unpack_from("<I", val_buf, offset)[0]
offset += 4
sb = struct.unpack_from("<{}s".format(length), val_buf, offset)[0]
offset += length
strs.append(sb)
return np.array(strs, dtype=np.object_)


class InferInput:
Expand Down Expand Up @@ -173,7 +211,14 @@ def as_numpy(self) -> np.ndarray:
if dtype is None:
raise InvalidInput(f"invalid datatype {dtype} in the input")
if self._raw_data is not None:
np_array = np.frombuffer(self._raw_data, dtype=dtype)
if self.datatype == "BYTES":
# String results contain a 4-byte string length
# followed by the actual string characters. Hence,
# need to decode the raw bytes to convert into
# array elements.
np_array = deserialize_bytes_tensor(self._raw_data)
else:
np_array = np.frombuffer(self._raw_data, dtype=dtype)
return np_array.reshape(self._shape)
else:
np_array = np.array(self._data, dtype=dtype)
Expand Down Expand Up @@ -286,6 +331,9 @@ def get_content(datatype: str, data: InferTensorContents):
return list(data.int_contents)
elif datatype == "INT64":
return list(data.int64_contents)
elif datatype == "FP16":
# FP16 data should be present in raw_input_content, so return an empty list.
return list()
elif datatype == "FP32":
return list(data.fp32_contents)
elif datatype == "FP64":
Expand Down Expand Up @@ -328,10 +376,20 @@ def __init__(
self.inputs = infer_inputs
self.parameters = parameters
self.from_grpc = from_grpc
self._use_raw_outputs = False
if raw_inputs:
self._use_raw_outputs = True
for i, raw_input in enumerate(raw_inputs):
self.inputs[i]._raw_data = raw_input

@property
def use_binary_outputs(self) -> bool:
"""Whether to use binary raw outputs
Returns:
a boolean indicating whether to use binary raw outputs
"""
return self._use_raw_outputs

@classmethod
def from_grpc(cls, request: ModelInferRequest):
"""The class method to construct the InferRequest from a ModelInferRequest"""
Expand Down Expand Up @@ -363,7 +421,7 @@ def to_rest(self) -> Dict:
infer_inputs = []
for infer_input in self.inputs:
datatype = infer_input.datatype
if isinstance(infer_input.datatype, numpy.dtype):
if isinstance(infer_input.datatype, np.dtype):
datatype = from_np_dtype(infer_input.datatype)
infer_input_dict = {
"name": infer_input.name,
Expand All @@ -374,7 +432,7 @@ def to_rest(self) -> Dict:
infer_input_dict["parameters"] = to_http_parameters(
infer_input.parameters
)
if isinstance(infer_input.data, numpy.ndarray):
if isinstance(infer_input.data, np.ndarray):
infer_input.set_data_from_numpy(infer_input.data, binary_data=False)
infer_input_dict["data"] = infer_input.data
else:
Expand All @@ -397,7 +455,7 @@ def to_grpc(self) -> ModelInferRequest:
infer_inputs = []
raw_input_contents = []
for infer_input in self.inputs:
if isinstance(infer_input.data, numpy.ndarray):
if isinstance(infer_input.data, np.ndarray):
infer_input.set_data_from_numpy(infer_input.data, binary_data=True)
infer_input_dict = {
"name": infer_input.name,
Expand Down Expand Up @@ -452,6 +510,21 @@ def as_dataframe(self) -> pd.DataFrame:
dfs.append(pd.DataFrame(input_data, columns=[input.name]))
return pd.concat(dfs, axis=1)

def get_input_by_name(self, name: str) -> Optional[InferInput]:
"""Find an input Tensor in the InferenceRequest that has the given name
Args:
name : str
name of the input Tensor object
Returns:
InferInput
The InferInput with the specified name, or None if no
input with this name exists
"""
for infer_input in self.inputs:
if name == infer_input.name:
return infer_input
return None

def __eq__(self, other):
if not isinstance(other, InferRequest):
return False
Expand Down Expand Up @@ -522,6 +595,15 @@ def data(self) -> Union[List, np.ndarray, InferTensorContents]:
"""
return self._data

@data.setter
def data(self, data: Union[List, np.ndarray, InferTensorContents]):
"""Set the data of inference output associated with this object.
Args:
data: inference output data.
"""
self._data = data

@property
def shape(self) -> List[int]:
"""Get the shape of inference output associated with this object.
Expand Down Expand Up @@ -552,7 +634,7 @@ def set_shape(self, shape: List[int]):
"""
self._shape = shape

def as_numpy(self) -> numpy.ndarray:
def as_numpy(self) -> np.ndarray:
"""Decode the tensor output data as numpy array.
Returns:
Expand All @@ -562,7 +644,14 @@ def as_numpy(self) -> numpy.ndarray:
if dtype is None:
raise InvalidInput("invalid datatype in the input")
if self._raw_data is not None:
np_array = np.frombuffer(self._raw_data, dtype=dtype)
if self.datatype == "BYTES":
# String results contain a 4-byte string length
# followed by the actual string characters. Hence,
# need to decode the raw bytes to convert into
# array elements.
np_array = deserialize_bytes_tensor(self._raw_data)
else:
np_array = np.frombuffer(self._raw_data, dtype=dtype)
return np_array.reshape(self._shape)
else:
np_array = np.array(self._data, dtype=dtype)
Expand Down Expand Up @@ -765,7 +854,7 @@ def to_rest(self) -> Dict:
infer_output_dict["parameters"] = to_http_parameters(
infer_output.parameters
)
if isinstance(infer_output.data, numpy.ndarray):
if isinstance(infer_output.data, np.ndarray):
infer_output.set_data_from_numpy(infer_output.data, binary_data=False)
infer_output_dict["data"] = infer_output.data
elif isinstance(infer_output._raw_data, bytes):
Expand All @@ -791,8 +880,18 @@ def to_grpc(self) -> ModelInferResponse:
"""
infer_outputs = []
raw_output_contents = []
use_raw_outputs = False
# If FP16 datatype is present in the outputs use raw outputs.
if _contains_fp16_datatype(self):
use_raw_outputs = True
for infer_output in self.outputs:
if isinstance(infer_output.data, numpy.ndarray):
if (
use_raw_outputs
and infer_output.data
and isinstance(infer_output.data, list)
):
infer_output.data = infer_output.as_numpy()
if isinstance(infer_output.data, np.ndarray):
infer_output.set_data_from_numpy(infer_output.data, binary_data=True)
infer_output_dict = {
"name": infer_output.name,
Expand Down Expand Up @@ -831,6 +930,22 @@ def to_grpc(self) -> ModelInferResponse:
parameters=to_grpc_parameters(self.parameters) if self.parameters else None,
)

def get_output_by_name(self, name: str) -> Optional[InferOutput]:
"""Find an output Tensor in the InferResponse that has the given name
Args:
name : str
name of the output Tensor object
Returns:
InferOutput
The InferOutput with the specified name, or None if no
output with this name exists
"""
for infer_output in self.outputs:
if name == infer_output.name:
return infer_output
return None

def __eq__(self, other):
if not isinstance(other, InferResponse):
return False
Expand Down Expand Up @@ -895,3 +1010,16 @@ def to_http_parameters(
else:
http_params[key] = val
return http_params


def _contains_fp16_datatype(infer_response: InferResponse) -> bool:
"""
Checks whether the InferResponse outputs contains FP16 datatype.
:param infer_response: An InferResponse object containing model inference results.
:return: A boolean indicating whether any output in the InferResponse uses the FP16 datatype.
"""
for infer_output in infer_response.outputs:
if infer_output.datatype == "FP16":
return True
return False
Loading

0 comments on commit c660972

Please sign in to comment.