diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index f6afc1bf386..027727bc09a 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -429,7 +429,7 @@ def update_model_version( cli_utils.print_table( [ _model_version_to_print( - Client().zen_store.get_model_version( + Client()._get_model_version( model_name_or_id=model_version.model_id, model_version_name_or_number_or_id=stage, ) @@ -488,9 +488,12 @@ def delete_model_version( return try: - Client().delete_model_version( + model_version = Client().get_model_version( model_name_or_id=model_name_or_id, - model_version_name_or_id=model_version_name_or_number_or_id, + model_version_name_or_number_or_id=model_version_name_or_number_or_id, + ) + Client().delete_model_version( + model_version_id=model_version.id, ) except (KeyError, ValueError) as e: cli_utils.error(str(e)) @@ -518,12 +521,15 @@ def _print_artifacts_links_generic( only_model_artifacts: If set, only print model artifacts. **kwargs: Keyword arguments to filter models. """ - model_version = Client().zen_store.get_model_version( + model_version = Client().get_model_version( model_name_or_id=model_name_or_id, model_version_name_or_number_or_id=ModelStages.LATEST if model_version_name_or_number_or_id == "0" else model_version_name_or_number_or_id, ) + model_version_response_model = Client().zen_store.get_model_version( + model_version_id=model_version.id + ) type_ = ( "data artifacts" if only_data_artifacts @@ -533,11 +539,18 @@ def _print_artifacts_links_generic( ) if ( - (only_data_artifacts and not model_version.data_artifact_ids) + ( + only_data_artifacts + and not model_version_response_model.data_artifact_ids + ) + or ( + only_endpoint_artifacts + and not model_version_response_model.endpoint_artifact_ids + ) or ( - only_endpoint_artifacts and not model_version.endpoint_artifact_ids + only_model_artifacts + and not model_version_response_model.model_artifact_ids ) - or (only_model_artifacts and not model_version.model_artifact_ids) ): cli_utils.declare(f"No {type_} linked to the model version found.") return @@ -546,9 +559,13 @@ def _print_artifacts_links_generic( f"{type_} linked to the model version `{model_version.name}[{model_version.number}]`:" ) + model_version = Client().get_model_version( + model_name_or_id=model_name_or_id, + model_version_name_or_number_or_id=model_version_name_or_number_or_id, + ) + links = Client().list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_number_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( only_data_artifacts=only_data_artifacts, only_endpoint_artifacts=only_endpoint_artifacts, @@ -674,23 +691,29 @@ def list_model_version_pipeline_runs( Or use 0 for the latest version. **kwargs: Keyword arguments to filter models. """ - model_version = Client().zen_store.get_model_version( + model_version = Client().get_model_version( model_name_or_id=model_name_or_id, model_version_name_or_number_or_id=ModelStages.LATEST if model_version_name_or_number_or_id == "0" else model_version_name_or_number_or_id, ) + model_version_response_model = Client().zen_store.get_model_version( + model_version_id=model_version.id + ) - if not model_version.pipeline_run_ids: + if not model_version_response_model.pipeline_run_ids: cli_utils.declare("No pipeline runs attached to model version found.") return cli_utils.title( - f"Pipeline runs linked to the model version `{model_version.name}[{model_version.number}]`:" + f"Pipeline runs linked to the model version `{model_version_response_model.name}[{model_version_response_model.number}]`:" + ) + model_version = Client().get_model_version( + model_name_or_id=model_name_or_id, + model_version_name_or_number_or_id=model_version_name_or_number_or_id, ) links = Client().list_model_version_pipeline_run_links( - model_name_or_id=model_version.model.id, - model_version_name_or_number_or_id=model_version.id, + model_version_id=model_version.id, model_version_pipeline_run_link_filter_model=ModelVersionPipelineRunFilterModel( **kwargs, ), diff --git a/src/zenml/client.py b/src/zenml/client.py index 07a7cab05de..4a938bf008b 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -57,6 +57,7 @@ OAuthDeviceStatus, PermissionType, SecretScope, + SorterOps, StackComponentType, StoreType, ) @@ -5584,25 +5585,22 @@ def create_model_version( def delete_model_version( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, ) -> None: """Deletes a model version from Model Control Plane. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_id: name or id of the model version to be deleted. + model_version_id: Id of the model version to be deleted. """ self.zen_store.delete_model_version( - model_name_or_id=model_name_or_id, - model_version_name_or_id=model_version_name_or_id, + model_version_id=model_version_id, ) def get_model_version( self, model_name_or_id: Union[str, UUID], model_version_name_or_number_or_id: Optional[ - Union[str, int, UUID, ModelStages] + Union[str, int, ModelStages, UUID] ] = None, ) -> "ModelVersion": """Get an existing model version from Model Control Plane. @@ -5614,22 +5612,109 @@ def get_model_version( Returns: The model version of interest. + + Raises: + RuntimeError: In case method inputs don't adhere to restrictions. + KeyError: In case no model version with the identifiers exists. """ - return self.zen_store.get_model_version( + return self._get_model_version( model_name_or_id=model_name_or_id, - model_version_name_or_number_or_id=model_version_name_or_number_or_id - or ModelStages.LATEST, + model_version_name_or_number_or_id=model_version_name_or_number_or_id, ).to_model_version(suppress_class_validation_warnings=True) - def list_model_versions( + def _get_model_version( self, model_name_or_id: Union[str, UUID], + model_version_name_or_number_or_id: Optional[ + Union[str, int, ModelStages, UUID] + ] = None, + ) -> "ModelVersionResponseModel": + """Get an existing model version from Model Control Plane. + + Args: + model_name_or_id: name or id of the model containing the model version. + model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. + If skipped - latest version is retrieved. + + Returns: + The model version of interest. + + Raises: + RuntimeError: In case method inputs don't adhere to restrictions. + KeyError: In case no model version with the identifiers exists. + """ + if model_version_name_or_number_or_id is None: + model_version_name_or_number_or_id = ModelStages.LATEST + + if isinstance(model_version_name_or_number_or_id, UUID): + return self.zen_store.get_model_version( + model_version_id=model_version_name_or_number_or_id + ) + elif isinstance(model_version_name_or_number_or_id, int): + model_versions = self.zen_store.list_model_versions( + model_name_or_id=model_name_or_id, + model_version_filter_model=ModelVersionFilterModel( + number=model_version_name_or_number_or_id, + ), + ).items + elif isinstance(model_version_name_or_number_or_id, str): + if model_version_name_or_number_or_id == ModelStages.LATEST: + model_versions = self.zen_store.list_model_versions( + model_name_or_id=model_name_or_id, + model_version_filter_model=ModelVersionFilterModel( + sort_by=f"{SorterOps.DESCENDING}:number" + ), + ).items + + if len(model_versions) > 1: + model_versions = [model_versions[0]] + else: + model_versions = [] + elif model_version_name_or_number_or_id in ModelStages.values(): + model_versions = self.zen_store.list_model_versions( + model_name_or_id=model_name_or_id, + model_version_filter_model=ModelVersionFilterModel( + stage=model_version_name_or_number_or_id + ), + ).items + else: + model_versions = self.zen_store.list_model_versions( + model_name_or_id=model_name_or_id, + model_version_filter_model=ModelVersionFilterModel( + name=model_version_name_or_number_or_id + ), + ).items + else: + raise RuntimeError( + f"The model version identifier " + f"`{model_version_name_or_number_or_id}` is not" + f"of the correct type." + ) + + if len(model_versions) == 1: + return model_versions[0] + elif len(model_versions) == 0: + raise KeyError( + f"No model version found for model " + f"`{model_name_or_id}` with version identifier " + f"`{model_version_name_or_number_or_id}`." + ) + else: + raise RuntimeError( + f"The model version identifier " + f"`{model_version_name_or_number_or_id}` is not" + f"unique for model `{model_name_or_id}`." + ) + + def list_model_versions( + self, sort_by: str = "number", page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, created: Optional[Union[datetime, str]] = None, updated: Optional[Union[datetime, str]] = None, + model_name_or_id: Optional[Union[str, UUID]] = None, name: Optional[str] = None, number: Optional[int] = None, stage: Optional[Union[str, ModelStages]] = None, @@ -5637,6 +5722,12 @@ def list_model_versions( """Get model versions by filter from Model Control Plane. Args: + sort_by: The column to sort by + page: The page of items + size: The maximum size of all pages + logical_operator: Which logical operator to use [and, or] + created: Use to filter by time of creation + updated: Use the last updated date for filtering model_name_or_id: name or id of the model containing the model version. sort_by: The column to sort by page: The page of items @@ -5652,6 +5743,12 @@ def list_model_versions( A page object with all model versions. """ model_version_filter_model = ModelVersionFilterModel( + page=page, + size=size, + sort_by=sort_by, + logical_operator=logical_operator, + created=created, + updated=updated, name=name, number=number, stage=stage, @@ -5714,28 +5811,21 @@ def update_model_version( def list_model_version_artifact_links( self, - model_name_or_id: Union[str, UUID], model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel, - model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages], + model_version_id: UUID, ) -> Page[ModelVersionArtifactResponseModel]: """Get model version to artifact links by filter in Model Control Plane. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. + model_version_id: id of the model version to be retrieved. model_version_artifact_link_filter_model: All filter parameters including pagination params. Returns: A page of all model version to artifact links. """ - mv = self.zen_store.get_model_version( - model_name_or_id=model_name_or_id, - model_version_name_or_number_or_id=model_version_name_or_number_or_id, - ) return self.zen_store.list_model_version_artifact_links( - model_name_or_id=mv.model.id, - model_version_name_or_id=mv.id, + model_version_id=model_version_id, model_version_artifact_link_filter_model=model_version_artifact_link_filter_model, ) @@ -5747,28 +5837,21 @@ def list_model_version_artifact_links( def list_model_version_pipeline_run_links( self, - model_name_or_id: Union[str, UUID], model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel, - model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages], + model_version_id: UUID, ) -> Page[ModelVersionPipelineRunResponseModel]: """Get all model version to pipeline run links by filter. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. + model_version_id: id of the model version to be retrieved. model_version_pipeline_run_link_filter_model: All filter parameters including pagination params. Returns: A page of all model version to pipeline run links. """ - mv = self.zen_store.get_model_version( - model_name_or_id=model_name_or_id, - model_version_name_or_number_or_id=model_version_name_or_number_or_id, - ) return self.zen_store.list_model_version_pipeline_run_links( - model_name_or_id=mv.model.id, - model_version_name_or_id=mv.id, + model_version_id=model_version_id, model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model, ) diff --git a/src/zenml/model/artifact_config.py b/src/zenml/model/artifact_config.py index a3dd669863c..d9ef00e210b 100644 --- a/src/zenml/model/artifact_config.py +++ b/src/zenml/model/artifact_config.py @@ -154,8 +154,7 @@ def _link_to_model_version( # Create the model version artifact link using the ZenML client existing_links = client.list_model_version_artifact_links( - model_name_or_id=model_version.model_id, - model_version_name_or_number_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=client.active_user.id, workspace_id=client.active_workspace.id, @@ -175,9 +174,9 @@ def _link_to_model_version( logger.warning( f"Existing artifact link(s) `{artifact_name}` found and will be deleted." ) + client.zen_store.delete_model_version_artifact_link( - model_name_or_id=model_version.model_id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_name_or_id=artifact_name, ) else: diff --git a/src/zenml/model/model_version.py b/src/zenml/model/model_version.py index 49a7fc94e9c..0d5b518372e 100644 --- a/src/zenml/model/model_version.py +++ b/src/zenml/model/model_version.py @@ -370,10 +370,14 @@ def _get_model_version(self) -> "ModelVersionResponseModel": from zenml.client import Client zenml_client = Client() - return zenml_client.zen_store.get_model_version( + mv = zenml_client._get_model_version( model_name_or_id=self.name, model_version_name_or_number_or_id=self.version, ) + if not self._id: + self._id = mv.id + + return mv def _get_or_create_model_version(self) -> "ModelVersionResponseModel": """This method should get or create a model and a model version from Model Control Plane. diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 776b6f63041..e05f0fbd62e 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -223,7 +223,7 @@ def to_model_version( """ from zenml.model.model_version import ModelVersion - return ModelVersion( + mv = ModelVersion( name=self.model.name, license=self.model.license, description=self.description, @@ -237,6 +237,9 @@ def to_model_version( was_created_in_this_run=was_created_in_this_run, suppress_class_validation_warnings=suppress_class_validation_warnings, ) + mv._id = self.id + + return mv @property def model_artifacts(self) -> Dict[str, Dict[str, ArtifactResponseModel]]: diff --git a/src/zenml/zen_server/routers/model_versions_endpoints.py b/src/zenml/zen_server/routers/model_versions_endpoints.py new file mode 100644 index 00000000000..65332c614ea --- /dev/null +++ b/src/zenml/zen_server/routers/model_versions_endpoints.py @@ -0,0 +1,267 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Endpoint definitions for models.""" + +from typing import Union +from uuid import UUID + +from fastapi import APIRouter, Depends, Security + +from zenml.constants import ( + API, + ARTIFACTS, + MODEL_VERSIONS, + RUNS, + VERSION_1, +) +from zenml.enums import PermissionType +from zenml.models import ( + ModelVersionArtifactFilterModel, + ModelVersionArtifactResponseModel, + ModelVersionFilterModel, + ModelVersionPipelineRunFilterModel, + ModelVersionPipelineRunResponseModel, + ModelVersionResponseModel, + ModelVersionUpdateModel, +) +from zenml.models.page_model import Page +from zenml.zen_server.auth import AuthContext, authorize +from zenml.zen_server.exceptions import error_response +from zenml.zen_server.utils import ( + handle_exceptions, + make_dependable, + zen_store, +) + +######### +# Models +######### + +router = APIRouter( + prefix=API + VERSION_1 + MODEL_VERSIONS, + tags=["model_versions"], + responses={401: error_response}, +) + +################# +# Model Versions +################# + + +@router.get( + "", + response_model=Page[ModelVersionResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_model_versions( + model_version_filter_model: ModelVersionFilterModel = Depends( + make_dependable(ModelVersionFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionResponseModel]: + """Get model versions according to query filters. + + Args: + model_version_filter_model: Filter model used for pagination, sorting, + filtering + + Returns: + The model versions according to query filters. + """ + return zen_store().list_model_versions( + model_version_filter_model=model_version_filter_model, + ) + + +@router.get( + "/{model_version_id}", + response_model=ModelVersionResponseModel, + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def get_model_version( + model_version_id: UUID, + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> ModelVersionResponseModel: + """Get a model version by ID. + + Args: + model_version_id: id of the model version to be retrieved. + + Returns: + The model version with the given name or ID. + """ + return zen_store().get_model_version( + model_version_id=model_version_id, + ) + + +@router.put( + "/{model_version_id}", + response_model=ModelVersionResponseModel, + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def update_model_version( + model_version_id: UUID, + model_version_update_model: ModelVersionUpdateModel, + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> ModelVersionResponseModel: + """Get all model versions by filter. + + Args: + model_version_id: The ID of model version to be updated. + model_version_update_model: The model version to be updated. + + Returns: + An updated model version. + """ + return zen_store().update_model_version( + model_version_id=model_version_id, + model_version_update_model=model_version_update_model, + ) + + +@router.delete( + "/{model_version_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_model_version( + model_version_id: UUID, + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> None: + """Delete a model by name or ID. + + Args: + model_version_id: The name or ID of the model version to delete. + """ + zen_store().delete_model_version(model_version_id) + + +########################## +# Model Version Artifacts +########################## + + +@router.get( + "/{model_version_id}" + ARTIFACTS, + response_model=Page[ModelVersionArtifactResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_model_version_artifact_links( + model_version_id: UUID, + model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel = Depends( + make_dependable(ModelVersionArtifactFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionArtifactResponseModel]: + """Get model version to artifact links according to query filters. + + Args: + model_version_id: ID of the model version containing links. + model_version_artifact_link_filter_model: Filter model used for pagination, sorting, + filtering + + Returns: + The model version to artifact links according to query filters. + """ + return zen_store().list_model_version_artifact_links( + model_version_id=model_version_id, + model_version_artifact_link_filter_model=model_version_artifact_link_filter_model, + ) + + +@router.delete( + "/{model_version_id}" + + ARTIFACTS + + "/{model_version_artifact_link_name_or_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_model_version_artifact_link( + model_version_id: UUID, + model_version_artifact_link_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> None: + """Deletes a model version link. + + Args: + model_version_id: ID of the model version containing the link. + model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. + """ + zen_store().delete_model_version_artifact_link( + model_version_id, + model_version_artifact_link_name_or_id, + ) + + +############################## +# Model Version Pipeline Runs +############################## + + +@router.get( + "/{model_version_id}" + RUNS, + response_model=Page[ModelVersionPipelineRunResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_model_version_pipeline_run_links( + model_version_id: UUID, + model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel = Depends( + make_dependable(ModelVersionPipelineRunFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionPipelineRunResponseModel]: + """Get model version to pipeline run links according to query filters. + + Args: + model_version_id: ID of the model version containing the link. + model_version_pipeline_run_link_filter_model: Filter model used for pagination, sorting, + and filtering + + Returns: + The model version to pipeline run links according to query filters. + """ + return zen_store().list_model_version_pipeline_run_links( + model_version_id=model_version_id, + model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model, + ) + + +@router.delete( + "/{model_version_id}" + + RUNS + + "/{model_version_pipeline_run_link_name_or_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_model_version_pipeline_run_link( + model_version_id: UUID, + model_version_pipeline_run_link_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> None: + """Deletes a model version link. + + Args: + model_version_id: name or ID of the model version containing the link. + model_version_pipeline_run_link_name_or_id: name or ID of the model version link to be deleted. + """ + zen_store().delete_model_version_pipeline_run_link( + model_version_id=model_version_id, + model_version_pipeline_run_link_name_or_id=model_version_pipeline_run_link_name_or_id, + ) diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 4f224793309..dae908ccdae 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -20,25 +20,17 @@ from zenml.constants import ( API, - ARTIFACTS, - LATEST_MODEL_VERSION_PLACEHOLDER, MODEL_VERSIONS, MODELS, - RUNS, VERSION_1, ) -from zenml.enums import ModelStages, PermissionType +from zenml.enums import PermissionType from zenml.models import ( ModelFilterModel, ModelResponseModel, ModelUpdateModel, - ModelVersionArtifactFilterModel, - ModelVersionArtifactResponseModel, ModelVersionFilterModel, - ModelVersionPipelineRunFilterModel, - ModelVersionPipelineRunResponseModel, ModelVersionResponseModel, - ModelVersionUpdateModel, ) from zenml.models.page_model import Page from zenml.zen_server.auth import AuthContext, authorize @@ -171,6 +163,8 @@ def list_model_versions( ) -> Page[ModelVersionResponseModel]: """Get model versions according to query filters. + This endpoint serves the purpose of allowing scoped filtering by model_id. + Args: model_name_or_id: The name or ID of the model to list in. model_version_filter_model: Filter model used for pagination, sorting, @@ -183,223 +177,3 @@ def list_model_versions( model_name_or_id=model_name_or_id, model_version_filter_model=model_version_filter_model, ) - - -@router.get( - "/{model_name_or_id}" - + MODEL_VERSIONS - + "/{model_version_name_or_number_or_id}", - response_model=ModelVersionResponseModel, - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def get_model_version( - model_name_or_id: Union[str, UUID], - model_version_name_or_number_or_id: Union[ - str, int, UUID, ModelStages - ] = LATEST_MODEL_VERSION_PLACEHOLDER, - is_number: bool = False, - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> ModelVersionResponseModel: - """Get a model version by name or ID. - - Args: - model_name_or_id: The name or ID of the model containing version. - model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. - If skipped latest version will be retrieved. - is_number: If the model_version_name_or_number_or_id is a version number - - Returns: - The model version with the given name or ID. - """ - return zen_store().get_model_version( - model_name_or_id, - model_version_name_or_number_or_id - if not is_number - else int(model_version_name_or_number_or_id), - ) - - -@router.put( - "/{model_id}" + MODEL_VERSIONS + "/{model_version_id}", - response_model=ModelVersionResponseModel, - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def update_model_version( - model_version_id: UUID, - model_version_update_model: ModelVersionUpdateModel, - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> ModelVersionResponseModel: - """Get all model versions by filter. - - Args: - model_version_id: The ID of model version to be updated. - model_version_update_model: The model version to be updated. - - Returns: - An updated model version. - """ - return zen_store().update_model_version( - model_version_id=model_version_id, - model_version_update_model=model_version_update_model, - ) - - -@router.delete( - "/{model_name_or_id}" + MODEL_VERSIONS + "/{model_version_name_or_id}", - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def delete_model_version( - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> None: - """Delete a model by name or ID. - - Args: - model_name_or_id: The name or ID of the model containing version. - model_version_name_or_id: The name or ID of the model version to delete. - """ - zen_store().delete_model_version( - model_name_or_id, model_version_name_or_id - ) - - -########################## -# Model Version Artifacts -########################## - - -@router.get( - "/{model_name_or_id}" - + MODEL_VERSIONS - + "/{model_version_name_or_id}" - + ARTIFACTS, - response_model=Page[ModelVersionArtifactResponseModel], - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def list_model_version_artifact_links( - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], - model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel = Depends( - make_dependable(ModelVersionArtifactFilterModel) - ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> Page[ModelVersionArtifactResponseModel]: - """Get model version to artifact links according to query filters. - - Args: - model_name_or_id: The name or ID of the model containing version. - model_version_name_or_id: The name or ID of the model version containing links. - model_version_artifact_link_filter_model: Filter model used for pagination, sorting, - filtering - - Returns: - The model version to artifact links according to query filters. - """ - return zen_store().list_model_version_artifact_links( - model_name_or_id=model_name_or_id, - model_version_name_or_id=model_version_name_or_id, - model_version_artifact_link_filter_model=model_version_artifact_link_filter_model, - ) - - -@router.delete( - "/{model_name_or_id}" - + MODEL_VERSIONS - + "/{model_version_name_or_id}" - + ARTIFACTS - + "/{model_version_artifact_link_name_or_id}", - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def delete_model_version_artifact_link( - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], - model_version_artifact_link_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> None: - """Deletes a model version link. - - Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. - model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. - """ - zen_store().delete_model_version_artifact_link( - model_name_or_id, - model_version_name_or_id, - model_version_artifact_link_name_or_id, - ) - - -############################## -# Model Version Pipeline Runs -############################## - - -@router.get( - "/{model_name_or_id}" - + MODEL_VERSIONS - + "/{model_version_name_or_id}" - + RUNS, - response_model=Page[ModelVersionPipelineRunResponseModel], - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def list_model_version_pipeline_run_links( - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], - model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel = Depends( - make_dependable(ModelVersionPipelineRunFilterModel) - ), - _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> Page[ModelVersionPipelineRunResponseModel]: - """Get model version to pipeline run links according to query filters. - - Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. - model_version_pipeline_run_link_filter_model: Filter model used for pagination, sorting, - and filtering - - Returns: - The model version to pipeline run links according to query filters. - """ - return zen_store().list_model_version_pipeline_run_links( - model_name_or_id=model_name_or_id, - model_version_name_or_id=model_version_name_or_id, - model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model, - ) - - -@router.delete( - "/{model_name_or_id}" - + MODEL_VERSIONS - + "/{model_version_name_or_id}" - + RUNS - + "/{model_version_pipeline_run_link_name_or_id}", - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def delete_model_version_pipeline_run_link( - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], - model_version_pipeline_run_link_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> None: - """Deletes a model version link. - - Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. - model_version_pipeline_run_link_name_or_id: name or ID of the model version link to be deleted. - """ - zen_store().delete_model_version_pipeline_run_link( - model_name_or_id, - model_version_name_or_id, - model_version_pipeline_run_link_name_or_id, - ) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index f3e834d58d7..db463cfdc00 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -1295,10 +1295,8 @@ def create_model_version( @router.post( WORKSPACES + "/{workspace_name_or_id}" - + MODELS - + "/{model_name_or_id}" + MODEL_VERSIONS - + "/{model_version_name_or_id}" + + "/{model_version_id}" + ARTIFACTS, response_model=ModelVersionArtifactResponseModel, responses={401: error_response, 409: error_response, 422: error_response}, @@ -1306,8 +1304,7 @@ def create_model_version( @handle_exceptions def create_model_version_artifact_link( workspace_name_or_id: Union[str, UUID], - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: Union[str, UUID], model_version_artifact_link: ModelVersionArtifactRequestModel, auth_context: AuthContext = Security( authorize, scopes=[PermissionType.WRITE] @@ -1316,9 +1313,8 @@ def create_model_version_artifact_link( """Create a new model version to artifact link. Args: - model_name_or_id: Name or ID of the model. workspace_name_or_id: Name or ID of the workspace. - model_version_name_or_id: Name or ID of the model version. + model_version_id: Name or ID of the model version. model_version_artifact_link: The model version to artifact link to create. auth_context: Authentication context. @@ -1331,6 +1327,13 @@ def create_model_version_artifact_link( user. """ workspace = zen_store().get_workspace(workspace_name_or_id) + if str(model_version_id) != str(model_version_artifact_link.model_version): + breakpoint() + raise IllegalOperationError( + f"The model version id in your path `{model_version_id}` does not " + f"match the model version specified in the request model " + f"`{model_version_artifact_link.model_version}`" + ) if model_version_artifact_link.workspace != workspace.id: raise IllegalOperationError( @@ -1353,7 +1356,7 @@ def create_model_version_artifact_link( WORKSPACES + "/{workspace_name_or_id}" + MODEL_VERSIONS - + "/{model_version_name_or_id}" + + "/{model_version_id}" + ARTIFACTS, response_model=Page[ModelVersionArtifactResponseModel], responses={401: error_response, 404: error_response, 422: error_response}, @@ -1361,8 +1364,7 @@ def create_model_version_artifact_link( @handle_exceptions def list_workspace_model_version_artifact_links( workspace_name_or_id: Union[str, UUID], - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel = Depends( make_dependable(ModelVersionArtifactFilterModel) ), @@ -1371,9 +1373,8 @@ def list_workspace_model_version_artifact_links( """Get model version to artifact links according to query filters. Args: - model_name_or_id: Name or ID of the model. workspace_name_or_id: Name or ID of the workspace. - model_version_name_or_id: Name or ID of the model version. + model_version_id: Name or ID of the model version. model_version_artifact_link_filter_model: Filter model used for pagination, sorting, filtering @@ -1383,8 +1384,7 @@ def list_workspace_model_version_artifact_links( workspace_id = zen_store().get_workspace(workspace_name_or_id).id model_version_artifact_link_filter_model.set_scope_workspace(workspace_id) return zen_store().list_model_version_artifact_links( - model_name_or_id=model_name_or_id, - model_version_name_or_id=model_version_name_or_id, + model_version_id=model_version_id, model_version_artifact_link_filter_model=model_version_artifact_link_filter_model, ) @@ -1392,10 +1392,8 @@ def list_workspace_model_version_artifact_links( @router.post( WORKSPACES + "/{workspace_name_or_id}" - + MODELS - + "/{model_name_or_id}" + MODEL_VERSIONS - + "/{model_version_name_or_id}" + + "/{model_version_id}" + RUNS, response_model=ModelVersionPipelineRunResponseModel, responses={401: error_response, 409: error_response, 422: error_response}, @@ -1403,8 +1401,7 @@ def list_workspace_model_version_artifact_links( @handle_exceptions def create_model_version_pipeline_run_link( workspace_name_or_id: Union[str, UUID], - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_pipeline_run_link: ModelVersionPipelineRunRequestModel, auth_context: AuthContext = Security( authorize, scopes=[PermissionType.WRITE] @@ -1413,9 +1410,8 @@ def create_model_version_pipeline_run_link( """Create a new model version to pipeline run link. Args: - model_name_or_id: Name or ID of the model. workspace_name_or_id: Name or ID of the workspace. - model_version_name_or_id: Name or ID of the model version. + model_version_id: ID of the model version. model_version_pipeline_run_link: The model version to pipeline run link to create. auth_context: Authentication context. @@ -1429,6 +1425,14 @@ def create_model_version_pipeline_run_link( user. """ workspace = zen_store().get_workspace(workspace_name_or_id) + if str(model_version_id) != str( + model_version_pipeline_run_link.model_version + ): + raise IllegalOperationError( + f"The model version id in your path `{model_version_id}` does not " + f"match the model version specified in the request model " + f"`{model_version_pipeline_run_link.model_version}`" + ) if model_version_pipeline_run_link.workspace != workspace.id: raise IllegalOperationError( @@ -1451,7 +1455,7 @@ def create_model_version_pipeline_run_link( WORKSPACES + "/{workspace_name_or_id}" + MODEL_VERSIONS - + "/{model_version_name_or_id}" + + "/{model_version_id}" + RUNS, response_model=Page[ModelVersionPipelineRunResponseModel], responses={401: error_response, 404: error_response, 422: error_response}, @@ -1459,8 +1463,7 @@ def create_model_version_pipeline_run_link( @handle_exceptions def list_workspace_model_version_pipeline_run_links( workspace_name_or_id: Union[str, UUID], - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel = Depends( make_dependable(ModelVersionPipelineRunFilterModel) ), @@ -1469,9 +1472,8 @@ def list_workspace_model_version_pipeline_run_links( """Get model version to pipeline links according to query filters. Args: - model_name_or_id: Name or ID of the model. workspace_name_or_id: Name or ID of the workspace. - model_version_name_or_id: Name or ID of the model version. + model_version_id: ID of the model version. model_version_pipeline_run_link_filter_model: Filter model used for pagination, sorting, filtering @@ -1483,7 +1485,6 @@ def list_workspace_model_version_pipeline_run_links( workspace_id ) return zen_store().list_model_version_pipeline_run_links( - model_name_or_id=model_name_or_id, - model_version_name_or_id=model_version_name_or_id, + model_version_id=model_version_id, model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model, ) diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index 82df028c83b..86849740c82 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -36,6 +36,7 @@ code_repositories_endpoints, devices_endpoints, flavors_endpoints, + model_versions_endpoints, models_endpoints, pipeline_builds_endpoints, pipeline_deployments_endpoints, @@ -232,6 +233,7 @@ def dashboard(request: Request) -> Any: app.include_router(pipeline_deployments_endpoints.router) app.include_router(code_repositories_endpoints.router) app.include_router(models_endpoints.router) +app.include_router(model_versions_endpoints.router) app.include_router(tags_endpoints.router) diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index e827505e80b..bc0593d185f 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -81,7 +81,6 @@ WORKSPACES, ) from zenml.enums import ( - ModelStages, OAuthGrantTypes, SecretsStoreType, StoreType, @@ -2676,53 +2675,40 @@ def create_model_version( def delete_model_version( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, ) -> None: """Deletes a model version. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_id: name or id of the model version to be deleted. + model_version_id: name or id of the model version to be deleted. """ self._delete_resource( - resource_id=model_version_name_or_id, - route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}", + resource_id=model_version_id, + route=f"{MODEL_VERSIONS}", ) def get_model_version( - self, - model_name_or_id: Union[str, UUID], - model_version_name_or_number_or_id: Optional[ - Union[str, int, UUID, ModelStages] - ] = None, + self, model_version_id: UUID ) -> ModelVersionResponseModel: """Get an existing model version. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. - If skipped - latest is retrieved. + model_version_id: name, id, stage or number of the model version to + be retrieved. If skipped - latest is retrieved. Returns: The model version of interest. """ return self._get_resource( - resource_id=model_version_name_or_number_or_id - or ModelStages.LATEST, - route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}", + resource_id=model_version_id, + route=MODEL_VERSIONS, response_model=ModelVersionResponseModel, - params={ - "is_number": isinstance( - model_version_name_or_number_or_id, int - ) - }, ) def list_model_versions( self, - model_name_or_id: Union[str, UUID], model_version_filter_model: ModelVersionFilterModel, + model_name_or_id: Optional[Union[str, UUID]] = None, ) -> Page[ModelVersionResponseModel]: """Get all model versions by filter. @@ -2734,11 +2720,18 @@ def list_model_versions( Returns: A page of all model versions. """ - return self._list_paginated_resources( - route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}", - response_model=ModelVersionResponseModel, - filter_model=model_version_filter_model, - ) + if model_name_or_id: + return self._list_paginated_resources( + route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}", + response_model=ModelVersionResponseModel, + filter_model=model_version_filter_model, + ) + else: + return self._list_paginated_resources( + route=MODEL_VERSIONS, + response_model=ModelVersionResponseModel, + filter_model=model_version_filter_model, + ) def update_model_version( self, @@ -2758,7 +2751,7 @@ def update_model_version( return self._update_resource( resource_id=model_version_id, resource_update=model_version_update_model, - route=f"{MODELS}/{model_version_update_model.model}{MODEL_VERSIONS}", + route=MODEL_VERSIONS, response_model=ModelVersionResponseModel, ) @@ -2780,20 +2773,18 @@ def create_model_version_artifact_link( return self._create_workspace_scoped_resource( resource=model_version_artifact_link, response_model=ModelVersionArtifactResponseModel, - route=f"{MODELS}/{model_version_artifact_link.model}{MODEL_VERSIONS}/{model_version_artifact_link.model_version}{ARTIFACTS}", + route=f"{MODEL_VERSIONS}/{model_version_artifact_link.model_version}{ARTIFACTS}", ) def list_model_version_artifact_links( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel, ) -> Page[ModelVersionArtifactResponseModel]: """Get all model version to artifact links by filter. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_artifact_link_filter_model: All filter parameters including pagination params. @@ -2801,27 +2792,25 @@ def list_model_version_artifact_links( A page of all model version to artifact links. """ return self._list_paginated_resources( - route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}/{model_version_name_or_id}{ARTIFACTS}", + route=f"{MODEL_VERSIONS}/{model_version_id}{ARTIFACTS}", response_model=ModelVersionArtifactResponseModel, filter_model=model_version_artifact_link_filter_model, ) def delete_model_version_artifact_link( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_artifact_link_name_or_id: Union[str, UUID], ) -> None: """Deletes a model version to artifact link. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. """ self._delete_resource( resource_id=model_version_artifact_link_name_or_id, - route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}/{model_version_name_or_id}{ARTIFACTS}", + route=f"{MODEL_VERSIONS}/{model_version_id}{ARTIFACTS}", ) ############################### @@ -2844,20 +2833,18 @@ def create_model_version_pipeline_run_link( return self._create_workspace_scoped_resource( resource=model_version_pipeline_run_link, response_model=ModelVersionPipelineRunResponseModel, - route=f"{MODELS}/{model_version_pipeline_run_link.model}{MODEL_VERSIONS}/{model_version_pipeline_run_link.model_version}{RUNS}", + route=f"{MODEL_VERSIONS}/{model_version_pipeline_run_link.model_version}{RUNS}", ) def list_model_version_pipeline_run_links( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel, ) -> Page[ModelVersionPipelineRunResponseModel]: """Get all model version to pipeline run links by filter. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_pipeline_run_link_filter_model: All filter parameters including pagination params. @@ -2865,27 +2852,25 @@ def list_model_version_pipeline_run_links( A page of all model version to pipeline run links. """ return self._list_paginated_resources( - route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}/{model_version_name_or_id}{RUNS}", + route=f"{MODEL_VERSIONS}/{model_version_id}{RUNS}", response_model=ModelVersionPipelineRunResponseModel, filter_model=model_version_pipeline_run_link_filter_model, ) def delete_model_version_pipeline_run_link( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_pipeline_run_link_name_or_id: Union[str, UUID], ) -> None: """Deletes a model version to pipeline run link. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_pipeline_run_link_name_or_id: name or ID of the model version to pipeline run link to be deleted. """ self._delete_resource( resource_id=model_version_pipeline_run_link_name_or_id, - route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}/{model_version_name_or_id}{RUNS}", + route=f"{MODEL_VERSIONS}/{model_version_id}{RUNS}", ) # ------------------ diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 45e285cc2d1..6d255c11946 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -6539,18 +6539,13 @@ def create_model_version( return mv def get_model_version( - self, - model_name_or_id: Union[str, UUID], - model_version_name_or_number_or_id: Optional[ - Union[str, int, UUID, ModelStages] - ] = None, + self, model_version_id: UUID ) -> ModelVersionResponseModel: """Get an existing model version. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. - If skipped - latest is retrieved. + model_version_id: name, id, stage or number of the model version to + be retrieved. If skipped - latest is retrieved. Returns: The model version of interest. @@ -6559,54 +6554,24 @@ def get_model_version( KeyError: specified ID or name not found. """ with Session(self.engine) as session: - model = self.get_model(model_name_or_id) - query = select(ModelVersionSchema).where( - ModelVersionSchema.model_id == model.id + model_version = self._get_schema_by_name_or_id( + object_name_or_id=model_version_id, + schema_class=ModelVersionSchema, + schema_name="model_version", + session=session, ) - if model_version_name_or_number_or_id is None: - model_version_name_or_number_or_id = ModelStages.LATEST - if ( - str(model_version_name_or_number_or_id) - == ModelStages.LATEST.value - ): - query = query.order_by(ModelVersionSchema.number.desc()) # type: ignore[attr-defined] - elif model_version_name_or_number_or_id in [ - stage.value for stage in ModelStages - ]: - query = query.where( - ModelVersionSchema.stage - == model_version_name_or_number_or_id - ) - elif isinstance(model_version_name_or_number_or_id, int): - query = query.where( - ModelVersionSchema.number - == model_version_name_or_number_or_id - ) - - else: - try: - UUID(str(model_version_name_or_number_or_id)) - query = query.where( - ModelVersionSchema.id - == model_version_name_or_number_or_id - ) - except ValueError: - query = query.where( - ModelVersionSchema.name - == model_version_name_or_number_or_id - ) - model_version = session.exec(query).first() if model_version is None: raise KeyError( - f"Unable to get model version with identifier `{model_version_name_or_number_or_id}`: " - f"No model version with this identifier found." + f"Unable to get model version with ID " + f"`{model_version_id}`: No model version with this " + f"ID found." ) return ModelVersionSchema.to_model(model_version) def list_model_versions( self, - model_name_or_id: Union[str, UUID], model_version_filter_model: ModelVersionFilterModel, + model_name_or_id: Optional[Union[str, UUID]] = None, ) -> Page[ModelVersionResponseModel]: """Get all model versions by filter. @@ -6619,7 +6584,10 @@ def list_model_versions( A page of all model versions. """ with Session(self.engine) as session: - model_version_filter_model.set_scope_model(model_name_or_id) + if model_name_or_id: + model = self.get_model(model_name_or_id) + model_version_filter_model.set_scope_model(model.id) + query = select(ModelVersionSchema) return self.filter_and_paginate( session=session, @@ -6630,37 +6598,25 @@ def list_model_versions( def delete_model_version( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, ) -> None: """Deletes a model version. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_id: name or id of the model version to be deleted. + model_version_id: name or id of the model version to be deleted. Raises: KeyError: specified ID or name not found. """ with Session(self.engine) as session: - model = self.get_model(model_name_or_id) query = select(ModelVersionSchema).where( - ModelVersionSchema.model_id == model.id + ModelVersionSchema.id == model_version_id ) - try: - UUID(str(model_version_name_or_id)) - query = query.where( - ModelVersionSchema.id == model_version_name_or_id - ) - except ValueError: - query = query.where( - ModelVersionSchema.name == model_version_name_or_id - ) model_version = session.exec(query).first() if model_version is None: raise KeyError( - f"Unable to delete model version with name `{model_version_name_or_id}`: " - f"No model version with this name found." + f"Unable to delete model version with id `{model_version_id}`: " + f"No model version with this id found." ) session.delete(model_version) session.commit() @@ -6861,15 +6817,13 @@ def create_model_version_artifact_link( def list_model_version_artifact_links( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: Union[UUID], model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel, ) -> Page[ModelVersionArtifactResponseModel]: """Get all model version to artifact links by filter. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_artifact_link_filter_model: All filter parameters including pagination params. @@ -6877,12 +6831,8 @@ def list_model_version_artifact_links( A page of all model version to artifact links. """ with Session(self.engine) as session: - # issue: https://github.com/tiangolo/sqlmodel/issues/109 - model_version_artifact_link_filter_model.set_scope_model( - model_name_or_id - ) model_version_artifact_link_filter_model.set_scope_model_version( - model_version_name_or_id + model_version_id ) if model_version_artifact_link_filter_model.only_data_artifacts: query = ( @@ -6946,25 +6896,20 @@ def list_model_version_artifact_links( def delete_model_version_artifact_link( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_artifact_link_name_or_id: Union[str, UUID], ) -> None: """Deletes a model version to artifact link. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. Raises: KeyError: specified ID or name not found. """ with Session(self.engine) as session: - self.get_model(model_name_or_id) - model_version = self.get_model_version( - model_name_or_id, model_version_name_or_id - ) + model_version = self.get_model_version(model_version_id) query = select(ModelVersionArtifactSchema).where( ModelVersionArtifactSchema.model_version_id == model_version.id ) @@ -7048,15 +6993,13 @@ def create_model_version_pipeline_run_link( def list_model_version_pipeline_run_links( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel, ) -> Page[ModelVersionPipelineRunResponseModel]: """Get all model version to pipeline run links by filter. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: name or ID of the model version containing the link. model_version_pipeline_run_link_filter_model: All filter parameters including pagination params. @@ -7064,11 +7007,8 @@ def list_model_version_pipeline_run_links( A page of all model version to pipeline run links. """ with Session(self.engine) as session: - model_version_pipeline_run_link_filter_model.set_scope_model( - model_name_or_id - ) model_version_pipeline_run_link_filter_model.set_scope_model_version( - model_version_name_or_id + model_version_id ) return self.filter_and_paginate( session=session, @@ -7079,24 +7019,21 @@ def list_model_version_pipeline_run_links( def delete_model_version_pipeline_run_link( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_pipeline_run_link_name_or_id: Union[str, UUID], ) -> None: """Deletes a model version to pipeline run link. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: name or ID of the model version containing the link. model_version_pipeline_run_link_name_or_id: name or ID of the model version to pipeline run link to be deleted. Raises: KeyError: specified ID not found. """ with Session(self.engine) as session: - self.get_model(model_name_or_id) model_version = self.get_model_version( - model_name_or_id, model_version_name_or_id + model_version_id=model_version_id ) query = select(ModelVersionPipelineRunSchema).where( ModelVersionPipelineRunSchema.model_version_id diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 29c37810cf8..e0946e1954c 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -16,7 +16,6 @@ from typing import List, Optional, Tuple, Union from uuid import UUID -from zenml.enums import ModelStages from zenml.models import ( APIKeyFilterModel, APIKeyRequestModel, @@ -2021,14 +2020,12 @@ def create_model_version( @abstractmethod def delete_model_version( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, ) -> None: """Deletes a model version. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_id: name or id of the model version to be deleted. + model_version_id: id of the model version to be deleted. Raises: KeyError: specified ID or name not found. @@ -2036,18 +2033,14 @@ def delete_model_version( @abstractmethod def get_model_version( - self, - model_name_or_id: Union[str, UUID], - model_version_name_or_number_or_id: Optional[ - Union[str, int, UUID, ModelStages] - ] = None, + self, model_version_id: UUID ) -> ModelVersionResponseModel: """Get an existing model version. Args: - model_name_or_id: name or id of the model containing the model version. - model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. - If skipped - latest is retrieved. + model_version_id: name, id, stage or number of the model version to + be retrieved. If skipped - latest is retrieved. + Returns: The model version of interest. @@ -2059,8 +2052,8 @@ def get_model_version( @abstractmethod def list_model_versions( self, - model_name_or_id: Union[str, UUID], model_version_filter_model: ModelVersionFilterModel, + model_name_or_id: Optional[Union[str, UUID]] = None, ) -> Page[ModelVersionResponseModel]: """Get all model versions by filter. @@ -2116,15 +2109,13 @@ def create_model_version_artifact_link( @abstractmethod def list_model_version_artifact_links( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel, ) -> Page[ModelVersionArtifactResponseModel]: """Get all model version to artifact links by filter. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_artifact_link_filter_model: All filter parameters including pagination params. @@ -2135,15 +2126,13 @@ def list_model_version_artifact_links( @abstractmethod def delete_model_version_artifact_link( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_artifact_link_name_or_id: Union[str, UUID], ) -> None: """Deletes a model version to artifact link. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. Raises: @@ -2172,15 +2161,13 @@ def create_model_version_pipeline_run_link( @abstractmethod def list_model_version_pipeline_run_links( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel, ) -> Page[ModelVersionPipelineRunResponseModel]: """Get all model version to pipeline run links by filter. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: name or ID of the model version containing the link. model_version_pipeline_run_link_filter_model: All filter parameters including pagination params. @@ -2191,15 +2178,13 @@ def list_model_version_pipeline_run_links( @abstractmethod def delete_model_version_pipeline_run_link( self, - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], + model_version_id: UUID, model_version_pipeline_run_link_name_or_id: Union[str, UUID], ) -> None: """Deletes a model version to pipeline run link. Args: - model_name_or_id: name or ID of the model containing the model version. - model_version_name_or_id: name or ID of the model version containing the link. + model_version_id: ID of the model version containing the link. model_version_pipeline_run_link_name_or_id: name or ID of the model version to pipeline run link to be deleted. Raises: diff --git a/tests/integration/functional/model/test_artifact_config.py b/tests/integration/functional/model/test_artifact_config.py index b1273156b91..7ed72a76bc7 100644 --- a/tests/integration/functional/model/test_artifact_config.py +++ b/tests/integration/functional/model/test_artifact_config.py @@ -82,8 +82,7 @@ def test_link_minimalistic(): assert mv.name == MODEL_NAME assert mv.number == 1 and mv.version == "1" links = client.list_model_version_artifact_links( - model_name_or_id=mv.model_id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -142,8 +141,7 @@ def test_link_multiple_named_outputs(): assert mv.name == MODEL_NAME assert mv.number == 1 and mv.version == "1" al = client.list_model_version_artifact_links( - model_name_or_id=mv.model_id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -187,8 +185,7 @@ def test_link_multiple_named_outputs_without_links(): assert mv.number == 1 and mv.version == "1" assert mv.name == MODEL_NAME artifact_links = client.list_model_version_artifact_links( - model_name_or_id=mv.model_id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -256,16 +253,14 @@ def test_link_multiple_named_outputs_with_self_context_and_caching(): multi_named_pipeline_from_self(run_count == 2) al1 = client.list_model_version_artifact_links( - model_name_or_id=mv1.model_id, - model_version_name_or_number_or_id=mv1.id, + model_version_id=mv1.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, ), ) al2 = client.list_model_version_artifact_links( - model_name_or_id=mv2.model_id, - model_version_name_or_number_or_id=mv2.id, + model_version_id=mv2.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -284,8 +279,7 @@ def test_link_multiple_named_outputs_with_self_context_and_caching(): for mv, al in zip([mv1, mv2], [al1, al2]): for al_ in al: client.zen_store.delete_model_version_artifact_link( - model_name_or_id=mv.model_id, - model_version_name_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_name_or_id=al_.id, ) @@ -375,8 +369,7 @@ def test_link_multiple_named_outputs_with_mixed_linkage(): for mv in mvs: artifact_links.append( client.list_model_version_artifact_links( - model_name_or_id=mv.model_id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -439,8 +432,7 @@ def test_link_no_versioning(): simple_pipeline_no_versioning() al1 = client.list_model_version_artifact_links( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -453,8 +445,7 @@ def test_link_no_versioning(): simple_pipeline_no_versioning() al2 = client.list_model_version_artifact_links( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -505,8 +496,7 @@ def test_link_with_versioning(): simple_pipeline_with_versioning() al1 = client.list_model_version_artifact_links( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -519,8 +509,7 @@ def test_link_with_versioning(): simple_pipeline_with_versioning() al2 = client.list_model_version_artifact_links( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -609,8 +598,7 @@ def test_link_with_manual_linkage(pipeline: Callable): pipeline() al1 = client.list_model_version_artifact_links( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -621,8 +609,7 @@ def test_link_with_manual_linkage(pipeline: Callable): assert al1[0].name == "1" al2 = client.list_model_version_artifact_links( - model_name_or_id=model2.id, - model_version_name_or_number_or_id=mv2.id, + model_version_id=mv2.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -711,8 +698,7 @@ def test_link_with_manual_linkage_flexible_config( simple_pipeline_with_manual_linkage_flexible_config(artifact_config) links = client.list_model_version_artifact_links( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv.id, + model_version_id=mv.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( user_id=user, workspace_id=ws, @@ -777,29 +763,30 @@ def _inner_pipeline(force_disable_cache: bool = False): ModelVersion(name="bar")._get_or_create_model_version() _inner_pipeline(i != 1) - mv = client.zen_store.get_model_version( + mvrm = client._get_model_version( model_name_or_id="foo", model_version_name_or_number_or_id=i ) - assert len(mv.data_artifact_ids) == 2, f"Failed on {i} run" - assert len(mv.model_artifact_ids) == 1, f"Failed on {i} run" - assert set(mv.data_artifact_ids.keys()) == { + assert len(mvrm.data_artifact_ids) == 2, f"Failed on {i} run" + assert len(mvrm.model_artifact_ids) == 1, f"Failed on {i} run" + assert set(mvrm.data_artifact_ids.keys()) == { "_inner_pipeline::_non_cacheable_step::output", "_inner_pipeline::_cacheable_step_not_annotated::output", }, f"Failed on {i} run" - assert set(mv.model_artifact_ids.keys()) == { + assert set(mvrm.model_artifact_ids.keys()) == { "_inner_pipeline::_cacheable_step_annotated::cacheable", }, f"Failed on {i} run" - mv = client.zen_store.get_model_version( + mvrm = client._get_model_version( model_name_or_id="bar", ) - assert len(mv.data_artifact_ids) == 1, f"Failed on {i} run" - assert set(mv.data_artifact_ids.keys()) == { + + assert len(mvrm.data_artifact_ids) == 1, f"Failed on {i} run" + assert set(mvrm.data_artifact_ids.keys()) == { "_inner_pipeline::_cacheable_step_custom_model_annotated::cacheable", }, f"Failed on {i} run" assert ( len( - mv.data_artifact_ids[ + mvrm.data_artifact_ids[ "_inner_pipeline::_cacheable_step_custom_model_annotated::cacheable" ] ) @@ -809,7 +796,7 @@ def _inner_pipeline(force_disable_cache: bool = False): def test_artifacts_linked_from_cache_steps_same_id(): """Test that artifacts are linked from cache steps with same id. - This case appears if cached step is executed inside same model version + This case appears if cached step is executed inside same model version, and we need to silently pass linkage without failing on same id. """ @@ -830,16 +817,16 @@ def _inner_pipeline(force_disable_cache: bool = False): ModelVersion(name="bar")._get_or_create_model_version() _inner_pipeline(i != 1) - mv = client.zen_store.get_model_version( + mvrm = client._get_model_version( model_name_or_id="bar", ) - assert len(mv.data_artifact_ids) == 1, f"Failed on {i} run" - assert set(mv.data_artifact_ids.keys()) == { + assert len(mvrm.data_artifact_ids) == 1, f"Failed on {i} run" + assert set(mvrm.data_artifact_ids.keys()) == { "_inner_pipeline::_cacheable_step_custom_model_annotated::cacheable", }, f"Failed on {i} run" assert ( len( - mv.data_artifact_ids[ + mvrm.data_artifact_ids[ "_inner_pipeline::_cacheable_step_custom_model_annotated::cacheable" ] ) diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index 0091888815d..129f9df8220 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -43,6 +43,9 @@ from zenml.metadata.metadata_types import MetadataTypeEnum from zenml.models import ( ComponentResponseModel, + ModelRequestModel, + ModelVersionRequestModel, + ModelVersionResponseModel, PipelineBuildRequestModel, PipelineDeploymentRequestModel, PipelineRequestModel, @@ -1324,7 +1327,7 @@ def warm_up_models(): description=TestModelVersion.VERSION_DESC, ) - def test_get_model_version_found(self, warm_up_models): + def test_get_model_version_by_name_found(self, warm_up_models): with model_killer(): client = Client() model_version = client.get_model_version( @@ -1336,6 +1339,55 @@ def test_get_model_version_found(self, warm_up_models): assert model_version.number == 1 assert model_version.description == self.VERSION_DESC + def test_get_model_version_by_id_found(self, warm_up_models): + with model_killer(): + client = Client() + mv = client.get_model_version(self.MODEL_NAME, self.VERSION_NAME) + + model_version = client.get_model_version(self.MODEL_NAME, mv.id) + + assert model_version.name == self.MODEL_NAME + assert model_version.version == self.VERSION_NAME + assert model_version.number == 1 + assert model_version.description == self.VERSION_DESC + + def test_get_model_version_by_index_found(self, warm_up_models): + with model_killer(): + client = Client() + model_version = client.get_model_version(self.MODEL_NAME, 1) + + assert model_version.name == self.MODEL_NAME + assert model_version.version == self.VERSION_NAME + assert model_version.number == 1 + assert model_version.description == self.VERSION_DESC + + def test_get_model_version_by_stage_found(self, warm_up_models): + with model_killer(): + client = Client() + + client.update_model_version( + model_name_or_id=self.MODEL_NAME, + version_name_or_id=self.VERSION_NAME, + stage=ModelStages.STAGING, + force=True, + ) + + model_version = client.get_model_version( + self.MODEL_NAME, ModelStages.STAGING + ) + + assert model_version.name == self.MODEL_NAME + assert model_version.version == self.VERSION_NAME + assert model_version.number == 1 + assert model_version.description == self.VERSION_DESC + + def test_get_model_version_by_stage_not_found(self, warm_up_models): + with model_killer(): + client = Client() + + with pytest.raises(KeyError): + client.get_model_version(self.MODEL_NAME, ModelStages.STAGING) + def test_get_model_version_not_found(self): with model_killer(): client = Client() @@ -1488,3 +1540,82 @@ def test_delete_model_version_not_found(self, warm_up_models): client.delete_model_version( self.MODEL_NAME, self.VERSION_NAME + "@" ) + + +def _create_some_model_version( + client: Client, + model_name: str = "aria_cat_supermodel", + model_version_name: str = "1.0.0", +) -> ModelVersionResponseModel: + model = client.create_model( + model=ModelRequestModel( + name=model_name, + user=client.active_user.id, + workspace=client.active_workspace.id, + ) + ) + return client.create_model_version( + ModelVersionRequestModel( + user=client.active_user.id, + workspace=client.active_workspace.id, + model=model.id, + name=model_version_name, + ) + ) + + +def test_get_by_latest(clean_client): + """Test that model version can be retrieved with latest.""" + + cl = Client() + mv1 = _create_some_model_version(client=cl) + + # latest returns the only model + mv2 = Client().get_model_version( + model_name_or_id=mv1.model.id, + model_version_name_or_number_or_id=ModelStages.LATEST, + ) + assert mv2 == mv1 + + # after second model version, latest should point to it + mv3 = cl.create_model_version(model_name_or_id=mv1.model.id, name="2.0.0") + mv4 = Client().get_model_version( + model_name_or_id=mv1.model.id, + model_version_name_or_number_or_id=ModelStages.LATEST, + ) + assert mv4 != mv1 + assert mv4 == mv3 + + +def test_get_by_stage(clean_client): + """Test that model version can be retrieved by stage.""" + + cl = Client() + mv1 = _create_some_model_version(client=cl) + + cl.update_model_version( + version_name_or_id=mv1.id, + model_name_or_id=mv1.model.id, + stage=ModelStages.STAGING, + force=True, + ) + + mv2 = cl.get_model_version( + model_name_or_id=mv1.model.id, + model_version_name_or_number_or_id=ModelStages.STAGING, + ) + + assert mv1 == mv2 + + +def test_stage_not_found(clean_client): + """Test that attempting to get model version fails if none at the given stage.""" + + cl = Client() + mv1 = _create_some_model_version(client=cl) + + with pytest.raises(KeyError): + Client().get_model_version( + model_name_or_id=mv1.model.id, + model_version_name_or_number_or_id=ModelStages.STAGING, + ) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index d242342d9b9..d64b4b1177c 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -4045,12 +4045,11 @@ def test_create_no_model(self): def test_get_not_found(self): """Test that get fails if not found.""" - with ModelVersionContext() as model: + with ModelVersionContext(): zs = Client().zen_store with pytest.raises(KeyError): zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id="1.0.0", + model_version_id=uuid4(), ) def test_get_found(self): @@ -4065,10 +4064,12 @@ def test_get_found(self): name="great one", ) ) - mv2 = zs.get_model_version( + mv2 = zs.list_model_versions( model_name_or_id=model.id, - model_version_name_or_number_or_id="great one", - ) + model_version_filter_model=ModelVersionFilterModel( + name="great one" + ), + ).items[0] assert mv1.id == mv2.id def test_list_empty(self): @@ -4111,19 +4112,18 @@ def test_list_not_empty(self): def test_delete_not_found(self): """Test that delete fails if not found.""" - with ModelVersionContext() as model: + with ModelVersionContext(): zs = Client().zen_store with pytest.raises(KeyError): zs.delete_model_version( - model_name_or_id=model.id, - model_version_name_or_id="1.0.0", + model_version_id=uuid4(), ) def test_delete_found(self): """Test that delete works, if model version exists.""" with ModelVersionContext() as model: zs = Client().zen_store - zs.create_model_version( + mv = zs.create_model_version( ModelVersionRequestModel( user=model.user.id, workspace=model.workspace.id, @@ -4132,14 +4132,15 @@ def test_delete_found(self): ) ) zs.delete_model_version( - model_name_or_id=model.id, - model_version_name_or_id="great one", + model_version_id=mv.id, ) - with pytest.raises(KeyError): - zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id="great one", - ) + mvl = zs.list_model_versions( + model_name_or_id=model.id, + model_version_filter_model=ModelVersionFilterModel( + name="great one" + ), + ).items + assert len(mvl) == 0 def test_update_not_found(self): """Test that update fails if not found.""" @@ -4183,15 +4184,19 @@ def test_update_not_forced(self): force=False, ), ) - mv2 = zs.get_model_version( + mv2 = zs.list_model_versions( model_name_or_id=model.id, - model_version_name_or_number_or_id="staging", - ) + model_version_filter_model=ModelVersionFilterModel( + stage="staging" + ), + ).items[0] assert mv1.id == mv2.id - mv3 = zs.get_model_version( + mv3 = zs.list_model_versions( model_name_or_id=model.id, - model_version_name_or_number_or_id=ModelStages.STAGING, - ) + model_version_filter_model=ModelVersionFilterModel( + stage=ModelStages.STAGING + ), + ).items[0] assert mv1.id == mv3.id def test_in_stage_not_found(self): @@ -4207,21 +4212,14 @@ def test_in_stage_not_found(self): ) ) - with pytest.raises(KeyError): - zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=ModelStages.STAGING, - ) + mvl = zs.list_model_versions( + model_name_or_id=model.id, + model_version_filter_model=ModelVersionFilterModel( + stage=ModelStages.STAGING + ), + ).items - def test_latest_not_found(self): - """Test that get latest fails if not found.""" - with ModelVersionContext() as model: - zs = Client().zen_store - with pytest.raises(KeyError): - zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=ModelStages.LATEST, - ) + assert len(mvl) == 0 def test_latest_found(self): """Test that get latest works, if model version exists.""" @@ -4244,8 +4242,8 @@ def test_latest_found(self): name="yet another one", ) ) - found_latest = zs.get_model_version( - model_name_or_id=model.id, + found_latest = Client().get_model_version( + model_name_or_id=model.id ) assert latest.id == found_latest.id @@ -4279,8 +4277,7 @@ def test_update_forced(self): ) assert ( zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv1.name, + model_version_id=mv1.id, ).stage == "staging" ) @@ -4296,22 +4293,19 @@ def test_update_forced(self): assert ( zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv1.name, + model_version_id=mv1.id, ).stage == "archived" ) assert ( zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv2.id, + model_version_id=mv2.id, ).stage == "staging" ) assert ( zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv2.id, + model_version_id=mv2.id, ).name == "I changed that..." ) @@ -4327,26 +4321,19 @@ def test_update_public_interface(self): model=model.id, ) ) - assert ( - zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv1.name, - ).stage - is None - ) + + assert mv1.stage is None mv1.set_stage("staging") assert ( zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv1.name, + model_version_id=mv1.id, ).stage == "staging" ) assert ( zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_number_or_id=mv1.id, + model_version_id=mv1.id, ).name == "1" ) @@ -4413,10 +4400,10 @@ def test_get_found_by_number(self): """Test that get works by integer version number.""" with ModelVersionContext(create_version=True) as model_version: zs = Client().zen_store - found = zs.get_model_version( + found = zs.list_model_versions( model_name_or_id=model_version.model.id, - model_version_name_or_number_or_id=1, - ) + model_version_filter_model=ModelVersionFilterModel(number=1), + ).items[0] assert found.id == model_version.id assert found.number == 1 assert found.name == model_version.name @@ -4425,18 +4412,13 @@ def test_get_not_found_by_number(self): """Test that get fails by integer version number, if not found and by string version number, cause treated as name.""" with ModelVersionContext(create_version=True) as model_version: zs = Client().zen_store - # no version numbered as 2 - with pytest.raises(KeyError): - zs.get_model_version( - model_name_or_id=model_version.model.id, - model_version_name_or_number_or_id=2, - ) - # cannot fetch by string number - treated as name - with pytest.raises(KeyError): - zs.get_model_version( - model_name_or_id=model_version.model.id, - model_version_name_or_number_or_id="1", - ) + + found = zs.list_model_versions( + model_name_or_id=model_version.model.id, + model_version_filter_model=ModelVersionFilterModel(number=2), + ).items + + assert len(found) == 0 class TestModelVersionArtifactLinks: @@ -4554,8 +4536,7 @@ def test_link_create_overwrite_deleted(self): assert al1.link_version == 1 assert al1.artifact == artifacts[0].id zs.delete_model_version_artifact_link( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_name_or_id=al1.id, ) al2 = zs.create_model_version_artifact_link( @@ -4644,8 +4625,7 @@ def test_link_create_single_version_of_same_output_name_from_different_steps( ) links = zs.list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( pipeline_name="pipeline", name="output", @@ -4673,13 +4653,11 @@ def test_link_delete_found(self): ) ) zs.delete_model_version_artifact_link( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_name_or_id="link", ) mvls = zs.list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel(), ) assert len(mvls) == 0 @@ -4689,8 +4667,7 @@ def test_link_delete_not_found(self): zs = Client().zen_store with pytest.raises(KeyError): zs.delete_model_version_artifact_link( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_name_or_id="link", ) @@ -4698,8 +4675,7 @@ def test_link_list_empty(self): with ModelVersionContext(True) as model_version: zs = Client().zen_store mvls = zs.list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel(), ) assert len(mvls) == 0 @@ -4711,8 +4687,7 @@ def test_link_list_populated(self): ): zs = Client().zen_store mvls = zs.list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel(), ) assert len(mvls) == 0 @@ -4737,15 +4712,13 @@ def test_link_list_populated(self): ) ) mvls = zs.list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel(), ) assert len(mvls) == len(artifacts) mvls = zs.list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( only_data_artifacts=True ), @@ -4757,8 +4730,7 @@ def test_link_list_populated(self): ) mvls = zs.list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( only_model_artifacts=True ), @@ -4766,8 +4738,7 @@ def test_link_list_populated(self): assert len(mvls) == 1 and mvls[0].name == "link2" mvls = zs.list_model_version_artifact_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel( only_endpoint_artifacts=True ), @@ -4775,8 +4746,7 @@ def test_link_list_populated(self): assert len(mvls) == 1 and mvls[0].name == "link3" mv = zs.get_model_version( - model_name_or_id=model_version.model.id, - model_version_name_or_number_or_id=model_version.id, + model_version_id=model_version.id, ) assert len(mv.model_artifact_ids) == 1 @@ -4901,12 +4871,9 @@ def test_link_delete_found(self): pipeline_run=prs[0].id, ) ) - zs.delete_model_version_pipeline_run_link( - model_version.model.id, model_version.id, "link" - ) + zs.delete_model_version_pipeline_run_link(model_version.id, "link") mvls = zs.list_model_version_pipeline_run_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_pipeline_run_link_filter_model=ModelVersionPipelineRunFilterModel(), ) assert len(mvls) == 0 @@ -4916,15 +4883,14 @@ def test_link_delete_not_found(self): zs = Client().zen_store with pytest.raises(KeyError): zs.delete_model_version_pipeline_run_link( - model_version.model.id, model_version.id, "link" + model_version.id, "link" ) def test_link_list_empty(self): with ModelVersionContext(True) as model_version: zs = Client().zen_store mvls = zs.list_model_version_pipeline_run_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_pipeline_run_link_filter_model=ModelVersionPipelineRunFilterModel(), ) assert len(mvls) == 0 @@ -4936,8 +4902,7 @@ def test_link_list_populated(self): ): zs = Client().zen_store mvls = zs.list_model_version_pipeline_run_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_pipeline_run_link_filter_model=ModelVersionPipelineRunFilterModel(), ) assert len(mvls) == 0 @@ -4953,15 +4918,13 @@ def test_link_list_populated(self): ) ) mvls = zs.list_model_version_pipeline_run_links( - model_name_or_id=model_version.model.id, - model_version_name_or_id=model_version.id, + model_version_id=model_version.id, model_version_pipeline_run_link_filter_model=ModelVersionPipelineRunFilterModel(), ) assert len(mvls) == 2 mv = zs.get_model_version( - model_name_or_id=model_version.model.id, - model_version_name_or_number_or_id=model_version.id, + model_version_id=model_version.id, ) assert len(mv.pipeline_run_ids) == 2 diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index f63f9838fd3..e07745600b8 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -662,9 +662,7 @@ def __enter__(self): model = client.create_model(name=self.model) if self.create_version: try: - mv = client.zen_store.get_model_version( - self.model, self.model_version - ) + mv = client._get_model_version(self.model, self.model_version) except KeyError: mv = client.zen_store.create_model_version( ModelVersionRequestModel(