Skip to content

Commit

Permalink
Local response cache implementation (#1440)
Browse files Browse the repository at this point in the history
* Basic cache implementation

Signed-off-by: Sachin Varghese <[email protected]>

* Setting max size to local cache implementation

Signed-off-by: Sachin Varghese <[email protected]>

* Adding lint fixes

Signed-off-by: Sachin Varghese <[email protected]>

* minor lint fix

Signed-off-by: Sachin Varghese <[email protected]>

* types update

Signed-off-by: Sachin Varghese <[email protected]>

* fixing tests

Signed-off-by: Sachin Varghese <[email protected]>

* optional typings updated

Signed-off-by: Sachin Varghese <[email protected]>

* cache impl refactor

Signed-off-by: Sachin Varghese <[email protected]>

* minor code refactor

Signed-off-by: Sachin Varghese <[email protected]>

* Adding basic handler cache tests

Signed-off-by: Sachin Varghese <[email protected]>

* Update cache tests

Signed-off-by: Sachin Varghese <[email protected]>

---------

Signed-off-by: Sachin Varghese <[email protected]>
  • Loading branch information
SachinVarghese authored Oct 26, 2023
1 parent dbfa408 commit a94068b
Show file tree
Hide file tree
Showing 13 changed files with 215 additions and 4 deletions.
4 changes: 4 additions & 0 deletions mlserver/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .cache import ResponseCache
from .local import LocalCache

__all__ = ["ResponseCache", "LocalCache"]
27 changes: 27 additions & 0 deletions mlserver/cache/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
class ResponseCache:
async def insert(self, key: str, value: str):
"""
Method responsible for inserting value to cache.
**This method should be overriden to implement your custom cache logic.**
"""
raise NotImplementedError("insert() method not implemented")

async def lookup(self, key: str) -> str:
"""
Method responsible for returning key value in the cache.
**This method should be overriden to implement your custom cache logic.**
"""
raise NotImplementedError("lookup() method not implemented")

async def size(self) -> int:
"""
Method responsible for returning the size of the cache.
**This method should be overriden to implement your custom cache logic.**
"""
raise NotImplementedError("size() method not implemented")
3 changes: 3 additions & 0 deletions mlserver/cache/local/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .local import LocalCache

__all__ = ["LocalCache"]
25 changes: 25 additions & 0 deletions mlserver/cache/local/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from collections import OrderedDict
from ..cache import ResponseCache


class LocalCache(ResponseCache):
def __init__(self, size=100):
self.cache = OrderedDict()
self.size_limit = size

async def insert(self, key: str, value: str):
self.cache[key] = value
cache_size = await self.size()
if cache_size > self.size_limit:
# The cache removes the first entry if it overflows (i.e. in FIFO order)
self.cache.popitem(last=False)
return None

async def lookup(self, key: str) -> str:
if key in self.cache:
return self.cache[key]
else:
return ""

async def size(self) -> int:
return len(self.cache)
28 changes: 26 additions & 2 deletions mlserver/handlers/dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..middleware import InferenceMiddlewares
from ..cloudevents import CloudEventsMiddleware
from ..utils import generate_uuid
from ..cache import ResponseCache, LocalCache


class DataPlane:
Expand All @@ -28,7 +29,9 @@ class DataPlane:
def __init__(self, settings: Settings, model_registry: MultiModelRegistry):
self._settings = settings
self._model_registry = model_registry

self._response_cache = None
if settings.cache_enabled:
self._response_cache = self._create_response_cache()
self._inference_middleware = InferenceMiddlewares(
CloudEventsMiddleware(settings)
)
Expand Down Expand Up @@ -89,6 +92,9 @@ async def infer(
model=name, version=version
).count_exceptions()

if self._response_cache is not None:
cache_key = payload.json()

with infer_duration, infer_errors:
if payload.id is None:
payload.id = generate_uuid()
Expand All @@ -101,7 +107,19 @@ async def infer(

# TODO: Make await optional for sync methods
with model_context(model.settings):
prediction = await model.predict(payload)
if (
self._response_cache is not None
and model.settings.cache_enabled is not False
):
cache_value = await self._response_cache.lookup(cache_key)
if cache_value != "":
prediction = InferenceResponse.parse_raw(cache_value)
else:
prediction = await model.predict(payload)
# ignore cache insertion error if any
await self._response_cache.insert(cache_key, prediction.json())
else:
prediction = await model.predict(payload)

# Ensure ID matches
prediction.id = payload.id
Expand All @@ -111,3 +129,9 @@ async def infer(
self._ModelInferRequestSuccess.labels(model=name, version=version).inc()

return prediction

def _create_response_cache(self) -> ResponseCache:
return LocalCache(size=self._settings.cache_size)

def _get_response_cache(self) -> Optional[ResponseCache]:
return self._response_cache
11 changes: 11 additions & 0 deletions mlserver/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ class Config:
_custom_metrics_server_settings: Optional[dict] = None
_custom_grpc_server_settings: Optional[dict] = None

cache_enabled: bool = False
"""Enable caching for the model predictions."""

cache_size: int = 100
"""Cache size to be used if caching is enabled."""


class ModelParameters(BaseSettings):
"""
Expand Down Expand Up @@ -393,3 +399,8 @@ def version(self) -> Optional[str]:
# However, it's also possible to override them manually.
parameters: Optional[ModelParameters] = None
"""Extra parameters for each instance of this model."""

cache_enabled: bool = False
"""Enable caching for a specific model. This parameter can be used to disable
cache for a specific model, if the server level caching is enabled. If the
server level caching is disabled, this parameter value will have no effect."""
Empty file added tests/cache/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions tests/cache/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest

from mlserver.cache.local import LocalCache
from mlserver.cache import ResponseCache

CACHE_SIZE = 10


@pytest.fixture
def local_cache() -> ResponseCache:
return LocalCache(size=CACHE_SIZE)
39 changes: 39 additions & 0 deletions tests/cache/test_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from string import ascii_lowercase

from .conftest import CACHE_SIZE


async def test_local_cache_lookup(local_cache):
assert await local_cache.size() == 0
assert await local_cache.lookup("unknown key") == ""
assert await local_cache.size() == 0


async def test_local_cache_insert(local_cache):
assert await local_cache.size() == 0

await local_cache.insert("key", "value")
assert await local_cache.lookup("key") == "value"

assert await local_cache.size() == 1

await local_cache.insert("new key", "new value")
assert await local_cache.lookup("key") == "value"
assert await local_cache.lookup("new key") == "new value"

assert await local_cache.size() == 2


async def test_local_cache_rotate(local_cache):
# Insert alphabets on a loop
for key, symbol in enumerate(ascii_lowercase):
await local_cache.insert(str(key), symbol)

if key < CACHE_SIZE:
assert await local_cache.size() == key + 1
assert await local_cache.lookup(str(key)) == symbol

else:
assert await local_cache.size() == CACHE_SIZE
assert await local_cache.lookup(str(key)) == symbol
assert await local_cache.lookup(str(key - CACHE_SIZE)) == ""
21 changes: 21 additions & 0 deletions tests/handlers/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,31 @@
import pytest
import os

from mlserver.settings import Settings
from mlserver.handlers import DataPlane
from mlserver.handlers.custom import CustomHandler
from mlserver.registry import MultiModelRegistry
from prometheus_client.registry import CollectorRegistry

from ..fixtures import SumModel
from ..conftest import TESTDATA_PATH


@pytest.fixture
def custom_handler(sum_model: SumModel) -> CustomHandler:
return CustomHandler(rest_path="/my-custom-endpoint")


@pytest.fixture
def cached_settings() -> Settings:
settings_path = os.path.join(TESTDATA_PATH, "settings-cache.json")
return Settings.parse_file(settings_path)


@pytest.fixture
def cached_data_plane(
cached_settings: Settings,
model_registry: MultiModelRegistry,
prometheus_registry: CollectorRegistry,
) -> DataPlane:
return DataPlane(settings=cached_settings, model_registry=model_registry)
37 changes: 36 additions & 1 deletion tests/handlers/test_dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from mlserver.errors import ModelNotReady
from mlserver.settings import ModelSettings, ModelParameters
from mlserver.types import MetadataTensor
from mlserver.types import MetadataTensor, InferenceResponse

from ..fixtures import SumModel

Expand Down Expand Up @@ -114,3 +114,38 @@ async def test_infer_generates_uuid(data_plane, sum_model, inference_request):

assert prediction.id is not None
assert prediction.id == str(uuid.UUID(prediction.id))


async def test_infer_response_cache(cached_data_plane, sum_model, inference_request):
cache_key = inference_request.json()
payload = inference_request.copy(deep=True)
prediction = await cached_data_plane.infer(
payload=payload, name=sum_model.name, version=sum_model.version
)

response_cache = cached_data_plane._get_response_cache()
assert response_cache is not None
assert await response_cache.size() == 1

cache_value = await response_cache.lookup(cache_key)
cached_response = InferenceResponse.parse_raw(cache_value)
assert cached_response.model_name == prediction.model_name
assert cached_response.model_version == prediction.model_version
assert cached_response.Config == prediction.Config
assert cached_response.outputs == prediction.outputs

prediction = await cached_data_plane.infer(
payload=inference_request, name=sum_model.name, version=sum_model.version
)

# Using existing cache value
assert await response_cache.size() == 1
assert cached_response.model_name == prediction.model_name
assert cached_response.model_version == prediction.model_version
assert cached_response.Config == prediction.Config
assert cached_response.outputs == prediction.outputs


async def test_response_cache_disabled(data_plane):
response_cache = data_plane._get_response_cache()
assert response_cache is None
4 changes: 3 additions & 1 deletion tests/testdata/model-settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@

"parameters": {
"version": "v1.2.3"
}
},

"cache_enabled": true
}
9 changes: 9 additions & 0 deletions tests/testdata/settings-cache.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"debug": true,
"host": "127.0.0.1",
"parallel_workers": 2,
"cors_settings": {
"allow_origins": ["*"]
},
"cache_enabled": true
}

0 comments on commit a94068b

Please sign in to comment.