Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/revamp model endpoints #2035

Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
21fc0e1
Moved model versions to their own root route
AlexejPenner Nov 9, 2023
d4d1ee2
Fix the delete endpoint
AlexejPenner Nov 9, 2023
1e08173
Fixed smaller mistakes
AlexejPenner Nov 9, 2023
f00593a
Fixed for Client as well
AlexejPenner Nov 9, 2023
c75ae0e
Fixed lint errors
AlexejPenner Nov 9, 2023
2d770c4
More linting
AlexejPenner Nov 9, 2023
641f408
More linting
AlexejPenner Nov 9, 2023
0e78b3e
More tiny fixes
AlexejPenner Nov 9, 2023
b407f71
Update src/zenml/zen_server/routers/model_versions_endpoints.py
AlexejPenner Nov 9, 2023
51c11a9
More linting
AlexejPenner Nov 9, 2023
a914370
Merge branch 'feature/revamp-model-endpoints' of github.com:zenml-io/…
AlexejPenner Nov 9, 2023
f95be63
Merge branch 'develop' into feature/revamp-model-endpoints
AlexejPenner Nov 13, 2023
7dd030f
Fixed tests and solved conflicts
AlexejPenner Nov 13, 2023
0d1cc73
Fixed linting
AlexejPenner Nov 13, 2023
c21dfbb
Fixed more tests
AlexejPenner Nov 13, 2023
f9f35d0
Merge branch 'develop' into feature/revamp-model-endpoints
AlexejPenner Nov 14, 2023
9695b60
Further refactoring
AlexejPenner Nov 14, 2023
1beb80f
Added raises section
AlexejPenner Nov 14, 2023
d114aea
Fix one failing test
AlexejPenner Nov 14, 2023
b965138
Take "latest" stage into account
AlexejPenner Nov 14, 2023
5a8cc70
Reformatted
AlexejPenner Nov 14, 2023
9a86fe6
Standardize use of list response
AlexejPenner Nov 15, 2023
79eb3a2
Rewrote some tests
AlexejPenner Nov 15, 2023
1d5c3ed
Add clien tests
AlexejPenner Nov 15, 2023
33176ba
Fixed spelling
AlexejPenner Nov 15, 2023
3dc5a4e
Merge branch 'feature/OSS-2609-OSS-2575-model-config-is-model-version…
AlexejPenner Nov 15, 2023
713d610
Merge branch 'feature/OSS-2609-OSS-2575-model-config-is-model-version…
AlexejPenner Nov 15, 2023
bf89710
Tested to work with e2e pipeline
AlexejPenner Nov 15, 2023
1c72fd0
Ugly fixes to get response models in the CLI
AlexejPenner Nov 15, 2023
1d539e0
Auto-update of E2E template
actions-user Nov 15, 2023
ac4c980
Access ModelVersionResponseModels in Client again
AlexejPenner Nov 16, 2023
bfde0ba
Merge branch 'feature/revamp-model-endpoints' of github.com:zenml-io/…
AlexejPenner Nov 16, 2023
bf34c71
Another small fix
AlexejPenner Nov 16, 2023
070e37e
Linted
AlexejPenner Nov 16, 2023
0061553
Merge branch 'feature/OSS-2609-OSS-2575-model-config-is-model-version…
avishniakov Nov 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions src/zenml/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,12 @@ def delete_model_version(
return

try:
Client().delete_model_version(
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_id=model_version_name_or_number_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)
Client().delete_model_version(
model_version_id=model_version.id,
)
except (KeyError, ValueError) as e:
cli_utils.error(str(e))
Expand Down Expand Up @@ -551,9 +554,13 @@ def _print_artifacts_links_generic(
f"{type_} linked to the model version `{model_version.name}[{model_version.number}]`:"
)

model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)

