From b7012ede53f1d268b4b791803428158c805e6c03 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 24 Oct 2024 11:03:50 +0200 Subject: [PATCH] Allow artifact response as step input --- src/zenml/pipelines/pipeline_definition.py | 5 ++++- src/zenml/steps/base_step.py | 6 +++++- src/zenml/steps/step_invocation.py | 19 ++++++++++++++----- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/zenml/pipelines/pipeline_definition.py b/src/zenml/pipelines/pipeline_definition.py index 74b706e91f9..e72a4b78668 100644 --- a/src/zenml/pipelines/pipeline_definition.py +++ b/src/zenml/pipelines/pipeline_definition.py @@ -99,6 +99,7 @@ from zenml.config.source import Source from zenml.model.lazy_load import ModelVersionDataLazyLoader from zenml.model.model import Model + from zenml.models import ArtifactVersionResponse from zenml.types import HookSpecification StepConfigurationUpdateOrDict = Union[ @@ -1080,7 +1081,9 @@ def add_step_invocation( self, step: "BaseStep", input_artifacts: Dict[str, StepArtifact], - external_artifacts: Dict[str, "ExternalArtifact"], + external_artifacts: Dict[ + str, Union["ExternalArtifact", "ArtifactVersionResponse"] + ], model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"], client_lazy_loaders: Dict[str, "ClientLazyLoader"], parameters: Dict[str, Any], diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index 2a3d324608c..6348a486570 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -78,6 +78,7 @@ ) from zenml.model.lazy_load import ModelVersionDataLazyLoader from zenml.model.model import Model + from zenml.models import ArtifactVersionResponse from zenml.types import HookSpecification MaterializerClassOrSource = Union[str, Source, Type["BaseMaterializer"]] @@ -307,7 +308,7 @@ def _parse_call_args( self, *args: Any, **kwargs: Any ) -> Tuple[ Dict[str, "StepArtifact"], - Dict[str, "ExternalArtifact"], + Dict[str, Union["ExternalArtifact", "ArtifactVersionResponse"]], Dict[str, "ModelVersionDataLazyLoader"], Dict[str, "ClientLazyLoader"], Dict[str, Any], @@ -328,6 +329,7 @@ def _parse_call_args( from zenml.artifacts.external_artifact import ExternalArtifact from zenml.model.lazy_load import ModelVersionDataLazyLoader from zenml.models.v2.core.artifact_version import ( + ArtifactVersionResponse, LazyArtifactVersionResponse, ) from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse @@ -378,6 +380,8 @@ def _parse_call_args( artifact_version=value.lazy_load_version, metadata_name=None, ) + elif isinstance(value, ArtifactVersionResponse): + external_artifacts[key] = value elif isinstance(value, LazyRunMetadataResponse): model_artifacts_or_metadata[key] = ModelVersionDataLazyLoader( model_name=value.lazy_load_model_name, diff --git a/src/zenml/steps/step_invocation.py b/src/zenml/steps/step_invocation.py index 4124caa1a9c..17341d40845 100644 --- a/src/zenml/steps/step_invocation.py +++ b/src/zenml/steps/step_invocation.py @@ -13,7 +13,9 @@ # permissions and limitations under the License. """Step invocation class definition.""" -from typing import TYPE_CHECKING, Any, Dict, Set +from typing import TYPE_CHECKING, Any, Dict, Set, Union + +from zenml.models import ArtifactVersionResponse if TYPE_CHECKING: from zenml.artifacts.external_artifact import ExternalArtifact @@ -33,7 +35,9 @@ def __init__( id: str, step: "BaseStep", input_artifacts: Dict[str, "StepArtifact"], - external_artifacts: Dict[str, "ExternalArtifact"], + external_artifacts: Dict[ + str, Union["ExternalArtifact", "ArtifactVersionResponse"] + ], model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"], client_lazy_loaders: Dict[str, "ClientLazyLoader"], parameters: Dict[str, Any], @@ -101,9 +105,14 @@ def finalize(self, parameters_to_ignore: Set[str]) -> "StepConfiguration": external_artifacts: Dict[str, ExternalArtifactConfiguration] = {} for key, artifact in self.external_artifacts.items(): - if artifact.value is not None: - artifact.upload_by_value() - external_artifacts[key] = artifact.config + if isinstance(artifact, ArtifactVersionResponse): + external_artifacts[key] = ExternalArtifactConfiguration( + id=artifact.id + ) + else: + if artifact.value is not None: + artifact.upload_by_value() + external_artifacts[key] = artifact.config return self.step._finalize_configuration( input_artifacts=self.input_artifacts,