Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

py: client: add update method for supported types #344

Merged
merged 3 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ model = registry.get_registered_model("my-model")
version = registry.get_model_version("my-model", "2.0.0")

experiment = registry.get_model_artifact("my-model", "2.0.0")

# change is not reflected on pushed model version
version.description = "Updated model version"

# you can update it using
registry.update(version)
```

### Importing from S3
Expand Down
19 changes: 18 additions & 1 deletion clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from pathlib import Path
from typing import Any, get_args
from typing import Any, TypeVar, Union, get_args
from warnings import warn

from .core import ModelRegistryAPIClient
Expand All @@ -18,6 +18,9 @@
SupportedTypes,
)

ModelTypes = Union[RegisteredModel, ModelVersion, ModelArtifact]
TModel = TypeVar("TModel", bound=ModelTypes)


class ModelRegistry:
"""Model registry client."""
Expand Down Expand Up @@ -191,6 +194,20 @@ def register_model(

return rm

def update(self, model: TModel) -> TModel:
"""Update a model."""
if not model.id:
msg = "Model must have an ID"
raise StoreError(msg)
if not isinstance(model, get_args(ModelTypes)):
msg = f"Model must be one of {get_args(ModelTypes)}"
raise StoreError(msg)
if isinstance(model, RegisteredModel):
return self.async_runner(self._api.upsert_registered_model(model))
if isinstance(model, ModelVersion):
return self.async_runner(self._api.upsert_model_version(model, model.id))
return self.async_runner(self._api.upsert_model_artifact(model, model.id))

def register_hf_model(
self,
repo: str,
Expand Down
64 changes: 64 additions & 0 deletions clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from model_registry import ModelRegistry, utils
from model_registry.exceptions import StoreError
from model_registry.types import ModelArtifact


def test_secure_client():
Expand Down Expand Up @@ -77,6 +78,69 @@ def test_register_existing_version(client: ModelRegistry):
client.register_model(**params)


@pytest.mark.e2e
async def test_update_models(client: ModelRegistry):
tarilabs marked this conversation as resolved.
Show resolved Hide resolved
name = "test_model"
version = "1.0.0"
rm = client.register_model(
name,
"s3",
model_format_name="test_format",
model_format_version="test_version",
version=version,
)
assert rm.id

mr_api = client._api
mv = await mr_api.get_model_version_by_params(rm.id, version)
assert mv
assert mv.id
ma = await mr_api.get_model_artifact_by_params(name, mv.id)
assert ma

new_description = "updated description"
rm.description = new_description
mv.description = new_description
ma.description = new_description
assert client.update(rm).description == new_description
assert client.update(mv).description == new_description
assert client.update(ma).description == new_description


@pytest.mark.e2e
async def test_update_preserves_model_info(client: ModelRegistry):
name = "test_model"
version = "1.0.0"
uri = "s3"
model_fmt_name = "test_format"
model_fmt_version = "test_version"
rm = client.register_model(
name,
uri,
model_format_name=model_fmt_name,
model_format_version=model_fmt_version,
version=version,
)
assert rm.id

mr_api = client._api
mv = await mr_api.get_model_version_by_params(rm.id, version)
assert mv
assert mv.id
ma = await mr_api.get_model_artifact_by_params(name, mv.id)
assert ma

new_description = "updated description"
ma = ModelArtifact(id=ma.id, uri=uri, description=new_description)

updated_ma = client.update(ma)
assert updated_ma.description == new_description
assert updated_ma.uri == uri
assert updated_ma.id == ma.id
assert updated_ma.model_format_name == model_fmt_name
assert updated_ma.model_format_version == model_fmt_version


@pytest.mark.e2e
async def test_get(client: ModelRegistry):
name = "test_model"
Expand Down
Loading