Skip to content

Commit

Permalink
minor code refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Sachin Varghese <[email protected]>
  • Loading branch information
SachinVarghese committed Oct 19, 2023
1 parent 43dfc1c commit b30c41f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 15 deletions.
21 changes: 9 additions & 12 deletions mlserver/handlers/dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__(self, settings: Settings, model_registry: MultiModelRegistry):
self._settings = settings
self._model_registry = model_registry

self._response_cache = self._create_response_cache()
if settings.cache_enabled:
self._response_cache = self._create_response_cache()
self._inference_middleware = InferenceMiddlewares(
CloudEventsMiddleware(settings)
)
Expand Down Expand Up @@ -91,8 +92,10 @@ async def infer(
model=name, version=version
).count_exceptions()

with infer_duration, infer_errors:
if self._response_cache is not None:
cache_key = payload.json()

with infer_duration, infer_errors:
if payload.id is None:
payload.id = generate_uuid()

Expand All @@ -105,17 +108,16 @@ async def infer(
# TODO: Make await optional for sync methods
with model_context(model.settings):
if (
self._settings.cache_enabled
self._response_cache is not None
and model.settings.cache_enabled is not False
and self._response_cache is not None
):
cache_value = await self._response_cache.lookup(cache_key)
if cache_value != "":
prediction = InferenceResponse.parse_raw(cache_value)
else:
prediction = await model.predict(payload)
# ignore cache insertion error if any
self._response_cache.insert(cache_key, prediction.json())
await self._response_cache.insert(cache_key, prediction.json())
else:
prediction = await model.predict(payload)

Expand All @@ -128,10 +130,5 @@ async def infer(

return prediction

def _create_response_cache(self) -> Optional[ResponseCache]:
if self._settings.cache_enabled:
if self._settings.cache_size is None:
# Default cache size if caching is enabled
self._settings.cache_size = 100
return LocalCache(size=self._settings.cache_size)
return None
def _create_response_cache(self) -> ResponseCache:
return LocalCache(size=self._settings.cache_size)
6 changes: 3 additions & 3 deletions mlserver/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,10 @@ class Config:
_custom_metrics_server_settings: Optional[dict] = None
_custom_grpc_server_settings: Optional[dict] = None

cache_enabled: Optional[bool] = None
cache_enabled: bool = False
"""Enable caching for the model predictions."""

cache_size: Optional[int] = None
cache_size: int = 100
"""Cache size to be used if caching is enabled."""


Expand Down Expand Up @@ -400,7 +400,7 @@ def version(self) -> Optional[str]:
parameters: Optional[ModelParameters] = None
"""Extra parameters for each instance of this model."""

cache_enabled: Optional[bool] = None
cache_enabled: bool = False
"""Enable caching for a specific model. This parameter can be used to disable
cache for a specific model, if the server level caching is enabled. If the
server level caching is disabled, this parameter value will have no effect."""

0 comments on commit b30c41f

Please sign in to comment.