From 57c56f170dff7ed5e4a9a82af12f46f2681d8cc4 Mon Sep 17 00:00:00 2001 From: Isabella Basso do Amaral Date: Fri, 5 Jul 2024 12:47:53 -0300 Subject: [PATCH] py: expose listing methods Signed-off-by: Isabella Basso do Amaral --- clients/python/README.md | 25 +++ clients/python/docs/reference.md | 42 +++- clients/python/src/model_registry/_client.py | 44 ++++- clients/python/src/model_registry/core.py | 48 ++++- .../src/model_registry/types/__init__.py | 3 + .../src/model_registry/types/options.py | 4 + .../python/src/model_registry/types/pager.py | 179 ++++++++++++++++++ clients/python/tests/test_client.py | 113 +++++++++++ clients/python/tests/test_core.py | 66 ++++++- 9 files changed, 508 insertions(+), 16 deletions(-) create mode 100644 clients/python/src/model_registry/types/pager.py diff --git a/clients/python/README.md b/clients/python/README.md index dff15f273..2b1e8f823 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -98,6 +98,31 @@ There are caveats to be noted when using this method: ) ``` +### Listing models + +To list models you can use +```py +for model in registry.get_registered_models(): + ... + +# and versions associated with a model +for version in registry.get_model_versions("my-model"): + ... +``` + +To customize sorting order or query limits you can also use + +```py +latest_updates = registry.get_model_versions("my-model").order_by_update_time().descending().limit(20) +for version in latest_updates: + ... +``` + +You can use `order_by_creation_time`, `order_by_update_time`, or `order_by_id` to change the sorting order. + +> Note that the `limit()` method only limits the query size, not the actual loop boundaries -- even if your limit is 1 +> you will still get all the models, with one query each. + ## Development Common tasks, such as building documentation and running tests, can be executed using [`nox`](https://github.com/wntrblm/nox) sessions. diff --git a/clients/python/docs/reference.md b/clients/python/docs/reference.md index 631c87e1f..8ef28ab9f 100644 --- a/clients/python/docs/reference.md +++ b/clients/python/docs/reference.md @@ -176,14 +176,14 @@ models = await mr_client.get_registered_models() versions = await mr_client.get_model_versions("registered_model_id") -# We can get a list of all model artifacts +# We can get a list of the first 20 model artifacts all_model_artifacts = await mr_client.get_model_artifacts() ``` -To limit or order the query, provide a {py:class}`model_registry.types.ListOptions` object. +To limit or sort the query by another parameter, provide a {py:class}`model_registry.types.ListOptions` object. ```py -from model_registry import ListOptions +from model_registry.types import ListOptions options = ListOptions(limit=50) @@ -195,6 +195,42 @@ options = ListOptions.order_by_creation_time(is_asc=False) last_50_models = await mr_client.get_registered_models(options) ``` +You can also use the high-level {py:class}`model_registry.types.Pager` to get an iterator. + +```py +from model_registry.types import Pager + +models = Pager(mr_client.get_registered_models) + +async for model in models: + ... +``` + +Note that the iterator currently only works with methods that take a `ListOptions` argument, so if you want to use a +method that needs additional arguments, you'll need to provide a partial application like in the example below. + +```py +model_version_artifacts = Pager(lambda o: mr_client.get_model_version_artifacts(mv.id, o)) +``` + +> ⚠️ Also note that a [`partial`](https://docs.python.org/3/library/functools.html#functools.partial) definition won't work as the `options` argument is optional, and thus has to be overriden as a positional argument. + +The iterator provides methods for setting up the {py:class}`model_registry.types.ListOptions` that will be used in each +call. + +```py +reverse_model_version_artifacts = model_version_artifacts.order_by_creation_time().descending().limit(100) +``` + +You can also get each page separately and iterate yourself: + +```py +page = await reverse_model_version_artifacts.next_page() +``` + +> Note: the iterator will be automagically sync or async depending on the paging function passed in for initialization. + + ```{eval-rst} .. automodule:: model_registry.core ``` diff --git a/clients/python/src/model_registry/_client.py b/clients/python/src/model_registry/_client.py index 2d263e520..25393d0d7 100644 --- a/clients/python/src/model_registry/_client.py +++ b/clients/python/src/model_registry/_client.py @@ -9,7 +9,14 @@ from .core import ModelRegistryAPIClient from .exceptions import StoreError -from .types import ModelArtifact, ModelVersion, RegisteredModel, SupportedTypes +from .types import ( + ListOptions, + ModelArtifact, + ModelVersion, + Pager, + RegisteredModel, + SupportedTypes, +) class ModelRegistry: @@ -327,3 +334,38 @@ def get_model_artifact(self, name: str, version: str) -> ModelArtifact | None: raise StoreError(msg) assert mv.id return self.async_runner(self._api.get_model_artifact_by_params(name, mv.id)) + + def get_registered_models(self) -> Pager[RegisteredModel]: + """Get a pager for registered models. + + Returns: + Iterable pager for registered models. + """ + + def rm_list(options: ListOptions) -> list[RegisteredModel]: + return self.async_runner(self._api.get_registered_models(options)) + + return Pager[RegisteredModel](rm_list) + + def get_model_versions(self, name: str) -> Pager[ModelVersion]: + """Get a pager for model versions. + + Args: + name: Name of the model. + + Returns: + Iterable pager for model versions. + + Raises: + StoreException: If the model does not exist. + """ + if not (rm := self.get_registered_model(name)): + msg = f"Model {name} does not exist" + raise StoreError(msg) + + def rm_versions(options: ListOptions) -> list[ModelVersion]: + # type checkers can't restrict the type inside a nested function: https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions + assert rm.id + return self.async_runner(self._api.get_model_versions(rm.id, options)) + + return Pager[ModelVersion](rm_versions) diff --git a/clients/python/src/model_registry/core.py b/clients/python/src/model_registry/core.py index d31606a91..4586668fa 100644 --- a/clients/python/src/model_registry/core.py +++ b/clients/python/src/model_registry/core.py @@ -172,11 +172,14 @@ async def get_registered_models( Registered models. """ async with self.get_client() as client: - rms = await client.get_registered_models( + rm_list = await client.get_registered_models( **(options or ListOptions()).as_options() ) - return [RegisteredModel.from_basemodel(rm) for rm in rms.items or []] + if options: + options.next_page_token = rm_list.next_page_token + + return [RegisteredModel.from_basemodel(rm) for rm in rm_list.items or []] async def upsert_model_version( self, model_version: ModelVersion, registered_model_id: str @@ -236,11 +239,14 @@ async def get_model_versions( Model versions. """ async with self.get_client() as client: - mvs = await client.get_registered_model_versions( + mv_list = await client.get_registered_model_versions( registered_model_id, **(options or ListOptions()).as_options() ) - return [ModelVersion.from_basemodel(mv) for mv in mvs.items or []] + if options: + options.next_page_token = mv_list.next_page_token + + return [ModelVersion.from_basemodel(mv) for mv in mv_list.items or []] @overload async def get_model_version_by_params( @@ -415,17 +421,43 @@ async def get_model_artifacts( """ async with self.get_client() as client: if model_version_id: - arts = await client.get_model_version_artifacts( + art_list = await client.get_model_version_artifacts( model_version_id, **(options or ListOptions()).as_options() ) + if options: + options.next_page_token = art_list.next_page_token models = [] - for art in arts.items or []: + for art in art_list.items or []: converted = Artifact.validate_artifact(art) if isinstance(converted, ModelArtifact): models.append(converted) return models - mas = await client.get_model_artifacts( + ma_list = await client.get_model_artifacts( **(options or ListOptions()).as_options() ) - return [ModelArtifact.from_basemodel(ma) for ma in mas.items or []] + if options: + options.next_page_token = ma_list.next_page_token + return [ModelArtifact.from_basemodel(ma) for ma in ma_list.items or []] + + async def get_model_version_artifacts( + self, + model_version_id: str, + options: ListOptions | None = None, + ) -> list[Artifact]: + """Fetches model artifacts. + + Args: + model_version_id: ID of the associated model version. + options: Options for listing model artifacts. + + Returns: + Model artifacts. + """ + async with self.get_client() as client: + art_list = await client.get_model_version_artifacts( + model_version_id, **(options or ListOptions()).as_options() + ) + if options: + options.next_page_token = art_list.next_page_token + return [Artifact.validate_artifact(art) for art in art_list.items or []] diff --git a/clients/python/src/model_registry/types/__init__.py b/clients/python/src/model_registry/types/__init__.py index 7d0291971..59329a563 100644 --- a/clients/python/src/model_registry/types/__init__.py +++ b/clients/python/src/model_registry/types/__init__.py @@ -12,6 +12,7 @@ RegisteredModelState, ) from .options import ListOptions +from .pager import Pager __all__ = [ # Artifacts @@ -27,4 +28,6 @@ "SupportedTypes", # Options "ListOptions", + # Pager + "Pager", ] diff --git a/clients/python/src/model_registry/types/options.py b/clients/python/src/model_registry/types/options.py index 3be74ef48..13f8b88a5 100644 --- a/clients/python/src/model_registry/types/options.py +++ b/clients/python/src/model_registry/types/options.py @@ -19,11 +19,13 @@ class ListOptions: limit: Maximum number of objects to return. order_by: Field to order by. is_asc: Whether to order in ascending order. Defaults to True. + next_page_token: Token to use to retrieve next page of results. """ limit: int | None = None order_by: OrderByField | None = None is_asc: bool = True + next_page_token: str | None = None @classmethod def order_by_creation_time(cls, **kwargs) -> ListOptions: @@ -49,4 +51,6 @@ def as_options(self) -> dict[str, Any]: options["order_by"] = self.order_by if self.is_asc is not None: options["sort_order"] = SortOrder.ASC if self.is_asc else SortOrder.DESC + if self.next_page_token is not None: + options["next_page_token"] = self.next_page_token return options diff --git a/clients/python/src/model_registry/types/pager.py b/clients/python/src/model_registry/types/pager.py new file mode 100644 index 000000000..a789f248d --- /dev/null +++ b/clients/python/src/model_registry/types/pager.py @@ -0,0 +1,179 @@ +"""Pager for iterating over items.""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator, Awaitable, Iterator +from dataclasses import dataclass, field +from typing import Callable, Generic, TypeVar, cast + +from .base import BaseModel +from .options import ListOptions, OrderByField + +T = TypeVar("T", bound=BaseModel) + + +@dataclass +class Pager(Generic[T], Iterator[T], AsyncIterator[T]): + """Pager for iterating over items. + + Assumes that page_fn is a paged function that takes ListOptions and returns a list of items. + """ + + page_fn: ( + Callable[[ListOptions], list[T]] | Callable[[ListOptions], Awaitable[list[T]]] + ) + options: ListOptions = field(default_factory=ListOptions) + + def __post_init__(self): + self.restart() + if asyncio.iscoroutinefunction(self.page_fn): + self.__next__ = NotImplemented + self.next_page = self._anext_page + self.next_item = self._anext_item + else: + self.__anext__ = NotImplemented + self.next_page = self._next_page + self.next_item = self._next_item + + def restart(self) -> Pager[T]: + """Reset the pager. + + This keeps the current options and page function, but resets the internal state. + """ + # as MLMD loops over pages, we need to keep track of the first page or we'll loop forever + self._start = None + self._current_page = None + # tracks the next item on the current page + self._i = 0 + self.options.next_page_token = None + return self + + def order_by_creation_time(self) -> Pager[T]: + """Order items by creation time. + + This resets the pager. + """ + self.options.order_by = OrderByField.CREATE_TIME + return self.restart() + + def order_by_update_time(self) -> Pager[T]: + """Order items by update time. + + This resets the pager. + """ + self.options.order_by = OrderByField.LAST_UPDATE_TIME + return self.restart() + + def order_by_id(self) -> Pager[T]: + """Order items by ID. + + This resets the pager. + """ + self.options.order_by = OrderByField.ID + return self.restart() + + def limit(self, limit: int) -> Pager[T]: + """Limit the number of items to return. + + This resets the pager. + """ + self.options.limit = limit + return self.restart() + + def ascending(self) -> Pager[T]: + """Order items in ascending order. + + This resets the pager. + """ + self.options.is_asc = True + return self.restart() + + def descending(self) -> Pager[T]: + """Order items in descending order. + + This resets the pager. + """ + self.options.is_asc = False + return self.restart() + + def _next_page(self) -> list[T]: + """Get the next page of items. + + This will automatically loop over pages. + """ + return cast(list[T], self.page_fn(self.options)) + + async def _anext_page(self) -> list[T]: + """Get the next page of items. + + This will automatically loop over pages. + """ + return await cast(Awaitable[list[T]], self.page_fn(self.options)) + + def _needs_fetch(self) -> bool: + return not self._current_page or self._i >= len(self._current_page) + + def _next_item(self) -> T: + """Get the next item in the pager. + + This variant won't check for looping, so it's useful for manual iteration/scripting. + + NOTE: This won't check for looping, so use with caution. + If you want to check for looping, use the pythonic `next()`. + """ + if self._needs_fetch(): + self._current_page = self._next_page() + self._i = 0 + assert self._current_page + + item = self._current_page[self._i] + self._i += 1 + return item + + async def _anext_item(self) -> T: + """Get the next item in the pager. + + This variant won't check for looping, so it's useful for manual iteration/scripting. + + NOTE: This won't check for looping, so use with caution. + If you want to check for looping, use the pythonic `next()`. + """ + if self._needs_fetch(): + self._current_page = await self._anext_page() + self._i = 0 + assert self._current_page + + item = self._current_page[self._i] + self._i += 1 + return item + + def __next__(self) -> T: + check_looping = self._needs_fetch() + + item = self._next_item() + + if not self._start: + self._start = self.options.next_page_token + elif check_looping and self.options.next_page_token == self._start: + raise StopIteration + + return item + + async def __anext__(self) -> T: + check_looping = self._needs_fetch() + + item = await self._anext_item() + + if not self._start: + self._start = self.options.next_page_token + elif check_looping and self.options.next_page_token == self._start: + raise StopAsyncIteration + + return item + + def __iter__(self) -> Iterator[T]: + return self + + def __aiter__(self) -> AsyncIterator[T]: + return self diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 698ac68a4..aa8ab39fb 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -1,4 +1,5 @@ import os +from itertools import islice import pytest from model_registry import ModelRegistry, utils @@ -94,6 +95,7 @@ async def test_get(client: ModelRegistry): metadata=metadata, ) + assert rm.id assert (_rm := client.get_registered_model(name)) assert rm.id == _rm.id @@ -109,6 +111,117 @@ async def test_get(client: ModelRegistry): assert ma.id == _ma.id +def test_get_registered_models(client: ModelRegistry): + models = 21 + + for name in [f"test_model{i}" for i in range(models)]: + client.register_model( + name, + "s3", + model_format_name="test_format", + model_format_version="test_version", + version="1.0.0", + ) + + rm_iter = client.get_registered_models().limit(10) + i = 0 + prev_tok = None + changes = 0 + with pytest.raises(StopIteration): # noqa: PT012 + while i < 50 and next(rm_iter): + if rm_iter.options.next_page_token != prev_tok: + print( + f"Token changed from {prev_tok} to {rm_iter.options.next_page_token} at {i}" + ) + prev_tok = rm_iter.options.next_page_token + changes += 1 + i += 1 + + assert changes == 3 + assert i == models + + +def test_get_registered_models_and_reset(client: ModelRegistry): + model_count = 6 + page = model_count // 2 + + for name in [f"test_model{i}" for i in range(model_count)]: + client.register_model( + name, + "s3", + model_format_name="test_format", + model_format_version="test_version", + version="1.0.0", + ) + + rm_iter = client.get_registered_models().limit(model_count - 1) + models = [] + for rm in islice(rm_iter, page): + models.append(rm) + assert len(models) == page + rm_iter.restart() + complete = list(rm_iter) + assert len(complete) == model_count + assert complete[:page] == models + + +def test_get_model_versions(client: ModelRegistry): + name = "test_model" + models = 21 + + for v in [f"1.0.{i}" for i in range(models)]: + client.register_model( + name, + "s3", + model_format_name="test_format", + model_format_version="test_version", + version=v, + ) + + mv_iter = client.get_model_versions(name).limit(10) + i = 0 + prev_tok = None + changes = 0 + with pytest.raises(StopIteration): # noqa: PT012 + while i < 50 and next(mv_iter): + if mv_iter.options.next_page_token != prev_tok: + print( + f"Token changed from {prev_tok} to {mv_iter.options.next_page_token} at {i}" + ) + prev_tok = mv_iter.options.next_page_token + changes += 1 + i += 1 + + assert changes == 3 + assert i == models + + +def test_get_model_versions_and_reset(client: ModelRegistry): + name = "test_model" + + model_count = 6 + page = model_count // 2 + + for v in [f"1.0.{i}" for i in range(model_count)]: + client.register_model( + name, + "s3", + model_format_name="test_format", + model_format_version="test_version", + version=v, + ) + + mv_iter = client.get_model_versions(name).limit(model_count - 1) + models = [] + for rm in islice(mv_iter, page): + models.append(rm) + assert len(models) == page + mv_iter.restart() + complete = list(mv_iter) + assert len(complete) == model_count + assert complete[:page] == models + + def test_hf_import(client: ModelRegistry): pytest.importorskip("huggingface_hub") name = "openai-community/gpt2" diff --git a/clients/python/tests/test_core.py b/clients/python/tests/test_core.py index 9ea6071c6..274638383 100644 --- a/clients/python/tests/test_core.py +++ b/clients/python/tests/test_core.py @@ -2,7 +2,13 @@ import pytest from model_registry.core import ModelRegistryAPIClient -from model_registry.types import ModelArtifact, ModelVersion, RegisteredModel +from model_registry.types import ( + DocArtifact, + ModelArtifact, + ModelVersion, + Pager, + RegisteredModel, +) from .conftest import REGISTRY_HOST, REGISTRY_PORT, cleanup @@ -81,6 +87,17 @@ async def test_get_registered_models( assert [registered_model, rm2] == rms +async def test_page_through_registered_models(client: ModelRegistryAPIClient): + models = 6 + for i in range(models): + await client.upsert_registered_model(RegisteredModel(name=f"rm{i}")) + pager = Pager(client.get_registered_models).limit(5) + total = 0 + async for _ in pager: + total += 1 + assert total == models + + async def test_insert_model_version( client: ModelRegistryAPIClient, registered_model: RegisteredModel, @@ -144,7 +161,7 @@ async def test_get_model_version_by_external_id( ): assert ( mv := await client.get_model_version_by_params( - external_id=model_version.external_id + external_id=str(model_version.external_id) ) ) assert mv == model_version @@ -163,6 +180,23 @@ async def test_get_model_versions( assert [model_version, mv2] == mvs +async def test_page_through_model_versions( + client: ModelRegistryAPIClient, registered_model: RegisteredModel +): + models = 6 + for i in range(models): + await client.upsert_model_version( + ModelVersion(name=f"mv{i}"), str(registered_model.id) + ) + pager = Pager( + lambda o: client.get_model_versions(str(registered_model.id), o) + ).limit(5) + total = 0 + async for _ in pager: + total += 1 + assert total == models + + async def test_insert_model_artifact( client: ModelRegistryAPIClient, model_version: ModelVersion, @@ -229,7 +263,7 @@ async def test_get_model_artifact_by_name( ): assert ( ma := await client.get_model_artifact_by_params( - name=model.name, model_version_id=str(model_version.id) + name=str(model.name), model_version_id=str(model_version.id) ) ) assert ma == model @@ -239,7 +273,9 @@ async def test_get_model_artifact_by_external_id( client: ModelRegistryAPIClient, model: ModelArtifact ): assert ( - ma := await client.get_model_artifact_by_params(external_id=model.external_id) + ma := await client.get_model_artifact_by_params( + external_id=str(model.external_id) + ) ) assert ma == model @@ -264,3 +300,25 @@ async def test_get_model_artifacts_by_mv_id( mas = await client.get_model_artifacts(str(model_version.id)) assert [model, ma2] == mas + + +async def test_page_through_model_version_artifacts( + client: ModelRegistryAPIClient, + registered_model: RegisteredModel, + model_version: ModelVersion, +): + _ = registered_model + models = 6 + for i in range(models): + if i % 2 == 0: + art = ModelArtifact(name=f"ma{i}", uri="uri") + else: + art = DocArtifact(name=f"ma{i}", uri="uri") + await client.create_model_version_artifact(art, str(model_version.id)) + pager = Pager( + lambda o: client.get_model_version_artifacts(str(model_version.id), o) + ).limit(5) + total = 0 + async for _ in pager: + total += 1 + assert total == models