Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local response cache implementation #1440

Merged
merged 11 commits into from
Oct 26, 2023
4 changes: 4 additions & 0 deletions mlserver/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .cache import ResponseCache
from .local import LocalCache

__all__ = ["ResponseCache", "LocalCache"]
27 changes: 27 additions & 0 deletions mlserver/cache/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
class ResponseCache:
async def insert(self, key: str, value: str):
"""
Method responsible for inserting value to cache.


**This method should be overriden to implement your custom cache logic.**
"""
raise NotImplementedError("insert() method not implemented")

async def lookup(self, key: str) -> str:
"""
Method responsible for returning key value in the cache.


**This method should be overriden to implement your custom cache logic.**
"""
raise NotImplementedError("lookup() method not implemented")

async def size(self) -> int:
"""
Method responsible for returning the size of the cache.


**This method should be overriden to implement your custom cache logic.**
"""
raise NotImplementedError("size() method not implemented")
3 changes: 3 additions & 0 deletions mlserver/cache/local/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .local import LocalCache

__all__ = ["LocalCache"]
25 changes: 25 additions & 0 deletions mlserver/cache/local/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from collections import OrderedDict
from ..cache import ResponseCache


class LocalCache(ResponseCache):
def __init__(self, size=100):
self.cache = OrderedDict()
self.size_limit = size

async def insert(self, key: str, value: str):
self.cache[key] = value
cache_size = await self.size()
if cache_size > self.size_limit:
# The cache removes the first entry if it overflows (i.e. in FIFO order)
self.cache.popitem(last=False)
adriangonz marked this conversation as resolved.
Show resolved Hide resolved
return None

async def lookup(self, key: str) -> str:
if key in self.cache:
return self.cache[key]
else:
return ""

async def size(self) -> int:
return len(self.cache)
28 changes: 26 additions & 2 deletions mlserver/handlers/dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..middleware import InferenceMiddlewares
from ..cloudevents import CloudEventsMiddleware
from ..utils import generate_uuid
from ..cache import ResponseCache, LocalCache


class DataPlane:
Expand All @@ -28,7 +29,9 @@ class DataPlane:
def __init__(self, settings: Settings, model_registry: MultiModelRegistry):
self._settings = settings
self._model_registry = model_registry

self._response_cache = None
if settings.cache_enabled:
self._response_cache = self._create_response_cache()
self._inference_middleware = InferenceMiddlewares(
CloudEventsMiddleware(settings)
)
Expand Down Expand Up @@ -89,6 +92,9 @@ async def infer(
model=name, version=version
).count_exceptions()

if self._response_cache is not None:
cache_key = payload.json()
adriangonz marked this conversation as resolved.
Show resolved Hide resolved

