diff --git a/clients/python/README.md b/clients/python/README.md index 0c174fc02..c1deb6720 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -146,18 +146,23 @@ for version in registry.get_model_versions("my-model"): ... ``` -To customize sorting order or query limits you can also use +You can also use `order_by_creation_time`, `order_by_update_time`, or `order_by_id` to change the sorting order ```py -latest_updates = registry.get_model_versions("my-model").order_by_update_time().descending().limit(20) +latest_updates = registry.get_model_versions("my-model").order_by_update_time().descending() for version in latest_updates: ... ``` -You can use `order_by_creation_time`, `order_by_update_time`, or `order_by_id` to change the sorting order. +By default, all queries will be `ascending`, but this method is also available for explicitness. -> Note that the `limit()` method only limits the query size, not the actual loop boundaries -- even if your limit is 1 -> you will still get all the models, with one query each. +> Note: You can also set the `page_size()` that you want the Pager to use when invoking the Model Registry backend. +> When using it as an iterator, it will automatically manage pages for you. + +#### Implementation notes + +The pager will manage pages for you in order to prevent infinite looping. +Currently, the Model Registry backend treats model lists as a circular buffer, and **will not end iteration** for you. ## Development diff --git a/clients/python/src/model_registry/_client.py b/clients/python/src/model_registry/_client.py index 0cf2a7a74..5484a7953 100644 --- a/clients/python/src/model_registry/_client.py +++ b/clients/python/src/model_registry/_client.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from collections.abc import Mapping from pathlib import Path from typing import Any, TypeVar, Union, get_args from warnings import warn @@ -138,7 +139,7 @@ def register_model( author: str | None = None, owner: str | None = None, description: str | None = None, - metadata: dict[str, SupportedTypes] | None = None, + metadata: Mapping[str, SupportedTypes] | None = None, ) -> RegisteredModel: """Register a model. diff --git a/clients/python/src/model_registry/types/base.py b/clients/python/src/model_registry/types/base.py index bf1b8dd9b..df1166e87 100644 --- a/clients/python/src/model_registry/types/base.py +++ b/clients/python/src/model_registry/types/base.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import Any, Union, get_args from pydantic import BaseModel, ConfigDict @@ -35,7 +35,7 @@ class BaseResourceModel(BaseModel, ABC): external_id: str | None = None create_time_since_epoch: str | None = None last_update_time_since_epoch: str | None = None - custom_properties: dict[str, SupportedTypes] | None = None + custom_properties: Mapping[str, SupportedTypes] | None = None @abstractmethod def create(self, **kwargs) -> Any: diff --git a/clients/python/src/model_registry/types/pager.py b/clients/python/src/model_registry/types/pager.py index 11dae8ab5..a5e2e5e9b 100644 --- a/clients/python/src/model_registry/types/pager.py +++ b/clients/python/src/model_registry/types/pager.py @@ -73,12 +73,15 @@ def order_by_id(self) -> Pager[T]: self.options.order_by = OrderByField.ID return self.restart() - def limit(self, limit: int) -> Pager[T]: - """Limit the number of items to return. + def page_size(self, n: int) -> Pager[T]: + """Set the page size for each request. This resets the pager. """ - self.options.limit = limit + if n < 1: + msg = f"Page size must be at least 1, got {n}" + raise ValueError(msg) + self.options.limit = n return self.restart() def ascending(self) -> Pager[T]: diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 002126da5..bb958cf50 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -72,10 +72,10 @@ def test_register_existing_version(client: ModelRegistry): "model_format_version": "test_version", "version": "1.0.0", } - client.register_model(**params) + client.register_model(**params, metadata=None) with pytest.raises(StoreError): - client.register_model(**params) + client.register_model(**params, metadata=None) @pytest.mark.e2e @@ -124,8 +124,10 @@ async def test_update_logical_model_with_labels(client: ModelRegistry): ) assert rm.id mv = client.get_model_version(name, version) + assert mv assert mv.id ma = client.get_model_artifact(name, version) + assert ma assert ma.id rm_labels = { @@ -149,9 +151,15 @@ async def test_update_logical_model_with_labels(client: ModelRegistry): ma.custom_properties = ma_labels client.update(ma) - assert client.get_registered_model(name).custom_properties == rm_labels - assert client.get_model_version(name, version).custom_properties == mv_labels - assert client.get_model_artifact(name, version).custom_properties == ma_labels + rm = client.get_registered_model(name) + assert rm + assert rm.custom_properties == rm_labels + mv = client.get_model_version(name, version) + assert mv + assert mv.custom_properties == mv_labels + ma = client.get_model_artifact(name, version) + assert ma + assert ma.custom_properties == ma_labels @pytest.mark.e2e @@ -232,7 +240,7 @@ def test_get_registered_models(client: ModelRegistry): version="1.0.0", ) - rm_iter = client.get_registered_models().limit(10) + rm_iter = client.get_registered_models().page_size(10) i = 0 prev_tok = None changes = 0 @@ -315,6 +323,17 @@ def test_get_registered_models_order_by(client: ModelRegistry): assert i == models + # or if descending is explicitly set + i = 0 + for rm, by_update in zip( + rms, + client.get_registered_models().order_by_update_time().descending(), + ): + assert rm.id == by_update.id + i += 1 + + assert i == models + @pytest.mark.e2e def test_get_registered_models_and_reset(client: ModelRegistry): @@ -330,7 +349,7 @@ def test_get_registered_models_and_reset(client: ModelRegistry): version="1.0.0", ) - rm_iter = client.get_registered_models().limit(model_count - 1) + rm_iter = client.get_registered_models().page_size(model_count - 1) models = [] for rm in islice(rm_iter, page): models.append(rm) @@ -355,7 +374,7 @@ def test_get_model_versions(client: ModelRegistry): version=v, ) - mv_iter = client.get_model_versions(name).limit(10) + mv_iter = client.get_model_versions(name).page_size(10) i = 0 prev_tok = None changes = 0 @@ -430,6 +449,18 @@ def test_get_model_versions_order_by(client: ModelRegistry): assert mv.id == by_update.id i += 1 + assert i == models + + i = 0 + for mv, by_update in zip( + mvs, + client.get_model_versions(name).order_by_update_time().descending(), + ): + assert mv.id == by_update.id + i += 1 + + assert i == models + @pytest.mark.e2e def test_get_model_versions_and_reset(client: ModelRegistry): @@ -447,7 +478,7 @@ def test_get_model_versions_and_reset(client: ModelRegistry): version=v, ) - mv_iter = client.get_model_versions(name).limit(model_count - 1) + mv_iter = client.get_model_versions(name).page_size(model_count - 1) models = [] for rm in islice(mv_iter, page): models.append(rm) diff --git a/clients/python/tests/test_core.py b/clients/python/tests/test_core.py index 75c52f36f..8796d8013 100644 --- a/clients/python/tests/test_core.py +++ b/clients/python/tests/test_core.py @@ -76,6 +76,7 @@ async def test_get_registered_model_by_external_id( client: ModelRegistryAPIClient, registered_model: RegisteredModel, ): + assert registered_model.external_id assert ( rm := await client.get_registered_model_by_params( external_id=registered_model.external_id @@ -99,7 +100,7 @@ async def test_page_through_registered_models(client: ModelRegistryAPIClient): models = 6 for i in range(models): await client.upsert_registered_model(RegisteredModel(name=f"rm{i}")) - pager = Pager(client.get_registered_models).limit(5) + pager = Pager(client.get_registered_models).page_size(5) total = 0 async for _ in pager: total += 1 @@ -205,7 +206,7 @@ async def test_page_through_model_versions( ) pager = Pager( lambda o: client.get_model_versions(str(registered_model.id), o) - ).limit(5) + ).page_size(5) total = 0 async for _ in pager: total += 1 @@ -227,7 +228,8 @@ async def test_insert_model_artifact( "service_account_name": "test service account", } ma = await client.upsert_model_artifact( - ModelArtifact(**props), str(model_version.id) + ModelArtifact(**props), # type: ignore + str(model_version.id), ) assert ma.id assert ma.name == "test model" @@ -340,7 +342,7 @@ async def test_page_through_model_version_artifacts( await client.create_model_version_artifact(art, str(model_version.id)) pager = Pager( lambda o: client.get_model_version_artifacts(str(model_version.id), o) - ).limit(5) + ).page_size(5) total = 0 async for _ in pager: total += 1