Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restrict Python Version Mismatch between Pickled Object and Remote Envrionment #2848

Merged
merged 8 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 59 additions & 3 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import importlib
import re
from abc import ABC
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, TypeVar, Union

from flyteidl.core import tasks_pb2
Expand Down Expand Up @@ -282,6 +283,32 @@ def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type: ignore
default_task_resolver = DefaultTaskResolver()


@dataclass
class PickledEntityMetadata:
"""
Metadata for a pickled entity containing version information.

Attributes:
python_version: The Python version string (e.g. "3.12.0") used to create the pickle
"""

python_version: str


@dataclass
class PickledEntity:
"""
Represents the structure of the pickled object stored in the .pkl file for interactive mode.

Attributes:
metadata: Metadata about the pickled entities including Python version
entities: Dictionary mapping entity names to their PythonAutoContainerTask instances
"""

metadata: PickledEntityMetadata
entities: Dict[str, PythonAutoContainerTask]


class DefaultNotebookTaskResolver(TrackedInstance, TaskResolverMixin):
"""
This resolved is used when the task is defined in a notebook. It is used to load the task from the notebook.
Expand All @@ -294,12 +321,41 @@ def name(self) -> str:
def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask:
_, entity_name, *_ = loader_args
import gzip
import sys

import cloudpickle

with gzip.open(PICKLE_FILE_PATH, "r") as f:
entity_dict = cloudpickle.load(f)
return entity_dict[entity_name]
try:
with gzip.open(PICKLE_FILE_PATH, "r") as f:
loaded_data = cloudpickle.load(f)
except TypeError:
raise RuntimeError(
"The Python version is different from the version used to create the pickle file. "
f"Current Python version: {sys.version_info.major}.{sys.version_info.minor}. "
"Please try using the same Python version to create the pickle file or use another "
"container image with a matching version."
)

# verify the loaded_data is of the correct type
if not isinstance(loaded_data, PickledEntity):
raise RuntimeError(
"The loaded data is not of the correct type. Please ensure that the pickle file is not corrupted."
)
pickled_object: PickledEntity = loaded_data

pickled_version = pickled_object.metadata.python_version.split(".")
if sys.version_info.major != int(pickled_version[0]) or sys.version_info.minor != int(pickled_version[1]):
raise RuntimeError(
"The Python version used to create the pickle file is different from the current Python version. "
f"Current Python version: {sys.version_info.major}.{sys.version_info.minor}. "
f"Python version used to create the pickle file: {pickled_object.metadata.python_version}. "
"Please try using the same Python version to create the pickle file or use another "
"container image with a matching version."
)

if entity_name not in pickled_object.entities:
raise ValueError(f"Entity {entity_name} not found in the pickled object")
return pickled_object.entities[entity_name]

def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: # type:ignore
n, _, _, _ = extract_task_module(task)
Expand Down
5 changes: 4 additions & 1 deletion flytekit/remote/executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def outputs(self) -> Optional[LiteralsResolver]:
"Please wait until the execution has completed before requesting the outputs."
)
if self.error:
raise user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.")
raise user_exceptions.FlyteAssertion(
"Outputs could not be found because the execution ended in failure. Error message: "
f"{self.error.message}"
)

return self._outputs

Expand Down
19 changes: 14 additions & 5 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
from flytekit.core.node import Node as CoreNode
from flytekit.core.python_auto_container import (
PICKLE_FILE_PATH,
PickledEntity,
PickledEntityMetadata,
PythonAutoContainerTask,
default_notebook_task_resolver,
)
Expand Down Expand Up @@ -202,14 +204,21 @@ def _get_git_repo_url(source_path: str):

def _get_pickled_target_dict(
root_entity: typing.Union[WorkflowBase, PythonTask],
) -> typing.Tuple[bytes, typing.Dict[str, PythonAutoContainerTask]]:
) -> typing.Tuple[bytes, PickledEntity]:
"""
Get the pickled target dictionary for the entity.
:param root_entity: The entity to get the pickled target for.
:return: hashed bytes and the pickled target dictionary.
"""
import sys

queue: typing.List[typing.Union[WorkflowBase, PythonTask, CoreNode]] = [root_entity]
pickled_target_dict = {}
pickled_target_dict = PickledEntity(
metadata=PickledEntityMetadata(
python_version=f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
),
entities={},
)
while queue:
entity = queue.pop()
if isinstance(entity, PythonFunctionTask):
Expand All @@ -222,10 +231,10 @@ def _get_pickled_target_dict(
if isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)):
if isinstance(entity, ArrayNodeMapTask):
entity._run_task.set_resolver(default_notebook_task_resolver)
pickled_target_dict[entity._run_task.name] = entity._run_task
pickled_target_dict.entities[entity._run_task.name] = entity._run_task
else:
entity.set_resolver(default_notebook_task_resolver)
pickled_target_dict[entity.name] = entity
pickled_target_dict.entities[entity.name] = entity
elif isinstance(entity, WorkflowBase):
for task in entity.nodes:
queue.append(task)
Expand Down Expand Up @@ -2643,7 +2652,7 @@ def download(
def _pickle_and_upload_entity(
self,
entity: typing.Union[PythonTask, WorkflowBase],
pickled_dict: typing.Optional[typing.Dict[str, PythonAutoContainerTask]] = None,
pickled_dict: typing.Optional[PickledEntity] = None,
) -> FastSerializationSettings:
"""
Pickle the entity to the specified location. This is useful for debugging and for sharing entities across
Expand Down
28 changes: 26 additions & 2 deletions tests/flytekit/unit/core/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import cloudpickle
import mock
import pytest
import sys