with infer_duration, infer_errors:
if payload.id is None:
payload.id = generate_uuid()
Expand All @@ -101,7 +107,19 @@ async def infer(

# TODO: Make await optional for sync methods
with model_context(model.settings):
prediction = await model.predict(payload)
if (
self._response_cache is not None
and model.settings.cache_enabled is not False
):
cache_value = await self._response_cache.lookup(cache_key)
if cache_value != "":
prediction = InferenceResponse.parse_raw(cache_value)
else:
prediction = await model.predict(payload)
# ignore cache insertion error if any
await self._response_cache.insert(cache_key, prediction.json())
else:
prediction = await model.predict(payload)

# Ensure ID matches
prediction.id = payload.id
Expand All @@ -111,3 +129,9 @@ async def infer(
self._ModelInferRequestSuccess.labels(model=name, version=version).inc()

return prediction

def _create_response_cache(self) -> ResponseCache:
return LocalCache(size=self._settings.cache_size)

def _get_response_cache(self) -> Optional[ResponseCache]:
return self._response_cache
11 changes: 11 additions & 0 deletions mlserver/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ class Config:
_custom_metrics_server_settings: Optional[dict] = None
_custom_grpc_server_settings: Optional[dict] = None

cache_enabled: bool = False
"""Enable caching for the model predictions."""

cache_size: int = 100
"""Cache size to be used if caching is enabled."""


class ModelParameters(BaseSettings):
"""
Expand Down Expand Up @@ -393,3 +399,8 @@ def version(self) -> Optional[str]:
# However, it's also possible to override them manually.
parameters: Optional[ModelParameters] = None
"""Extra parameters for each instance of this model."""

cache_enabled: bool = False
"""Enable caching for a specific model. This parameter can be used to disable
cache for a specific model, if the server level caching is enabled. If the
server level caching is disabled, this parameter value will have no effect."""
Empty file added tests/cache/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions tests/cache/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest

from mlserver.cache.local import LocalCache
from mlserver.cache import ResponseCache

CACHE_SIZE = 10


@pytest.fixture
def local_cache() -> ResponseCache:
return LocalCache(size=CACHE_SIZE)
39 changes: 39 additions & 0 deletions tests/cache/test_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from string import ascii_lowercase

from .conftest import CACHE_SIZE


async def test_local_cache_lookup(local_cache):
assert await local_cache.size() == 0
assert await local_cache.lookup("unknown key") == ""
assert await local_cache.size() == 0


async def test_local_cache_insert(local_cache):
assert await local_cache.size() == 0

await local_cache.insert("key", "value")
assert await local_cache.lookup("key") == "value"

assert await local_cache.size() == 1

await local_cache.insert("new key", "new value")
assert await local_cache.lookup("key") == "value"
assert await local_cache.lookup("new key") == "new value"

assert await local_cache.size() == 2


async def test_local_cache_rotate(local_cache):
# Insert alphabets on a loop
for key, symbol in enumerate(ascii_lowercase):
await local_cache.insert(str(key), symbol)

if key < CACHE_SIZE:
assert await local_cache.size() == key + 1
assert await local_cache.lookup(str(key)) == symbol

else:
assert await local_cache.size() == CACHE_SIZE
assert await local_cache.lookup(str(key)) == symbol
assert await local_cache.lookup(str(key - CACHE_SIZE)) == ""
21 changes: 21 additions & 0 deletions tests/handlers/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,31 @@
import pytest
import os

from mlserver.settings import Settings
from mlserver.handlers import DataPlane
from mlserver.handlers.custom import CustomHandler
from mlserver.registry import MultiModelRegistry
from prometheus_client.registry import CollectorRegistry

from ..fixtures import SumModel
from ..conftest import TESTDATA_PATH


@pytest.fixture
def custom_handler(sum_model: SumModel) -> CustomHandler:
return CustomHandler(rest_path="/my-custom-endpoint")


@pytest.fixture
def cached_settings() -> Settings:
settings_path = os.path.join(TESTDATA_PATH, "settings-cache.json")
return Settings.parse_file(settings_path)


@pytest.fixture
def cached_data_plane(
cached_settings: Settings,
model_registry: MultiModelRegistry,
prometheus_registry: CollectorRegistry,
) -> DataPlane:
return DataPlane(settings=cached_settings, model_registry=model_registry)
37 changes: 36 additions & 1 deletion tests/handlers/test_dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from mlserver.errors import ModelNotReady
from mlserver.settings import ModelSettings, ModelParameters
from mlserver.types import MetadataTensor
from mlserver.types import MetadataTensor, InferenceResponse

from ..fixtures import SumModel

Expand Down Expand Up @@ -114,3 +114,38 @@ async def test_infer_generates_uuid(data_plane, sum_model, inference_request):

assert prediction.id is not None
assert prediction.id == str(uuid.UUID(prediction.id))


async def test_infer_response_cache(cached_data_plane, sum_model, inference_request):
cache_key = inference_request.json()
payload = inference_request.copy(deep=True)
prediction = await cached_data_plane.infer(
payload=payload, name=sum_model.name, version=sum_model.version
)

response_cache = cached_data_plane._get_response_cache()
assert response_cache is not None
assert await response_cache.size() == 1

cache_value = await response_cache.lookup(cache_key)
cached_response = InferenceResponse.parse_raw(cache_value)
assert cached_response.model_name == prediction.model_name
assert cached_response.model_version == prediction.model_version
assert cached_response.Config == prediction.Config
assert cached_response.outputs == prediction.outputs

prediction = await cached_data_plane.infer(
payload=inference_request, name=sum_model.name, version=sum_model.version
)

# Using existing cache value
assert await response_cache.size() == 1
assert cached_response.model_name == prediction.model_name
assert cached_response.model_version == prediction.model_version
assert cached_response.Config == prediction.Config
assert cached_response.outputs == prediction.outputs


async def test_response_cache_disabled(data_plane):
response_cache = data_plane._get_response_cache()
assert response_cache is None
4 changes: 3 additions & 1 deletion tests/testdata/model-settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@

"parameters": {
"version": "v1.2.3"
}
},

"cache_enabled": true
}
9 changes: 9 additions & 0 deletions tests/testdata/settings-cache.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"debug": true,
"host": "127.0.0.1",
"parallel_workers": 2,
"cors_settings": {
"allow_origins": ["*"]
},
"cache_enabled": true
}