diff --git a/clients/python/README.md b/clients/python/README.md index 2b1e8f823..9e97e1ca3 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -37,6 +37,12 @@ model = registry.get_registered_model("my-model") version = registry.get_model_version("my-model", "2.0.0") experiment = registry.get_model_artifact("my-model", "2.0.0") + +# change is not reflected on pushed model version +version.description = "Updated model version" + +# you can update it using +registry.update(version) ``` ### Importing from S3 diff --git a/clients/python/src/model_registry/_client.py b/clients/python/src/model_registry/_client.py index 25393d0d7..c1d79111e 100644 --- a/clients/python/src/model_registry/_client.py +++ b/clients/python/src/model_registry/_client.py @@ -4,7 +4,7 @@ import os from pathlib import Path -from typing import Any, get_args +from typing import Any, TypeVar, Union, get_args from warnings import warn from .core import ModelRegistryAPIClient @@ -18,6 +18,9 @@ SupportedTypes, ) +ModelTypes = Union[RegisteredModel, ModelVersion, ModelArtifact] +TModel = TypeVar("TModel", bound=ModelTypes) + class ModelRegistry: """Model registry client.""" @@ -191,6 +194,20 @@ def register_model( return rm + def update(self, model: TModel) -> TModel: + """Update a model.""" + if not model.id: + msg = "Model must have an ID" + raise StoreError(msg) + if not isinstance(model, get_args(ModelTypes)): + msg = f"Model must be one of {get_args(ModelTypes)}" + raise StoreError(msg) + if isinstance(model, RegisteredModel): + return self.async_runner(self._api.upsert_registered_model(model)) + if isinstance(model, ModelVersion): + return self.async_runner(self._api.upsert_model_version(model, model.id)) + return self.async_runner(self._api.upsert_model_artifact(model, model.id)) + def register_hf_model( self, repo: str, diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index ba2cdbef6..99223a088 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -5,6 +5,7 @@ from model_registry import ModelRegistry, utils from model_registry.exceptions import StoreError +from model_registry.types import ModelArtifact def test_secure_client(): @@ -77,6 +78,69 @@ def test_register_existing_version(client: ModelRegistry): client.register_model(**params) +@pytest.mark.e2e +async def test_update_models(client: ModelRegistry): + name = "test_model" + version = "1.0.0" + rm = client.register_model( + name, + "s3", + model_format_name="test_format", + model_format_version="test_version", + version=version, + ) + assert rm.id + + mr_api = client._api + mv = await mr_api.get_model_version_by_params(rm.id, version) + assert mv + assert mv.id + ma = await mr_api.get_model_artifact_by_params(name, mv.id) + assert ma + + new_description = "updated description" + rm.description = new_description + mv.description = new_description + ma.description = new_description + assert client.update(rm).description == new_description + assert client.update(mv).description == new_description + assert client.update(ma).description == new_description + + +@pytest.mark.e2e +async def test_update_preserves_model_info(client: ModelRegistry): + name = "test_model" + version = "1.0.0" + uri = "s3" + model_fmt_name = "test_format" + model_fmt_version = "test_version" + rm = client.register_model( + name, + uri, + model_format_name=model_fmt_name, + model_format_version=model_fmt_version, + version=version, + ) + assert rm.id + + mr_api = client._api + mv = await mr_api.get_model_version_by_params(rm.id, version) + assert mv + assert mv.id + ma = await mr_api.get_model_artifact_by_params(name, mv.id) + assert ma + + new_description = "updated description" + ma = ModelArtifact(id=ma.id, uri=uri, description=new_description) + + updated_ma = client.update(ma) + assert updated_ma.description == new_description + assert updated_ma.uri == uri + assert updated_ma.id == ma.id + assert updated_ma.model_format_name == model_fmt_name + assert updated_ma.model_format_version == model_fmt_version + + @pytest.mark.e2e async def test_get(client: ModelRegistry): name = "test_model"