Skip to content

Commit

Permalink
Allow artifact response as step input
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi committed Oct 24, 2024
1 parent af525a9 commit b7012ed
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
5 changes: 4 additions & 1 deletion src/zenml/pipelines/pipeline_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
from zenml.config.source import Source
from zenml.model.lazy_load import ModelVersionDataLazyLoader
from zenml.model.model import Model
from zenml.models import ArtifactVersionResponse
from zenml.types import HookSpecification

StepConfigurationUpdateOrDict = Union[
Expand Down Expand Up @@ -1080,7 +1081,9 @@ def add_step_invocation(
self,
step: "BaseStep",
input_artifacts: Dict[str, StepArtifact],
external_artifacts: Dict[str, "ExternalArtifact"],
external_artifacts: Dict[
str, Union["ExternalArtifact", "ArtifactVersionResponse"]
],
model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"],
client_lazy_loaders: Dict[str, "ClientLazyLoader"],
parameters: Dict[str, Any],
Expand Down
6 changes: 5 additions & 1 deletion src/zenml/steps/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
)
from zenml.model.lazy_load import ModelVersionDataLazyLoader
from zenml.model.model import Model
from zenml.models import ArtifactVersionResponse
from zenml.types import HookSpecification

MaterializerClassOrSource = Union[str, Source, Type["BaseMaterializer"]]
Expand Down Expand Up @@ -307,7 +308,7 @@ def _parse_call_args(
self, *args: Any, **kwargs: Any
) -> Tuple[
Dict[str, "StepArtifact"],
Dict[str, "ExternalArtifact"],
Dict[str, Union["ExternalArtifact", "ArtifactVersionResponse"]],
Dict[str, "ModelVersionDataLazyLoader"],
Dict[str, "ClientLazyLoader"],
Dict[str, Any],
Expand All @@ -328,6 +329,7 @@ def _parse_call_args(
from zenml.artifacts.external_artifact import ExternalArtifact
from zenml.model.lazy_load import ModelVersionDataLazyLoader
from zenml.models.v2.core.artifact_version import (
ArtifactVersionResponse,
LazyArtifactVersionResponse,
)
from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse
Expand Down Expand Up @@ -378,6 +380,8 @@ def _parse_call_args(
artifact_version=value.lazy_load_version,
metadata_name=None,
)
elif isinstance(value, ArtifactVersionResponse):
external_artifacts[key] = value
elif isinstance(value, LazyRunMetadataResponse):
model_artifacts_or_metadata[key] = ModelVersionDataLazyLoader(
model_name=value.lazy_load_model_name,
Expand Down
19 changes: 14 additions & 5 deletions src/zenml/steps/step_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# permissions and limitations under the License.
"""Step invocation class definition."""

from typing import TYPE_CHECKING, Any, Dict, Set
from typing import TYPE_CHECKING, Any, Dict, Set, Union

from zenml.models import ArtifactVersionResponse

if TYPE_CHECKING:
from zenml.artifacts.external_artifact import ExternalArtifact
Expand All @@ -33,7 +35,9 @@ def __init__(
id: str,
step: "BaseStep",
input_artifacts: Dict[str, "StepArtifact"],
external_artifacts: Dict[str, "ExternalArtifact"],
external_artifacts: Dict[
str, Union["ExternalArtifact", "ArtifactVersionResponse"]
],
model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"],
client_lazy_loaders: Dict[str, "ClientLazyLoader"],
parameters: Dict[str, Any],
Expand Down Expand Up @@ -101,9 +105,14 @@ def finalize(self, parameters_to_ignore: Set[str]) -> "StepConfiguration":

external_artifacts: Dict[str, ExternalArtifactConfiguration] = {}
for key, artifact in self.external_artifacts.items():
if artifact.value is not None:
artifact.upload_by_value()
external_artifacts[key] = artifact.config
if isinstance(artifact, ArtifactVersionResponse):
external_artifacts[key] = ExternalArtifactConfiguration(
id=artifact.id
)
else:
if artifact.value is not None:
artifact.upload_by_value()
external_artifacts[key] = artifact.config

return self.step._finalize_configuration(
input_artifacts=self.input_artifacts,
Expand Down

0 comments on commit b7012ed

Please sign in to comment.