Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/revamp model endpoints #2035

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
21fc0e1
Moved model versions to their own root route
AlexejPenner Nov 9, 2023
d4d1ee2
Fix the delete endpoint
AlexejPenner Nov 9, 2023
1e08173
Fixed smaller mistakes
AlexejPenner Nov 9, 2023
f00593a
Fixed for Client as well
AlexejPenner Nov 9, 2023
c75ae0e
Fixed lint errors
AlexejPenner Nov 9, 2023
2d770c4
More linting
AlexejPenner Nov 9, 2023
641f408
More linting
AlexejPenner Nov 9, 2023
0e78b3e
More tiny fixes
AlexejPenner Nov 9, 2023
b407f71
Update src/zenml/zen_server/routers/model_versions_endpoints.py
AlexejPenner Nov 9, 2023
51c11a9
More linting
AlexejPenner Nov 9, 2023
a914370
Merge branch 'feature/revamp-model-endpoints' of github.com:zenml-io/…
AlexejPenner Nov 9, 2023
f95be63
Merge branch 'develop' into feature/revamp-model-endpoints
AlexejPenner Nov 13, 2023
7dd030f
Fixed tests and solved conflicts
AlexejPenner Nov 13, 2023
0d1cc73
Fixed linting
AlexejPenner Nov 13, 2023
c21dfbb
Fixed more tests
AlexejPenner Nov 13, 2023
f9f35d0
Merge branch 'develop' into feature/revamp-model-endpoints
AlexejPenner Nov 14, 2023
9695b60
Further refactoring
AlexejPenner Nov 14, 2023
1beb80f
Added raises section
AlexejPenner Nov 14, 2023
d114aea
Fix one failing test
AlexejPenner Nov 14, 2023
b965138
Take "latest" stage into account
AlexejPenner Nov 14, 2023
5a8cc70
Reformatted
AlexejPenner Nov 14, 2023
9a86fe6
Standardize use of list response
AlexejPenner Nov 15, 2023
79eb3a2
Rewrote some tests
AlexejPenner Nov 15, 2023
1d5c3ed
Add clien tests
AlexejPenner Nov 15, 2023
33176ba
Fixed spelling
AlexejPenner Nov 15, 2023
3dc5a4e
Merge branch 'feature/OSS-2609-OSS-2575-model-config-is-model-version…
AlexejPenner Nov 15, 2023
713d610
Merge branch 'feature/OSS-2609-OSS-2575-model-config-is-model-version…
AlexejPenner Nov 15, 2023
bf89710
Tested to work with e2e pipeline
AlexejPenner Nov 15, 2023
1c72fd0
Ugly fixes to get response models in the CLI
AlexejPenner Nov 15, 2023
1d539e0
Auto-update of E2E template
actions-user Nov 15, 2023
ac4c980
Access ModelVersionResponseModels in Client again
AlexejPenner Nov 16, 2023
bfde0ba
Merge branch 'feature/revamp-model-endpoints' of github.com:zenml-io/…
AlexejPenner Nov 16, 2023
bf34c71
Another small fix
AlexejPenner Nov 16, 2023
070e37e
Linted
AlexejPenner Nov 16, 2023
0061553
Merge branch 'feature/OSS-2609-OSS-2575-model-config-is-model-version…
avishniakov Nov 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions src/zenml/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def update_model_version(
cli_utils.print_table(
[
_model_version_to_print(
Client().zen_store.get_model_version(
Client()._get_model_version(
model_name_or_id=model_version.model_id,
model_version_name_or_number_or_id=stage,
)
Expand Down Expand Up @@ -488,9 +488,12 @@ def delete_model_version(
return

try:
Client().delete_model_version(
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_id=model_version_name_or_number_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)
Client().delete_model_version(
model_version_id=model_version.id,
)
except (KeyError, ValueError) as e:
cli_utils.error(str(e))
Expand Down Expand Up @@ -518,12 +521,15 @@ def _print_artifacts_links_generic(
only_model_artifacts: If set, only print model artifacts.
**kwargs: Keyword arguments to filter models.
"""
model_version = Client().zen_store.get_model_version(
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=ModelStages.LATEST
if model_version_name_or_number_or_id == "0"
else model_version_name_or_number_or_id,
)
model_version_response_model = Client().zen_store.get_model_version(
model_version_id=model_version.id
)
type_ = (
"data artifacts"
if only_data_artifacts
Expand All @@ -533,11 +539,18 @@ def _print_artifacts_links_generic(
)

if (
(only_data_artifacts and not model_version.data_artifact_ids)
(
only_data_artifacts
and not model_version_response_model.data_artifact_ids
)
or (
only_endpoint_artifacts
and not model_version_response_model.endpoint_artifact_ids
)
or (
only_endpoint_artifacts and not model_version.endpoint_artifact_ids
only_model_artifacts
and not model_version_response_model.model_artifact_ids
)
or (only_model_artifacts and not model_version.model_artifact_ids)
):
cli_utils.declare(f"No {type_} linked to the model version found.")
return
Expand All @@ -546,9 +559,13 @@ def _print_artifacts_links_generic(
f"{type_} linked to the model version `{model_version.name}[{model_version.number}]`:"
)

model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)

links = Client().list_model_version_artifact_links(
model_name_or_id=model_version.model.id,
model_version_name_or_number_or_id=model_version.id,
model_version_id=model_version.id,
model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel(
only_data_artifacts=only_data_artifacts,
only_endpoint_artifacts=only_endpoint_artifacts,
Expand Down Expand Up @@ -674,23 +691,29 @@ def list_model_version_pipeline_runs(
Or use 0 for the latest version.
**kwargs: Keyword arguments to filter models.
"""
model_version = Client().zen_store.get_model_version(
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=ModelStages.LATEST
if model_version_name_or_number_or_id == "0"
else model_version_name_or_number_or_id,
)
model_version_response_model = Client().zen_store.get_model_version(
model_version_id=model_version.id
)

if not model_version.pipeline_run_ids:
if not model_version_response_model.pipeline_run_ids:
cli_utils.declare("No pipeline runs attached to model version found.")
return
cli_utils.title(
f"Pipeline runs linked to the model version `{model_version.name}[{model_version.number}]`:"
f"Pipeline runs linked to the model version `{model_version_response_model.name}[{model_version_response_model.number}]`:"
)
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)

links = Client().list_model_version_pipeline_run_links(
model_name_or_id=model_version.model.id,
model_version_name_or_number_or_id=model_version.id,
model_version_id=model_version.id,
model_version_pipeline_run_link_filter_model=ModelVersionPipelineRunFilterModel(
**kwargs,
),
Expand Down
145 changes: 114 additions & 31 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
OAuthDeviceStatus,
PermissionType,
SecretScope,
SorterOps,
StackComponentType,
StoreType,
)
Expand Down Expand Up @@ -5584,25 +5585,22 @@ def create_model_version(

def delete_model_version(
self,
model_name_or_id: Union[str, UUID],
model_version_name_or_id: Union[str, UUID],
model_version_id: UUID,
) -> None:
"""Deletes a model version from Model Control Plane.

Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_id: name or id of the model version to be deleted.
model_version_id: Id of the model version to be deleted.
"""
self.zen_store.delete_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_id=model_version_name_or_id,
model_version_id=model_version_id,
)

def get_model_version(
self,
model_name_or_id: Union[str, UUID],
model_version_name_or_number_or_id: Optional[
Union[str, int, UUID, ModelStages]
Union[str, int, ModelStages, UUID]
] = None,
) -> "ModelVersion":
"""Get an existing model version from Model Control Plane.
Expand All @@ -5614,29 +5612,122 @@ def get_model_version(

Returns:
The model version of interest.

Raises:
RuntimeError: In case method inputs don't adhere to restrictions.
KeyError: In case no model version with the identifiers exists.
"""
return self.zen_store.get_model_version(
return self._get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id
or ModelStages.LATEST,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
).to_model_version(suppress_class_validation_warnings=True)

def list_model_versions(
def _get_model_version(
self,
model_name_or_id: Union[str, UUID],
model_version_name_or_number_or_id: Optional[
Union[str, int, ModelStages, UUID]
] = None,
) -> "ModelVersionResponseModel":
"""Get an existing model version from Model Control Plane.

Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved.
If skipped - latest version is retrieved.

Returns:
The model version of interest.

Raises:
RuntimeError: In case method inputs don't adhere to restrictions.
KeyError: In case no model version with the identifiers exists.
"""
if model_version_name_or_number_or_id is None:
model_version_name_or_number_or_id = ModelStages.LATEST

if isinstance(model_version_name_or_number_or_id, UUID):
return self.zen_store.get_model_version(
model_version_id=model_version_name_or_number_or_id
)
elif isinstance(model_version_name_or_number_or_id, int):
model_versions = self.zen_store.list_model_versions(
model_name_or_id=model_name_or_id,
model_version_filter_model=ModelVersionFilterModel(
number=model_version_name_or_number_or_id,
),
).items
elif isinstance(model_version_name_or_number_or_id, str):
if model_version_name_or_number_or_id == ModelStages.LATEST:
model_versions = self.zen_store.list_model_versions(
model_name_or_id=model_name_or_id,
model_version_filter_model=ModelVersionFilterModel(
sort_by=f"{SorterOps.DESCENDING}:number"
),
).items

if len(model_versions) > 1:
model_versions = [model_versions[0]]
else:
model_versions = []
elif model_version_name_or_number_or_id in ModelStages.values():
model_versions = self.zen_store.list_model_versions(
model_name_or_id=model_name_or_id,
model_version_filter_model=ModelVersionFilterModel(
stage=model_version_name_or_number_or_id
),
).items
else:
model_versions = self.zen_store.list_model_versions(
model_name_or_id=model_name_or_id,
model_version_filter_model=ModelVersionFilterModel(
name=model_version_name_or_number_or_id
),
).items
else:
raise RuntimeError(
f"The model version identifier "
f"`{model_version_name_or_number_or_id}` is not"
f"of the correct type."
)

if len(model_versions) == 1:
return model_versions[0]
elif len(model_versions) == 0:
raise KeyError(
f"No model version found for model "
f"`{model_name_or_id}` with version identifier "
f"`{model_version_name_or_number_or_id}`."
)
else:
raise RuntimeError(
f"The model version identifier "
f"`{model_version_name_or_number_or_id}` is not"
f"unique for model `{model_name_or_id}`."
)

def list_model_versions(
self,
sort_by: str = "number",
page: int = PAGINATION_STARTING_PAGE,
size: int = PAGE_SIZE_DEFAULT,
logical_operator: LogicalOperators = LogicalOperators.AND,
created: Optional[Union[datetime, str]] = None,
updated: Optional[Union[datetime, str]] = None,
model_name_or_id: Optional[Union[str, UUID]] = None,
name: Optional[str] = None,
number: Optional[int] = None,
stage: Optional[Union[str, ModelStages]] = None,
) -> Page["ModelVersionResponseModel"]:
"""Get model versions by filter from Model Control Plane.

Args:
sort_by: The column to sort by
page: The page of items
size: The maximum size of all pages
logical_operator: Which logical operator to use [and, or]
created: Use to filter by time of creation
updated: Use the last updated date for filtering
model_name_or_id: name or id of the model containing the model version.
sort_by: The column to sort by
page: The page of items
Expand All @@ -5652,6 +5743,12 @@ def list_model_versions(
A page object with all model versions.
"""
model_version_filter_model = ModelVersionFilterModel(
page=page,
size=size,
sort_by=sort_by,
logical_operator=logical_operator,
created=created,
updated=updated,
name=name,
number=number,
stage=stage,
Expand Down Expand Up @@ -5714,28 +5811,21 @@ def update_model_version(

def list_model_version_artifact_links(
self,
model_name_or_id: Union[str, UUID],
model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel,
model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages],
model_version_id: UUID,
) -> Page[ModelVersionArtifactResponseModel]:
"""Get model version to artifact links by filter in Model Control Plane.

Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved.
model_version_id: id of the model version to be retrieved.
model_version_artifact_link_filter_model: All filter parameters including pagination
params.

Returns:
A page of all model version to artifact links.
"""
mv = self.zen_store.get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)
return self.zen_store.list_model_version_artifact_links(
model_name_or_id=mv.model.id,
model_version_name_or_id=mv.id,
model_version_id=model_version_id,
model_version_artifact_link_filter_model=model_version_artifact_link_filter_model,
)

Expand All @@ -5747,28 +5837,21 @@ def list_model_version_artifact_links(

def list_model_version_pipeline_run_links(
self,
model_name_or_id: Union[str, UUID],
model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel,
model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages],
model_version_id: UUID,
) -> Page[ModelVersionPipelineRunResponseModel]:
"""Get all model version to pipeline run links by filter.

Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved.
model_version_id: id of the model version to be retrieved.
model_version_pipeline_run_link_filter_model: All filter parameters including pagination
params.

Returns:
A page of all model version to pipeline run links.
"""
mv = self.zen_store.get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)
return self.zen_store.list_model_version_pipeline_run_links(
model_name_or_id=mv.model.id,
model_version_name_or_id=mv.id,
model_version_id=model_version_id,
model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model,
)

Expand Down
7 changes: 3 additions & 4 deletions src/zenml/model/artifact_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ def _link_to_model_version(

# Create the model version artifact link using the ZenML client
existing_links = client.list_model_version_artifact_links(
model_name_or_id=model_version.model_id,
model_version_name_or_number_or_id=model_version.id,
model_version_id=model_version.id,
model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel(
user_id=client.active_user.id,
workspace_id=client.active_workspace.id,
Expand All @@ -175,9 +174,9 @@ def _link_to_model_version(
logger.warning(
f"Existing artifact link(s) `{artifact_name}` found and will be deleted."
)

client.zen_store.delete_model_version_artifact_link(
model_name_or_id=model_version.model_id,
model_version_name_or_id=model_version.id,
model_version_id=model_version.id,
model_version_artifact_link_name_or_id=artifact_name,
)
else:
Expand Down
6 changes: 5 additions & 1 deletion src/zenml/model/model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,14 @@ def _get_model_version(self) -> "ModelVersionResponseModel":
from zenml.client import Client

zenml_client = Client()
return zenml_client.zen_store.get_model_version(
mv = zenml_client._get_model_version(
model_name_or_id=self.name,
model_version_name_or_number_or_id=self.version,
)
if not self._id:
self._id = mv.id

return mv

def _get_or_create_model_version(self) -> "ModelVersionResponseModel":
"""This method should get or create a model and a model version from Model Control Plane.
Expand Down
Loading
Loading