Skip to content

Commit

Permalink
py: expose listing methods
Browse files Browse the repository at this point in the history
Signed-off-by: Isabella Basso do Amaral <[email protected]>
  • Loading branch information
isinyaaa committed Jul 18, 2024
1 parent 2c5f4e8 commit 57c56f1
Show file tree
Hide file tree
Showing 9 changed files with 508 additions and 16 deletions.
25 changes: 25 additions & 0 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,31 @@ There are caveats to be noted when using this method:
)
```

### Listing models

To list models you can use
```py
for model in registry.get_registered_models():
...

# and versions associated with a model
for version in registry.get_model_versions("my-model"):
...
```

To customize sorting order or query limits you can also use

```py
latest_updates = registry.get_model_versions("my-model").order_by_update_time().descending().limit(20)
for version in latest_updates:
...
```

You can use `order_by_creation_time`, `order_by_update_time`, or `order_by_id` to change the sorting order.

> Note that the `limit()` method only limits the query size, not the actual loop boundaries -- even if your limit is 1
> you will still get all the models, with one query each.
## Development

Common tasks, such as building documentation and running tests, can be executed using [`nox`](https://github.com/wntrblm/nox) sessions.
Expand Down
42 changes: 39 additions & 3 deletions clients/python/docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,14 @@ models = await mr_client.get_registered_models()

versions = await mr_client.get_model_versions("registered_model_id")

# We can get a list of all model artifacts
# We can get a list of the first 20 model artifacts
all_model_artifacts = await mr_client.get_model_artifacts()
```

To limit or order the query, provide a {py:class}`model_registry.types.ListOptions` object.
To limit or sort the query by another parameter, provide a {py:class}`model_registry.types.ListOptions` object.

```py
from model_registry import ListOptions
from model_registry.types import ListOptions

options = ListOptions(limit=50)

Expand All @@ -195,6 +195,42 @@ options = ListOptions.order_by_creation_time(is_asc=False)
last_50_models = await mr_client.get_registered_models(options)
```

You can also use the high-level {py:class}`model_registry.types.Pager` to get an iterator.

```py
from model_registry.types import Pager

models = Pager(mr_client.get_registered_models)

async for model in models:
...
```

Note that the iterator currently only works with methods that take a `ListOptions` argument, so if you want to use a
method that needs additional arguments, you'll need to provide a partial application like in the example below.

```py
model_version_artifacts = Pager(lambda o: mr_client.get_model_version_artifacts(mv.id, o))
```