links = Client().list_model_version_artifact_links(
model_name_or_id=model_version.model.id,
model_version_name_or_number_or_id=model_version.id,
model_version_id=model_version.id,
model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel(
only_artifacts=only_artifact_objects,
only_deployments=only_deployments,
Expand Down Expand Up @@ -692,10 +699,13 @@ def list_model_version_pipeline_runs(
cli_utils.title(
f"Pipeline runs linked to the model version `{model_version.name}[{model_version.number}]`:"
)
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)

links = Client().list_model_version_pipeline_run_links(
model_name_or_id=model_version.model.id,
model_version_name_or_number_or_id=model_version.id,
model_version_id=model_version.id,
model_version_pipeline_run_link_filter_model=ModelVersionPipelineRunFilterModel(
**kwargs,
),
Expand Down
111 changes: 78 additions & 33 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
OAuthDeviceStatus,
PermissionType,
SecretScope,
SorterOps,
StackComponentType,
StoreType,
)
Expand Down Expand Up @@ -5486,25 +5487,22 @@ def create_model_version(

def delete_model_version(
self,
model_name_or_id: Union[str, UUID],
model_version_name_or_id: Union[str, UUID],
model_version_id: UUID,
) -> None:
"""Deletes a model version from Model Control Plane.

Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_id: name or id of the model version to be deleted.
model_version_id: Id of the model version to be deleted.
"""
self.zen_store.delete_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_id=model_version_name_or_id,
model_version_id=model_version_id,
)

def get_model_version(
self,
model_name_or_id: Union[str, UUID],
model_version_name_or_number_or_id: Optional[
Union[str, int, UUID, ModelStages]
Union[str, int, ModelStages, UUID]
] = None,
) -> ModelVersionResponseModel:
"""Get an existing model version from Model Control Plane.
Expand All @@ -5516,17 +5514,78 @@ def get_model_version(

Returns:
The model version of interest.

Raises:
RuntimeError: In case method inputs don't adhere to restrictions.
KeyError: In case no model version with the identifiers exists.
"""
return self.zen_store.get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id
or ModelStages.LATEST,
)
if model_version_name_or_number_or_id is None:
model_version_name_or_number_or_id = ModelStages.LATEST

if isinstance(model_version_name_or_number_or_id, UUID):
return self.zen_store.get_model_version(
model_version_id=model_version_name_or_number_or_id
)
elif isinstance(model_version_name_or_number_or_id, int):
model_versions = self.list_model_versions(
model_name_or_id=model_name_or_id,
model_version_filter_model=ModelVersionFilterModel(
number=model_version_name_or_number_or_id
),
).items
elif isinstance(model_version_name_or_number_or_id, str):
if model_version_name_or_number_or_id == ModelStages.LATEST:
model_versions_page = self.list_model_versions(
model_name_or_id=model_name_or_id,
model_version_filter_model=ModelVersionFilterModel(
sort_by=f"{SorterOps.DESCENDING}:number",
),
)

if model_versions_page.size > 0:
model_versions = [model_versions_page.items[0]]
else:
model_versions = []
elif model_version_name_or_number_or_id in ModelStages.values():
model_versions = self.list_model_versions(
model_name_or_id=model_name_or_id,
model_version_filter_model=ModelVersionFilterModel(
stage=model_version_name_or_number_or_id
),
).items
else:
model_versions = self.list_model_versions(
model_name_or_id=model_name_or_id,
model_version_filter_model=ModelVersionFilterModel(
name=model_version_name_or_number_or_id
),
).items
else:
raise RuntimeError(
f"The model version identifier "
f"`{model_version_name_or_number_or_id}` is not"
f"of the correct type."
)

if len(model_versions) == 1:
return model_versions[0]
elif len(model_versions) == 0:
raise KeyError(
f"No model version found for model "
f"`{model_name_or_id}` with version identifier "
f"`{model_version_name_or_number_or_id}`."
)
else:
raise RuntimeError(
f"The model version identifier "
f"`{model_version_name_or_number_or_id}` is not"
f"unique for model `{model_name_or_id}`."
)

def list_model_versions(
self,
model_name_or_id: Union[str, UUID],
model_version_filter_model: ModelVersionFilterModel,
model_name_or_id: Optional[Union[str, UUID]] = None,
) -> Page[ModelVersionResponseModel]:
"""Get model versions by filter from Model Control Plane.

Expand Down Expand Up @@ -5570,28 +5629,21 @@ def update_model_version(

def list_model_version_artifact_links(
self,
model_name_or_id: Union[str, UUID],
model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel,
model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages],
model_version_id: UUID,
) -> Page[ModelVersionArtifactResponseModel]:
"""Get model version to artifact links by filter in Model Control Plane.

Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved.
model_version_id: id of the model version to be retrieved.
model_version_artifact_link_filter_model: All filter parameters including pagination
params.

Returns:
A page of all model version to artifact links.
"""
mv = self.zen_store.get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)
return self.zen_store.list_model_version_artifact_links(
model_name_or_id=mv.model.id,
model_version_name_or_id=mv.id,
model_version_id=model_version_id,
model_version_artifact_link_filter_model=model_version_artifact_link_filter_model,
)

Expand All @@ -5603,28 +5655,21 @@ def list_model_version_artifact_links(

def list_model_version_pipeline_run_links(
self,
model_name_or_id: Union[str, UUID],
model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel,
model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages],
model_version_id: UUID,
) -> Page[ModelVersionPipelineRunResponseModel]:
"""Get all model version to pipeline run links by filter.

Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved.
model_version_id: id of the model version to be retrieved.
model_version_pipeline_run_link_filter_model: All filter parameters including pagination
params.

Returns:
A page of all model version to pipeline run links.
"""
mv = self.zen_store.get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)
return self.zen_store.list_model_version_pipeline_run_links(
model_name_or_id=mv.model.id,
model_version_name_or_id=mv.id,
model_version_id=model_version_id,
model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model,
)

Expand Down
7 changes: 3 additions & 4 deletions src/zenml/model/artifact_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,7 @@ def _link_to_model_version(

# Create the model version artifact link using the ZenML client
existing_links = client.list_model_version_artifact_links(
model_name_or_id=model_version.model.id,
model_version_name_or_number_or_id=model_version.id,
model_version_id=model_version.id,
model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel(
user_id=client.active_user.id,
workspace_id=client.active_workspace.id,
Expand All @@ -162,9 +161,9 @@ def _link_to_model_version(
logger.warning(
f"Existing artifact link(s) `{artifact_name}` found and will be deleted."
)

client.zen_store.delete_model_version_artifact_link(
model_name_or_id=model_version.model.id,
model_version_name_or_id=model_version.id,
model_version_id=model_version.id,
model_version_artifact_link_name_or_id=artifact_name,
)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/zenml/models/model_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def _update_default_running_version_name(self) -> None:
class ModelVersionFilterModel(ModelScopedFilterModel):
"""Filter Model for Model Version."""

name: Optional[Union[str, UUID]] = Field(
name: Optional[str] = Field(
default=None,
description="The name of the Model Version",
)
Expand Down
5 changes: 2 additions & 3 deletions src/zenml/new/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,14 +945,13 @@ def delete_running_versions_without_recovery(
new_version_request.model_config.delete_new_version_on_failure
and new_version_request.model_config.was_created_in_this_run
):
model = Client().get_model_version(
model_version_model = Client().get_model_version(
model_name_or_id=model_name,
model_version_name_or_number_or_id=model_version
or constants.RUNNING_MODEL_VERSION,
)
Client().delete_model_version(
model_name_or_id=model_name,
model_version_name_or_id=model.id,
model_version_id=model_version_model.id
)

def get_runs(self, **kwargs: Any) -> List[PipelineRunResponseModel]:
Expand Down
Loading
Loading