diff --git a/python/kserve/kserve/model.py b/python/kserve/kserve/model.py index aa8f1cd5e8f..c0ef160ff88 100644 --- a/python/kserve/kserve/model.py +++ b/python/kserve/kserve/model.py @@ -63,6 +63,14 @@ def __init__(self, predictor_host: str, predictor_protocol: str = PredictorProtocol.REST_V1.value, predictor_use_ssl: bool = False, predictor_request_timeout_seconds: int = 600): + """ The configuration for the http call to the predictor + + Args: + predictor_host: The host name of the predictor + predictor_protocol: The inference protocol used for predictor http call + predictor_use_ssl: Enable using ssl for http connection to the predictor + predictor_request_timeout_seconds: The request timeout seconds for the predictor http call + """ self.predictor_host = predictor_host self.predictor_protocol = predictor_protocol self.predictor_use_ssl = predictor_use_ssl @@ -73,10 +81,11 @@ class Model: def __init__(self, name: str, predictor_config: Optional[PredictorConfig] = None): """KServe Model Public Interface - Model is intended to be subclassed by various components within KServe. + Model is intended to be subclassed to implement the model handlers. Args: - name (str): Name of the model. + name: The name of the model. + predictor_config: The configurations for http call to the predictor. """ self.name = name self.ready = False @@ -99,9 +108,9 @@ async def __call__(self, body: Union[Dict, CloudEvent, InferRequest], """Method to call predictor or explainer with the given input. Args: - body (Dict|CloudEvent|InferRequest): Request payload body. - model_type (ModelType): Model type enum. Can be either predictor or explainer. - headers (Dict): Request headers. + body: Request body. + model_type: ModelType enum: `ModelType.PREDICTOR` or `ModelType.EXPLAINER`. + headers: Request headers. Returns: Dict: Response output from preprocess -> predictor/explainer -> postprocess @@ -184,8 +193,8 @@ def validate(self, payload): return payload def load(self) -> bool: - """Load handler can be overridden to load the model from storage - ``self.ready`` flag is used for model health check + """Load handler can be overridden to load the model from storage. + The `self.ready` should be set to True after the model is loaded. The flag is used for model health check. Returns: bool: True if model is ready, False otherwise @@ -211,32 +220,33 @@ def get_output_types(self) -> List[Dict]: async def preprocess(self, payload: Union[Dict, InferRequest], headers: Dict[str, str] = None) -> Union[Dict, InferRequest]: - """`preprocess` handler can be overridden for data or feature transformation. - The default implementation decodes to Dict if it is a binary CloudEvent - or gets the data field from a structured CloudEvent. + """ `preprocess` handler can be overridden for data or feature transformation. + The model decodes the request body to `Dict` for v1 endpoints and `InferRequest` for v2 endpoints. Args: - payload (Dict|InferRequest): Body of the request, v2 endpoints pass InferRequest. - headers (Dict): Request headers. + payload: Payload of the request. + headers: Request headers. Returns: - Dict|InferRequest: Transformed inputs to ``predict`` handler or return InferRequest for predictor call. + A Dict or InferRequest in KServe Model Transformer mode which is transmitted on the wire to predictor. + Tensors in KServe Predictor mode which is passed to predict handler for performing the inference. """ return payload - async def postprocess(self, response: Union[Dict, InferResponse], headers: Dict[str, str] = None) \ + async def postprocess(self, result: Union[Dict, InferResponse], headers: Dict[str, str] = None) \ -> Union[Dict, InferResponse]: - """The postprocess handler can be overridden for inference response transformation. + """ The `postprocess` handler can be overridden for inference result or response transformation. + The predictor sends back the inference result in `Dict` for v1 endpoints and `InferResponse` for v2 endpoints. Args: - response (Dict|InferResponse|ModelInferResponse): The response passed from ``predict`` handler. - headers (Dict): Request headers. + result: The inference result passed from `predict` handler or the HTTP response from predictor. + headers: Request headers. Returns: - Dict: post-processed response. + A Dict or InferResponse after post-process to return back to the client. """ - return response + return result async def _http_predict(self, payload: Union[Dict, InferRequest], headers: Dict[str, str] = None) -> Dict: protocol = "https" if self.use_ssl else "http" @@ -289,15 +299,18 @@ async def _grpc_predict(self, payload: Union[ModelInferRequest, InferRequest], h async def predict(self, payload: Union[Dict, InferRequest, ModelInferRequest], headers: Dict[str, str] = None) -> Union[Dict, InferResponse]: - """ + """ The `predict` handler can be overridden for performing the inference. + By default, the predict handler makes call to predictor for the inference step. Args: - payload (Dict|InferRequest|ModelInferRequest): Prediction inputs passed from ``preprocess`` handler. - headers (Dict): Request headers. + payload: Model inputs passed from `preprocess` handler. + headers: Request headers. Returns: - Dict|InferResponse|ModelInferResponse: Return InferResponse for serializing the prediction result or - return the response from the predictor call. + Inference result or a Response from the predictor. + + Raises: + HTTPStatusError when getting back an error response from the predictor. """ if not self.predictor_host: raise NotImplementedError("Could not find predictor_host.") @@ -311,22 +324,24 @@ async def predict(self, payload: Union[Dict, InferRequest, ModelInferRequest], async def explain(self, payload: Dict, headers: Dict[str, str] = None) -> Dict: """`explain` handler can be overridden to implement the model explanation. - The default implementation makes call to the explainer if ``explainer_host`` is specified + The default implementation makes call to the explainer if ``explainer_host`` is specified. Args: - payload (Dict): Dict passed from preprocess handler. - headers (Dict): Request headers. + payload: Explainer model inputs passed from preprocess handler. + headers: Request headers. Returns: - Dict: Response from the explainer. + An Explanation for the inference result. + + Raises: + HTTPStatusError when getting back an error response from the explainer. """ if self.explainer_host is None: raise NotImplementedError("Could not find explainer_host.") protocol = "https" if self.use_ssl else "http" + # Currently explainer only supports the kserve v1 endpoints explain_url = EXPLAINER_URL_FORMAT.format(protocol, self.explainer_host, self.name) - if self.protocol == PredictorProtocol.REST_V2.value: - explain_url = EXPLAINER_V2_URL_FORMAT.format(protocol, self.explainer_host, self.name) response = await self._http_client.post( url=explain_url, timeout=self.timeout, diff --git a/python/kserve/kserve/model_server.py b/python/kserve/kserve/model_server.py index 2a1cdff9b46..31844bea70b 100644 --- a/python/kserve/kserve/model_server.py +++ b/python/kserve/kserve/model_server.py @@ -84,23 +84,6 @@ class ModelServer: - """KServe ModelServer - - Args: - http_port (int): HTTP port. Default: ``8080``. - grpc_port (int): GRPC port. Default: ``8081``. - workers (int): Number of workers for uvicorn. Default: ``1``. - max_threads (int): Max number of processing threads. Default: ``4`` - max_asyncio_workers (int): Max number of AsyncIO threads. Default: ``None`` - registered_models (ModelRepository): Model repository with registered models. - enable_grpc (bool): Whether to turn on grpc server. Default: ``True`` - enable_docs_url (bool): Whether to turn on ``/docs`` Swagger UI. Default: ``False``. - enable_latency_logging (bool): Whether to log latency metric. Default: ``True``. - configure_logging (bool): Whether to configure KServe and Uvicorn logging. Default: ``True``. - log_config (dict or str): File path or dict containing log config. Default: ``None``. - access_log_format (string): Format to set for the access log (provided by asgi-logger). Default: ``None`` - """ - def __init__(self, http_port: int = args.http_port, grpc_port: int = args.grpc_port, workers: int = args.workers, @@ -114,6 +97,22 @@ def __init__(self, http_port: int = args.http_port, log_config: Optional[Union[Dict, str]] = args.log_config_file, access_log_format: str = args.access_log_format, ): + """KServe ModelServer Constructor + + Args: + http_port: HTTP port. Default: ``8080``. + grpc_port: GRPC port. Default: ``8081``. + workers: Number of uvicorn workers. Default: ``1``. + max_threads: Max number of gRPC processing threads. Default: ``4`` + max_asyncio_workers: Max number of AsyncIO threads. Default: ``None`` + registered_models: Model repository with registered models. + enable_grpc: Whether to turn on grpc server. Default: ``True`` + enable_docs_url: Whether to turn on ``/docs`` Swagger UI. Default: ``False``. + enable_latency_logging: Whether to log latency metric. Default: ``True``. + configure_logging: Whether to configure KServe and Uvicorn logging. Default: ``True``. + log_config: File path or dict containing log config. Default: ``None``. + access_log_format: Format to set for the access log (provided by asgi-logger). Default: ``None`` + """ self.registered_models = registered_models self.http_port = http_port self.grpc_port = grpc_port @@ -143,6 +142,11 @@ def __init__(self, http_port: int = args.http_port, self.access_log_format = access_log_format def start(self, models: Union[List[Model], Dict[str, Deployment]]) -> None: + """ Start the model server with a set of registered models. + + Args: + models: a list of models to register to the model server. + """ if isinstance(models, list): for model in models: if isinstance(model, Model): @@ -217,6 +221,8 @@ async def servers_task(): asyncio.run(servers_task()) async def stop(self, sig: Optional[int] = None): + """ Stop the instances of REST and gRPC model servers. + """ logger.info("Stopping the model server") if self._rest_server: logger.info("Stopping the rest server") diff --git a/python/kserve/kserve/protocol/infer_type.py b/python/kserve/kserve/protocol/infer_type.py index c2a9c6472f8..7250345ebc0 100644 --- a/python/kserve/kserve/protocol/infer_type.py +++ b/python/kserve/kserve/protocol/infer_type.py @@ -12,17 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, List, Dict +from typing import Optional, List, Dict, Union import numpy import numpy as np import pandas as pd import uuid -from tritonclient.utils import raise_error, serialize_byte_tensor + +from google.protobuf.internal.containers import MessageMap +from tritonclient.utils import serialize_byte_tensor from ..constants.constants import GRPC_CONTENT_DATATYPE_MAPPINGS -from ..errors import InvalidInput -from ..protocol.grpc.grpc_predict_v2_pb2 import ModelInferRequest, InferTensorContents, ModelInferResponse +from ..errors import InvalidInput, InferenceError +from ..protocol.grpc.grpc_predict_v2_pb2 import ModelInferRequest, InferTensorContents, ModelInferResponse, \ + InferParameter from ..utils.numpy_codec import to_np_dtype, from_np_dtype @@ -32,21 +35,19 @@ class InferInput: _datatype: str _parameters: Dict - def __init__(self, name, shape, datatype, data=None, parameters=None): - """An object of InferInput class is used to describe - input tensor for an inference request. - Parameters - ---------- - name : str - The name of input whose data will be described by this object - shape : list - The shape of the associated input. - datatype : str - The datatype of the associated input. - data : Union[List, InferTensorContents] - The data of the REST/gRPC input. When data is not set, raw_data is used for gRPC for numpy array bytes. - parameters : dict - The additional server-specific parameters. + def __init__(self, name: str, shape: List[int], datatype: str, + data: Union[List, np.ndarray, InferTensorContents] = None, + parameters: Optional[Union[Dict, MessageMap[str, InferParameter]]] = None): + """An object of InferInput class is used to describe the input tensor of an inference request. + + Args: + name: The name of the inference input whose data will be described by this object. + shape : The shape of the associated inference input. + datatype : The data type of the associated inference input. + data : The data of the inference input. + When data is not set, raw_data is used for gRPC to transmit with numpy array bytes + by using `set_data_from_numpy`. + parameters : The additional inference parameters. """ if parameters is None: parameters = {} @@ -58,65 +59,67 @@ def __init__(self, name, shape, datatype, data=None, parameters=None): self._raw_data = None @property - def name(self): - """Get the name of input associated with this object. - Returns - ------- - str - The name of input + def name(self) -> str: + """Get the name of inference input associated with this object. + + Returns: + The name of the inference input """ return self._name @property - def datatype(self): - """Get the datatype of input associated with this object. - Returns - ------- - str - The datatype of input + def datatype(self) -> str: + """Get the datatype of inference input associated with this object. + + Returns: + The datatype of the inference input. """ return self._datatype @property - def data(self): - """Get the data of InferInput + def data(self) -> Union[List, np.ndarray, InferTensorContents]: + """Get the data of the inference input associated with this object. + Returns: + The data of the inference input. """ return self._data @property - def shape(self): - """Get the shape of input associated with this object. - Returns - ------- - list - The shape of input + def shape(self) -> List[int]: + """Get the shape of inference input associated with this object. + + Returns: + The shape of the inference input. """ return self._shape @property - def parameters(self): - """Get the parameters of input associated with this object. - Returns - ------- - dict - The key, value pair of string and InferParameter + def parameters(self) -> Union[Dict, MessageMap[str, InferParameter]]: + """Get the parameters of the inference input associated with this object. + + Returns: + The additional inference parameters """ return self._parameters - def set_shape(self, shape): - """Set the shape of input. - Parameters - ---------- - shape : list - The shape of the associated input. + def set_shape(self, shape: List[int]): + """Set the shape of inference input. + + Args: + shape : The shape of the associated inference input. """ self._shape = shape def as_numpy(self) -> np.ndarray: + """ Decode the inference input data as numpy array. + + Returns: + A numpy array of the inference input data + """ dtype = to_np_dtype(self.datatype) if dtype is None: - raise InvalidInput("invalid datatype in the input") + raise InvalidInput(f"invalid datatype {dtype} in the input") if self._raw_data is not None: np_array = np.frombuffer(self._raw_data, dtype=dtype) return np_array.reshape(self._shape) @@ -124,29 +127,25 @@ def as_numpy(self) -> np.ndarray: np_array = np.array(self._data, dtype=dtype) return np_array.reshape(self._shape) - def set_data_from_numpy(self, input_tensor, binary_data=True): - """Set the tensor data from the specified numpy array for - input associated with this object. - Parameters - ---------- - input_tensor : numpy array - The tensor data in numpy array format - binary_data : bool - Indicates whether to set data for the input in binary format - or explicit tensor within JSON. The default value is True, - which means the data will be delivered as binary data in the - HTTP body after the JSON object. - Raises - ------ - InferenceServerException - If failed to set data for the tensor. + def set_data_from_numpy(self, input_tensor: np.ndarray, binary_data: bool = True): + """Set the tensor data from the specified numpy array for input associated with this object. + + Args: + input_tensor : The tensor data in numpy array format. + binary_data : Indicates whether to set data for the input in binary format + or explicit tensor within JSON. The default value is True, + which means the data will be delivered as binary data with gRPC or in the + HTTP body after the JSON object for REST. + + Raises: + InferenceError if failed to set data for the tensor. """ if not isinstance(input_tensor, (np.ndarray,)): - raise_error("input_tensor must be a numpy array") + InferenceError("input_tensor must be a numpy array") dtype = from_np_dtype(input_tensor.dtype) if self._datatype != dtype: - raise_error( + InferenceError( "got unexpected datatype {} from numpy array, expected {}".format(dtype, self._datatype)) valid_shape = True if len(self._shape) != len(input_tensor.shape): @@ -156,7 +155,7 @@ def set_data_from_numpy(self, input_tensor, binary_data=True): if self._shape[i] != input_tensor.shape[i]: valid_shape = False if not valid_shape: - raise_error( + InferenceError( "got unexpected numpy array shape [{}], expected [{}]".format( str(input_tensor.shape)[1:-1], str(self._shape)[1:-1])) @@ -184,7 +183,7 @@ def set_data_from_numpy(self, input_tensor, binary_data=True): self._data.append( str(obj.item(), encoding='utf-8')) except UnicodeDecodeError: - raise_error( + InferenceError( f'Failed to encode "{obj.item()}" using UTF-8. Please use binary_data=True, if' ' you want to pass a byte array.') else: @@ -224,16 +223,6 @@ def get_content(datatype: str, data: InferTensorContents): class InferRequest: - """InferenceRequest Model - - $inference_request = - { - "id" : $string #optional, - "parameters" : $parameters #optional, - "inputs" : [ $request_input, ... ], - "outputs" : [ $request_output, ... ] #optional - } - """ id: Optional[str] model_name: str parameters: Optional[Dict] @@ -241,7 +230,20 @@ class InferRequest: from_grpc: bool def __init__(self, model_name: str, infer_inputs: List[InferInput], - request_id=None, raw_inputs=None, from_grpc=False, parameters=None): + request_id: Optional[str] = None, + raw_inputs=None, + from_grpc: Optional[bool] = False, + parameters: Optional[Union[Dict, MessageMap[str, InferParameter]]] = None): + """InferRequest Data Model. + + Args: + model_name: The model name. + infer_inputs: The inference inputs for the model. + request_id: The id for the inference request. + raw_inputs: The binary data for the inference inputs. + from_grpc: Indicate if the data model is constructed from gRPC request. + parameters: The additional inference parameters. + """ if parameters is None: parameters = {} self.id = request_id @@ -255,6 +257,9 @@ def __init__(self, model_name: str, infer_inputs: List[InferInput], @classmethod def from_grpc(cls, request: ModelInferRequest): + """ The class method to construct the InferRequest from a ModelInferRequest + + """ infer_inputs = [InferInput(name=input_tensor.name, shape=list(input_tensor.shape), datatype=input_tensor.datatype, data=get_content(input_tensor.datatype, input_tensor.contents), @@ -264,9 +269,11 @@ def from_grpc(cls, request: ModelInferRequest): raw_inputs=request.raw_input_contents, from_grpc=True, parameters=request.parameters) def to_rest(self) -> Dict: - """ Converts the InferRequest object to v2 REST InferenceRequest message + """ Converts the InferRequest object to v2 REST InferRequest Dict. - """ + Returns: + The InferRequest Dict converted from InferRequest object. + """ infer_inputs = [] for infer_input in self.inputs: infer_input_dict = { @@ -287,8 +294,10 @@ def to_rest(self) -> Dict: return infer_request def to_grpc(self) -> ModelInferRequest: - """ Converts the InferRequest object to gRPC ModelInferRequest message + """ Converts the InferRequest object to gRPC ModelInferRequest type. + Returns: + The ModelInferResponse gRPC type converted from InferRequest object. """ infer_inputs = [] raw_input_contents = [] @@ -320,8 +329,10 @@ def to_grpc(self) -> ModelInferRequest: raw_input_contents=raw_input_contents) def as_dataframe(self) -> pd.DataFrame: - """ - Decode the tensor inputs as pandas dataframe + """ Decode the tensor inputs as pandas dataframe. + + Returns: + The inference input data as pandas dataframe """ dfs = [] for input in self.inputs: @@ -334,21 +345,18 @@ def as_dataframe(self) -> pd.DataFrame: class InferOutput: - def __init__(self, name, shape, datatype, data=None, parameters=None): - """An object of InferOutput class is used to describe - input tensor for an inference request. - Parameters - ---------- - name : str - The name of input whose data will be described by this object - shape : list - The shape of the associated input. - datatype : str - The datatype of the associated input. - data : Union[List, InferTensorContents] - The data of the REST/gRPC input. When data is not set, raw_data is used for gRPC for numpy array bytes. - parameters : dict - The additional server-specific parameters. + def __init__(self, name: str, shape: List[int], datatype: str, + data: Union[List, np.ndarray, InferTensorContents] = None, + parameters: Optional[Union[Dict, MessageMap[str, InferParameter]]] = None): + """An object of InferOutput class is used to describe the output tensor for an inference response. + + Args: + name : The name of inference output whose data will be described by this object. + shape : The shape of the associated inference output. + datatype : The data type of the associated inference output. + data : The data of the inference output. When data is not set, + raw_data is used for gRPC with numpy array bytes by calling set_data_from_numpy. + parameters : The additional inference parameters. """ if parameters is None: parameters = {} @@ -360,64 +368,63 @@ def __init__(self, name, shape, datatype, data=None, parameters=None): self._raw_data = None @property - def name(self): - """Get the name of input associated with this object. - Returns - ------- - str - The name of input + def name(self) -> str: + """Get the name of inference output associated with this object. + + Returns: + The name of inference output. """ return self._name @property - def datatype(self): - """Get the datatype of input associated with this object. - Returns - ------- - str - The datatype of input + def datatype(self) -> str: + """Get the data type of inference output associated with this object. + + Returns: + The data type of inference output. """ return self._datatype @property - def data(self): - """Get the data of InferOutput + def data(self) -> Union[List, np.ndarray, InferTensorContents]: + """Get the data of inference output associated with this object. + Returns: + The data of inference output. """ return self._data @property - def shape(self): - """Get the shape of input associated with this object. - Returns - ------- - list - The shape of input + def shape(self) -> List[int]: + """Get the shape of inference output associated with this object. + + Returns: + The shape of inference output """ return self._shape @property - def parameters(self): - """Get the parameters of input associated with this object. - Returns - ------- - dict - The key, value pair of string and InferParameter + def parameters(self) -> Union[Dict, MessageMap[str, InferParameter]]: + """Get the parameters of inference output associated with this object. + + Returns: + The additional inference parameters associated with the inference output. """ return self._parameters - def set_shape(self, shape): - """Set the shape of input. - Parameters - ---------- - shape : list - The shape of the associated input. + def set_shape(self, shape: List[int]): + """Set the shape of inference output. + + Args: + shape: The shape of the associated inference output. """ self._shape = shape def as_numpy(self) -> numpy.ndarray: - """ - Decode the tensor data as numpy array + """ Decode the tensor output data as numpy array. + + Returns: + The numpy array of the associated inference output data. """ dtype = to_np_dtype(self.datatype) if dtype is None: @@ -429,41 +436,37 @@ def as_numpy(self) -> numpy.ndarray: np_array = np.array(self._data, dtype=dtype) return np_array.reshape(self._shape) - def set_data_from_numpy(self, input_tensor, binary_data=True): - """Set the tensor data from the specified numpy array for - input associated with this object. - Parameters - ---------- - input_tensor : numpy array - The tensor data in numpy array format - binary_data : bool - Indicates whether to set data for the input in binary format - or explicit tensor within JSON. The default value is True, - which means the data will be delivered as binary data in the - HTTP body after the JSON object. - Raises - ------ - InferenceServerException - If failed to set data for the tensor. + def set_data_from_numpy(self, output_tensor: np.ndarray, binary_data: bool = True): + """Set the tensor data from the specified numpy array for the inference output associated with this object. + + Args: + output_tensor : The tensor data in numpy array format. + binary_data : Indicates whether to set data for the input in binary format + or explicit tensor within JSON. The default value is True, + which means the data will be delivered as binary data with gRPC or in the + HTTP body after the JSON object for REST. + + Raises: + InferenceError if failed to set data for the output tensor. """ - if not isinstance(input_tensor, (np.ndarray,)): - raise_error("input_tensor must be a numpy array") + if not isinstance(output_tensor, (np.ndarray,)): + InferenceError("input_tensor must be a numpy array") - dtype = from_np_dtype(input_tensor.dtype) + dtype = from_np_dtype(output_tensor.dtype) if self._datatype != dtype: - raise_error( + InferenceError( "got unexpected datatype {} from numpy array, expected {}".format(dtype, self._datatype)) valid_shape = True - if len(self._shape) != len(input_tensor.shape): + if len(self._shape) != len(output_tensor.shape): valid_shape = False else: for i in range(len(self._shape)): - if self._shape[i] != input_tensor.shape[i]: + if self._shape[i] != output_tensor.shape[i]: valid_shape = False if not valid_shape: - raise_error( + InferenceError( "got unexpected numpy array shape [{}], expected [{}]".format( - str(input_tensor.shape)[1:-1], + str(output_tensor.shape)[1:-1], str(self._shape)[1:-1])) if not binary_data: @@ -472,14 +475,14 @@ def set_data_from_numpy(self, input_tensor, binary_data=True): if self._datatype == "BYTES": self._data = [] try: - if input_tensor.size > 0: - for obj in np.nditer(input_tensor, + if output_tensor.size > 0: + for obj in np.nditer(output_tensor, flags=["refs_ok"], order='C'): # We need to convert the object to string using utf-8, # if we want to use the binary_data=False. JSON requires # the input to be a UTF-8 string. - if input_tensor.dtype == np.object_: + if output_tensor.dtype == np.object_: if type(obj.item()) == bytes: self._data.append( str(obj.item(), encoding='utf-8')) @@ -489,36 +492,25 @@ def set_data_from_numpy(self, input_tensor, binary_data=True): self._data.append( str(obj.item(), encoding='utf-8')) except UnicodeDecodeError: - raise_error( + InferenceError( f'Failed to encode "{obj.item()}" using UTF-8. Please use binary_data=True, if' ' you want to pass a byte array.') else: - self._data = [val.item() for val in input_tensor.flatten()] + self._data = [val.item() for val in output_tensor.flatten()] else: self._data = None if self._datatype == "BYTES": - serialized_output = serialize_byte_tensor(input_tensor) + serialized_output = serialize_byte_tensor(output_tensor) if serialized_output.size > 0: self._raw_data = serialized_output.item() else: self._raw_data = b'' else: - self._raw_data = input_tensor.tobytes() + self._raw_data = output_tensor.tobytes() self._parameters['binary_data_size'] = len(self._raw_data) class InferResponse: - """InferenceResponse - - $inference_response = - { - "model_name" : $string, - "model_version" : $string #optional, - "id" : $string, - "parameters" : $parameters #optional, - "outputs" : [ $response_output, ... ] - } - """ id: str model_name: str parameters: Optional[Dict] @@ -526,7 +518,18 @@ class InferResponse: from_grpc: bool def __init__(self, response_id: str, model_name: str, infer_outputs: List[InferOutput], - raw_outputs=None, from_grpc=False, parameters=None): + raw_outputs=None, from_grpc: Optional[bool] = False, + parameters: Optional[Union[Dict, MessageMap[str, InferParameter]]] = None): + """The InferResponse Data Model + + Args: + response_id: The id of the inference response. + model_name: The name of the model. + infer_outputs: The inference outputs of the inference response. + raw_outputs: The raw binary data of the inference outputs. + from_grpc: Indicate if the InferResponse is constructed from a gRPC response. + parameters: The additional inference parameters. + """ if parameters is None: parameters = {} self.id = response_id @@ -540,6 +543,8 @@ def __init__(self, response_id: str, model_name: str, infer_outputs: List[InferO @classmethod def from_grpc(cls, response: ModelInferResponse) -> 'InferResponse': + """ The class method to construct the InferResponse object from gRPC message type. + """ infer_outputs = [InferOutput(name=output.name, shape=list(output.shape), datatype=output.datatype, data=get_content(output.datatype, output.contents), @@ -550,6 +555,9 @@ def from_grpc(cls, response: ModelInferResponse) -> 'InferResponse': @classmethod def from_rest(cls, model_name: str, response: Dict) -> 'InferResponse': + """ The class method to construct the InferResponse object from REST message type. + + """ infer_outputs = [InferOutput(name=output['name'], shape=list(output['shape']), datatype=output['datatype'], @@ -562,8 +570,10 @@ def from_rest(cls, model_name: str, response: Dict) -> 'InferResponse': infer_outputs=infer_outputs) def to_rest(self) -> Dict: - """ Converts the InferResponse object to v2 REST InferenceRequest message + """ Converts the InferResponse object to v2 REST InferResponse dict. + Returns: + The InferResponse Dict. """ infer_outputs = [] for i, infer_output in enumerate(self.outputs): @@ -588,8 +598,10 @@ def to_rest(self) -> Dict: return res def to_grpc(self) -> ModelInferResponse: - """ Converts the InferResponse object to gRPC ModelInferRequest message + """ Converts the InferResponse object to gRPC ModelInferResponse type. + Returns: + The ModelInferResponse gRPC message. """ infer_outputs = [] raw_output_contents = []