diff --git a/src/zenml/pipelines/pipeline_definition.py b/src/zenml/pipelines/pipeline_definition.py index 74b706e91f9..e72a4b78668 100644 --- a/src/zenml/pipelines/pipeline_definition.py +++ b/src/zenml/pipelines/pipeline_definition.py @@ -99,6 +99,7 @@ from zenml.config.source import Source from zenml.model.lazy_load import ModelVersionDataLazyLoader from zenml.model.model import Model + from zenml.models import ArtifactVersionResponse from zenml.types import HookSpecification StepConfigurationUpdateOrDict = Union[ @@ -1080,7 +1081,9 @@ def add_step_invocation( self, step: "BaseStep", input_artifacts: Dict[str, StepArtifact], - external_artifacts: Dict[str, "ExternalArtifact"], + external_artifacts: Dict[ + str, Union["ExternalArtifact", "ArtifactVersionResponse"] + ], model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"], client_lazy_loaders: Dict[str, "ClientLazyLoader"], parameters: Dict[str, Any], diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index 2a3d324608c..982b16e2529 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -78,6 +78,7 @@ ) from zenml.model.lazy_load import ModelVersionDataLazyLoader from zenml.model.model import Model + from zenml.models import ArtifactVersionResponse from zenml.types import HookSpecification MaterializerClassOrSource = Union[str, Source, Type["BaseMaterializer"]] @@ -307,7 +308,7 @@ def _parse_call_args( self, *args: Any, **kwargs: Any ) -> Tuple[ Dict[str, "StepArtifact"], - Dict[str, "ExternalArtifact"], + Dict[str, Union["ExternalArtifact", "ArtifactVersionResponse"]], Dict[str, "ModelVersionDataLazyLoader"], Dict[str, "ClientLazyLoader"], Dict[str, Any], @@ -328,6 +329,7 @@ def _parse_call_args( from zenml.artifacts.external_artifact import ExternalArtifact from zenml.model.lazy_load import ModelVersionDataLazyLoader from zenml.models.v2.core.artifact_version import ( + ArtifactVersionResponse, LazyArtifactVersionResponse, ) from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse @@ -342,7 +344,9 @@ def _parse_call_args( ) from e artifacts = {} - external_artifacts = {} + external_artifacts: Dict[ + str, Union["ExternalArtifact", "ArtifactVersionResponse"] + ] = {} model_artifacts_or_metadata = {} client_lazy_loaders = {} parameters = {} @@ -378,6 +382,8 @@ def _parse_call_args( artifact_version=value.lazy_load_version, metadata_name=None, ) + elif isinstance(value, ArtifactVersionResponse): + external_artifacts[key] = value elif isinstance(value, LazyRunMetadataResponse): model_artifacts_or_metadata[key] = ModelVersionDataLazyLoader( model_name=value.lazy_load_model_name, diff --git a/src/zenml/steps/step_invocation.py b/src/zenml/steps/step_invocation.py index 4124caa1a9c..17341d40845 100644 --- a/src/zenml/steps/step_invocation.py +++ b/src/zenml/steps/step_invocation.py @@ -13,7 +13,9 @@ # permissions and limitations under the License. """Step invocation class definition.""" -from typing import TYPE_CHECKING, Any, Dict, Set +from typing import TYPE_CHECKING, Any, Dict, Set, Union + +from zenml.models import ArtifactVersionResponse if TYPE_CHECKING: from zenml.artifacts.external_artifact import ExternalArtifact @@ -33,7 +35,9 @@ def __init__( id: str, step: "BaseStep", input_artifacts: Dict[str, "StepArtifact"], - external_artifacts: Dict[str, "ExternalArtifact"], + external_artifacts: Dict[ + str, Union["ExternalArtifact", "ArtifactVersionResponse"] + ], model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"], client_lazy_loaders: Dict[str, "ClientLazyLoader"], parameters: Dict[str, Any], @@ -101,9 +105,14 @@ def finalize(self, parameters_to_ignore: Set[str]) -> "StepConfiguration": external_artifacts: Dict[str, ExternalArtifactConfiguration] = {} for key, artifact in self.external_artifacts.items(): - if artifact.value is not None: - artifact.upload_by_value() - external_artifacts[key] = artifact.config + if isinstance(artifact, ArtifactVersionResponse): + external_artifacts[key] = ExternalArtifactConfiguration( + id=artifact.id + ) + else: + if artifact.value is not None: + artifact.upload_by_value() + external_artifacts[key] = artifact.config return self.step._finalize_configuration( input_artifacts=self.input_artifacts, diff --git a/tests/unit/steps/test_base_step.py b/tests/unit/steps/test_base_step.py index 4998eee4dbb..5a5d9fcf9c3 100644 --- a/tests/unit/steps/test_base_step.py +++ b/tests/unit/steps/test_base_step.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. import sys from contextlib import ExitStack as does_not_raise -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import pytest from pydantic import BaseModel @@ -790,7 +790,12 @@ def test_configure_pipeline_with_hooks(one_step_pipeline): @step -def step_with_int_input(input_: int) -> int: +def step_with_int_input( + input_: int, expected_value: Optional[int] = None +) -> int: + if expected_value is not None: + assert input_ == expected_value + return input_ @@ -1038,3 +1043,18 @@ def test_pipeline(): with does_not_raise(): test_pipeline() + + +def test_artifact_version_as_step_input(clean_client): + """Test passing an artifact version as step input.""" + from zenml import save_artifact + + artifact_value = 3 + artifact = save_artifact(artifact_value, name="test") + + @pipeline + def test_pipeline(): + step_with_int_input(input_=artifact, expected_value=artifact_value) + + with does_not_raise(): + test_pipeline()