diff --git a/mlserver/cache/cache.py b/mlserver/cache/cache.py index 81daf1f9b..e55fb03d3 100644 --- a/mlserver/cache/cache.py +++ b/mlserver/cache/cache.py @@ -16,3 +16,12 @@ async def lookup(self, key: str) -> str: **This method should be overriden to implement your custom cache logic.** """ raise NotImplementedError("lookup() method not implemented") + + async def size(self) -> int: + """ + Method responsible for returning the size of the cache. + + + **This method should be overriden to implement your custom cache logic.** + """ + raise NotImplementedError("size() method not implemented") diff --git a/mlserver/cache/local/local.py b/mlserver/cache/local/local.py index 021af0d5c..46a3a2a2b 100644 --- a/mlserver/cache/local/local.py +++ b/mlserver/cache/local/local.py @@ -9,7 +9,7 @@ def __init__(self, size=100): async def insert(self, key: str, value: str): self.cache[key] = value - if len(self.cache) > self.size_limit: + if self.size() > self.size_limit: # The cache removes the first entry if it overflows (i.e. in FIFO order) self.cache.popitem(last=False) return None @@ -19,3 +19,6 @@ async def lookup(self, key: str) -> str: return self.cache[key] else: return "" + + async def size(self) -> int: + return len(self.cache) diff --git a/mlserver/handlers/dataplane.py b/mlserver/handlers/dataplane.py index 250f2874a..1270c0d1f 100644 --- a/mlserver/handlers/dataplane.py +++ b/mlserver/handlers/dataplane.py @@ -29,7 +29,7 @@ class DataPlane: def __init__(self, settings: Settings, model_registry: MultiModelRegistry): self._settings = settings self._model_registry = model_registry - + self._response_cache = None if settings.cache_enabled: self._response_cache = self._create_response_cache() self._inference_middleware = InferenceMiddlewares( @@ -132,3 +132,6 @@ async def infer( def _create_response_cache(self) -> ResponseCache: return LocalCache(size=self._settings.cache_size) + + def _get_response_cache(self) -> ResponseCache: + return self._response_cache diff --git a/tests/handlers/test_dataplane.py b/tests/handlers/test_dataplane.py index 705c672e7..caa22414a 100644 --- a/tests/handlers/test_dataplane.py +++ b/tests/handlers/test_dataplane.py @@ -3,8 +3,8 @@ from mlserver.errors import ModelNotReady from mlserver.settings import ModelSettings, ModelParameters -from mlserver.types import MetadataTensor - +from mlserver.types import MetadataTensor, InferenceResponse +from mlserver.handlers import DataPlane from ..fixtures import SumModel @@ -114,3 +114,18 @@ async def test_infer_generates_uuid(data_plane, sum_model, inference_request): assert prediction.id is not None assert prediction.id == str(uuid.UUID(prediction.id)) + + +async def test_infer_response_cache(data_plane, sum_model, inference_request): + prediction = await data_plane.infer( + payload=inference_request, name=sum_model.name, version=sum_model.version + ) + response_cache = data_plane._get_response_cache() + + assert response_cache is not None + assert response_cache.size() == 1 + assert len( + InferenceResponse.parse_raw( + response_cache.lookup(inference_request.json()) + ).outputs + ) == len(prediction.outputs) diff --git a/tests/testdata/model-settings.json b/tests/testdata/model-settings.json index a6d08e810..a73bd6f07 100644 --- a/tests/testdata/model-settings.json +++ b/tests/testdata/model-settings.json @@ -28,5 +28,7 @@ "parameters": { "version": "v1.2.3" - } + }, + + "cache_enabled": true } diff --git a/tests/testdata/settings.json b/tests/testdata/settings.json index 27f7a5ab8..5879855df 100644 --- a/tests/testdata/settings.json +++ b/tests/testdata/settings.json @@ -4,5 +4,6 @@ "parallel_workers": 2, "cors_settings": { "allow_origins": ["*"] - } + }, + "cache_enabled": true }