From 21fc0e13370ed225c3c0dbb8b5b3d4d217f5e7f8 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Thu, 9 Nov 2023 15:20:23 +0100 Subject: [PATCH 01/28] Moved model versions to their own root route --- .../routers/model_versions_endpoints.py | 272 ++++++++++++++++++ .../zen_server/routers/models_endpoints.py | 194 +------------ src/zenml/zen_server/zen_server_api.py | 2 + src/zenml/zen_stores/rest_zen_store.py | 51 ++-- src/zenml/zen_stores/sql_zen_store.py | 51 ++-- src/zenml/zen_stores/zen_store_interface.py | 26 +- 6 files changed, 326 insertions(+), 270 deletions(-) create mode 100644 src/zenml/zen_server/routers/model_versions_endpoints.py diff --git a/src/zenml/zen_server/routers/model_versions_endpoints.py b/src/zenml/zen_server/routers/model_versions_endpoints.py new file mode 100644 index 00000000000..5e19be30bf8 --- /dev/null +++ b/src/zenml/zen_server/routers/model_versions_endpoints.py @@ -0,0 +1,272 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Endpoint definitions for models.""" + +from typing import Union +from uuid import UUID + +from fastapi import APIRouter, Depends, Security + +from zenml.constants import ( + API, + ARTIFACTS, + MODEL_VERSIONS, + RUNS, + VERSION_1, +) +from zenml.enums import PermissionType +from zenml.models import ( + ModelVersionArtifactFilterModel, + ModelVersionArtifactResponseModel, + ModelVersionFilterModel, + ModelVersionPipelineRunFilterModel, + ModelVersionPipelineRunResponseModel, + ModelVersionResponseModel, + ModelVersionUpdateModel, +) +from zenml.models.page_model import Page +from zenml.zen_server.auth import AuthContext, authorize +from zenml.zen_server.exceptions import error_response +from zenml.zen_server.utils import ( + handle_exceptions, + make_dependable, + zen_store, +) + +######### +# Models +######### + +router = APIRouter( + prefix=API + VERSION_1 + MODEL_VERSIONS, + tags=["model_versions"], + responses={401: error_response}, +) + +################# +# Model Versions +################# + + +@router.get( + "", + response_model=Page[ModelVersionResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_model_versions( + model_version_filter_model: ModelVersionFilterModel = Depends( + make_dependable(ModelVersionFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionResponseModel]: + """Get model versions according to query filters. + + Args: + model_version_filter_model: Filter model used for pagination, sorting, + filtering + + Returns: + The model versions according to query filters. + """ + return zen_store().list_model_versions( + model_version_filter_model=model_version_filter_model, + ) + + +@router.get( + "/{model_version_id}", + response_model=ModelVersionResponseModel, + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def get_model_version( + model_version_id: UUID, + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> ModelVersionResponseModel: + """Get a model version by ID. + + Args: + model_version_id: id of the model version to be retrieved. + is_number: If the model_version_name_or_number_or_id is a version number + + Returns: + The model version with the given name or ID. + """ + return zen_store().get_model_version( + model_name_or_id=model_version_id, + ) + + +@router.put( + "{model_version_id}", + response_model=ModelVersionResponseModel, + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def update_model_version( + model_version_id: UUID, + model_version_update_model: ModelVersionUpdateModel, + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> ModelVersionResponseModel: + """Get all model versions by filter. + + Args: + model_version_id: The ID of model version to be updated. + model_version_update_model: The model version to be updated. + + Returns: + An updated model version. + """ + return zen_store().update_model_version( + model_version_id=model_version_id, + model_version_update_model=model_version_update_model, + ) + + +@router.delete( + "{model_version_name_or_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_model_version( + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> None: + """Delete a model by name or ID. + + Args: + model_name_or_id: The name or ID of the model containing version. + model_version_name_or_id: The name or ID of the model version to delete. + """ + zen_store().delete_model_version( + model_name_or_id, model_version_name_or_id + ) + + +########################## +# Model Version Artifacts +########################## + + +@router.get( + "{model_version_id}" + ARTIFACTS, + response_model=Page[ModelVersionArtifactResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_model_version_artifact_links( + model_version_name_or_id: Union[str, UUID], + model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel = Depends( + make_dependable(ModelVersionArtifactFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionArtifactResponseModel]: + """Get model version to artifact links according to query filters. + + Args: + model_version_name_or_id: The name or ID of the model version containing links. + model_version_artifact_link_filter_model: Filter model used for pagination, sorting, + filtering + + Returns: + The model version to artifact links according to query filters. + """ + return zen_store().list_model_version_artifact_links( + model_version_id=model_version_name_or_id, + model_version_artifact_link_filter_model=model_version_artifact_link_filter_model, + ) + + +@router.delete( + "{model_version_id}" + + ARTIFACTS + + "{model_version_artifact_link_name_or_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_model_version_artifact_link( + model_version_id: UUID, + model_version_artifact_link_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> None: + """Deletes a model version link. + + Args: + model_version_id: ID of the model version containing the link. + model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. + """ + zen_store().delete_model_version_artifact_link( + model_version_id, + model_version_artifact_link_name_or_id, + ) + + +############################## +# Model Version Pipeline Runs +############################## + + +@router.get( + "{model_version_name_or_id}" + RUNS, + response_model=Page[ModelVersionPipelineRunResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_model_version_pipeline_run_links( + model_version_id: UUID, + model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel = Depends( + make_dependable(ModelVersionPipelineRunFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionPipelineRunResponseModel]: + """Get model version to pipeline run links according to query filters. + + Args: + model_version_id: ID of the model version containing the link. + model_version_pipeline_run_link_filter_model: Filter model used for pagination, sorting, + and filtering + + Returns: + The model version to pipeline run links according to query filters. + """ + return zen_store().list_model_version_pipeline_run_links( + model_version_id=model_version_id, + model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model, + ) + + +@router.delete( + "{model_version_id}" + + RUNS + + "/{model_version_pipeline_run_link_name_or_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_model_version_pipeline_run_link( + model_version_id: UUID, + model_version_pipeline_run_link_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> None: + """Deletes a model version link. + + Args: + model_version_id: name or ID of the model version containing the link. + model_version_pipeline_run_link_name_or_id: name or ID of the model version link to be deleted. + """ + zen_store().delete_model_version_pipeline_run_link( + model_version_id=model_version_id, + model_version_pipeline_run_link_name_or_id=model_version_pipeline_run_link_name_or_id, + ) diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 4f224793309..3688a4324a0 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -20,11 +20,9 @@ from zenml.constants import ( API, - ARTIFACTS, LATEST_MODEL_VERSION_PLACEHOLDER, MODEL_VERSIONS, MODELS, - RUNS, VERSION_1, ) from zenml.enums import ModelStages, PermissionType @@ -32,13 +30,8 @@ ModelFilterModel, ModelResponseModel, ModelUpdateModel, - ModelVersionArtifactFilterModel, - ModelVersionArtifactResponseModel, ModelVersionFilterModel, - ModelVersionPipelineRunFilterModel, - ModelVersionPipelineRunResponseModel, ModelVersionResponseModel, - ModelVersionUpdateModel, ) from zenml.models.page_model import Page from zenml.zen_server.auth import AuthContext, authorize @@ -171,6 +164,8 @@ def list_model_versions( ) -> Page[ModelVersionResponseModel]: """Get model versions according to query filters. + This endpoint serves the purpose of allowing scoped filtering by model_id. + Args: model_name_or_id: The name or ID of the model to list in. model_version_filter_model: Filter model used for pagination, sorting, @@ -218,188 +213,3 @@ def get_model_version( if not is_number else int(model_version_name_or_number_or_id), ) - - -@router.put( - "/{model_id}" + MODEL_VERSIONS + "/{model_version_id}", - response_model=ModelVersionResponseModel, - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def update_model_version( - model_version_id: UUID, - model_version_update_model: ModelVersionUpdateModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> ModelVersionResponseModel: - """Get all model versions by filter. - - Args: - model_version_id: The ID of model version to be updated. - model_version_update_model: The model version to be updated. - - Returns: - An updated model version. - """ - return zen_store().update_model_version( - model_version_id=model_version_id, - model_version_update_model=model_version_update_model, - ) - - -@router.delete( - "/{model_name_or_id}" + MODEL_VERSIONS + "/{model_version_name_or_id}", - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def delete_model_version( - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> None: - """Delete a model by name or ID. - - Args: - model_name_or_id: The name or ID of the model containing version. - model_version_name_or_id: The name or ID of the model version to delete. - """ - zen_store().delete_model_version( - model_name_or_id, model_version_name_or_id - ) - - -########################## -# Model Version Artifacts -########################## - - -@router.get( - "/{model_name_or_id}" - + MODEL_VERSIONS - + "/{model_version_name_or_id}" - + ARTIFACTS, - response_model=Page[ModelVersionArtifactResponseModel], - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def list_model_version_artifact_links( - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], - model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel = Depends( - make_dependable(ModelVersionArtifactFilterModel) - ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> Page[ModelVersionArtifactResponseModel]: - """Get model version to artifact links according to query filters. - - Args: - model_name_or_id: The name or ID of the model containing version. - model_version_name_or_id: The name or ID of the model version containing links. - model_version_artifact_link_filter_model: Filter model used for pagination, sorting, - filtering - - Returns: - The model version to artifact links according to query filters. - """ - return zen_store().list_model_version_artifact_links( - model_name_or_id=model_name_or_id, - model_version_name_or_id=model_version_name_or_id, - model_version_artifact_link_filter_model=model_version_artifact_link_filter_model, - ) - - -@router.delete( - "/{model_name_or_id}" - + MODEL_VERSIONS - + "/{model_version_name_or_id}" - + ARTIFACTS - + "/{model_version_artifact_link_name_or_id}", - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def delete_model_version_artifact_link( - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], - model_version_artifact_link_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> None: - """Deletes a model version link. - - Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. - model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. - """ - zen_store().delete_model_version_artifact_link( - model_name_or_id, - model_version_name_or_id, - model_version_artifact_link_name_or_id, - ) - - -############################## -# Model Version Pipeline Runs -############################## - - -@router.get( - "/{model_name_or_id}" - + MODEL_VERSIONS - + "/{model_version_name_or_id}" - + RUNS, - response_model=Page[ModelVersionPipelineRunResponseModel], - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def list_model_version_pipeline_run_links( - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], - model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel = Depends( - make_dependable(ModelVersionPipelineRunFilterModel) - ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> Page[ModelVersionPipelineRunResponseModel]: - """Get model version to pipeline run links according to query filters. - - Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. - model_version_pipeline_run_link_filter_model: Filter model used for pagination, sorting, - and filtering - - Returns: - The model version to pipeline run links according to query filters. - """ - return zen_store().list_model_version_pipeline_run_links( - model_name_or_id=model_name_or_id, - model_version_name_or_id=model_version_name_or_id, - model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model, - ) - - -@router.delete( - "/{model_name_or_id}" - + MODEL_VERSIONS - + "/{model_version_name_or_id}" - + RUNS - + "/{model_version_pipeline_run_link_name_or_id}", - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def delete_model_version_pipeline_run_link( - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], - model_version_pipeline_run_link_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> None: - """Deletes a model version link. - - Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. - model_version_pipeline_run_link_name_or_id: name or ID of the model version link to be deleted. - """ - zen_store().delete_model_version_pipeline_run_link( - model_name_or_id, - model_version_name_or_id, - model_version_pipeline_run_link_name_or_id, - ) diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index 17dafdedc41..ad069feaf3a 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -36,6 +36,7 @@ code_repositories_endpoints, devices_endpoints, flavors_endpoints, + model_versions_endpoints, models_endpoints, pipeline_builds_endpoints, pipeline_deployments_endpoints, @@ -229,6 +230,7 @@ def dashboard(request: Request) -> Any: app.include_router(pipeline_deployments_endpoints.router) app.include_router(code_repositories_endpoints.router) app.include_router(models_endpoints.router) +app.include_router(model_versions_endpoints.router) def get_root_static_files() -> List[str]: diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 1c74dc1e26a..fb71ac34171 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -2458,8 +2458,8 @@ def get_model_version( def list_model_versions( self, - model_name_or_id: Union[str, UUID], model_version_filter_model: ModelVersionFilterModel, + model_name_or_id: Optional[Union[str, UUID]] = None, ) -> Page[ModelVersionResponseModel]: """Get all model versions by filter. @@ -2471,11 +2471,18 @@ def list_model_versions( Returns: A page of all model versions. """ - return self._list_paginated_resources( - route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}", - response_model=ModelVersionResponseModel, - filter_model=model_version_filter_model, - ) + if model_name_or_id: + return self._list_paginated_resources( + route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}", + response_model=ModelVersionResponseModel, + filter_model=model_version_filter_model, + ) + else: + return self._list_paginated_resources( + route=f"{MODEL_VERSIONS}", + response_model=ModelVersionResponseModel, + filter_model=model_version_filter_model, + ) def update_model_version( self, @@ -2522,15 +2529,13 @@ def create_model_version_artifact_link( def list_model_version_artifact_links( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel, ) -> Page[ModelVersionArtifactResponseModel]: """Get all model version to artifact links by filter. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_artifact_link_filter_model: All filter parameters including pagination params. @@ -2538,27 +2543,25 @@ def list_model_version_artifact_links( A page of all model version to artifact links. """ return self._list_paginated_resources( - route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}/{model_version_name_or_id}{ARTIFACTS}", + route=f"{MODEL_VERSIONS}/{model_version_id}{ARTIFACTS}", response_model=ModelVersionArtifactResponseModel, filter_model=model_version_artifact_link_filter_model, ) def delete_model_version_artifact_link( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_artifact_link_name_or_id: Union[str, UUID], ) -> None: """Deletes a model version to artifact link. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. """ self._delete_resource( resource_id=model_version_artifact_link_name_or_id, - route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}/{model_version_name_or_id}{ARTIFACTS}", + route=f"{MODEL_VERSIONS}/{model_version_id}{ARTIFACTS}", ) ############################### @@ -2586,15 +2589,13 @@ def create_model_version_pipeline_run_link( def list_model_version_pipeline_run_links( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel, ) -> Page[ModelVersionPipelineRunResponseModel]: """Get all model version to pipeline run links by filter. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_pipeline_run_link_filter_model: All filter parameters including pagination params. @@ -2602,27 +2603,25 @@ def list_model_version_pipeline_run_links( A page of all model version to pipeline run links. """ return self._list_paginated_resources( - route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}/{model_version_name_or_id}{RUNS}", + route=f"{MODEL_VERSIONS}/{model_version_id}{RUNS}", response_model=ModelVersionPipelineRunResponseModel, filter_model=model_version_pipeline_run_link_filter_model, ) def delete_model_version_pipeline_run_link( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_pipeline_run_link_name_or_id: Union[str, UUID], ) -> None: """Deletes a model version to pipeline run link. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_pipeline_run_link_name_or_id: name or ID of the model version to pipeline run link to be deleted. """ self._delete_resource( resource_id=model_version_pipeline_run_link_name_or_id, - route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}/{model_version_name_or_id}{RUNS}", + route=f"{MODEL_VERSIONS}/{model_version_id}{RUNS}", ) # ------------------ diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index f4a504b0163..e6b26813244 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5802,8 +5802,8 @@ def get_model_version( def list_model_versions( self, - model_name_or_id: Union[str, UUID], model_version_filter_model: ModelVersionFilterModel, + model_name_or_id: Optional[Union[str, UUID]] = None, ) -> Page[ModelVersionResponseModel]: """Get all model versions by filter. @@ -5816,7 +5816,9 @@ def list_model_versions( A page of all model versions. """ with Session(self.engine) as session: - model_version_filter_model.set_scope_model(model_name_or_id) + if model_name_or_id: + model_version_filter_model.set_scope_model(model_name_or_id) + query = select(ModelVersionSchema) return self.filter_and_paginate( session=session, @@ -6058,15 +6060,13 @@ def create_model_version_artifact_link( def list_model_version_artifact_links( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: Union[UUID], model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel, ) -> Page[ModelVersionArtifactResponseModel]: """Get all model version to artifact links by filter. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_artifact_link_filter_model: All filter parameters including pagination params. @@ -6074,12 +6074,8 @@ def list_model_version_artifact_links( A page of all model version to artifact links. """ with Session(self.engine) as session: - # issue: https://github.com/tiangolo/sqlmodel/issues/109 - model_version_artifact_link_filter_model.set_scope_model( - model_name_or_id - ) model_version_artifact_link_filter_model.set_scope_model_version( - model_version_name_or_id + model_version_id ) if model_version_artifact_link_filter_model.only_artifacts: query = ( @@ -6137,25 +6133,20 @@ def list_model_version_artifact_links( def delete_model_version_artifact_link( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_artifact_link_name_or_id: Union[str, UUID], ) -> None: """Deletes a model version to artifact link. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. Raises: KeyError: specified ID or name not found. """ with Session(self.engine) as session: - self.get_model(model_name_or_id) - model_version = self.get_model_version( - model_name_or_id, model_version_name_or_id - ) + model_version = self.get_model_version(model_version_id) query = select(ModelVersionArtifactSchema).where( ModelVersionArtifactSchema.model_version_id == model_version.id ) @@ -6239,15 +6230,13 @@ def create_model_version_pipeline_run_link( def list_model_version_pipeline_run_links( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel, ) -> Page[ModelVersionPipelineRunResponseModel]: """Get all model version to pipeline run links by filter. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: name or ID of the model version containing the link. model_version_pipeline_run_link_filter_model: All filter parameters including pagination params. @@ -6255,11 +6244,8 @@ def list_model_version_pipeline_run_links( A page of all model version to pipeline run links. """ with Session(self.engine) as session: - model_version_pipeline_run_link_filter_model.set_scope_model( - model_name_or_id - ) model_version_pipeline_run_link_filter_model.set_scope_model_version( - model_version_name_or_id + model_version_id ) return self.filter_and_paginate( session=session, @@ -6270,25 +6256,20 @@ def list_model_version_pipeline_run_links( def delete_model_version_pipeline_run_link( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_pipeline_run_link_name_or_id: Union[str, UUID], ) -> None: """Deletes a model version to pipeline run link. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: name or ID of the model version containing the link. model_version_pipeline_run_link_name_or_id: name or ID of the model version to pipeline run link to be deleted. Raises: KeyError: specified ID not found. """ with Session(self.engine) as session: - self.get_model(model_name_or_id) - model_version = self.get_model_version( - model_name_or_id, model_version_name_or_id - ) + model_version = self.get_model_version(model_version_id) query = select(ModelVersionPipelineRunSchema).where( ModelVersionPipelineRunSchema.model_version_id == model_version.id diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index bc0512a9340..14ff7097f83 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -1831,8 +1831,8 @@ def get_model_version( @abstractmethod def list_model_versions( self, - model_name_or_id: Union[str, UUID], model_version_filter_model: ModelVersionFilterModel, + model_name_or_id: Optional[Union[str, UUID]] = None, ) -> Page[ModelVersionResponseModel]: """Get all model versions by filter. @@ -1888,15 +1888,13 @@ def create_model_version_artifact_link( @abstractmethod def list_model_version_artifact_links( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel, ) -> Page[ModelVersionArtifactResponseModel]: """Get all model version to artifact links by filter. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_artifact_link_filter_model: All filter parameters including pagination params. @@ -1907,15 +1905,13 @@ def list_model_version_artifact_links( @abstractmethod def delete_model_version_artifact_link( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_artifact_link_name_or_id: Union[str, UUID], ) -> None: """Deletes a model version to artifact link. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. Raises: @@ -1944,15 +1940,13 @@ def create_model_version_pipeline_run_link( @abstractmethod def list_model_version_pipeline_run_links( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel, ) -> Page[ModelVersionPipelineRunResponseModel]: """Get all model version to pipeline run links by filter. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: name or ID of the model version containing the link. model_version_pipeline_run_link_filter_model: All filter parameters including pagination params. @@ -1963,15 +1957,13 @@ def list_model_version_pipeline_run_links( @abstractmethod def delete_model_version_pipeline_run_link( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_pipeline_run_link_name_or_id: Union[str, UUID], ) -> None: """Deletes a model version to pipeline run link. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_pipeline_run_link_name_or_id: name or ID of the model version to pipeline run link to be deleted. Raises: From d4d1ee28c2dbe89641ca6edd8d13d9c120c58141 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Thu, 9 Nov 2023 15:30:50 +0100 Subject: [PATCH 02/28] Fix the delete endpoint --- .../routers/model_versions_endpoints.py | 19 +++++++--------- src/zenml/zen_stores/rest_zen_store.py | 10 ++++----- src/zenml/zen_stores/sql_zen_store.py | 22 +++++-------------- src/zenml/zen_stores/zen_store_interface.py | 6 ++--- 4 files changed, 19 insertions(+), 38 deletions(-) diff --git a/src/zenml/zen_server/routers/model_versions_endpoints.py b/src/zenml/zen_server/routers/model_versions_endpoints.py index 5e19be30bf8..ace57cc3150 100644 --- a/src/zenml/zen_server/routers/model_versions_endpoints.py +++ b/src/zenml/zen_server/routers/model_versions_endpoints.py @@ -99,7 +99,6 @@ def get_model_version( Args: model_version_id: id of the model version to be retrieved. - is_number: If the model_version_name_or_number_or_id is a version number Returns: The model version with the given name or ID. @@ -136,23 +135,21 @@ def update_model_version( @router.delete( - "{model_version_name_or_id}", + "{model_version_id}", responses={401: error_response, 404: error_response, 422: error_response}, ) @handle_exceptions def delete_model_version( - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: Union[str, UUID], _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), ) -> None: """Delete a model by name or ID. Args: - model_name_or_id: The name or ID of the model containing version. - model_version_name_or_id: The name or ID of the model version to delete. + model_version_id: The name or ID of the model version to delete. """ zen_store().delete_model_version( - model_name_or_id, model_version_name_or_id + model_version_id ) @@ -162,7 +159,7 @@ def delete_model_version( @router.get( - "{model_version_id}" + ARTIFACTS, + "/{model_version_id}" + ARTIFACTS, response_model=Page[ModelVersionArtifactResponseModel], responses={401: error_response, 404: error_response, 422: error_response}, ) @@ -191,7 +188,7 @@ def list_model_version_artifact_links( @router.delete( - "{model_version_id}" + "/{model_version_id}" + ARTIFACTS + "{model_version_artifact_link_name_or_id}", responses={401: error_response, 404: error_response, 422: error_response}, @@ -220,7 +217,7 @@ def delete_model_version_artifact_link( @router.get( - "{model_version_name_or_id}" + RUNS, + "/{model_version_name_or_id}" + RUNS, response_model=Page[ModelVersionPipelineRunResponseModel], responses={401: error_response, 404: error_response, 422: error_response}, ) @@ -249,7 +246,7 @@ def list_model_version_pipeline_run_links( @router.delete( - "{model_version_id}" + "/{model_version_id}" + RUNS + "/{model_version_pipeline_run_link_name_or_id}", responses={401: error_response, 404: error_response, 422: error_response}, diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index fb71ac34171..f9d85223411 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -2413,18 +2413,16 @@ def create_model_version( def delete_model_version( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: Union[str, UUID], ) -> None: """Deletes a model version. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_id: name or id of the model version to be deleted. + model_version_id: name or id of the model version to be deleted. """ self._delete_resource( - resource_id=model_version_name_or_id, - route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}", + resource_id=model_version_id, + route=f"{MODEL_VERSIONS}", ) def get_model_version( diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index e6b26813244..1fc86051515 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5829,37 +5829,25 @@ def list_model_versions( def delete_model_version( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, ) -> None: """Deletes a model version. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_id: name or id of the model version to be deleted. + model_version_id: name or id of the model version to be deleted. Raises: KeyError: specified ID or name not found. """ with Session(self.engine) as session: - model = self.get_model(model_name_or_id) query = select(ModelVersionSchema).where( - ModelVersionSchema.model_id == model.id - ) - try: - UUID(str(model_version_name_or_id)) - query = query.where( - ModelVersionSchema.id == model_version_name_or_id - ) - except ValueError: - query = query.where( - ModelVersionSchema.name == model_version_name_or_id + ModelVersionSchema.id == model_version_id ) model_version = session.exec(query).first() if model_version is None: raise KeyError( - f"Unable to delete model version with name `{model_version_name_or_id}`: " - f"No model version with this name found." + f"Unable to delete model version with id `{model_version_id}`: " + f"No model version with this id found." ) session.delete(model_version) session.commit() diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 14ff7097f83..e46ebea4707 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -1793,14 +1793,12 @@ def create_model_version( @abstractmethod def delete_model_version( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: Union[str, UUID], ) -> None: """Deletes a model version. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_id: name or id of the model version to be deleted. + model_version_id: id of the model version to be deleted. Raises: KeyError: specified ID or name not found. From 1e081739a20830e767985bd5b90c1638420936c1 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Thu, 9 Nov 2023 15:35:06 +0100 Subject: [PATCH 03/28] Fixed smaller mistakes --- .../zen_server/routers/model_versions_endpoints.py | 12 +++++------- src/zenml/zen_stores/sql_zen_store.py | 4 ++-- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/zenml/zen_server/routers/model_versions_endpoints.py b/src/zenml/zen_server/routers/model_versions_endpoints.py index ace57cc3150..c2381a54201 100644 --- a/src/zenml/zen_server/routers/model_versions_endpoints.py +++ b/src/zenml/zen_server/routers/model_versions_endpoints.py @@ -148,9 +148,7 @@ def delete_model_version( Args: model_version_id: The name or ID of the model version to delete. """ - zen_store().delete_model_version( - model_version_id - ) + zen_store().delete_model_version(model_version_id) ########################## @@ -165,7 +163,7 @@ def delete_model_version( ) @handle_exceptions def list_model_version_artifact_links( - model_version_name_or_id: Union[str, UUID], + model_version_id: Union[str, UUID], model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel = Depends( make_dependable(ModelVersionArtifactFilterModel) ), @@ -174,7 +172,7 @@ def list_model_version_artifact_links( """Get model version to artifact links according to query filters. Args: - model_version_name_or_id: The name or ID of the model version containing links. + model_version_id: ID of the model version containing links. model_version_artifact_link_filter_model: Filter model used for pagination, sorting, filtering @@ -182,7 +180,7 @@ def list_model_version_artifact_links( The model version to artifact links according to query filters. """ return zen_store().list_model_version_artifact_links( - model_version_id=model_version_name_or_id, + model_version_id=model_version_id, model_version_artifact_link_filter_model=model_version_artifact_link_filter_model, ) @@ -217,7 +215,7 @@ def delete_model_version_artifact_link( @router.get( - "/{model_version_name_or_id}" + RUNS, + "/{model_version_id}" + RUNS, response_model=Page[ModelVersionPipelineRunResponseModel], responses={401: error_response, 404: error_response, 422: error_response}, ) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 1fc86051515..5bd127af013 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5841,8 +5841,8 @@ def delete_model_version( """ with Session(self.engine) as session: query = select(ModelVersionSchema).where( - ModelVersionSchema.id == model_version_id - ) + ModelVersionSchema.id == model_version_id + ) model_version = session.exec(query).first() if model_version is None: raise KeyError( From f00593aec6848c098049254bead6fedecb8bf787 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Thu, 9 Nov 2023 15:51:56 +0100 Subject: [PATCH 04/28] Fixed for Client as well --- src/zenml/cli/model.py | 22 ++++++++++----- src/zenml/client.py | 43 +++++++----------------------- src/zenml/model/artifact_config.py | 3 +-- 3 files changed, 27 insertions(+), 41 deletions(-) diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index 5f72018e43d..9d6daf33939 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -479,9 +479,12 @@ def delete_model_version( return try: - Client().delete_model_version( + model_version = Client().get_model_version( model_name_or_id=model_name_or_id, - model_version_name_or_id=model_version_name_or_number_or_id, + model_version_name_or_number_or_id=model_version_name_or_number_or_id, + ) + Client().delete_model_version( + model_version_id=model_version.id, ) except (KeyError, ValueError) as e: cli_utils.error(str(e)) @@ -535,9 +538,13 @@ def _print_artifacts_links_generic( f"{type_} linked to the model version `{model_version.name}[{model_version.number}]`:" ) + model_version = Client().get_model_version( + model_name_or_id=model_name_or_id, + model_version_name_or_number_or_id=model_version_name_or_number_or_id, + ) + links = Client().list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_number_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( only_artifacts=only_artifact_objects, only_deployments=only_deployments, @@ -676,10 +683,13 @@ def list_model_version_pipeline_runs( cli_utils.title( f"Pipeline runs linked to the model version `{model_version.name}[{model_version.number}]`:" ) + model_version = Client().get_model_version( + model_name_or_id=model_name_or_id, + model_version_name_or_number_or_id=model_version_name_or_number_or_id, + ) links = Client().list_model_version_pipeline_run_links( - model_name_or_id=model_version.model.id, - model_version_name_or_number_or_id=model_version.id, + model_version_id=model_version.id, model_version_pipeline_run_link_filter_model=ModelVersionPipelineRunFilterModel( **kwargs, ), diff --git a/src/zenml/client.py b/src/zenml/client.py index 37ba6ef2d97..7128cf4670d 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -5060,18 +5060,15 @@ def create_model_version( def delete_model_version( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, ) -> None: """Deletes a model version from Model Control Plane. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_id: name or id of the model version to be deleted. + model_version_id: Id of the model version to be deleted. """ self.zen_store.delete_model_version( - model_name_or_id=model_name_or_id, - model_version_name_or_id=model_version_name_or_id, + model_version_id=model_version_id, ) def get_model_version( @@ -5098,8 +5095,8 @@ def get_model_version( def list_model_versions( self, - model_name_or_id: Union[str, UUID], model_version_filter_model: ModelVersionFilterModel, + model_name_or_id: Optional[Union[str, UUID]] = None, ) -> Page[ModelVersionResponseModel]: """Get model versions by filter from Model Control Plane. @@ -5143,31 +5140,21 @@ def update_model_version( def list_model_version_artifact_links( self, - model_name_or_id: Union[str, UUID], model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel, - model_version_name_or_number_or_id: Optional[ - Union[str, int, UUID, ModelStages] - ] = None, + model_version_id: UUID, ) -> Page[ModelVersionArtifactResponseModel]: """Get model version to artifact links by filter in Model Control Plane. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. - If skipped latest version will be retrieved. + model_version_id: id of the model version to be retrieved. model_version_artifact_link_filter_model: All filter parameters including pagination params. Returns: A page of all model version to artifact links. """ - mv = self.zen_store.get_model_version( - model_name_or_id=model_name_or_id, - model_version_name_or_number_or_id=model_version_name_or_number_or_id, - ) return self.zen_store.list_model_version_artifact_links( - model_name_or_id=mv.model.id, - model_version_name_or_id=mv.id, + model_version_id=model_version_id, model_version_artifact_link_filter_model=model_version_artifact_link_filter_model, ) @@ -5179,31 +5166,21 @@ def list_model_version_artifact_links( def list_model_version_pipeline_run_links( self, - model_name_or_id: Union[str, UUID], model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel, - model_version_name_or_number_or_id: Optional[ - Union[str, int, UUID, ModelStages] - ] = None, + model_version_id: UUID, ) -> Page[ModelVersionPipelineRunResponseModel]: """Get all model version to pipeline run links by filter. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. - If skipped latest version will be retrieved. + model_version_id: id of the model version to be retrieved. model_version_pipeline_run_link_filter_model: All filter parameters including pagination params. Returns: A page of all model version to pipeline run links. """ - mv = self.zen_store.get_model_version( - model_name_or_id=model_name_or_id, - model_version_name_or_number_or_id=model_version_name_or_number_or_id, - ) return self.zen_store.list_model_version_pipeline_run_links( - model_name_or_id=mv.model.id, - model_version_name_or_id=mv.id, + model_version_id=model_version_id, model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model, ) diff --git a/src/zenml/model/artifact_config.py b/src/zenml/model/artifact_config.py index e9ad32b2f45..a1aa4ad4fe5 100644 --- a/src/zenml/model/artifact_config.py +++ b/src/zenml/model/artifact_config.py @@ -144,8 +144,7 @@ def _link_to_model_version( # Create the model version artifact link using the ZenML client existing_links = client.list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_number_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=client.active_user.id, workspace_id=client.active_workspace.id, From c75ae0e6304f956c54b3d3eb1181fed02412fa8d Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Thu, 9 Nov 2023 16:21:33 +0100 Subject: [PATCH 05/28] Fixed lint errors --- src/zenml/model/artifact_config.py | 4 ++-- src/zenml/new/pipelines/pipeline.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/zenml/model/artifact_config.py b/src/zenml/model/artifact_config.py index a1aa4ad4fe5..8e31f0882df 100644 --- a/src/zenml/model/artifact_config.py +++ b/src/zenml/model/artifact_config.py @@ -162,9 +162,9 @@ def _link_to_model_version( logger.warning( f"Existing artifact link(s) `{artifact_name}` found and will be deleted." ) + client.zen_store.delete_model_version_artifact_link( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_name_or_id=artifact_name, ) else: diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index 0a24bd341b7..e25d4b24cce 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -952,13 +952,12 @@ def delete_running_versions_without_recovery( new_version_request.model_config.delete_new_version_on_failure and new_version_request.model_config.version is not None ): - model = Client().get_model_version( + model_version = Client().get_model_version( model_name_or_id=model_name, model_version_name_or_number_or_id=new_version_request.model_config.version, ) Client().delete_model_version( - model_name_or_id=model_name, - model_version_name_or_id=model.id, + model_version_id=model_version.id ) def get_runs(self, **kwargs: Any) -> List[PipelineRunResponseModel]: From 2d770c46c51703c1d68166e6dc35b18ffcf6140f Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Thu, 9 Nov 2023 16:32:55 +0100 Subject: [PATCH 06/28] More linting --- .../routers/workspaces_endpoints.py | 36 +++++++++---------- src/zenml/zen_stores/rest_zen_store.py | 2 +- src/zenml/zen_stores/zen_store_interface.py | 2 +- 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index f3e834d58d7..fbd5d18bad6 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -1353,7 +1353,7 @@ def create_model_version_artifact_link( WORKSPACES + "/{workspace_name_or_id}" + MODEL_VERSIONS - + "/{model_version_name_or_id}" + + "/{model_version_id}" + ARTIFACTS, response_model=Page[ModelVersionArtifactResponseModel], responses={401: error_response, 404: error_response, 422: error_response}, @@ -1361,8 +1361,7 @@ def create_model_version_artifact_link( @handle_exceptions def list_workspace_model_version_artifact_links( workspace_name_or_id: Union[str, UUID], - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: Union[str, UUID], model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel = Depends( make_dependable(ModelVersionArtifactFilterModel) ), @@ -1371,9 +1370,8 @@ def list_workspace_model_version_artifact_links( """Get model version to artifact links according to query filters. Args: - model_name_or_id: Name or ID of the model. workspace_name_or_id: Name or ID of the workspace. - model_version_name_or_id: Name or ID of the model version. + model_version_id: Name or ID of the model version. model_version_artifact_link_filter_model: Filter model used for pagination, sorting, filtering @@ -1383,8 +1381,7 @@ def list_workspace_model_version_artifact_links( workspace_id = zen_store().get_workspace(workspace_name_or_id).id model_version_artifact_link_filter_model.set_scope_workspace(workspace_id) return zen_store().list_model_version_artifact_links( - model_name_or_id=model_name_or_id, - model_version_name_or_id=model_version_name_or_id, + model_version_id=model_version_id, model_version_artifact_link_filter_model=model_version_artifact_link_filter_model, ) @@ -1392,8 +1389,6 @@ def list_workspace_model_version_artifact_links( @router.post( WORKSPACES + "/{workspace_name_or_id}" - + MODELS - + "/{model_name_or_id}" + MODEL_VERSIONS + "/{model_version_name_or_id}" + RUNS, @@ -1403,8 +1398,7 @@ def list_workspace_model_version_artifact_links( @handle_exceptions def create_model_version_pipeline_run_link( workspace_name_or_id: Union[str, UUID], - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_pipeline_run_link: ModelVersionPipelineRunRequestModel, auth_context: AuthContext = Security( authorize, scopes=[PermissionType.WRITE] @@ -1413,9 +1407,8 @@ def create_model_version_pipeline_run_link( """Create a new model version to pipeline run link. Args: - model_name_or_id: Name or ID of the model. workspace_name_or_id: Name or ID of the workspace. - model_version_name_or_id: Name or ID of the model version. + model_version_id: ID of the model version. model_version_pipeline_run_link: The model version to pipeline run link to create. auth_context: Authentication context. @@ -1429,6 +1422,12 @@ def create_model_version_pipeline_run_link( user. """ workspace = zen_store().get_workspace(workspace_name_or_id) + if model_version__id != model_version_pipeline_run_link.model_version: + raise IllegalOperationError( + f"The model version id in your path `{model_version__id}` does not " + f"match the model version specified in the request model " + f"`{model_version_pipeline_run_link.model_version}`" + ) if model_version_pipeline_run_link.workspace != workspace.id: raise IllegalOperationError( @@ -1451,7 +1450,7 @@ def create_model_version_pipeline_run_link( WORKSPACES + "/{workspace_name_or_id}" + MODEL_VERSIONS - + "/{model_version_name_or_id}" + + "/{model_version_id}" + RUNS, response_model=Page[ModelVersionPipelineRunResponseModel], responses={401: error_response, 404: error_response, 422: error_response}, @@ -1459,8 +1458,7 @@ def create_model_version_pipeline_run_link( @handle_exceptions def list_workspace_model_version_pipeline_run_links( workspace_name_or_id: Union[str, UUID], - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: Union[str, UUID], model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel = Depends( make_dependable(ModelVersionPipelineRunFilterModel) ), @@ -1469,9 +1467,8 @@ def list_workspace_model_version_pipeline_run_links( """Get model version to pipeline links according to query filters. Args: - model_name_or_id: Name or ID of the model. workspace_name_or_id: Name or ID of the workspace. - model_version_name_or_id: Name or ID of the model version. + model_version_id: ID of the model version. model_version_pipeline_run_link_filter_model: Filter model used for pagination, sorting, filtering @@ -1483,7 +1480,6 @@ def list_workspace_model_version_pipeline_run_links( workspace_id ) return zen_store().list_model_version_pipeline_run_links( - model_name_or_id=model_name_or_id, - model_version_name_or_id=model_version_name_or_id, + model_version_id=model_version_id, model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model, ) diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index f9d85223411..7d047137a8d 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -2413,7 +2413,7 @@ def create_model_version( def delete_model_version( self, - model_version_id: Union[str, UUID], + model_version_id: UUID, ) -> None: """Deletes a model version. diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index e46ebea4707..424e88af901 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -1793,7 +1793,7 @@ def create_model_version( @abstractmethod def delete_model_version( self, - model_version_id: Union[str, UUID], + model_version_id: UUID, ) -> None: """Deletes a model version. From 641f408afbdb4b00dcc4022d72d05dcedb3e6a78 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Thu, 9 Nov 2023 16:33:11 +0100 Subject: [PATCH 07/28] More linting --- src/zenml/zen_server/routers/workspaces_endpoints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index fbd5d18bad6..d850f3f4c11 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -1422,9 +1422,9 @@ def create_model_version_pipeline_run_link( user. """ workspace = zen_store().get_workspace(workspace_name_or_id) - if model_version__id != model_version_pipeline_run_link.model_version: + if model_version_id != model_version_pipeline_run_link.model_version: raise IllegalOperationError( - f"The model version id in your path `{model_version__id}` does not " + f"The model version id in your path `{model_version_id}` does not " f"match the model version specified in the request model " f"`{model_version_pipeline_run_link.model_version}`" ) From 0e78b3e6af0e45e36b44c774ae09eb85ad3e6394 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Thu, 9 Nov 2023 16:36:12 +0100 Subject: [PATCH 08/28] More tiny fixes --- .../zen_server/routers/workspaces_endpoints.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index d850f3f4c11..1d610667323 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -1295,8 +1295,6 @@ def create_model_version( @router.post( WORKSPACES + "/{workspace_name_or_id}" - + MODELS - + "/{model_name_or_id}" + MODEL_VERSIONS + "/{model_version_name_or_id}" + ARTIFACTS, @@ -1306,8 +1304,7 @@ def create_model_version( @handle_exceptions def create_model_version_artifact_link( workspace_name_or_id: Union[str, UUID], - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: Union[str, UUID], model_version_artifact_link: ModelVersionArtifactRequestModel, auth_context: AuthContext = Security( authorize, scopes=[PermissionType.WRITE] @@ -1316,9 +1313,8 @@ def create_model_version_artifact_link( """Create a new model version to artifact link. Args: - model_name_or_id: Name or ID of the model. workspace_name_or_id: Name or ID of the workspace. - model_version_name_or_id: Name or ID of the model version. + model_version_id: Name or ID of the model version. model_version_artifact_link: The model version to artifact link to create. auth_context: Authentication context. @@ -1331,6 +1327,12 @@ def create_model_version_artifact_link( user. """ workspace = zen_store().get_workspace(workspace_name_or_id) + if model_version_id != model_version_artifact_link.model_version: + raise IllegalOperationError( + f"The model version id in your path `{model_version_id}` does not " + f"match the model version specified in the request model " + f"`{model_version_artifact_link.model_version}`" + ) if model_version_artifact_link.workspace != workspace.id: raise IllegalOperationError( From b407f7191843461153ab8557b67c89e5c97cc85f Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Thu, 9 Nov 2023 16:42:32 +0100 Subject: [PATCH 09/28] Update src/zenml/zen_server/routers/model_versions_endpoints.py Co-authored-by: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> --- src/zenml/zen_server/routers/model_versions_endpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/zen_server/routers/model_versions_endpoints.py b/src/zenml/zen_server/routers/model_versions_endpoints.py index c2381a54201..abfdc352247 100644 --- a/src/zenml/zen_server/routers/model_versions_endpoints.py +++ b/src/zenml/zen_server/routers/model_versions_endpoints.py @@ -188,7 +188,7 @@ def list_model_version_artifact_links( @router.delete( "/{model_version_id}" + ARTIFACTS - + "{model_version_artifact_link_name_or_id}", + + "/{model_version_artifact_link_name_or_id}", responses={401: error_response, 404: error_response, 422: error_response}, ) @handle_exceptions From 51c11a996ce3ae8c3c5ceddb1b4403f171142e7f Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Thu, 9 Nov 2023 17:18:50 +0100 Subject: [PATCH 10/28] More linting --- src/zenml/zen_server/routers/model_versions_endpoints.py | 4 ++-- src/zenml/zen_server/routers/workspaces_endpoints.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/zenml/zen_server/routers/model_versions_endpoints.py b/src/zenml/zen_server/routers/model_versions_endpoints.py index c2381a54201..341c3eb3bee 100644 --- a/src/zenml/zen_server/routers/model_versions_endpoints.py +++ b/src/zenml/zen_server/routers/model_versions_endpoints.py @@ -140,7 +140,7 @@ def update_model_version( ) @handle_exceptions def delete_model_version( - model_version_id: Union[str, UUID], + model_version_id: UUID, _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), ) -> None: """Delete a model by name or ID. @@ -163,7 +163,7 @@ def delete_model_version( ) @handle_exceptions def list_model_version_artifact_links( - model_version_id: Union[str, UUID], + model_version_id: UUID, model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel = Depends( make_dependable(ModelVersionArtifactFilterModel) ), diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 1d610667323..97511306ce9 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -1363,7 +1363,7 @@ def create_model_version_artifact_link( @handle_exceptions def list_workspace_model_version_artifact_links( workspace_name_or_id: Union[str, UUID], - model_version_id: Union[str, UUID], + model_version_id: UUID, model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel = Depends( make_dependable(ModelVersionArtifactFilterModel) ), @@ -1460,7 +1460,7 @@ def create_model_version_pipeline_run_link( @handle_exceptions def list_workspace_model_version_pipeline_run_links( workspace_name_or_id: Union[str, UUID], - model_version_id: Union[str, UUID], + model_version_id: UUID, model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel = Depends( make_dependable(ModelVersionPipelineRunFilterModel) ), From 7dd030f1834b49f909ece78ad252fa77235c2b91 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Mon, 13 Nov 2023 15:13:01 +0100 Subject: [PATCH 11/28] Fixed tests and solved conflicts --- .../functional/model/test_artifact_config.py | 44 +++++++------------ 1 file changed, 15 insertions(+), 29 deletions(-) diff --git a/tests/integration/functional/model/test_artifact_config.py b/tests/integration/functional/model/test_artifact_config.py index e4389fc6186..c9d778980b2 100644 --- a/tests/integration/functional/model/test_artifact_config.py +++ b/tests/integration/functional/model/test_artifact_config.py @@ -86,8 +86,7 @@ def test_link_minimalistic(): mv = client.get_model_version(MODEL_NAME, ModelStages.LATEST) assert mv.name == "1" links = client.list_model_version_artifact_links( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -147,8 +146,7 @@ def test_link_multiple_named_outputs(): mv = client.get_model_version(MODEL_NAME, ModelStages.LATEST) assert mv.name == "1" al = client.list_model_version_artifact_links( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -193,8 +191,7 @@ def test_link_multiple_named_outputs_without_links(): mv = client.get_model_version(MODEL_NAME, ModelStages.LATEST) assert mv.name == "1" artifact_links = client.list_model_version_artifact_links( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -270,16 +267,14 @@ def test_link_multiple_named_outputs_with_self_context_and_caching(): multi_named_pipeline_from_self(run_count == 2) al1 = client.list_model_version_artifact_links( - model_name_or_id=mv1.model.id, - model_version_name_or_number_or_id=mv1.id, + model_version_id=mv1.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, ), ) al2 = client.list_model_version_artifact_links( - model_name_or_id=mv2.model.id, - model_version_name_or_number_or_id=mv2.id, + model_version_id=mv2.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -298,8 +293,7 @@ def test_link_multiple_named_outputs_with_self_context_and_caching(): for mv, al in zip([mv1, mv2], [al1, al2]): for al_ in al: client.zen_store.delete_model_version_artifact_link( - model_name_or_id=mv.model.id, - model_version_name_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_name_or_id=al_.id, ) @@ -391,8 +385,7 @@ def test_link_multiple_named_outputs_with_mixed_linkage(): for mv in mvs: artifact_links.append( client.list_model_version_artifact_links( - model_name_or_id=mv.model.id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -459,8 +452,7 @@ def test_link_no_versioning(): simple_pipeline_no_versioning() al1 = client.list_model_version_artifact_links( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -473,8 +465,7 @@ def test_link_no_versioning(): simple_pipeline_no_versioning() al2 = client.list_model_version_artifact_links( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -531,8 +522,7 @@ def test_link_with_versioning(): simple_pipeline_with_versioning() al1 = client.list_model_version_artifact_links( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -545,8 +535,7 @@ def test_link_with_versioning(): simple_pipeline_with_versioning() al2 = client.list_model_version_artifact_links( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -651,8 +640,7 @@ def test_link_with_manual_linkage(pipeline: Callable): pipeline() al1 = client.list_model_version_artifact_links( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -663,8 +651,7 @@ def test_link_with_manual_linkage(pipeline: Callable): assert al1[0].name == "1" al2 = client.list_model_version_artifact_links( - model_name_or_id=model2.id, - model_version_name_or_number_or_id=mv2.id, + model_version_id=mv2.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -761,8 +748,7 @@ def test_link_with_manual_linkage_flexible_config( simple_pipeline_with_manual_linkage_flexible_config(artifact_config) links = client.list_model_version_artifact_links( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -864,7 +850,7 @@ def _inner_pipeline(force_disable_cache: bool = False): def test_artifacts_linked_from_cache_steps_same_id(): """Test that artifacts are linked from cache steps with same id. - This case appears if cached step is executed inside same model version + This case appears if cached step is executed inside same model version, and we need to silently pass linkage without failing on same id. """ From 0d1cc73baaf78c29edff6dce9156145940fbab37 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Mon, 13 Nov 2023 15:47:06 +0100 Subject: [PATCH 12/28] Fixed linting --- src/zenml/new/pipelines/pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index b5d24aa0f66..b861883cd08 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -945,13 +945,13 @@ def delete_running_versions_without_recovery( new_version_request.model_config.delete_new_version_on_failure and new_version_request.model_config.was_created_in_this_run ): - model_version = Client().get_model_version( + model_version_model = Client().get_model_version( model_name_or_id=model_name, model_version_name_or_number_or_id=model_version or constants.RUNNING_MODEL_VERSION, ) Client().delete_model_version( - model_version_id=model_version.id + model_version_id=model_version_model.id ) def get_runs(self, **kwargs: Any) -> List[PipelineRunResponseModel]: From c21dfbb6ec11198c0e32e9ae2add4a9e056c566a Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Mon, 13 Nov 2023 17:53:00 +0100 Subject: [PATCH 13/28] Fixed more tests --- .../functional/zen_stores/test_zen_store.py | 61 +++++++------------ 1 file changed, 21 insertions(+), 40 deletions(-) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 67650402f6c..756da8b82bf 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -4090,19 +4090,18 @@ def test_list_not_empty(self): def test_delete_not_found(self): """Test that delete fails if not found.""" - with ModelVersionContext() as model: + with ModelVersionContext(): zs = Client().zen_store with pytest.raises(KeyError): zs.delete_model_version( - model_name_or_id=model.id, - model_version_name_or_id="1.0.0", + model_version_id=uuid4(), ) def test_delete_found(self): """Test that delete works, if model version exists.""" with ModelVersionContext() as model: zs = Client().zen_store - zs.create_model_version( + mv = zs.create_model_version( ModelVersionRequestModel( user=model.user.id, workspace=model.workspace.id, @@ -4111,8 +4110,7 @@ def test_delete_found(self): ) ) zs.delete_model_version( - model_name_or_id=model.id, - model_version_name_or_id="great one", + model_version_id=mv.id, ) with pytest.raises(KeyError): zs.get_model_version( @@ -4535,8 +4533,7 @@ def test_link_create_overwrite_deleted(self): assert al1.link_version == 1 assert al1.artifact == artifacts[0].id zs.delete_model_version_artifact_link( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_name_or_id=al1.id, ) al2 = zs.create_model_version_artifact_link( @@ -4625,8 +4622,7 @@ def test_link_create_single_version_of_same_output_name_from_different_steps( ) links = zs.list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( pipeline_name="pipeline", name="output", @@ -4654,13 +4650,11 @@ def test_link_delete_found(self): ) ) zs.delete_model_version_artifact_link( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_name_or_id="link", ) mvls = zs.list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel(), ) assert len(mvls) == 0 @@ -4670,8 +4664,7 @@ def test_link_delete_not_found(self): zs = Client().zen_store with pytest.raises(KeyError): zs.delete_model_version_artifact_link( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_name_or_id="link", ) @@ -4679,8 +4672,7 @@ def test_link_list_empty(self): with ModelVersionContext(True) as model_version: zs = Client().zen_store mvls = zs.list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel(), ) assert len(mvls) == 0 @@ -4692,8 +4684,7 @@ def test_link_list_populated(self): ): zs = Client().zen_store mvls = zs.list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel(), ) assert len(mvls) == 0 @@ -4718,15 +4709,13 @@ def test_link_list_populated(self): ) ) mvls = zs.list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel(), ) assert len(mvls) == len(artifacts) mvls = zs.list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( only_artifacts=True ), @@ -4738,8 +4727,7 @@ def test_link_list_populated(self): ) mvls = zs.list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( only_model_objects=True ), @@ -4747,8 +4735,7 @@ def test_link_list_populated(self): assert len(mvls) == 1 and mvls[0].name == "link2" mvls = zs.list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( only_deployments=True ), @@ -4882,12 +4869,9 @@ def test_link_delete_found(self): pipeline_run=prs[0].id, ) ) - zs.delete_model_version_pipeline_run_link( - model_version.model.id, model_version.id, "link" - ) + zs.delete_model_version_pipeline_run_link(model_version.id, "link") mvls = zs.list_model_version_pipeline_run_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_pipeline_run_link_filter_model=ModelVersionPipelineRunFilterModel(), ) assert len(mvls) == 0 @@ -4897,15 +4881,14 @@ def test_link_delete_not_found(self): zs = Client().zen_store with pytest.raises(KeyError): zs.delete_model_version_pipeline_run_link( - model_version.model.id, model_version.id, "link" + model_version.id, "link" ) def test_link_list_empty(self): with ModelVersionContext(True) as model_version: zs = Client().zen_store mvls = zs.list_model_version_pipeline_run_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_pipeline_run_link_filter_model=ModelVersionPipelineRunFilterModel(), ) assert len(mvls) == 0 @@ -4917,8 +4900,7 @@ def test_link_list_populated(self): ): zs = Client().zen_store mvls = zs.list_model_version_pipeline_run_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_pipeline_run_link_filter_model=ModelVersionPipelineRunFilterModel(), ) assert len(mvls) == 0 @@ -4934,8 +4916,7 @@ def test_link_list_populated(self): ) ) mvls = zs.list_model_version_pipeline_run_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_pipeline_run_link_filter_model=ModelVersionPipelineRunFilterModel(), ) assert len(mvls) == 2 From 9695b60e109170551b6ae41a7762dac6b1082d0a Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Tue, 14 Nov 2023 15:39:40 +0100 Subject: [PATCH 14/28] Further refactoring --- src/zenml/client.py | 57 +++++++++++++++-- src/zenml/models/model_models.py | 2 +- .../routers/model_versions_endpoints.py | 6 +- .../zen_server/routers/models_endpoints.py | 38 +---------- src/zenml/zen_stores/rest_zen_store.py | 26 ++------ src/zenml/zen_stores/sql_zen_store.py | 64 +++++-------------- src/zenml/zen_stores/zen_store_interface.py | 13 ++-- 7 files changed, 83 insertions(+), 123 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index e6279d4b197..fc24abe9ff9 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -5501,7 +5501,7 @@ def get_model_version( self, model_name_or_id: Union[str, UUID], model_version_name_or_number_or_id: Optional[ - Union[str, int, UUID, ModelStages] + Union[str, int, ModelStages, UUID] ] = None, ) -> ModelVersionResponseModel: """Get an existing model version from Model Control Plane. @@ -5514,11 +5514,56 @@ def get_model_version( Returns: The model version of interest. """ - return self.zen_store.get_model_version( - model_name_or_id=model_name_or_id, - model_version_name_or_number_or_id=model_version_name_or_number_or_id - or ModelStages.LATEST, - ) + if model_version_name_or_number_or_id is None: + model_version_name_or_number_or_id = ModelStages.LATEST + + if isinstance(model_version_name_or_number_or_id, UUID): + return self.zen_store.get_model_version( + model_version_id=model_version_name_or_number_or_id + ) + elif isinstance(model_version_name_or_number_or_id, int): + model_versions = self.list_model_versions( + model_name_or_id=model_name_or_id, + model_version_filter_model=ModelVersionFilterModel( + number=model_version_name_or_number_or_id + ), + ) + elif isinstance(model_version_name_or_number_or_id, str): + if model_version_name_or_number_or_id in ModelStages.values(): + model_versions = self.list_model_versions( + model_name_or_id=model_name_or_id, + model_version_filter_model=ModelVersionFilterModel( + stage=model_version_name_or_number_or_id + ), + ) + else: + model_versions = self.list_model_versions( + model_name_or_id=model_name_or_id, + model_version_filter_model=ModelVersionFilterModel( + name=model_version_name_or_number_or_id + ), + ) + else: + raise RuntimeError( + f"The model version identifier " + f"`{model_version_name_or_number_or_id}` is not" + f"of the correct type." + ) + + if len(model_versions) == 1: + return model_versions[0] + elif len(model_versions) == 0: + raise KeyError( + f"No model version found for model " + f"`{model_name_or_id}` with version identifier " + f"`{model_version_name_or_number_or_id}`." + ) + else: + raise RuntimeError( + f"The model version identifier " + f"`{model_version_name_or_number_or_id}` is not" + f"unique for model `{model_name_or_id}`." + ) def list_model_versions( self, diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 8edd9599b2c..a946ff5ad34 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -452,7 +452,7 @@ def _update_default_running_version_name(self) -> None: class ModelVersionFilterModel(ModelScopedFilterModel): """Filter Model for Model Version.""" - name: Optional[Union[str, UUID]] = Field( + name: Optional[str] = Field( default=None, description="The name of the Model Version", ) diff --git a/src/zenml/zen_server/routers/model_versions_endpoints.py b/src/zenml/zen_server/routers/model_versions_endpoints.py index c72b1980cdd..65332c614ea 100644 --- a/src/zenml/zen_server/routers/model_versions_endpoints.py +++ b/src/zenml/zen_server/routers/model_versions_endpoints.py @@ -104,12 +104,12 @@ def get_model_version( The model version with the given name or ID. """ return zen_store().get_model_version( - model_name_or_id=model_version_id, + model_version_id=model_version_id, ) @router.put( - "{model_version_id}", + "/{model_version_id}", response_model=ModelVersionResponseModel, responses={401: error_response, 404: error_response, 422: error_response}, ) @@ -135,7 +135,7 @@ def update_model_version( @router.delete( - "{model_version_id}", + "/{model_version_id}", responses={401: error_response, 404: error_response, 422: error_response}, ) @handle_exceptions diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 3688a4324a0..dae908ccdae 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -20,12 +20,11 @@ from zenml.constants import ( API, - LATEST_MODEL_VERSION_PLACEHOLDER, MODEL_VERSIONS, MODELS, VERSION_1, ) -from zenml.enums import ModelStages, PermissionType +from zenml.enums import PermissionType from zenml.models import ( ModelFilterModel, ModelResponseModel, @@ -178,38 +177,3 @@ def list_model_versions( model_name_or_id=model_name_or_id, model_version_filter_model=model_version_filter_model, ) - - -@router.get( - "/{model_name_or_id}" - + MODEL_VERSIONS - + "/{model_version_name_or_number_or_id}", - response_model=ModelVersionResponseModel, - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def get_model_version( - model_name_or_id: Union[str, UUID], - model_version_name_or_number_or_id: Union[ - str, int, UUID, ModelStages - ] = LATEST_MODEL_VERSION_PLACEHOLDER, - is_number: bool = False, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> ModelVersionResponseModel: - """Get a model version by name or ID. - - Args: - model_name_or_id: The name or ID of the model containing version. - model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. - If skipped latest version will be retrieved. - is_number: If the model_version_name_or_number_or_id is a version number - - Returns: - The model version with the given name or ID. - """ - return zen_store().get_model_version( - model_name_or_id, - model_version_name_or_number_or_id - if not is_number - else int(model_version_name_or_number_or_id), - ) diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 9515ea1b766..28c535b27c9 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -81,7 +81,6 @@ WORKSPACES, ) from zenml.enums import ( - ModelStages, OAuthGrantTypes, SecretsStoreType, StoreType, @@ -2689,32 +2688,21 @@ def delete_model_version( ) def get_model_version( - self, - model_name_or_id: Union[str, UUID], - model_version_name_or_number_or_id: Optional[ - Union[str, int, UUID, ModelStages] - ] = None, + self, model_version_id: UUID ) -> ModelVersionResponseModel: """Get an existing model version. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. - If skipped - latest is retrieved. + model_version_id: name, id, stage or number of the model version to + be retrieved. If skipped - latest is retrieved. Returns: The model version of interest. """ return self._get_resource( - resource_id=model_version_name_or_number_or_id - or ModelStages.LATEST, - route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}", + resource_id=model_version_id, + route=MODEL_VERSIONS, response_model=ModelVersionResponseModel, - params={ - "is_number": isinstance( - model_version_name_or_number_or_id, int - ) - }, ) def list_model_versions( @@ -2785,7 +2773,7 @@ def create_model_version_artifact_link( return self._create_workspace_scoped_resource( resource=model_version_artifact_link, response_model=ModelVersionArtifactResponseModel, - route=f"{MODELS}/{model_version_artifact_link.model}{MODEL_VERSIONS}/{model_version_artifact_link.model_version}{ARTIFACTS}", + route=f"{MODEL_VERSIONS}/{model_version_artifact_link.model_version}{ARTIFACTS}", ) def list_model_version_artifact_links( @@ -2845,7 +2833,7 @@ def create_model_version_pipeline_run_link( return self._create_workspace_scoped_resource( resource=model_version_pipeline_run_link, response_model=ModelVersionPipelineRunResponseModel, - route=f"{MODELS}/{model_version_pipeline_run_link.model}{MODEL_VERSIONS}/{model_version_pipeline_run_link.model_version}{RUNS}", + route=f"{MODEL_VERSIONS}/{model_version_pipeline_run_link.model_version}{RUNS}", ) def list_model_version_pipeline_run_links( diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 66ca649a9a8..53c3511090e 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -6539,18 +6539,13 @@ def create_model_version( return mv def get_model_version( - self, - model_name_or_id: Union[str, UUID], - model_version_name_or_number_or_id: Optional[ - Union[str, int, UUID, ModelStages] - ] = None, + self, model_version_id: UUID ) -> ModelVersionResponseModel: """Get an existing model version. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. - If skipped - latest is retrieved. + model_version_id: name, id, stage or number of the model version to + be retrieved. If skipped - latest is retrieved. Returns: The model version of interest. @@ -6559,47 +6554,17 @@ def get_model_version( KeyError: specified ID or name not found. """ with Session(self.engine) as session: - model = self.get_model(model_name_or_id) - query = select(ModelVersionSchema).where( - ModelVersionSchema.model_id == model.id + model_version = self._get_schema_by_name_or_id( + object_name_or_id=model_version_id, + schema_class=ModelVersionSchema, + schema_name="model_version", + session=session, ) - if model_version_name_or_number_or_id is None: - model_version_name_or_number_or_id = ModelStages.LATEST - if ( - str(model_version_name_or_number_or_id) - == ModelStages.LATEST.value - ): - query = query.order_by(ModelVersionSchema.created.desc()) # type: ignore[attr-defined] - elif model_version_name_or_number_or_id in [ - stage.value for stage in ModelStages - ]: - query = query.where( - ModelVersionSchema.stage - == model_version_name_or_number_or_id - ) - elif isinstance(model_version_name_or_number_or_id, int): - query = query.where( - ModelVersionSchema.number - == model_version_name_or_number_or_id - ) - - else: - try: - UUID(str(model_version_name_or_number_or_id)) - query = query.where( - ModelVersionSchema.id - == model_version_name_or_number_or_id - ) - except ValueError: - query = query.where( - ModelVersionSchema.name - == model_version_name_or_number_or_id - ) - model_version = session.exec(query).first() if model_version is None: raise KeyError( - f"Unable to get model version with identifier `{model_version_name_or_number_or_id}`: " - f"No model version with this identifier found." + f"Unable to get model version with ID " + f"`{model_version_id}`: No model version with this " + f"ID found." ) return ModelVersionSchema.to_model(model_version) @@ -6620,7 +6585,8 @@ def list_model_versions( """ with Session(self.engine) as session: if model_name_or_id: - model_version_filter_model.set_scope_model(model_name_or_id) + model = self.get_model(model_name_or_id) + model_version_filter_model.set_scope_model(model.id) query = select(ModelVersionSchema) return self.filter_and_paginate( @@ -7060,7 +7026,9 @@ def delete_model_version_pipeline_run_link( KeyError: specified ID not found. """ with Session(self.engine) as session: - model_version = self.get_model_version(model_version_id) + model_version = self.get_model_version( + model_version_id=model_version_id + ) query = select(ModelVersionPipelineRunSchema).where( ModelVersionPipelineRunSchema.model_version_id == model_version.id diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 1cfa343c75c..e0946e1954c 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -16,7 +16,6 @@ from typing import List, Optional, Tuple, Union from uuid import UUID -from zenml.enums import ModelStages from zenml.models import ( APIKeyFilterModel, APIKeyRequestModel, @@ -2034,18 +2033,14 @@ def delete_model_version( @abstractmethod def get_model_version( - self, - model_name_or_id: Union[str, UUID], - model_version_name_or_number_or_id: Optional[ - Union[str, int, UUID, ModelStages] - ] = None, + self, model_version_id: UUID ) -> ModelVersionResponseModel: """Get an existing model version. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. - If skipped - latest is retrieved. + model_version_id: name, id, stage or number of the model version to + be retrieved. If skipped - latest is retrieved. + Returns: The model version of interest. From 1beb80f29dd53a62fc7f58e8e356325a0277e891 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Tue, 14 Nov 2023 17:17:48 +0100 Subject: [PATCH 15/28] Added raises section --- src/zenml/client.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/zenml/client.py b/src/zenml/client.py index fc24abe9ff9..123d8537664 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -5513,6 +5513,10 @@ def get_model_version( Returns: The model version of interest. + + Raises: + RuntimeError: In case method inputs don't adhere to restrictions. + KeyError: In case no model version with the identifiers exists. """ if model_version_name_or_number_or_id is None: model_version_name_or_number_or_id = ModelStages.LATEST From d114aeac06ee045632b597af9d97cb8816dd7209 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Tue, 14 Nov 2023 17:50:03 +0100 Subject: [PATCH 16/28] Fix one failing test --- tests/integration/functional/model/test_model_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/functional/model/test_model_config.py b/tests/integration/functional/model/test_model_config.py index 9ee4c171579..211c6cce017 100644 --- a/tests/integration/functional/model/test_model_config.py +++ b/tests/integration/functional/model/test_model_config.py @@ -155,7 +155,7 @@ def test_model_fetch_model_and_version_latest(self): mc = ModelConfig(name=MODEL_NAME, version=ModelStages.LATEST) mv = mc.get_or_create_model_version() - assert mv.name == "1.0.0" + assert mv.name == "latest" def test_init_stage_logic(self): """Test that if version is set to string contained in ModelStages user is informed about it.""" From b9651383a6a3730d12972044e080af7223691212 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Tue, 14 Nov 2023 18:42:29 +0100 Subject: [PATCH 17/28] Take "latest" stage into account --- src/zenml/client.py | 11 +++++++++-- .../integration/functional/model/test_model_config.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index 123d8537664..ecaa575b86a 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -58,7 +58,7 @@ PermissionType, SecretScope, StackComponentType, - StoreType, + StoreType, SorterOps, ) from zenml.exceptions import ( AuthorizationException, @@ -5533,7 +5533,14 @@ def get_model_version( ), ) elif isinstance(model_version_name_or_number_or_id, str): - if model_version_name_or_number_or_id in ModelStages.values(): + if model_version_name_or_number_or_id == ModelStages.LATEST: + model_versions = [self.list_model_versions( + model_name_or_id=model_name_or_id, + model_version_filter_model=ModelVersionFilterModel( + sort_by=f"{SorterOps.DESCENDING}:created", + ), + ).items[0]] + elif model_version_name_or_number_or_id in ModelStages.values(): model_versions = self.list_model_versions( model_name_or_id=model_name_or_id, model_version_filter_model=ModelVersionFilterModel( diff --git a/tests/integration/functional/model/test_model_config.py b/tests/integration/functional/model/test_model_config.py index 211c6cce017..9ee4c171579 100644 --- a/tests/integration/functional/model/test_model_config.py +++ b/tests/integration/functional/model/test_model_config.py @@ -155,7 +155,7 @@ def test_model_fetch_model_and_version_latest(self): mc = ModelConfig(name=MODEL_NAME, version=ModelStages.LATEST) mv = mc.get_or_create_model_version() - assert mv.name == "latest" + assert mv.name == "1.0.0" def test_init_stage_logic(self): """Test that if version is set to string contained in ModelStages user is informed about it.""" From 5a8cc70cbfefbcc987f3b4f0d81acacbfa597e8c Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Tue, 14 Nov 2023 22:08:45 +0100 Subject: [PATCH 18/28] Reformatted --- src/zenml/client.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index ecaa575b86a..dd04915092b 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -57,8 +57,9 @@ OAuthDeviceStatus, PermissionType, SecretScope, + SorterOps, StackComponentType, - StoreType, SorterOps, + StoreType, ) from zenml.exceptions import ( AuthorizationException, @@ -5534,12 +5535,14 @@ def get_model_version( ) elif isinstance(model_version_name_or_number_or_id, str): if model_version_name_or_number_or_id == ModelStages.LATEST: - model_versions = [self.list_model_versions( - model_name_or_id=model_name_or_id, - model_version_filter_model=ModelVersionFilterModel( - sort_by=f"{SorterOps.DESCENDING}:created", - ), - ).items[0]] + model_versions = [ + self.list_model_versions( + model_name_or_id=model_name_or_id, + model_version_filter_model=ModelVersionFilterModel( + sort_by=f"{SorterOps.DESCENDING}:created", + ), + ).items[0] + ] elif model_version_name_or_number_or_id in ModelStages.values(): model_versions = self.list_model_versions( model_name_or_id=model_name_or_id, From 9a86fe6331568fdb0e9ed6e45c4f3805c9ed70f8 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Wed, 15 Nov 2023 01:27:08 +0100 Subject: [PATCH 19/28] Standardize use of list response --- src/zenml/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index dd04915092b..487eb82a78b 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -5532,7 +5532,7 @@ def get_model_version( model_version_filter_model=ModelVersionFilterModel( number=model_version_name_or_number_or_id ), - ) + ).items elif isinstance(model_version_name_or_number_or_id, str): if model_version_name_or_number_or_id == ModelStages.LATEST: model_versions = [ @@ -5549,14 +5549,14 @@ def get_model_version( model_version_filter_model=ModelVersionFilterModel( stage=model_version_name_or_number_or_id ), - ) + ).items else: model_versions = self.list_model_versions( model_name_or_id=model_name_or_id, model_version_filter_model=ModelVersionFilterModel( name=model_version_name_or_number_or_id ), - ) + ).items else: raise RuntimeError( f"The model version identifier " From 79eb3a2d62bbde6df9e12821630ce92e4163b8e3 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Wed, 15 Nov 2023 10:05:28 +0100 Subject: [PATCH 20/28] Rewrote some tests --- src/zenml/client.py | 19 +-- tests/integration/functional/test_client.py | 45 ++++++- .../functional/zen_stores/test_zen_store.py | 122 ++++++++---------- 3 files changed, 107 insertions(+), 79 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index 487eb82a78b..472748402de 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -5535,14 +5535,17 @@ def get_model_version( ).items elif isinstance(model_version_name_or_number_or_id, str): if model_version_name_or_number_or_id == ModelStages.LATEST: - model_versions = [ - self.list_model_versions( - model_name_or_id=model_name_or_id, - model_version_filter_model=ModelVersionFilterModel( - sort_by=f"{SorterOps.DESCENDING}:created", - ), - ).items[0] - ] + model_versions_page = self.list_model_versions( + model_name_or_id=model_name_or_id, + model_version_filter_model=ModelVersionFilterModel( + sort_by=f"{SorterOps.DESCENDING}:number", + ), + ) + + if model_versions_page.size > 0: + model_versions = model_versions_page.items[0] + else: + model_versions = [] elif model_version_name_or_number_or_id in ModelStages.values(): model_versions = self.list_model_versions( model_name_or_id=model_name_or_id, diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index 31f151743c0..377d7986327 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -30,7 +30,7 @@ from zenml.client import Client from zenml.config.pipeline_spec import PipelineSpec from zenml.config.source import Source -from zenml.enums import SecretScope, StackComponentType +from zenml.enums import ModelStages, SecretScope, StackComponentType from zenml.exceptions import ( EntityExistsError, IllegalOperationError, @@ -42,6 +42,7 @@ from zenml.metadata.metadata_types import MetadataTypeEnum from zenml.models import ( ComponentResponseModel, + ModelRequestModel, PipelineBuildRequestModel, PipelineDeploymentRequestModel, PipelineRequestModel, @@ -1168,3 +1169,45 @@ def test_basic_crud_for_entity( # This means the test already succeeded and deleted the entity, # nothing to do here pass + + +def test_latest_not_found(clean_client): + """Test that get latest fails if not found.""" + + cl = Client() + model_name = "super_model" + + cl.create_model( + model=ModelRequestModel( + name=model_name, + user=cl.active_user.id, + workspace=cl.active_workspace.id, + ) + ) + + with pytest.raises(KeyError): + Client().get_model_version( + model_name_or_id=model_name, + model_version_name_or_number_or_id=ModelStages.LATEST, + ) + + +def test_stage_not_found(clean_client): + """Test that get latest fails if not found.""" + + cl = Client() + model_name = "super_model" + + cl.create_model( + model=ModelRequestModel( + name=model_name, + user=cl.active_user.id, + workspace=cl.active_workspace.id, + ) + ) + + with pytest.raises(KeyError): + Client().get_model_version( + model_name_or_id=model_name, + model_version_name_or_number_or_id=ModelStages.STAGING, + ) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 71c09a7ad7c..e716b042b04 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -4045,12 +4045,11 @@ def test_create_no_model(self): def test_get_not_found(self): """Test that get fails if not found.""" - with ModelVersionContext() as model: + with ModelVersionContext(): zs = Client().zen_store with pytest.raises(KeyError): zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id="1.0.0", + model_version_id=uuid4(), ) def test_get_found(self): @@ -4065,10 +4064,12 @@ def test_get_found(self): name="great one", ) ) - mv2 = zs.get_model_version( + mv2 = zs.list_model_versions( model_name_or_id=model.id, - model_version_name_or_number_or_id="great one", - ) + model_version_filter_model=ModelVersionFilterModel( + name="great one" + ), + ).items[0] assert mv1.id == mv2.id def test_list_empty(self): @@ -4133,11 +4134,13 @@ def test_delete_found(self): zs.delete_model_version( model_version_id=mv.id, ) - with pytest.raises(KeyError): - zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id="great one", - ) + mvl = zs.list_model_versions( + model_name_or_id=model.id, + model_version_filter_model=ModelVersionFilterModel( + name="great one" + ), + ).items + assert len(mvl) == 0 def test_update_not_found(self): """Test that update fails if not found.""" @@ -4181,15 +4184,19 @@ def test_update_not_forced(self): force=False, ), ) - mv2 = zs.get_model_version( + mv2 = zs.list_model_versions( model_name_or_id=model.id, - model_version_name_or_number_or_id="staging", - ) + model_version_filter_model=ModelVersionFilterModel( + stage="staging" + ), + ).items[0] assert mv1.id == mv2.id - mv3 = zs.get_model_version( + mv3 = zs.list_model_versions( model_name_or_id=model.id, - model_version_name_or_number_or_id=ModelStages.STAGING, - ) + model_version_filter_model=ModelVersionFilterModel( + stage=ModelStages.STAGING + ), + ).items[0] assert mv1.id == mv3.id def test_in_stage_not_found(self): @@ -4205,21 +4212,14 @@ def test_in_stage_not_found(self): ) ) - with pytest.raises(KeyError): - zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=ModelStages.STAGING, - ) + mvl = zs.list_model_versions( + model_name_or_id=model.id, + model_version_filter_model=ModelVersionFilterModel( + stage=ModelStages.STAGING + ), + ).items - def test_latest_not_found(self): - """Test that get latest fails if not found.""" - with ModelVersionContext() as model: - zs = Client().zen_store - with pytest.raises(KeyError): - zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=ModelStages.LATEST, - ) + assert len(mvl) == 0 def test_latest_found(self): """Test that get latest works, if model version exists.""" @@ -4243,7 +4243,7 @@ def test_latest_found(self): ) ) found_latest = zs.get_model_version( - model_name_or_id=model.id, + model_version_id=model.id, ) assert latest.id == found_latest.id @@ -4277,8 +4277,7 @@ def test_update_forced(self): ) assert ( zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv1.name, + model_version_id=mv1.id, ).stage == "staging" ) @@ -4294,22 +4293,19 @@ def test_update_forced(self): assert ( zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv1.name, + model_version_id=mv1.id, ).stage == "archived" ) assert ( zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv2.id, + model_version_id=mv2.id, ).stage == "staging" ) assert ( zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv2.id, + model_version_id=mv2.id, ).name == "I changed that..." ) @@ -4326,18 +4322,12 @@ def test_update_public_interface(self): name=RUNNING_MODEL_VERSION, ) ) - assert ( - zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv1.name, - ).stage - is None - ) + + assert mv1.stage is None mv1.set_stage("staging") assert ( zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv1.name, + model_version_id=mv1.id, ).stage == "staging" ) @@ -4345,8 +4335,7 @@ def test_update_public_interface(self): mv1._update_default_running_version_name() assert ( zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv1.id, + model_version_id=mv1.id, ).name == "1" ) @@ -4413,10 +4402,10 @@ def test_get_found_by_number(self): """Test that get works by integer version number.""" with ModelVersionContext(create_version=True) as model_version: zs = Client().zen_store - found = zs.get_model_version( + found = zs.list_model_versions( model_name_or_id=model_version.model.id, - model_version_name_or_number_or_id=1, - ) + model_version_filter_model=ModelVersionFilterModel(number=1), + ).items[0] assert found.id == model_version.id assert found.number == 1 assert found.name == model_version.name @@ -4425,18 +4414,13 @@ def test_get_not_found_by_number(self): """Test that get fails by integer version number, if not found and by string version number, cause treated as name.""" with ModelVersionContext(create_version=True) as model_version: zs = Client().zen_store - # no version numbered as 2 - with pytest.raises(KeyError): - zs.get_model_version( - model_name_or_id=model_version.model.id, - model_version_name_or_number_or_id=2, - ) - # cannot fetch by string number - treated as name - with pytest.raises(KeyError): - zs.get_model_version( - model_name_or_id=model_version.model.id, - model_version_name_or_number_or_id="1", - ) + + found = zs.list_model_versions( + model_name_or_id=model_version.model.id, + model_version_filter_model=ModelVersionFilterModel(number=1), + ).items + + assert len(found) == 0 class TestModelVersionArtifactLinks: @@ -4764,8 +4748,7 @@ def test_link_list_populated(self): assert len(mvls) == 1 and mvls[0].name == "link3" mv = zs.get_model_version( - model_name_or_id=model_version.model.id, - model_version_name_or_number_or_id=model_version.id, + model_version_id=model_version.id, ) assert len(mv.model_object_ids) == 1 @@ -4943,8 +4926,7 @@ def test_link_list_populated(self): assert len(mvls) == 2 mv = zs.get_model_version( - model_name_or_id=model_version.model.id, - model_version_name_or_number_or_id=model_version.id, + model_version_id=model_version.id, ) assert len(mv.pipeline_run_ids) == 2 From 1d5c3ed06ceede2070dc580bc3a5fb318d11e0d6 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Wed, 15 Nov 2023 10:34:14 +0100 Subject: [PATCH 21/28] Add clien tests --- src/zenml/client.py | 2 +- tests/integration/functional/test_client.py | 126 +++++++++++++++++--- 2 files changed, 109 insertions(+), 19 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index 472748402de..a090e9475f3 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -5543,7 +5543,7 @@ def get_model_version( ) if model_versions_page.size > 0: - model_versions = model_versions_page.items[0] + model_versions = [model_versions_page.items[0]] else: model_versions = [] elif model_version_name_or_number_or_id in ModelStages.values(): diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index 377d7986327..fd9b46017f5 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -43,6 +43,9 @@ from zenml.models import ( ComponentResponseModel, ModelRequestModel, + ModelVersionRequestModel, + ModelVersionResponseModel, + ModelVersionUpdateModel, PipelineBuildRequestModel, PipelineDeploymentRequestModel, PipelineRequestModel, @@ -1171,43 +1174,130 @@ def test_basic_crud_for_entity( pass -def test_latest_not_found(clean_client): - """Test that get latest fails if not found.""" +def _create_some_model_version( + client: Client, + model_name: str = "aria_cat_supermodel", + model_version_name: str = "1.0.0", +) -> ModelVersionResponseModel: + model = client.create_model( + model=ModelRequestModel( + name=model_name, + user=client.active_user.id, + workspace=client.active_workspace.id, + ) + ) + return client.create_model_version( + ModelVersionRequestModel( + user=client.active_user.id, + workspace=client.active_workspace.id, + model=model.id, + name=model_version_name, + ) + ) + + +def test_get_by_latest(clean_client): + """Test that model version can be retrieved with latest.""" cl = Client() - model_name = "super_model" + mv1 = _create_some_model_version(client=cl) - cl.create_model( - model=ModelRequestModel( - name=model_name, + # latest returns the only model + mv2 = Client().get_model_version( + model_name_or_id=mv1.model.id, + model_version_name_or_number_or_id=ModelStages.LATEST, + ) + assert mv2 == mv1 + + # after second model version, latest should point to it + mv3 = cl.create_model_version( + ModelVersionRequestModel( user=cl.active_user.id, workspace=cl.active_workspace.id, + model=mv1.model.id, + name="2.0.0", ) ) + mv4 = Client().get_model_version( + model_name_or_id=mv1.model.id, + model_version_name_or_number_or_id=ModelStages.LATEST, + ) + assert mv4 != mv1 + assert mv4 == mv3 + + +def test_get_by_stage(clean_client): + """Test that model version can be retrived by stage.""" + + cl = Client() + mv1 = _create_some_model_version(client=cl) + + cl.update_model_version( + model_version_id=mv1.id, + model_version_update_model=ModelVersionUpdateModel( + model=mv1.model.id, stage=ModelStages.STAGING, force=True + ), + ) + + mv2 = cl.get_model_version( + model_name_or_id=mv1.model.id, + model_version_name_or_number_or_id=ModelStages.STAGING, + ) + + assert mv1 == mv2 + + +def test_stage_not_found(clean_client): + """Test that attempting to get model version fails if none at the given stage.""" + + cl = Client() + mv1 = _create_some_model_version(client=cl) with pytest.raises(KeyError): Client().get_model_version( - model_name_or_id=model_name, - model_version_name_or_number_or_id=ModelStages.LATEST, + model_name_or_id=mv1.model.id, + model_version_name_or_number_or_id=ModelStages.STAGING, ) -def test_stage_not_found(clean_client): - """Test that get latest fails if not found.""" +def test_get_model_version_by_name(clean_client): + """Test that model version can be retrieved by its name.""" cl = Client() - model_name = "super_model" + model_name = "aria_cat_super_model" + model_version_name = "1.0" - cl.create_model( - model=ModelRequestModel( - name=model_name, - user=cl.active_user.id, - workspace=cl.active_workspace.id, - ) + mv1 = _create_some_model_version( + client=cl, model_name=model_name, model_version_name=model_version_name ) + mv2 = Client().get_model_version( + model_name_or_id=model_name, + model_version_name_or_number_or_id=model_version_name, + ) + assert mv1 == mv2 + with pytest.raises(KeyError): Client().get_model_version( model_name_or_id=model_name, - model_version_name_or_number_or_id=ModelStages.STAGING, + model_version_name_or_number_or_id="blupus_cat_super_model", + ) + + +def test_get_model_version_by_index(clean_client): + """Test that model version can be retrieved by its index.""" + + cl = Client() + mv1 = _create_some_model_version(client=cl) + + mv2 = cl.get_model_version( + model_name_or_id=mv1.model.id, + model_version_name_or_number_or_id=1, + ) + assert mv1 == mv2 + + with pytest.raises(KeyError): + cl.get_model_version( + model_name_or_id=mv1.model.id, + model_version_name_or_number_or_id=2, ) From 33176ba5f3f44706b2fe45efbc2508bd00b295d5 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Wed, 15 Nov 2023 12:01:40 +0100 Subject: [PATCH 22/28] Fixed spelling --- tests/integration/functional/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index fd9b46017f5..edee717f8ba 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -1227,7 +1227,7 @@ def test_get_by_latest(clean_client): def test_get_by_stage(clean_client): - """Test that model version can be retrived by stage.""" + """Test that model version can be retrieved by stage.""" cl = Client() mv1 = _create_some_model_version(client=cl) From bf897109cb25c68b7f38127a9e93b32d8fab6ef4 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Wed, 15 Nov 2023 17:58:40 +0100 Subject: [PATCH 23/28] Tested to work with e2e pipeline --- src/zenml/models/model_models.py | 5 ++++- src/zenml/zen_server/routers/workspaces_endpoints.py | 11 +++++++---- src/zenml/zen_stores/rest_zen_store.py | 4 ++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 0c73c00341c..43957fc3693 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -223,7 +223,7 @@ def to_model_version( """ from zenml.model.model_version import ModelVersion - return ModelVersion( + mv = ModelVersion( name=self.model.name, license=self.model.license, description=self.description, @@ -237,6 +237,9 @@ def to_model_version( was_created_in_this_run=was_created_in_this_run, suppress_class_validation_warnings=suppress_class_validation_warnings, ) + mv._id = self.id + + return mv @property def model_artifacts(self) -> Dict[str, Dict[str, ArtifactResponseModel]]: diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 97511306ce9..db463cfdc00 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -1296,7 +1296,7 @@ def create_model_version( WORKSPACES + "/{workspace_name_or_id}" + MODEL_VERSIONS - + "/{model_version_name_or_id}" + + "/{model_version_id}" + ARTIFACTS, response_model=ModelVersionArtifactResponseModel, responses={401: error_response, 409: error_response, 422: error_response}, @@ -1327,7 +1327,8 @@ def create_model_version_artifact_link( user. """ workspace = zen_store().get_workspace(workspace_name_or_id) - if model_version_id != model_version_artifact_link.model_version: + if str(model_version_id) != str(model_version_artifact_link.model_version): + breakpoint() raise IllegalOperationError( f"The model version id in your path `{model_version_id}` does not " f"match the model version specified in the request model " @@ -1392,7 +1393,7 @@ def list_workspace_model_version_artifact_links( WORKSPACES + "/{workspace_name_or_id}" + MODEL_VERSIONS - + "/{model_version_name_or_id}" + + "/{model_version_id}" + RUNS, response_model=ModelVersionPipelineRunResponseModel, responses={401: error_response, 409: error_response, 422: error_response}, @@ -1424,7 +1425,9 @@ def create_model_version_pipeline_run_link( user. """ workspace = zen_store().get_workspace(workspace_name_or_id) - if model_version_id != model_version_pipeline_run_link.model_version: + if str(model_version_id) != str( + model_version_pipeline_run_link.model_version + ): raise IllegalOperationError( f"The model version id in your path `{model_version_id}` does not " f"match the model version specified in the request model " diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 28c535b27c9..bc0593d185f 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -2728,7 +2728,7 @@ def list_model_versions( ) else: return self._list_paginated_resources( - route=f"{MODEL_VERSIONS}", + route=MODEL_VERSIONS, response_model=ModelVersionResponseModel, filter_model=model_version_filter_model, ) @@ -2751,7 +2751,7 @@ def update_model_version( return self._update_resource( resource_id=model_version_id, resource_update=model_version_update_model, - route=f"{MODELS}/{model_version_update_model.model}{MODEL_VERSIONS}", + route=MODEL_VERSIONS, response_model=ModelVersionResponseModel, ) From 1c72fd0d828f53f0378007ed202d71e267a6ee5e Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Wed, 15 Nov 2023 20:58:38 +0100 Subject: [PATCH 24/28] Ugly fixes to get response models in the CLI --- src/zenml/cli/model.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index f331ff4de52..2b2921d3c5b 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -426,12 +426,15 @@ def update_model_version( ) except RuntimeError: if not force: + mv = Client().get_model_version( + model_name_or_id=model_version.model_id, + model_version_name_or_number_or_id=stage, + ) cli_utils.print_table( [ _model_version_to_print( Client().zen_store.get_model_version( - model_name_or_id=model_version.model_id, - model_version_name_or_number_or_id=stage, + model_version_id=mv.id ) ) ] @@ -521,12 +524,15 @@ def _print_artifacts_links_generic( only_model_artifacts: If set, only print model artifacts. **kwargs: Keyword arguments to filter models. """ - model_version = Client().zen_store.get_model_version( + model_version = Client().get_model_version( model_name_or_id=model_name_or_id, model_version_name_or_number_or_id=ModelStages.LATEST if model_version_name_or_number_or_id == "0" else model_version_name_or_number_or_id, ) + model_version_response_model = Client().zen_store.get_model_version( + model_version_id=model_version.id + ) type_ = ( "data artifacts" if only_data_artifacts @@ -536,11 +542,18 @@ def _print_artifacts_links_generic( ) if ( - (only_data_artifacts and not model_version.data_artifact_ids) + ( + only_data_artifacts + and not model_version_response_model.data_artifact_ids + ) or ( - only_endpoint_artifacts and not model_version.endpoint_artifact_ids + only_endpoint_artifacts + and not model_version_response_model.endpoint_artifact_ids + ) + or ( + only_model_artifacts + and not model_version_response_model.model_artifact_ids ) - or (only_model_artifacts and not model_version.model_artifact_ids) ): cli_utils.declare(f"No {type_} linked to the model version found.") return @@ -681,18 +694,21 @@ def list_model_version_pipeline_runs( Or use 0 for the latest version. **kwargs: Keyword arguments to filter models. """ - model_version = Client().zen_store.get_model_version( + model_version = Client().get_model_version( model_name_or_id=model_name_or_id, model_version_name_or_number_or_id=ModelStages.LATEST if model_version_name_or_number_or_id == "0" else model_version_name_or_number_or_id, ) + model_version_response_model = Client().zen_store.get_model_version( + model_version_id=model_version.id + ) - if not model_version.pipeline_run_ids: + if not model_version_response_model.pipeline_run_ids: cli_utils.declare("No pipeline runs attached to model version found.") return cli_utils.title( - f"Pipeline runs linked to the model version `{model_version.name}[{model_version.number}]`:" + f"Pipeline runs linked to the model version `{model_version_response_model.name}[{model_version_response_model.number}]`:" ) model_version = Client().get_model_version( model_name_or_id=model_name_or_id, From 1d539e0b8f4b345a02ce29f741e98882e9cf3c34 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Wed, 15 Nov 2023 20:41:59 +0000 Subject: [PATCH 25/28] Auto-update of E2E template --- examples/e2e/artifacts/__init__.py | 16 --- examples/e2e/artifacts/materializer.py | 103 ----------------- examples/e2e/artifacts/model_metadata.py | 77 ------------- examples/e2e/config.py | 80 ------------- examples/e2e/inference_config.yaml | 35 ------ .../inference_get_current_version.py | 49 -------- .../e2e/steps/promotion/promote_get_metric.py | 68 ------------ .../steps/promotion/promote_get_versions.py | 69 ------------ .../promote_metric_compare_promoter.py | 102 ----------------- ...tric_compare_promoter_in_model_registry.py | 105 ------------------ ...te_model_version_in_model_control_plane.py | 43 ------- examples/e2e/train_config.yaml | 94 ---------------- examples/e2e/utils/misc.py | 48 -------- examples/e2e/utils/model_versions.py | 62 ----------- 14 files changed, 951 deletions(-) delete mode 100644 examples/e2e/artifacts/__init__.py delete mode 100644 examples/e2e/artifacts/materializer.py delete mode 100644 examples/e2e/artifacts/model_metadata.py delete mode 100644 examples/e2e/config.py delete mode 100644 examples/e2e/inference_config.yaml delete mode 100644 examples/e2e/steps/inference/inference_get_current_version.py delete mode 100644 examples/e2e/steps/promotion/promote_get_metric.py delete mode 100644 examples/e2e/steps/promotion/promote_get_versions.py delete mode 100644 examples/e2e/steps/promotion/promote_metric_compare_promoter.py delete mode 100644 examples/e2e/steps/promotion/promote_metric_compare_promoter_in_model_registry.py delete mode 100644 examples/e2e/steps/promotion/promote_model_version_in_model_control_plane.py delete mode 100644 examples/e2e/train_config.yaml delete mode 100644 examples/e2e/utils/misc.py delete mode 100644 examples/e2e/utils/model_versions.py diff --git a/examples/e2e/artifacts/__init__.py b/examples/e2e/artifacts/__init__.py deleted file mode 100644 index 9df569975a0..00000000000 --- a/examples/e2e/artifacts/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Apache Software License 2.0 -# -# Copyright (c) ZenML GmbH 2023. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/examples/e2e/artifacts/materializer.py b/examples/e2e/artifacts/materializer.py deleted file mode 100644 index f62720593e9..00000000000 --- a/examples/e2e/artifacts/materializer.py +++ /dev/null @@ -1,103 +0,0 @@ -# Apache Software License 2.0 -# -# Copyright (c) ZenML GmbH 2023. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import json -import os -from typing import Type - -from artifacts.model_metadata import ModelMetadata -from zenml.enums import ArtifactType -from zenml.io import fileio -from zenml.materializers.base_materializer import BaseMaterializer - - -class ModelMetadataMaterializer(BaseMaterializer): - ASSOCIATED_TYPES = (ModelMetadata,) - ASSOCIATED_ARTIFACT_TYPE = ArtifactType.STATISTICS - - def load(self, data_type: Type[ModelMetadata]) -> ModelMetadata: - """Read from artifact store. - - Args: - data_type: What type the artifact data should be loaded as. - - Raises: - ValueError: on deserialization issue - - Returns: - Read artifact. - """ - super().load(data_type) - - ### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ### - import sklearn.ensemble - import sklearn.linear_model - import sklearn.tree - - modules = [sklearn.ensemble, sklearn.linear_model, sklearn.tree] - - with fileio.open(os.path.join(self.uri, "data.json"), "r") as f: - data_json = json.loads(f.read()) - class_name = data_json["model_class"] - cls = None - for module in modules: - if cls := getattr(module, class_name, None): - break - if cls is None: - raise ValueError( - f"Cannot deserialize `{class_name}` using {self.__class__.__name__}. " - f"Only classes from modules {[m.__name__ for m in modules]} " - "are supported" - ) - data = ModelMetadata(cls) - if "search_grid" in data_json: - data.search_grid = data_json["search_grid"] - if "params" in data_json: - data.params = data_json["params"] - if "metric" in data_json: - data.metric = data_json["metric"] - ### YOUR CODE ENDS HERE ### - - return data - - def save(self, data: ModelMetadata) -> None: - """Write to artifact store. - - Args: - data: The data of the artifact to save. - """ - super().save(data) - - ### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ### - # Dump the model metadata directly into the artifact store as a JSON file - data_json = dict() - with fileio.open(os.path.join(self.uri, "data.json"), "w") as f: - data_json["model_class"] = data.model_class.__name__ - if data.search_grid: - data_json["search_grid"] = {} - for k, v in data.search_grid.items(): - if type(v) == range: - data_json["search_grid"][k] = list(v) - else: - data_json["search_grid"][k] = v - if data.params: - data_json["params"] = data.params - if data.metric: - data_json["metric"] = data.metric - f.write(json.dumps(data_json)) - ### YOUR CODE ENDS HERE ### diff --git a/examples/e2e/artifacts/model_metadata.py b/examples/e2e/artifacts/model_metadata.py deleted file mode 100644 index 67abcde90dd..00000000000 --- a/examples/e2e/artifacts/model_metadata.py +++ /dev/null @@ -1,77 +0,0 @@ -# Apache Software License 2.0 -# -# Copyright (c) ZenML GmbH 2023. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -from typing import Any, Dict - -from sklearn.base import ClassifierMixin - - -class ModelMetadata: - """A custom artifact that stores model metadata. - - A model metadata object gathers together information that is collected - about the model being trained in a training pipeline run. This data type - is used for one of the artifacts returned by the model evaluation step. - - This is an example of a *custom artifact data type*: a type returned by - one of the pipeline steps that isn't natively supported by the ZenML - framework. Custom artifact data types are a common occurrence in ZenML, - usually encountered in one of the following circumstances: - - - you use a third party library that is not covered as a ZenML integration - and you model one or more step artifacts from the data types provided by - this library (e.g. datasets, models, data validation profiles, model - evaluation results/reports etc.) - - you need to use one of your own data types as a step artifact and it is - not one of the basic Python artifact data types supported by the ZenML - framework (e.g. str, int, float, dictionaries, lists, etc.) - - you want to extend one of the artifact data types already natively - supported by ZenML (e.g. pandas.DataFrame or sklearn.ClassifierMixin) - to customize it with your own data and/or behavior. - - In all above cases, the ZenML framework lacks one very important piece of - information: it doesn't "know" how to convert the data into a format that - can be saved in the artifact store (e.g. on a filesystem or persistent - storage service like S3 or GCS). Saving and loading artifacts from the - artifact store is something called "materialization" in ZenML terms and - you need to provide this missing information in the form of a custom - materializer - a class that implements loading/saving artifacts from/to - the artifact store. Take a look at the `materializers` folder to see how a - custom materializer is implemented for this artifact data type. - - More information about custom step artifact data types and ZenML - materializers is available in the docs: - - https://docs.zenml.io/user-guide/advanced-guide/artifact-management/handle-custom-data-types - - """ - - ### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ### - def __init__( - self, - model_class: ClassifierMixin, - search_grid: Dict[str, Any] = None, - params: Dict[str, Any] = None, - metric: float = None, - ) -> None: - self.model_class = model_class - self.search_grid = search_grid - self.params = params - self.metric = metric - - ### YOUR CODE ENDS HERE ### diff --git a/examples/e2e/config.py b/examples/e2e/config.py deleted file mode 100644 index 1201d896d41..00000000000 --- a/examples/e2e/config.py +++ /dev/null @@ -1,80 +0,0 @@ -# Apache Software License 2.0 -# -# Copyright (c) ZenML GmbH 2023. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -from artifacts.model_metadata import ModelMetadata -from pydantic import BaseConfig -from sklearn.ensemble import RandomForestClassifier -from sklearn.tree import DecisionTreeClassifier - -from zenml.config import DockerSettings -from zenml.integrations.constants import ( - AWS, - EVIDENTLY, - KUBEFLOW, - KUBERNETES, - MLFLOW, - SKLEARN, - SLACK, -) -from zenml.model_registries.base_model_registry import ModelVersionStage - -PIPELINE_SETTINGS = dict( - docker=DockerSettings( - required_integrations=[ - AWS, - EVIDENTLY, - KUBEFLOW, - KUBERNETES, - MLFLOW, - SKLEARN, - SLACK, - ], - ) -) - -DEFAULT_PIPELINE_EXTRAS = dict(notify_on_success=False, notify_on_failure=True) - - -class MetaConfig(BaseConfig): - pipeline_name_training = "e2e_use_case_training" - pipeline_name_batch_inference = "e2e_use_case_batch_inference" - mlflow_model_name = "e2e_use_case_model" - target_env = ModelVersionStage.STAGING - - ### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ### - # This set contains all the models that you want to evaluate - # during hyperparameter tuning stage. - model_search_space = { - ModelMetadata( - RandomForestClassifier, - search_grid=dict( - criterion=["gini", "entropy"], - max_depth=[2, 4, 6, 8, 10, 12], - min_samples_leaf=range(1, 10), - n_estimators=range(50, 500, 25), - ), - ), - ModelMetadata( - DecisionTreeClassifier, - search_grid=dict( - criterion=["gini", "entropy"], - max_depth=[2, 4, 6, 8, 10, 12], - min_samples_leaf=range(1, 10), - ), - ), - } diff --git a/examples/e2e/inference_config.yaml b/examples/e2e/inference_config.yaml deleted file mode 100644 index 31a7eb1f86f..00000000000 --- a/examples/e2e/inference_config.yaml +++ /dev/null @@ -1,35 +0,0 @@ -# Apache Software License 2.0 -# -# Copyright (c) ZenML GmbH 2023. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -settings: - docker: - required_integrations: - - aws - - evidently - - kubeflow - - kubernetes - - mlflow - - sklearn - - slack -extra: - mlflow_model_name: e2e_use_case - target_env: Staging - notify_on_success: False - notify_on_failure: True -model_config: - name: e2e_use_case - version: staging diff --git a/examples/e2e/steps/inference/inference_get_current_version.py b/examples/e2e/steps/inference/inference_get_current_version.py deleted file mode 100644 index 7b3bbcdf353..00000000000 --- a/examples/e2e/steps/inference/inference_get_current_version.py +++ /dev/null @@ -1,49 +0,0 @@ -# Apache Software License 2.0 -# -# Copyright (c) ZenML GmbH 2023. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -from config import MetaConfig -from typing_extensions import Annotated - -from zenml import step -from zenml.client import Client -from zenml.logger import get_logger - -logger = get_logger(__name__) - -model_registry = Client().active_stack.model_registry - - -@step -def inference_get_current_version() -> Annotated[str, "model_version"]: - """Get currently tagged model version for deployment. - - Returns: - The model version of currently tagged model in Registry. - """ - - ### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ### - - current_version = model_registry.list_model_versions( - name=MetaConfig.mlflow_model_name, - stage=MetaConfig.target_env, - )[0].version - logger.info( - f"Current model version in `{MetaConfig.target_env.value}` is `{current_version}`" - ) - - return current_version diff --git a/examples/e2e/steps/promotion/promote_get_metric.py b/examples/e2e/steps/promotion/promote_get_metric.py deleted file mode 100644 index 63ed25aa06c..00000000000 --- a/examples/e2e/steps/promotion/promote_get_metric.py +++ /dev/null @@ -1,68 +0,0 @@ -# Apache Software License 2.0 -# -# Copyright (c) ZenML GmbH 2023. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import pandas as pd -from sklearn.metrics import accuracy_score -from typing_extensions import Annotated - -from zenml import step -from zenml.client import Client -from zenml.integrations.mlflow.services import MLFlowDeploymentService -from zenml.logger import get_logger - -logger = get_logger(__name__) - -model_registry = Client().active_stack.model_registry - - -@step -def promote_get_metric( - dataset_tst: pd.DataFrame, - deployment_service: MLFlowDeploymentService, -) -> Annotated[float, "metric"]: - """Get metric for comparison for one model deployment. - - This is an example of a metric calculation step. It get a model deployment - service and computes metric on recent test dataset. - - This step is parameterized, which allows you to configure the step - independently of the step code, before running it in a pipeline. - In this example, the step can be configured to use different input data. - See the documentation for more information: - - https://docs.zenml.io/user-guide/advanced-guide/configure-steps-pipelines - - Args: - dataset_tst: The test dataset. - deployment_service: Model version deployment. - - Returns: - Metric value for a given deployment on test set. - - """ - - ### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ### - X = dataset_tst.drop(columns=["target"]) - y = dataset_tst["target"].to_numpy() - logger.info("Evaluating model metrics...") - - predictions = deployment_service.predict(request=X) - metric = accuracy_score(y, predictions) - deployment_service.deprovision(force=True) - ### YOUR CODE ENDS HERE ### - return metric diff --git a/examples/e2e/steps/promotion/promote_get_versions.py b/examples/e2e/steps/promotion/promote_get_versions.py deleted file mode 100644 index 95920da7eca..00000000000 --- a/examples/e2e/steps/promotion/promote_get_versions.py +++ /dev/null @@ -1,69 +0,0 @@ -# Apache Software License 2.0 -# -# Copyright (c) ZenML GmbH 2023. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -from typing import Tuple - -from typing_extensions import Annotated - -from zenml import get_step_context, step -from zenml.client import Client -from zenml.logger import get_logger -from zenml.model_registries.base_model_registry import ModelVersionStage - -logger = get_logger(__name__) - -model_registry = Client().active_stack.model_registry - - -@step -def promote_get_versions() -> ( - Tuple[Annotated[str, "latest_version"], Annotated[str, "current_version"]] -): - """Step to get latest and currently tagged model version from Model Registry. - - This is an example of a model version extraction step. It will retrieve 2 model - versions from Model Registry: latest and currently promoted to target - environment (Production, Staging, etc). - - Returns: - The model versions: latest and current. If not current version - returns same - for both. - """ - - ### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ### - pipeline_extra = get_step_context().pipeline_run.config.extra - none_versions = model_registry.list_model_versions( - name=pipeline_extra["mlflow_model_name"], - stage=None, - ) - latest_versions = none_versions[0].version - logger.info(f"Latest model version is {latest_versions}") - - target_versions = model_registry.list_model_versions( - name=pipeline_extra["mlflow_model_name"], - stage=ModelVersionStage(pipeline_extra["target_env"]), - ) - current_version = latest_versions - if target_versions: - current_version = target_versions[0].version - logger.info(f"Currently promoted model version is {current_version}") - else: - logger.info("No currently promoted model version found.") - ### YOUR CODE ENDS HERE ### - - return latest_versions, current_version diff --git a/examples/e2e/steps/promotion/promote_metric_compare_promoter.py b/examples/e2e/steps/promotion/promote_metric_compare_promoter.py deleted file mode 100644 index 18fd335a995..00000000000 --- a/examples/e2e/steps/promotion/promote_metric_compare_promoter.py +++ /dev/null @@ -1,102 +0,0 @@ -# Apache Software License 2.0 -# -# Copyright (c) ZenML GmbH 2023. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -from config import MetaConfig - -from zenml import step -from zenml.client import Client -from zenml.logger import get_logger -from zenml.model_registries.base_model_registry import ModelVersionStage - -logger = get_logger(__name__) - -model_registry = Client().active_stack.model_registry - - -@step -def promote_metric_compare_promoter( - latest_metric: float, - current_metric: float, - latest_version: str, - current_version: str, -): - """Try to promote trained model. - - This is an example of a model promotion step. It gets precomputed - metrics for 2 model version: latest and currently promoted to target environment - (Production, Staging, etc) and compare than in order to define - if newly trained model is performing better or not. If new model - version is better by metric - it will get relevant - tag, otherwise previously promoted model version will remain. - - If the latest version is the only one - it will get promoted automatically. - - This step is parameterized, which allows you to configure the step - independently of the step code, before running it in a pipeline. - In this example, the step can be configured to use different input data. - See the documentation for more information: - - https://docs.zenml.io/user-guide/advanced-guide/configure-steps-pipelines - - Args: - latest_metric: Recently trained model metric results. - current_metric: Previously promoted model metric results. - latest_version: Recently trained model version. - current_version:Previously promoted model version. - - """ - - ### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ### - should_promote = True - - if latest_version == current_version: - logger.info("No current model version found - promoting latest") - else: - logger.info( - f"Latest model metric={latest_metric:.6f}\n" - f"Current model metric={current_metric:.6f}" - ) - if latest_metric > current_metric: - logger.info( - "Latest model versions outperformed current versions - promoting latest" - ) - else: - logger.info( - "Current model versions outperformed latest versions - keeping current" - ) - should_promote = False - - promoted_version = current_version - if should_promote: - if latest_version != current_version: - model_registry.update_model_version( - name=MetaConfig.mlflow_model_name, - version=current_version, - stage=ModelVersionStage.ARCHIVED, - ) - model_registry.update_model_version( - name=MetaConfig.mlflow_model_name, - version=latest_version, - stage=MetaConfig.target_env, - ) - promoted_version = latest_version - - logger.info( - f"Current model version in `{MetaConfig.target_env.value}` is `{promoted_version}`" - ) - ### YOUR CODE ENDS HERE ### diff --git a/examples/e2e/steps/promotion/promote_metric_compare_promoter_in_model_registry.py b/examples/e2e/steps/promotion/promote_metric_compare_promoter_in_model_registry.py deleted file mode 100644 index 0973f6fe5ee..00000000000 --- a/examples/e2e/steps/promotion/promote_metric_compare_promoter_in_model_registry.py +++ /dev/null @@ -1,105 +0,0 @@ -# Apache Software License 2.0 -# -# Copyright (c) ZenML GmbH 2023. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing_extensions import Annotated, Tuple - -from zenml import get_step_context, step -from zenml.client import Client -from zenml.logger import get_logger -from zenml.model_registries.base_model_registry import ModelVersionStage - -logger = get_logger(__name__) - -model_registry = Client().active_stack.model_registry - - -@step -def promote_metric_compare_promoter_in_model_registry( - latest_metric: float, - current_metric: float, - latest_version: str, - current_version: str, -) -> Tuple[ - Annotated[bool, "was_promoted"], Annotated[int, "promoted_version"] -]: - """Try to promote trained model. - - This is an example of a model promotion step. It gets precomputed - metrics for 2 model version: latest and currently promoted to target environment - (Production, Staging, etc) and compare than in order to define - if newly trained model is performing better or not. If new model - version is better by metric - it will get relevant - tag, otherwise previously promoted model version will remain. - - If the latest version is the only one - it will get promoted automatically. - - This step is parameterized, which allows you to configure the step - independently of the step code, before running it in a pipeline. - In this example, the step can be configured to use different input data. - See the documentation for more information: - - https://docs.zenml.io/user-guide/advanced-guide/configure-steps-pipelines - - Args: - latest_metric: Recently trained model metric results. - current_metric: Previously promoted model metric results. - latest_version: Recently trained model version. - current_version:Previously promoted model version. - - """ - - ### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ### - pipeline_extra = get_step_context().pipeline_run.config.extra - should_promote = True - - if latest_version == current_version: - logger.info("No current model version found - promoting latest") - else: - logger.info( - f"Latest model metric={latest_metric:.6f}\n" - f"Current model metric={current_metric:.6f}" - ) - if latest_metric >= current_metric: - logger.info( - "Latest model versions outperformed current versions - promoting latest" - ) - else: - logger.info( - "Current model versions outperformed latest versions - keeping current" - ) - should_promote = False - - promoted_version = current_version - if should_promote: - if latest_version != current_version: - model_registry.update_model_version( - name=pipeline_extra["mlflow_model_name"], - version=current_version, - stage=ModelVersionStage.ARCHIVED, - ) - model_registry.update_model_version( - name=pipeline_extra["mlflow_model_name"], - version=latest_version, - stage=ModelVersionStage(pipeline_extra["target_env"]), - ) - promoted_version = latest_version - - logger.info( - f"Current model version in `{pipeline_extra['target_env']}` is `{promoted_version}`" - ) - ### YOUR CODE ENDS HERE ### - return should_promote, int(promoted_version) diff --git a/examples/e2e/steps/promotion/promote_model_version_in_model_control_plane.py b/examples/e2e/steps/promotion/promote_model_version_in_model_control_plane.py deleted file mode 100644 index 10994ee7dc9..00000000000 --- a/examples/e2e/steps/promotion/promote_model_version_in_model_control_plane.py +++ /dev/null @@ -1,43 +0,0 @@ -# Apache Software License 2.0 -# -# Copyright (c) ZenML GmbH 2023. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -from zenml import get_step_context, step -from zenml.logger import get_logger - -logger = get_logger(__name__) - - -@step -def promote_model_version_in_model_control_plane(promotion_decision: bool): - """Step to promote current model version to target environment in Model Control Plane. - - Args: - promotion_decision: Whether to promote current model version to target environment - """ - - ### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ### - if promotion_decision: - target_env = ( - get_step_context().pipeline_run.config.extra["target_env"].lower() - ) - model_version = get_step_context().model_config._get_model_version() - model_version.set_stage(stage=target_env, force=True) - logger.info(f"Current model version was promoted to '{target_env}'.") - else: - logger.info("Current model version was not promoted.") - ### YOUR CODE ENDS HERE ### diff --git a/examples/e2e/train_config.yaml b/examples/e2e/train_config.yaml deleted file mode 100644 index 03360da2bd5..00000000000 --- a/examples/e2e/train_config.yaml +++ /dev/null @@ -1,94 +0,0 @@ -# Apache Software License 2.0 -# -# Copyright (c) ZenML GmbH 2023. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -settings: - docker: - required_integrations: - - aws - - evidently - - kubeflow - - kubernetes - - mlflow - - sklearn - - slack -extra: - mlflow_model_name: e2e_use_case - target_env: Staging - notify_on_success: False - notify_on_failure: True - # This set contains all the models that you want to evaluate - # during hyperparameter tuning stage. - model_search_space: - random_forest: - model_package: sklearn.ensemble - model_class: RandomForestClassifier - search_grid: - criterion: - - gini - - entropy - max_depth: - - 2 - - 4 - - 6 - - 8 - - 10 - - 12 - min_samples_leaf: - range: - start: 1 - end: 10 - n_estimators: - range: - start: 50 - end: 500 - step: 25 - decision_tree: - model_package: sklearn.tree - model_class: DecisionTreeClassifier - search_grid: - criterion: - - gini - - entropy - max_depth: - - 2 - - 4 - - 6 - - 8 - - 10 - - 12 - min_samples_leaf: - range: - start: 1 - end: 10 -model_config: - name: e2e_use_case - license: apache - description: e2e_use_case E2E Batch Use Case - audience: All ZenML users - use_cases: | - The ZenML E2E project project demonstrates how the most important steps of - the ML Production Lifecycle can be implemented in a reusable way remaining - agnostic to the underlying infrastructure, and shows how to integrate them together - into pipelines for Training and Batch Inference purposes. - ethics: No impact. - tags: - - e2e - - batch - - sklearn - - from template - - ZenML delivered - create_new_model_version: true \ No newline at end of file diff --git a/examples/e2e/utils/misc.py b/examples/e2e/utils/misc.py deleted file mode 100644 index 34f9e4ca035..00000000000 --- a/examples/e2e/utils/misc.py +++ /dev/null @@ -1,48 +0,0 @@ -# Apache Software License 2.0 -# -# Copyright (c) ZenML GmbH 2023. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import string - -import pandas as pd -from sklearn.datasets import make_classification - - -def generate_random_data(n_samples: int) -> pd.DataFrame: - """Generate random data for model input. - - Args: - n_samples: Number of records to generate. - - Returns: - pd.DataFrame: Generated dataset for classification task. - """ - n_features = 20 - X, y = make_classification( - n_samples=n_samples, - n_features=n_features, - n_classes=2, - random_state=42, - ) - dataset = pd.concat( - [ - pd.DataFrame(X, columns=list(string.ascii_uppercase[:n_features])), - pd.Series(y, name="target"), - ], - axis=1, - ) - return dataset diff --git a/examples/e2e/utils/model_versions.py b/examples/e2e/utils/model_versions.py deleted file mode 100644 index 65e720e28d0..00000000000 --- a/examples/e2e/utils/model_versions.py +++ /dev/null @@ -1,62 +0,0 @@ -# Apache Software License 2.0 -# -# Copyright (c) ZenML GmbH 2023. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing import Tuple - -from typing_extensions import Annotated - -from zenml import get_step_context -from zenml.model import ModelConfig -from zenml.models.model_models import ModelVersionResponseModel - - -def get_model_versions( - target_env: str, -) -> Tuple[ - Annotated[ModelVersionResponseModel, "latest_version"], - Annotated[ModelVersionResponseModel, "current_version"], -]: - """Get latest and currently promoted model versions from Model Control Plane. - - Args: - target_env: Target stage to search for currently promoted version - - Returns: - Latest and currently promoted model versions from the Model Control Plane - """ - latest_version = get_step_context().model_config._get_model_version() - try: - current_version = ModelConfig( - name=latest_version.model.name, version=target_env - )._get_model_version() - except KeyError: - current_version = latest_version - - return latest_version, current_version - - -def get_model_registry_version(model_version: ModelVersionResponseModel): - """Get model version in model registry from metadata of a model in the Model Control Plane. - - Args: - model_version: the Model Control Plane version response - """ - return ( - model_version.get_model_object("model") - .metadata["model_registry_version"] - .value - ) From ac4c980587a4aa43da7745667eae82fd6e87d476 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Thu, 16 Nov 2023 01:20:50 +0100 Subject: [PATCH 26/28] Access ModelVersionResponseModels in Client again --- src/zenml/cli/model.py | 9 +-- src/zenml/client.py | 60 +++++++++++++++---- src/zenml/model/model_version.py | 13 ++-- .../functional/model/test_artifact_config.py | 9 +-- tests/integration/functional/test_client.py | 18 ++---- .../functional/zen_stores/test_zen_store.py | 6 +- .../functional/zen_stores/utils.py | 4 +- 7 files changed, 69 insertions(+), 50 deletions(-) diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index 2b2921d3c5b..027727bc09a 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -426,15 +426,12 @@ def update_model_version( ) except RuntimeError: if not force: - mv = Client().get_model_version( - model_name_or_id=model_version.model_id, - model_version_name_or_number_or_id=stage, - ) cli_utils.print_table( [ _model_version_to_print( - Client().zen_store.get_model_version( - model_version_id=mv.id + Client()._get_model_version( + model_name_or_id=model_version.model_id, + model_version_name_or_number_or_id=stage, ) ) ] diff --git a/src/zenml/client.py b/src/zenml/client.py index 15781ea0aba..17b6e401a4a 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -161,6 +161,7 @@ ModelVersionPipelineRunFilterModel, ModelVersionPipelineRunResponseModel, ModelVersionRequestModel, + ModelVersionResponseModel, ModelVersionUpdateModel, ) from zenml.models.page_model import Page @@ -5604,39 +5605,74 @@ def get_model_version( RuntimeError: In case method inputs don't adhere to restrictions. KeyError: In case no model version with the identifiers exists. """ + return self._get_model_version( + model_name_or_id=model_name_or_id, + model_version_name_or_number_or_id=model_version_name_or_number_or_id, + ).to_model_version(suppress_class_validation_warnings=True) + + def _get_model_version( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_number_or_id: Optional[ + Union[str, int, ModelStages, UUID] + ] = None, + ) -> "ModelVersionResponseModel": + """Get an existing model version from Model Control Plane. + + Args: + model_name_or_id: name or id of the model containing the model version. + model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. + If skipped - latest version is retrieved. + + Returns: + The model version of interest. + + Raises: + RuntimeError: In case method inputs don't adhere to restrictions. + KeyError: In case no model version with the identifiers exists. + """ + if model_version_name_or_number_or_id is None: model_version_name_or_number_or_id = ModelStages.LATEST if isinstance(model_version_name_or_number_or_id, UUID): return self.zen_store.get_model_version( model_version_id=model_version_name_or_number_or_id - ).to_model_version(suppress_class_validation_warnings=True) + ) elif isinstance(model_version_name_or_number_or_id, int): - model_versions = self.list_model_versions( + model_versions = self.zen_store.list_model_versions( model_name_or_id=model_name_or_id, - number=model_version_name_or_number_or_id, + model_version_filter_model=ModelVersionFilterModel( + number=model_version_name_or_number_or_id, + ), ) elif isinstance(model_version_name_or_number_or_id, str): if model_version_name_or_number_or_id == ModelStages.LATEST: - model_versions = self.list_model_versions( + model_versions = self.zen_store.list_model_versions( model_name_or_id=model_name_or_id, - sort_by=f"{SorterOps.DESCENDING}:number", - ) + model_version_filter_model=ModelVersionFilterModel( + sort_by=f"{SorterOps.DESCENDING}:number" + ), + ).items if len(model_versions) > 1: model_versions = [model_versions[0]] else: model_versions = [] elif model_version_name_or_number_or_id in ModelStages.values(): - model_versions = self.list_model_versions( + model_versions = self.zen_store.list_model_versions( model_name_or_id=model_name_or_id, - stage=model_version_name_or_number_or_id, - ) + model_version_filter_model=ModelVersionFilterModel( + stage=model_version_name_or_number_or_id + ), + ).items else: - model_versions = self.list_model_versions( + model_versions = self.zen_store.list_model_versions( model_name_or_id=model_name_or_id, - name=model_version_name_or_number_or_id, - ) + model_version_filter_model=ModelVersionFilterModel( + name=model_version_name_or_number_or_id + ), + ).items else: raise RuntimeError( f"The model version identifier " diff --git a/src/zenml/model/model_version.py b/src/zenml/model/model_version.py index cab7b7653bd..4be47bab2b6 100644 --- a/src/zenml/model/model_version.py +++ b/src/zenml/model/model_version.py @@ -370,15 +370,14 @@ def _get_model_version(self) -> "ModelVersionResponseModel": from zenml.client import Client zenml_client = Client() + mv = zenml_client._get_model_version( + model_name_or_id=self.name, + model_version_name_or_number_or_id=self.version, + ) if not self._id: - mv = zenml_client.get_model_version( - model_name_or_id=self.name, - model_version_name_or_number_or_id=self.version, - ) self._id = mv._id - return zenml_client.zen_store.get_model_version( - model_version_id=self._id - ) + + return mv def _get_or_create_model_version(self) -> "ModelVersionResponseModel": """This method should get or create a model and a model version from Model Control Plane. diff --git a/tests/integration/functional/model/test_artifact_config.py b/tests/integration/functional/model/test_artifact_config.py index 6070cc2b94f..7ed72a76bc7 100644 --- a/tests/integration/functional/model/test_artifact_config.py +++ b/tests/integration/functional/model/test_artifact_config.py @@ -763,10 +763,9 @@ def _inner_pipeline(force_disable_cache: bool = False): ModelVersion(name="bar")._get_or_create_model_version() _inner_pipeline(i != 1) - mv = client.get_model_version( + mvrm = client._get_model_version( model_name_or_id="foo", model_version_name_or_number_or_id=i ) - mvrm = client.zen_store.get_model_version(model_version_id=mv.id) assert len(mvrm.data_artifact_ids) == 2, f"Failed on {i} run" assert len(mvrm.model_artifact_ids) == 1, f"Failed on {i} run" assert set(mvrm.data_artifact_ids.keys()) == { @@ -777,10 +776,9 @@ def _inner_pipeline(force_disable_cache: bool = False): "_inner_pipeline::_cacheable_step_annotated::cacheable", }, f"Failed on {i} run" - mv = client.get_model_version( + mvrm = client._get_model_version( model_name_or_id="bar", ) - mvrm = client.zen_store.get_model_version(mv.id) assert len(mvrm.data_artifact_ids) == 1, f"Failed on {i} run" assert set(mvrm.data_artifact_ids.keys()) == { @@ -819,10 +817,9 @@ def _inner_pipeline(force_disable_cache: bool = False): ModelVersion(name="bar")._get_or_create_model_version() _inner_pipeline(i != 1) - mv = client.get_model_version( + mvrm = client._get_model_version( model_name_or_id="bar", ) - mvrm = client.zen_store.get_model_version(mv.id) assert len(mvrm.data_artifact_ids) == 1, f"Failed on {i} run" assert set(mvrm.data_artifact_ids.keys()) == { "_inner_pipeline::_cacheable_step_custom_model_annotated::cacheable", diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index fbdd36c3aa1..4c36e21fcd4 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -46,7 +46,6 @@ ModelRequestModel, ModelVersionRequestModel, ModelVersionResponseModel, - ModelVersionUpdateModel, PipelineBuildRequestModel, PipelineDeploymentRequestModel, PipelineRequestModel, @@ -1572,14 +1571,7 @@ def test_get_by_latest(clean_client): assert mv2 == mv1 # after second model version, latest should point to it - mv3 = cl.create_model_version( - ModelVersionRequestModel( - user=cl.active_user.id, - workspace=cl.active_workspace.id, - model=mv1.model.id, - name="2.0.0", - ) - ) + mv3 = cl.create_model_version(model_name_or_id=mv1.model.id, name="2.0.0") mv4 = Client().get_model_version( model_name_or_id=mv1.model.id, model_version_name_or_number_or_id=ModelStages.LATEST, @@ -1595,10 +1587,10 @@ def test_get_by_stage(clean_client): mv1 = _create_some_model_version(client=cl) cl.update_model_version( - model_version_id=mv1.id, - model_version_update_model=ModelVersionUpdateModel( - model=mv1.model.id, stage=ModelStages.STAGING, force=True - ), + version_name_or_id=mv1.id, + model_name_or_id=mv1.model.id, + stage=ModelStages.STAGING, + force=True, ) mv2 = cl.get_model_version( diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 5703ded5db5..d64b4b1177c 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -4242,8 +4242,8 @@ def test_latest_found(self): name="yet another one", ) ) - found_latest = zs.get_model_version( - model_version_id=model.id, + found_latest = Client().get_model_version( + model_name_or_id=model.id ) assert latest.id == found_latest.id @@ -4415,7 +4415,7 @@ def test_get_not_found_by_number(self): found = zs.list_model_versions( model_name_or_id=model_version.model.id, - model_version_filter_model=ModelVersionFilterModel(number=1), + model_version_filter_model=ModelVersionFilterModel(number=2), ).items assert len(found) == 0 diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index f63f9838fd3..e07745600b8 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -662,9 +662,7 @@ def __enter__(self): model = client.create_model(name=self.model) if self.create_version: try: - mv = client.zen_store.get_model_version( - self.model, self.model_version - ) + mv = client._get_model_version(self.model, self.model_version) except KeyError: mv = client.zen_store.create_model_version( ModelVersionRequestModel( From bf34c71e80b391b585b0be9177e1352d17b22554 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Thu, 16 Nov 2023 01:51:14 +0100 Subject: [PATCH 27/28] Another small fix --- src/zenml/client.py | 1 - src/zenml/model/model_version.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index 17b6e401a4a..83b97e55f53 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -5631,7 +5631,6 @@ def _get_model_version( RuntimeError: In case method inputs don't adhere to restrictions. KeyError: In case no model version with the identifiers exists. """ - if model_version_name_or_number_or_id is None: model_version_name_or_number_or_id = ModelStages.LATEST diff --git a/src/zenml/model/model_version.py b/src/zenml/model/model_version.py index 4be47bab2b6..0d5b518372e 100644 --- a/src/zenml/model/model_version.py +++ b/src/zenml/model/model_version.py @@ -375,7 +375,7 @@ def _get_model_version(self) -> "ModelVersionResponseModel": model_version_name_or_number_or_id=self.version, ) if not self._id: - self._id = mv._id + self._id = mv.id return mv From 070e37e56de396703395e2904cb449cda0d63bf0 Mon Sep 17 00:00:00 2001 From: Alexej Penner Date: Thu, 16 Nov 2023 09:08:24 +0100 Subject: [PATCH 28/28] Linted --- src/zenml/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index 83b97e55f53..cd429d7b0ad 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -5644,7 +5644,7 @@ def _get_model_version( model_version_filter_model=ModelVersionFilterModel( number=model_version_name_or_number_or_id, ), - ) + ).items elif isinstance(model_version_name_or_number_or_id, str): if model_version_name_or_number_or_id == ModelStages.LATEST: model_versions = self.zen_store.list_model_versions(