import flytekit.configuration
from flytekit.configuration import Image, ImageConfig
from flytekit.core.base_task import TaskResolverMixin
from flytekit.core.class_based_resolver import ClassStorageTaskResolver
from flytekit.core.python_auto_container import default_task_resolver, default_notebook_task_resolver, PICKLE_FILE_PATH
from flytekit.core.python_auto_container import default_task_resolver, default_notebook_task_resolver, PickledEntity, PickledEntityMetadata
from flytekit.core.task import task
from flytekit.core.workflow import workflow
from flytekit.tools.translator import get_serializable
Expand Down Expand Up @@ -123,10 +124,33 @@ def t1(a: str, b: str) -> str:

assert c.loader_args(None, t1) == ["entity-name", "tests.flytekit.unit.core.test_resolver.t1"]

pickled_dict = {"tests.flytekit.unit.core.test_resolver.t1": t1}
pickled_dict = PickledEntity(
metadata=PickledEntityMetadata(
python_version=f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
),
entities={
"tests.flytekit.unit.core.test_resolver.t1": t1,
},
)

custom_pickled_object = cloudpickle.dumps(pickled_dict)
mock_gzip_open.return_value.read.return_value = custom_pickled_object
mock_cloudpickle.return_value = pickled_dict

t = c.load_task(["entity-name", "tests.flytekit.unit.core.test_resolver.t1"])
assert t == t1

mismatched_pickled_dict = PickledEntity(
metadata=PickledEntityMetadata(
python_version=f"{sys.version_info.major}.{sys.version_info.minor - 1}.{sys.version_info.micro}"
),
entities={
"tests.flytekit.unit.core.test_resolver.t1": t1,
},
)
mismatched_custom_pickled_object = cloudpickle.dumps(mismatched_pickled_dict)
mock_gzip_open.return_value.read.return_value = mismatched_custom_pickled_object
mock_cloudpickle.return_value = mismatched_pickled_dict

with pytest.raises(RuntimeError):
c.load_task(["entity-name", "tests.flytekit.unit.core.test_resolver.t1"])
25 changes: 17 additions & 8 deletions tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pathlib
import shutil
import subprocess
import sys
import tempfile
import typing
import uuid
Expand Down Expand Up @@ -707,11 +708,15 @@ def w() -> int:
return t2(a=t1())

_, target_dict = _get_pickled_target_dict(w)
assert len(target_dict) == 2
assert t1.name in target_dict
assert t2.name in target_dict
assert target_dict[t1.name] == t1
assert target_dict[t2.name] == t2
assert (
target_dict.metadata.python_version
== f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
)
assert len(target_dict.entities) == 2
assert t1.name in target_dict.entities
assert t2.name in target_dict.entities
assert target_dict.entities[t1.name] == t1
assert target_dict.entities[t2.name] == t2

def test_get_pickled_target_dict_with_map_task():
@task
Expand All @@ -723,9 +728,13 @@ def w() -> int:
return map_task(partial(t1, y=2))(x=[1, 2, 3])

_, target_dict = _get_pickled_target_dict(w)
assert len(target_dict) == 1
assert t1.name in target_dict
assert target_dict[t1.name] == t1
assert (
target_dict.metadata.python_version
== f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
)
assert len(target_dict.entities) == 1
assert t1.name in target_dict.entities
assert target_dict.entities[t1.name] == t1

def test_get_pickled_target_dict_with_dynamic():
@task
Expand Down
Loading