Skip to content

Commit

Permalink
cache impl refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Sachin Varghese <[email protected]>
  • Loading branch information
SachinVarghese committed Oct 18, 2023
1 parent 8a13e64 commit 43dfc1c
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 38 deletions.
1 change: 1 addition & 0 deletions mlserver/cache/local/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def __init__(self, size=100):
async def insert(self, key: str, value: str):
self.cache[key] = value
if len(self.cache) > self.size_limit:
# The cache removes the first entry if it overflows (i.e. in FIFO order)
self.cache.popitem(last=False)
return None

Expand Down
28 changes: 15 additions & 13 deletions mlserver/handlers/dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
)
from typing import Optional

import asyncio
from ..errors import ModelNotReady
from ..context import model_context
from ..settings import Settings
Expand All @@ -18,7 +17,7 @@
from ..middleware import InferenceMiddlewares
from ..cloudevents import CloudEventsMiddleware
from ..utils import generate_uuid
from ..cache import ResponseCache
from ..cache import ResponseCache, LocalCache


class DataPlane:
Expand All @@ -27,15 +26,11 @@ class DataPlane:
servers.
"""

def __init__(
self,
settings: Settings,
model_registry: MultiModelRegistry,
response_cache: Optional[ResponseCache] = None,
):
def __init__(self, settings: Settings, model_registry: MultiModelRegistry):
self._settings = settings
self._model_registry = model_registry
self._response_cache = response_cache

self._response_cache = self._create_response_cache()
self._inference_middleware = InferenceMiddlewares(
CloudEventsMiddleware(settings)
)
Expand Down Expand Up @@ -111,17 +106,16 @@ async def infer(
with model_context(model.settings):
if (
self._settings.cache_enabled
and model.settings.cache_enabled
and model.settings.cache_enabled is not False
and self._response_cache is not None
):
cache_value = await self._response_cache.lookup(cache_key)
if cache_value != "":
prediction = InferenceResponse.parse_raw(cache_value)
else:
prediction = await model.predict(payload)
asyncio.create_task(
self._response_cache.insert(cache_key, prediction.json())
)
# ignore cache insertion error if any
self._response_cache.insert(cache_key, prediction.json())
else:
prediction = await model.predict(payload)

Expand All @@ -133,3 +127,11 @@ async def infer(
self._ModelInferRequestSuccess.labels(model=name, version=version).inc()

return prediction

def _create_response_cache(self) -> Optional[ResponseCache]:
if self._settings.cache_enabled:
if self._settings.cache_size is None:
# Default cache size if caching is enabled
self._settings.cache_size = 100
return LocalCache(size=self._settings.cache_size)
return None
12 changes: 1 addition & 11 deletions mlserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from .metrics import MetricsServer
from .kafka import KafkaServer
from .utils import logger
from .cache import ResponseCache, LocalCache

HANDLED_SIGNALS = [signal.SIGINT, signal.SIGTERM, signal.SIGQUIT]

Expand Down Expand Up @@ -47,11 +46,8 @@ def __init__(self, settings: Settings):
self._model_repository = ModelRepositoryFactory.resolve_model_repository(
self._settings
)
self._response_cache = self._create_response_cache()
self._data_plane = DataPlane(
settings=self._settings,
model_registry=self._model_registry,
response_cache=self._response_cache,
settings=self._settings, model_registry=self._model_registry
)
self._model_repository_handlers = ModelRepositoryHandlers(
repository=self._model_repository, model_registry=self._model_registry
Expand All @@ -60,12 +56,6 @@ def __init__(self, settings: Settings):
self._configure_logger()
self._create_servers()

def _create_response_cache(self) -> Optional[ResponseCache]:
if self._settings.cache_enabled:
return LocalCache(size=self._settings.cache_size)
else:
return None

def _create_model_registry(self) -> MultiModelRegistry:
on_model_load = [
self.add_custom_handlers,
Expand Down
12 changes: 7 additions & 5 deletions mlserver/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,10 @@ class Config:
_custom_metrics_server_settings: Optional[dict] = None
_custom_grpc_server_settings: Optional[dict] = None

cache_enabled: Optional[bool] = False
"""Enable Caching for the model predictions."""
cache_enabled: Optional[bool] = None
"""Enable caching for the model predictions."""

cache_size: Optional[int] = 100
cache_size: Optional[int] = None
"""Cache size to be used if caching is enabled."""


Expand Down Expand Up @@ -400,5 +400,7 @@ def version(self) -> Optional[str]:
parameters: Optional[ModelParameters] = None
"""Extra parameters for each instance of this model."""

cache_enabled: Optional[bool] = False
"""Enable Caching for the model predictions."""
cache_enabled: Optional[bool] = None
"""Enable caching for a specific model. This parameter can be used to disable
cache for a specific model, if the server level caching is enabled. If the
server level caching is disabled, this parameter value will have no effect."""
4 changes: 1 addition & 3 deletions runtimes/alibi-explain/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,7 @@ async def model_registry(

@pytest.fixture
def data_plane(settings: Settings, model_registry: MultiModelRegistry) -> DataPlane:
return DataPlane(
settings=settings, model_registry=model_registry, response_cache=None
)
return DataPlane(settings=settings, model_registry=model_registry)


@pytest.fixture
Expand Down
4 changes: 1 addition & 3 deletions runtimes/mlflow/tests/rest/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ def settings() -> Settings:

@pytest.fixture
def data_plane(settings: Settings, model_registry: MultiModelRegistry) -> DataPlane:
return DataPlane(
settings=settings, model_registry=model_registry, response_cache=None
)
return DataPlane(settings=settings, model_registry=model_registry)


@pytest.fixture
Expand Down
4 changes: 1 addition & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,7 @@ def data_plane(
model_registry: MultiModelRegistry,
prometheus_registry: CollectorRegistry,
) -> DataPlane:
return DataPlane(
settings=settings, model_registry=model_registry, response_cache=None
)
return DataPlane(settings=settings, model_registry=model_registry)


@pytest.fixture
Expand Down

0 comments on commit 43dfc1c

Please sign in to comment.