Skip to content

Commit

Permalink
Update python model serving runtime API docstring (kserve#3338)
Browse files Browse the repository at this point in the history
* Update kserve python runtime API docstring

Signed-off-by: Dan Sun <[email protected]>

* Use InferenceError

Signed-off-by: Dan Sun <[email protected]>

---------

Signed-off-by: Dan Sun <[email protected]>
  • Loading branch information
yuzisun authored Jan 1, 2024
1 parent 99d355e commit fff1802
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 229 deletions.
75 changes: 45 additions & 30 deletions python/kserve/kserve/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ def __init__(self, predictor_host: str,
predictor_protocol: str = PredictorProtocol.REST_V1.value,
predictor_use_ssl: bool = False,
predictor_request_timeout_seconds: int = 600):
""" The configuration for the http call to the predictor
Args:
predictor_host: The host name of the predictor
predictor_protocol: The inference protocol used for predictor http call
predictor_use_ssl: Enable using ssl for http connection to the predictor
predictor_request_timeout_seconds: The request timeout seconds for the predictor http call
"""
self.predictor_host = predictor_host
self.predictor_protocol = predictor_protocol
self.predictor_use_ssl = predictor_use_ssl
Expand All @@ -73,10 +81,11 @@ class Model:
def __init__(self, name: str, predictor_config: Optional[PredictorConfig] = None):
"""KServe Model Public Interface
Model is intended to be subclassed by various components within KServe.
Model is intended to be subclassed to implement the model handlers.
Args:
name (str): Name of the model.
name: The name of the model.
predictor_config: The configurations for http call to the predictor.
"""
self.name = name
self.ready = False
Expand All @@ -99,9 +108,9 @@ async def __call__(self, body: Union[Dict, CloudEvent, InferRequest],
"""Method to call predictor or explainer with the given input.
Args:
body (Dict|CloudEvent|InferRequest): Request payload body.
model_type (ModelType): Model type enum. Can be either predictor or explainer.
headers (Dict): Request headers.
body: Request body.
model_type: ModelType enum: `ModelType.PREDICTOR` or `ModelType.EXPLAINER`.
headers: Request headers.
Returns:
Dict: Response output from preprocess -> predictor/explainer -> postprocess
Expand Down Expand Up @@ -184,8 +193,8 @@ def validate(self, payload):
return payload

def load(self) -> bool:
"""Load handler can be overridden to load the model from storage
``self.ready`` flag is used for model health check
"""Load handler can be overridden to load the model from storage.
The `self.ready` should be set to True after the model is loaded. The flag is used for model health check.
Returns:
bool: True if model is ready, False otherwise
Expand All @@ -211,32 +220,33 @@ def get_output_types(self) -> List[Dict]:

async def preprocess(self, payload: Union[Dict, InferRequest],
headers: Dict[str, str] = None) -> Union[Dict, InferRequest]:
"""`preprocess` handler can be overridden for data or feature transformation.
The default implementation decodes to Dict if it is a binary CloudEvent
or gets the data field from a structured CloudEvent.
""" `preprocess` handler can be overridden for data or feature transformation.
The model decodes the request body to `Dict` for v1 endpoints and `InferRequest` for v2 endpoints.
Args:
payload (Dict|InferRequest): Body of the request, v2 endpoints pass InferRequest.
headers (Dict): Request headers.
payload: Payload of the request.
headers: Request headers.
Returns:
Dict|InferRequest: Transformed inputs to ``predict`` handler or return InferRequest for predictor call.
A Dict or InferRequest in KServe Model Transformer mode which is transmitted on the wire to predictor.
Tensors in KServe Predictor mode which is passed to predict handler for performing the inference.
"""

return payload

async def postprocess(self, response: Union[Dict, InferResponse], headers: Dict[str, str] = None) \
async def postprocess(self, result: Union[Dict, InferResponse], headers: Dict[str, str] = None) \
-> Union[Dict, InferResponse]:
"""The postprocess handler can be overridden for inference response transformation.
""" The `postprocess` handler can be overridden for inference result or response transformation.
The predictor sends back the inference result in `Dict` for v1 endpoints and `InferResponse` for v2 endpoints.
Args:
response (Dict|InferResponse|ModelInferResponse): The response passed from ``predict`` handler.
headers (Dict): Request headers.
result: The inference result passed from `predict` handler or the HTTP response from predictor.
headers: Request headers.
Returns:
Dict: post-processed response.
A Dict or InferResponse after post-process to return back to the client.
"""
return response
return result

async def _http_predict(self, payload: Union[Dict, InferRequest], headers: Dict[str, str] = None) -> Dict:
protocol = "https" if self.use_ssl else "http"
Expand Down Expand Up @@ -289,15 +299,18 @@ async def _grpc_predict(self, payload: Union[ModelInferRequest, InferRequest], h

async def predict(self, payload: Union[Dict, InferRequest, ModelInferRequest],
headers: Dict[str, str] = None) -> Union[Dict, InferResponse]:
"""
""" The `predict` handler can be overridden for performing the inference.
By default, the predict handler makes call to predictor for the inference step.
Args:
payload (Dict|InferRequest|ModelInferRequest): Prediction inputs passed from ``preprocess`` handler.
headers (Dict): Request headers.
payload: Model inputs passed from `preprocess` handler.
headers: Request headers.
Returns:
Dict|InferResponse|ModelInferResponse: Return InferResponse for serializing the prediction result or
return the response from the predictor call.
Inference result or a Response from the predictor.
Raises:
HTTPStatusError when getting back an error response from the predictor.
"""
if not self.predictor_host:
raise NotImplementedError("Could not find predictor_host.")
Expand All @@ -311,22 +324,24 @@ async def predict(self, payload: Union[Dict, InferRequest, ModelInferRequest],

async def explain(self, payload: Dict, headers: Dict[str, str] = None) -> Dict:
"""`explain` handler can be overridden to implement the model explanation.
The default implementation makes call to the explainer if ``explainer_host`` is specified
The default implementation makes call to the explainer if ``explainer_host`` is specified.
Args:
payload (Dict): Dict passed from preprocess handler.
headers (Dict): Request headers.
payload: Explainer model inputs passed from preprocess handler.
headers: Request headers.
Returns:
Dict: Response from the explainer.
An Explanation for the inference result.
Raises:
HTTPStatusError when getting back an error response from the explainer.
"""
if self.explainer_host is None:
raise NotImplementedError("Could not find explainer_host.")

protocol = "https" if self.use_ssl else "http"
# Currently explainer only supports the kserve v1 endpoints
explain_url = EXPLAINER_URL_FORMAT.format(protocol, self.explainer_host, self.name)
if self.protocol == PredictorProtocol.REST_V2.value:
explain_url = EXPLAINER_V2_URL_FORMAT.format(protocol, self.explainer_host, self.name)
response = await self._http_client.post(
url=explain_url,
timeout=self.timeout,
Expand Down
40 changes: 23 additions & 17 deletions python/kserve/kserve/model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,6 @@


class ModelServer:
"""KServe ModelServer
Args:
http_port (int): HTTP port. Default: ``8080``.
grpc_port (int): GRPC port. Default: ``8081``.
workers (int): Number of workers for uvicorn. Default: ``1``.
max_threads (int): Max number of processing threads. Default: ``4``
max_asyncio_workers (int): Max number of AsyncIO threads. Default: ``None``
registered_models (ModelRepository): Model repository with registered models.
enable_grpc (bool): Whether to turn on grpc server. Default: ``True``
enable_docs_url (bool): Whether to turn on ``/docs`` Swagger UI. Default: ``False``.
enable_latency_logging (bool): Whether to log latency metric. Default: ``True``.
configure_logging (bool): Whether to configure KServe and Uvicorn logging. Default: ``True``.
log_config (dict or str): File path or dict containing log config. Default: ``None``.
access_log_format (string): Format to set for the access log (provided by asgi-logger). Default: ``None``
"""

def __init__(self, http_port: int = args.http_port,
grpc_port: int = args.grpc_port,
workers: int = args.workers,
Expand All @@ -114,6 +97,22 @@ def __init__(self, http_port: int = args.http_port,
log_config: Optional[Union[Dict, str]] = args.log_config_file,
access_log_format: str = args.access_log_format,
):
"""KServe ModelServer Constructor
Args:
http_port: HTTP port. Default: ``8080``.
grpc_port: GRPC port. Default: ``8081``.
workers: Number of uvicorn workers. Default: ``1``.
max_threads: Max number of gRPC processing threads. Default: ``4``
max_asyncio_workers: Max number of AsyncIO threads. Default: ``None``
registered_models: Model repository with registered models.
enable_grpc: Whether to turn on grpc server. Default: ``True``
enable_docs_url: Whether to turn on ``/docs`` Swagger UI. Default: ``False``.
enable_latency_logging: Whether to log latency metric. Default: ``True``.
configure_logging: Whether to configure KServe and Uvicorn logging. Default: ``True``.
log_config: File path or dict containing log config. Default: ``None``.
access_log_format: Format to set for the access log (provided by asgi-logger). Default: ``None``
"""
self.registered_models = registered_models
self.http_port = http_port
self.grpc_port = grpc_port
Expand Down Expand Up @@ -143,6 +142,11 @@ def __init__(self, http_port: int = args.http_port,
self.access_log_format = access_log_format

def start(self, models: Union[List[Model], Dict[str, Deployment]]) -> None:
""" Start the model server with a set of registered models.
Args:
models: a list of models to register to the model server.
"""
if isinstance(models, list):
for model in models:
if isinstance(model, Model):
Expand Down Expand Up @@ -217,6 +221,8 @@ async def servers_task():
asyncio.run(servers_task())

async def stop(self, sig: Optional[int] = None):
""" Stop the instances of REST and gRPC model servers.
"""
logger.info("Stopping the model server")
if self._rest_server:
logger.info("Stopping the rest server")
Expand Down
Loading

0 comments on commit fff1802

Please sign in to comment.