Skip to content

Commit

Permalink
Update cache tests
Browse files Browse the repository at this point in the history
Signed-off-by: Sachin Varghese <[email protected]>
  • Loading branch information
SachinVarghese committed Oct 22, 2023
1 parent ffc5ae1 commit 31caf16
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 13 deletions.
3 changes: 2 additions & 1 deletion mlserver/cache/local/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ def __init__(self, size=100):

async def insert(self, key: str, value: str):
self.cache[key] = value
if self.size() > self.size_limit:
cache_size = await self.size()
if cache_size > self.size_limit:
# The cache removes the first entry if it overflows (i.e. in FIFO order)
self.cache.popitem(last=False)
return None
Expand Down
Empty file added tests/cache/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions tests/cache/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest

from mlserver.cache.local import LocalCache
from mlserver.cache import ResponseCache

CACHE_SIZE = 10


@pytest.fixture
def local_cache() -> ResponseCache:
return LocalCache(size=CACHE_SIZE)
39 changes: 39 additions & 0 deletions tests/cache/test_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from string import ascii_lowercase

from .conftest import CACHE_SIZE


async def test_local_cache_lookup(local_cache):
assert await local_cache.size() == 0
assert await local_cache.lookup("unknown key") == ""
assert await local_cache.size() == 0


async def test_local_cache_insert(local_cache):
assert await local_cache.size() == 0

await local_cache.insert("key", "value")
assert await local_cache.lookup("key") == "value"

assert await local_cache.size() == 1

await local_cache.insert("new key", "new value")
assert await local_cache.lookup("key") == "value"
assert await local_cache.lookup("new key") == "new value"

assert await local_cache.size() == 2


async def test_local_cache_rotate(local_cache):
# Insert alphabets on a loop
for key, symbol in enumerate(ascii_lowercase):
await local_cache.insert(str(key), symbol)

if key < CACHE_SIZE:
assert await local_cache.size() == key + 1
assert await local_cache.lookup(str(key)) == symbol

else:
assert await local_cache.size() == CACHE_SIZE
assert await local_cache.lookup(str(key)) == symbol
assert await local_cache.lookup(str(key - CACHE_SIZE)) == ""
21 changes: 21 additions & 0 deletions tests/handlers/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,31 @@
import pytest
import os

from mlserver.settings import Settings
from mlserver.handlers import DataPlane
from mlserver.handlers.custom import CustomHandler
from mlserver.registry import MultiModelRegistry
from prometheus_client.registry import CollectorRegistry

from ..fixtures import SumModel
from ..conftest import TESTDATA_PATH


@pytest.fixture
def custom_handler(sum_model: SumModel) -> CustomHandler:
return CustomHandler(rest_path="/my-custom-endpoint")


@pytest.fixture
def cached_settings() -> Settings:
settings_path = os.path.join(TESTDATA_PATH, "settings-cache.json")
return Settings.parse_file(settings_path)


@pytest.fixture
def cached_data_plane(
cached_settings: Settings,
model_registry: MultiModelRegistry,
prometheus_registry: CollectorRegistry,
) -> DataPlane:
return DataPlane(settings=cached_settings, model_registry=model_registry)
40 changes: 30 additions & 10 deletions tests/handlers/test_dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,36 @@ async def test_infer_generates_uuid(data_plane, sum_model, inference_request):
assert prediction.id == str(uuid.UUID(prediction.id))


async def test_infer_response_cache(data_plane, sum_model, inference_request):
prediction = await data_plane.infer(
payload=inference_request, name=sum_model.name, version=sum_model.version
async def test_infer_response_cache(cached_data_plane, sum_model, inference_request):
cache_key = inference_request.json()
payload = inference_request.copy(deep=True)
prediction = await cached_data_plane.infer(
payload=payload, name=sum_model.name, version=sum_model.version
)
response_cache = data_plane._get_response_cache()

response_cache = cached_data_plane._get_response_cache()
assert response_cache is not None
assert response_cache.size() == 1
assert len(
InferenceResponse.parse_raw(
response_cache.lookup(inference_request.json())
).outputs
) == len(prediction.outputs)
assert await response_cache.size() == 1

cache_value = await response_cache.lookup(cache_key)
cached_response = InferenceResponse.parse_raw(cache_value)
assert cached_response.model_name == prediction.model_name
assert cached_response.model_version == prediction.model_version
assert cached_response.Config == prediction.Config
assert cached_response.outputs == prediction.outputs

prediction = await cached_data_plane.infer(
payload=inference_request, name=sum_model.name, version=sum_model.version
)

# Using existing cache value
assert await response_cache.size() == 1
assert cached_response.model_name == prediction.model_name
assert cached_response.model_version == prediction.model_version
assert cached_response.Config == prediction.Config
assert cached_response.outputs == prediction.outputs


async def test_response_cache_disabled(data_plane):
response_cache = data_plane._get_response_cache()
assert response_cache is None
9 changes: 9 additions & 0 deletions tests/testdata/settings-cache.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"debug": true,
"host": "127.0.0.1",
"parallel_workers": 2,
"cors_settings": {
"allow_origins": ["*"]
},
"cache_enabled": true
}
3 changes: 1 addition & 2 deletions tests/testdata/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@
"parallel_workers": 2,
"cors_settings": {
"allow_origins": ["*"]
},
"cache_enabled": true
}
}

0 comments on commit 31caf16

Please sign in to comment.