Skip to content

Commit

Permalink
Feature/revamp model endpoints (#2035)
Browse files Browse the repository at this point in the history
* Moved model versions to their own root route

* Fix the delete endpoint

* Fixed smaller mistakes

* Fixed for Client as well

* Fixed lint errors

* More linting

* More linting

* More tiny fixes

* Update src/zenml/zen_server/routers/model_versions_endpoints.py

Co-authored-by: Andrei Vishniakov <[email protected]>

* More linting

* Fixed tests and solved conflicts

* Fixed linting

* Fixed more tests

* Further refactoring

* Added raises section

* Fix one failing test

* Take "latest" stage into account

* Reformatted

* Standardize use of list response

* Rewrote some tests

* Add clien tests

* Fixed spelling

* Tested to work with e2e pipeline

* Ugly fixes to get response models in the CLI

* Auto-update of E2E template

* Access ModelVersionResponseModels in Client again

* Another small fix

* Linted

---------

Co-authored-by: Andrei Vishniakov <[email protected]>
Co-authored-by: GitHub Actions <[email protected]>
  • Loading branch information
3 people authored Nov 16, 2023
1 parent 374ca2f commit 7f1642c
Show file tree
Hide file tree
Showing 16 changed files with 785 additions and 643 deletions.
51 changes: 37 additions & 14 deletions src/zenml/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def update_model_version(
cli_utils.print_table(
[
_model_version_to_print(
Client().zen_store.get_model_version(
Client()._get_model_version(
model_name_or_id=model_version.model_id,
model_version_name_or_number_or_id=stage,
)
Expand Down Expand Up @@ -488,9 +488,12 @@ def delete_model_version(
return

try:
Client().delete_model_version(
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_id=model_version_name_or_number_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)
Client().delete_model_version(
model_version_id=model_version.id,
)
except (KeyError, ValueError) as e:
cli_utils.error(str(e))
Expand Down Expand Up @@ -518,12 +521,15 @@ def _print_artifacts_links_generic(
only_model_artifacts: If set, only print model artifacts.
**kwargs: Keyword arguments to filter models.
"""
model_version = Client().zen_store.get_model_version(
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=ModelStages.LATEST
if model_version_name_or_number_or_id == "0"
else model_version_name_or_number_or_id,
)
model_version_response_model = Client().zen_store.get_model_version(
model_version_id=model_version.id
)
type_ = (
"data artifacts"
if only_data_artifacts
Expand All @@ -533,11 +539,18 @@ def _print_artifacts_links_generic(
)

if (
(only_data_artifacts and not model_version.data_artifact_ids)
(
only_data_artifacts
and not model_version_response_model.data_artifact_ids
)
or (
only_endpoint_artifacts
and not model_version_response_model.endpoint_artifact_ids
)
or (
only_endpoint_artifacts and not model_version.endpoint_artifact_ids
only_model_artifacts
and not model_version_response_model.model_artifact_ids
)
or (only_model_artifacts and not model_version.model_artifact_ids)
):
cli_utils.declare(f"No {type_} linked to the model version found.")
return
Expand All @@ -546,9 +559,13 @@ def _print_artifacts_links_generic(
f"{type_} linked to the model version `{model_version.name}[{model_version.number}]`:"
)

model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)

links = Client().list_model_version_artifact_links(
model_name_or_id=model_version.model.id,
model_version_name_or_number_or_id=model_version.id,
model_version_id=model_version.id,
model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel(
only_data_artifacts=only_data_artifacts,
only_endpoint_artifacts=only_endpoint_artifacts,
Expand Down Expand Up @@ -674,23 +691,29 @@ def list_model_version_pipeline_runs(
Or use 0 for the latest version.
**kwargs: Keyword arguments to filter models.
"""
model_version = Client().zen_store.get_model_version(
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=ModelStages.LATEST
if model_version_name_or_number_or_id == "0"
else model_version_name_or_number_or_id,
)
model_version_response_model = Client().zen_store.get_model_version(
model_version_id=model_version.id
)

if not model_version.pipeline_run_ids:
if not model_version_response_model.pipeline_run_ids:
cli_utils.declare("No pipeline runs attached to model version found.")
return
cli_utils.title(
f"Pipeline runs linked to the model version `{model_version.name}[{model_version.number}]`:"
f"Pipeline runs linked to the model version `{model_version_response_model.name}[{model_version_response_model.number}]`:"
)
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)

links = Client().list_model_version_pipeline_run_links(
model_name_or_id=model_version.model.id,
model_version_name_or_number_or_id=model_version.id,
model_version_id=model_version.id,
model_version_pipeline_run_link_filter_model=ModelVersionPipelineRunFilterModel(
**kwargs,
),
Expand Down
145 changes: 114 additions & 31 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
OAuthDeviceStatus,
PermissionType,
SecretScope,
SorterOps,
StackComponentType,
StoreType,
)
Expand Down Expand Up @@ -5584,25 +5585,22 @@ def create_model_version(

def delete_model_version(
self,
model_name_or_id: Union[str, UUID],
model_version_name_or_id: Union[str, UUID],
model_version_id: UUID,
) -> None:
"""Deletes a model version from Model Control Plane.
Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_id: name or id of the model version to be deleted.
model_version_id: Id of the model version to be deleted.
"""
self.zen_store.delete_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_id=model_version_name_or_id,
model_version_id=model_version_id,
)

def get_model_version(
self,
model_name_or_id: Union[str, UUID],
model_version_name_or_number_or_id: Optional[
Union[str, int, UUID, ModelStages]
Union[str, int, ModelStages, UUID]
] = None,
) -> "ModelVersion":
"""Get an existing model version from Model Control Plane.
Expand All @@ -5614,29 +5612,122 @@ def get_model_version(
Returns:
The model version of interest.
Raises:
RuntimeError: In case method inputs don't adhere to restrictions.
KeyError: In case no model version with the identifiers exists.
"""
return self.zen_store.get_model_version(
return self._get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id
or ModelStages.LATEST,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
).to_model_version(suppress_class_validation_warnings=True)

def list_model_versions(
def _get_model_version(
self,
model_name_or_id: Union[str, UUID],
model_version_name_or_number_or_id: Optional[
Union[str, int, ModelStages, UUID]
] = None,
) -> "ModelVersionResponseModel":
"""Get an existing model version from Model Control Plane.
Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved.
If skipped - latest version is retrieved.
Returns:
The model version of interest.
Raises:
RuntimeError: In case method inputs don't adhere to restrictions.
KeyError: In case no model version with the identifiers exists.
"""
if model_version_name_or_number_or_id is None:
model_version_name_or_number_or_id = ModelStages.LATEST

if isinstance(model_version_name_or_number_or_id, UUID):
return self.zen_store.get_model_version(
model_version_id=model_version_name_or_number_or_id
)
elif isinstance(model_version_name_or_number_or_id, int):
model_versions = self.zen_store.list_model_versions(
model_name_or_id=model_name_or_id,
model_version_filter_model=ModelVersionFilterModel(
number=model_version_name_or_number_or_id,
),
).items
elif isinstance(model_version_name_or_number_or_id, str):
if model_version_name_or_number_or_id == ModelStages.LATEST:
model_versions = self.zen_store.list_model_versions(
model_name_or_id=model_name_or_id,
model_version_filter_model=ModelVersionFilterModel(
sort_by=f"{SorterOps.DESCENDING}:number"
),
).items

if len(model_versions) > 1:
model_versions = [model_versions[0]]
else:
model_versions = []
elif model_version_name_or_number_or_id in ModelStages.values():
model_versions = self.zen_store.list_model_versions(
model_name_or_id=model_name_or_id,
model_version_filter_model=ModelVersionFilterModel(
stage=model_version_name_or_number_or_id
),
).items
else:
model_versions = self.zen_store.list_model_versions(
model_name_or_id=model_name_or_id,
model_version_filter_model=ModelVersionFilterModel(
name=model_version_name_or_number_or_id
),
).items
else:
raise RuntimeError(
f"The model version identifier "
f"`{model_version_name_or_number_or_id}` is not"
f"of the correct type."
)

if len(model_versions) == 1:
return model_versions[0]
elif len(model_versions) == 0:
raise KeyError(
f"No model version found for model "
f"`{model_name_or_id}` with version identifier "
f"`{model_version_name_or_number_or_id}`."
)
else:
raise RuntimeError(
f"The model version identifier "
f"`{model_version_name_or_number_or_id}` is not"
f"unique for model `{model_name_or_id}`."
)

def list_model_versions(
self,
sort_by: str = "number",
page: int = PAGINATION_STARTING_PAGE,
size: int = PAGE_SIZE_DEFAULT,
logical_operator: LogicalOperators = LogicalOperators.AND,
created: Optional[Union[datetime, str]] = None,
updated: Optional[Union[datetime, str]] = None,
model_name_or_id: Optional[Union[str, UUID]] = None,
name: Optional[str] = None,
number: Optional[int] = None,
stage: Optional[Union[str, ModelStages]] = None,
) -> Page["ModelVersionResponseModel"]:
"""Get model versions by filter from Model Control Plane.
Args:
sort_by: The column to sort by
page: The page of items
size: The maximum size of all pages
logical_operator: Which logical operator to use [and, or]
created: Use to filter by time of creation
updated: Use the last updated date for filtering
model_name_or_id: name or id of the model containing the model version.
sort_by: The column to sort by
page: The page of items
Expand All @@ -5652,6 +5743,12 @@ def list_model_versions(
A page object with all model versions.
"""
model_version_filter_model = ModelVersionFilterModel(
page=page,
size=size,
sort_by=sort_by,
logical_operator=logical_operator,
created=created,
updated=updated,
name=name,
number=number,
stage=stage,
Expand Down Expand Up @@ -5714,28 +5811,21 @@ def update_model_version(

def list_model_version_artifact_links(
self,
model_name_or_id: Union[str, UUID],
model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel,
model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages],
model_version_id: UUID,
) -> Page[ModelVersionArtifactResponseModel]:
"""Get model version to artifact links by filter in Model Control Plane.
Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved.
model_version_id: id of the model version to be retrieved.
model_version_artifact_link_filter_model: All filter parameters including pagination
params.
Returns:
A page of all model version to artifact links.
"""
mv = self.zen_store.get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)
return self.zen_store.list_model_version_artifact_links(
model_name_or_id=mv.model.id,
model_version_name_or_id=mv.id,
model_version_id=model_version_id,
model_version_artifact_link_filter_model=model_version_artifact_link_filter_model,
)

Expand All @@ -5747,28 +5837,21 @@ def list_model_version_artifact_links(

def list_model_version_pipeline_run_links(
self,
model_name_or_id: Union[str, UUID],
model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel,
model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages],
model_version_id: UUID,
) -> Page[ModelVersionPipelineRunResponseModel]:
"""Get all model version to pipeline run links by filter.
Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved.
model_version_id: id of the model version to be retrieved.
model_version_pipeline_run_link_filter_model: All filter parameters including pagination
params.
Returns:
A page of all model version to pipeline run links.
"""
mv = self.zen_store.get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)
return self.zen_store.list_model_version_pipeline_run_links(
model_name_or_id=mv.model.id,
model_version_name_or_id=mv.id,
model_version_id=model_version_id,
model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model,
)

Expand Down
7 changes: 3 additions & 4 deletions src/zenml/model/artifact_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ def _link_to_model_version(

# Create the model version artifact link using the ZenML client
existing_links = client.list_model_version_artifact_links(
model_name_or_id=model_version.model_id,
model_version_name_or_number_or_id=model_version.id,
model_version_id=model_version.id,
model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel(
user_id=client.active_user.id,
workspace_id=client.active_workspace.id,
Expand All @@ -175,9 +174,9 @@ def _link_to_model_version(
logger.warning(
f"Existing artifact link(s) `{artifact_name}` found and will be deleted."
)

client.zen_store.delete_model_version_artifact_link(
model_name_or_id=model_version.model_id,
model_version_name_or_id=model_version.id,
model_version_id=model_version.id,
model_version_artifact_link_name_or_id=artifact_name,
)
else:
Expand Down
6 changes: 5 additions & 1 deletion src/zenml/model/model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,14 @@ def _get_model_version(self) -> "ModelVersionResponseModel":
from zenml.client import Client

zenml_client = Client()
return zenml_client.zen_store.get_model_version(
mv = zenml_client._get_model_version(
model_name_or_id=self.name,
model_version_name_or_number_or_id=self.version,
)
if not self._id:
self._id = mv.id

return mv

def _get_or_create_model_version(self) -> "ModelVersionResponseModel":
"""This method should get or create a model and a model version from Model Control Plane.
Expand Down
Loading

0 comments on commit 7f1642c

Please sign in to comment.