> ⚠️ Also note that a [`partial`](https://docs.python.org/3/library/functools.html#functools.partial) definition won't work as the `options` argument is optional, and thus has to be overriden as a positional argument.
The iterator provides methods for setting up the {py:class}`model_registry.types.ListOptions` that will be used in each
call.

```py
reverse_model_version_artifacts = model_version_artifacts.order_by_creation_time().descending().limit(100)
```

You can also get each page separately and iterate yourself:

```py
page = await reverse_model_version_artifacts.next_page()
```

> Note: the iterator will be automagically sync or async depending on the paging function passed in for initialization.

```{eval-rst}
.. automodule:: model_registry.core
```
Expand Down
44 changes: 43 additions & 1 deletion clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@

from .core import ModelRegistryAPIClient
from .exceptions import StoreError
from .types import ModelArtifact, ModelVersion, RegisteredModel, SupportedTypes
from .types import (
ListOptions,
ModelArtifact,
ModelVersion,
Pager,
RegisteredModel,
SupportedTypes,
)


class ModelRegistry:
Expand Down Expand Up @@ -327,3 +334,38 @@ def get_model_artifact(self, name: str, version: str) -> ModelArtifact | None:
raise StoreError(msg)
assert mv.id
return self.async_runner(self._api.get_model_artifact_by_params(name, mv.id))

def get_registered_models(self) -> Pager[RegisteredModel]:
"""Get a pager for registered models.
Returns:
Iterable pager for registered models.
"""

def rm_list(options: ListOptions) -> list[RegisteredModel]:
return self.async_runner(self._api.get_registered_models(options))

return Pager[RegisteredModel](rm_list)

def get_model_versions(self, name: str) -> Pager[ModelVersion]:
"""Get a pager for model versions.
Args:
name: Name of the model.
Returns:
Iterable pager for model versions.
Raises:
StoreException: If the model does not exist.
"""
if not (rm := self.get_registered_model(name)):
msg = f"Model {name} does not exist"
raise StoreError(msg)

def rm_versions(options: ListOptions) -> list[ModelVersion]:
# type checkers can't restrict the type inside a nested function: https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
assert rm.id
return self.async_runner(self._api.get_model_versions(rm.id, options))

return Pager[ModelVersion](rm_versions)
48 changes: 40 additions & 8 deletions clients/python/src/model_registry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,14 @@ async def get_registered_models(
Registered models.
"""
async with self.get_client() as client:
rms = await client.get_registered_models(
rm_list = await client.get_registered_models(
**(options or ListOptions()).as_options()
)

return [RegisteredModel.from_basemodel(rm) for rm in rms.items or []]
if options:
options.next_page_token = rm_list.next_page_token

return [RegisteredModel.from_basemodel(rm) for rm in rm_list.items or []]

async def upsert_model_version(
self, model_version: ModelVersion, registered_model_id: str
Expand Down Expand Up @@ -236,11 +239,14 @@ async def get_model_versions(
Model versions.
"""
async with self.get_client() as client:
mvs = await client.get_registered_model_versions(
mv_list = await client.get_registered_model_versions(
registered_model_id, **(options or ListOptions()).as_options()
)

return [ModelVersion.from_basemodel(mv) for mv in mvs.items or []]
if options:
options.next_page_token = mv_list.next_page_token

return [ModelVersion.from_basemodel(mv) for mv in mv_list.items or []]

@overload
async def get_model_version_by_params(
Expand Down Expand Up @@ -415,17 +421,43 @@ async def get_model_artifacts(
"""
async with self.get_client() as client:
if model_version_id:
arts = await client.get_model_version_artifacts(
art_list = await client.get_model_version_artifacts(
model_version_id, **(options or ListOptions()).as_options()
)
if options:
options.next_page_token = art_list.next_page_token
models = []
for art in arts.items or []:
for art in art_list.items or []:
converted = Artifact.validate_artifact(art)
if isinstance(converted, ModelArtifact):
models.append(converted)
return models

mas = await client.get_model_artifacts(
ma_list = await client.get_model_artifacts(
**(options or ListOptions()).as_options()
)
return [ModelArtifact.from_basemodel(ma) for ma in mas.items or []]
if options:
options.next_page_token = ma_list.next_page_token
return [ModelArtifact.from_basemodel(ma) for ma in ma_list.items or []]

async def get_model_version_artifacts(
self,
model_version_id: str,
options: ListOptions | None = None,
) -> list[Artifact]:
"""Fetches model artifacts.
Args:
model_version_id: ID of the associated model version.
options: Options for listing model artifacts.
Returns:
Model artifacts.
"""
async with self.get_client() as client:
art_list = await client.get_model_version_artifacts(
model_version_id, **(options or ListOptions()).as_options()
)
if options:
options.next_page_token = art_list.next_page_token
return [Artifact.validate_artifact(art) for art in art_list.items or []]
3 changes: 3 additions & 0 deletions clients/python/src/model_registry/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
RegisteredModelState,
)
from .options import ListOptions
from .pager import Pager

__all__ = [
# Artifacts
Expand All @@ -27,4 +28,6 @@
"SupportedTypes",
# Options
"ListOptions",
# Pager
"Pager",
]
4 changes: 4 additions & 0 deletions clients/python/src/model_registry/types/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ class ListOptions:
limit: Maximum number of objects to return.
order_by: Field to order by.
is_asc: Whether to order in ascending order. Defaults to True.
next_page_token: Token to use to retrieve next page of results.
"""

limit: int | None = None
order_by: OrderByField | None = None
is_asc: bool = True
next_page_token: str | None = None

@classmethod
def order_by_creation_time(cls, **kwargs) -> ListOptions:
Expand All @@ -49,4 +51,6 @@ def as_options(self) -> dict[str, Any]:
options["order_by"] = self.order_by
if self.is_asc is not None:
options["sort_order"] = SortOrder.ASC if self.is_asc else SortOrder.DESC
if self.next_page_token is not None:
options["next_page_token"] = self.next_page_token
return options
Loading

0 comments on commit 57c56f1

Please sign in to comment.