Skip to content

Commit

Permalink
py: client: add update method for supported types (#344)
Browse files Browse the repository at this point in the history
* py: client: add update method for supported types

Signed-off-by: Isabella do Amaral <[email protected]>

* py: test update preserves attrs

Signed-off-by: Isabella do Amaral <[email protected]>

* Update clients/python/tests/test_client.py

Co-authored-by: Matteo Mortari <[email protected]>
Signed-off-by: Isabella do Amaral <[email protected]>

---------

Signed-off-by: Isabella do Amaral <[email protected]>
Co-authored-by: Matteo Mortari <[email protected]>
  • Loading branch information
isinyaaa and tarilabs authored Sep 5, 2024
1 parent cc1fd8c commit 5755b77
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 1 deletion.
6 changes: 6 additions & 0 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ model = registry.get_registered_model("my-model")
version = registry.get_model_version("my-model", "2.0.0")

experiment = registry.get_model_artifact("my-model", "2.0.0")

# change is not reflected on pushed model version
version.description = "Updated model version"

# you can update it using
registry.update(version)
```

### Importing from S3
Expand Down
19 changes: 18 additions & 1 deletion clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from pathlib import Path
from typing import Any, get_args
from typing import Any, TypeVar, Union, get_args
from warnings import warn

from .core import ModelRegistryAPIClient
Expand All @@ -18,6 +18,9 @@
SupportedTypes,
)

ModelTypes = Union[RegisteredModel, ModelVersion, ModelArtifact]
TModel = TypeVar("TModel", bound=ModelTypes)


class ModelRegistry:
"""Model registry client."""
Expand Down Expand Up @@ -191,6 +194,20 @@ def register_model(

return rm

def update(self, model: TModel) -> TModel:
"""Update a model."""
if not model.id:
msg = "Model must have an ID"
raise StoreError(msg)
if not isinstance(model, get_args(ModelTypes)):
msg = f"Model must be one of {get_args(ModelTypes)}"
raise StoreError(msg)
if isinstance(model, RegisteredModel):
return self.async_runner(self._api.upsert_registered_model(model))
if isinstance(model, ModelVersion):
return self.async_runner(self._api.upsert_model_version(model, model.id))
return self.async_runner(self._api.upsert_model_artifact(model, model.id))

def register_hf_model(
self,
repo: str,
Expand Down
64 changes: 64 additions & 0 deletions clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from model_registry import ModelRegistry, utils
from model_registry.exceptions import StoreError
from model_registry.types import ModelArtifact


def test_secure_client():
Expand Down Expand Up @@ -77,6 +78,69 @@ def test_register_existing_version(client: ModelRegistry):
client.register_model(**params)


@pytest.mark.e2e
async def test_update_models(client: ModelRegistry):
name = "test_model"
version = "1.0.0"
rm = client.register_model(
name,
"s3",
model_format_name="test_format",
model_format_version="test_version",
version=version,
)
assert rm.id

mr_api = client._api
mv = await mr_api.get_model_version_by_params(rm.id, version)
assert mv
assert mv.id
ma = await mr_api.get_model_artifact_by_params(name, mv.id)
assert ma

new_description = "updated description"
rm.description = new_description
mv.description = new_description
ma.description = new_description
assert client.update(rm).description == new_description
assert client.update(mv).description == new_description
assert client.update(ma).description == new_description


@pytest.mark.e2e
async def test_update_preserves_model_info(client: ModelRegistry):
name = "test_model"
version = "1.0.0"
uri = "s3"
model_fmt_name = "test_format"
model_fmt_version = "test_version"
rm = client.register_model(
name,
uri,
model_format_name=model_fmt_name,
model_format_version=model_fmt_version,
version=version,
)
assert rm.id

mr_api = client._api
mv = await mr_api.get_model_version_by_params(rm.id, version)
assert mv
assert mv.id
ma = await mr_api.get_model_artifact_by_params(name, mv.id)
assert ma

new_description = "updated description"
ma = ModelArtifact(id=ma.id, uri=uri, description=new_description)

updated_ma = client.update(ma)
assert updated_ma.description == new_description
assert updated_ma.uri == uri
assert updated_ma.id == ma.id
assert updated_ma.model_format_name == model_fmt_name
assert updated_ma.model_format_version == model_fmt_version


@pytest.mark.e2e
async def test_get(client: ModelRegistry):
name = "test_model"
Expand Down

0 comments on commit 5755b77

Please sign in to comment.