Skip to content

Commit

Permalink
Register and launch single task execution (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan authored Jun 22, 2020
1 parent eec85fb commit 58fd9cd
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 14 deletions.
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

import flytekit.plugins

__version__ = '0.9.2'
__version__ = '0.9.3'
2 changes: 1 addition & 1 deletion flytekit/common/mixins/launchable.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def launch(self, project, domain, inputs=None, name=None, notification_overrides
:rtype: T
"""
return self.execute_with_literals(
return self.launch_with_literals(
project,
domain,
self._python_std_input_map_to_literal_map(inputs or {}),
Expand Down
17 changes: 13 additions & 4 deletions flytekit/common/mixins/registerable.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,19 @@ def some_task()
m = _importlib.import_module(self.instantiated_in)

for k in dir(m):
if getattr(m, k) == self:
self._platform_valid_name = _utils.fqdn(m.__name__, k, entity_type=self.resource_type)
_logging.debug("Auto-assigning name to {}".format(self._platform_valid_name))
return
try:
if getattr(m, k) == self:
self._platform_valid_name = _utils.fqdn(m.__name__, k, entity_type=self.resource_type)
_logging.debug("Auto-assigning name to {}".format(self._platform_valid_name))
return
except ValueError as err:
# Empty pandas dataframes behave weirdly here such that calling `m.df` raises:
# ValueError: The truth value of a {type(self).__name__} is ambiguous. Use a.empty, a.bool(), a.item(),
# a.any() or a.all()
# Since dataframes aren't registrable entities to begin with we swallow any errors they raise and
# continue looping through m.
_logging.warning("Caught ValueError {} while attempting to auto-assign name".format(err))
pass

_logging.error("Could not auto-assign name")
raise _system_exceptions.FlyteSystemException("Error looking for object while auto-assigning name.")
39 changes: 33 additions & 6 deletions flytekit/common/tasks/presto_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import absolute_import

import six as _six


from google.protobuf.json_format import MessageToDict as _MessageToDict
from flytekit import __version__

Expand All @@ -14,8 +17,7 @@
from flytekit.common import interface as _interface
import datetime as _datetime
from flytekit.models import presto as _presto_models
from flytekit.common.exceptions.user import \
FlyteValueException as _FlyteValueException
from flytekit.common.types import helpers as _type_helpers
from flytekit.common.exceptions import scopes as _exception_scopes


Expand All @@ -37,6 +39,7 @@ def __init__(
discovery_version=None,
retries=1,
timeout=None,
deprecated=None
):
"""
:param Text statement: Presto query specification
Expand All @@ -49,6 +52,8 @@ def __init__(
:param Text discovery_version: String describing the version for task discovery purposes
:param int retries: Number of retries to attempt
:param datetime.timedelta timeout:
:param Text deprecated: This string can be used to mark the task as deprecated. Consumers of the task will
receive deprecation warnings.
"""

# Set as class fields which are used down below to configure implicit
Expand All @@ -67,7 +72,7 @@ def __init__(
_literals.RetryStrategy(retries),
interruptible,
discovery_version,
"This is deprecated!"
deprecated
)

presto_query = _presto_models.PrestoQuery(
Expand Down Expand Up @@ -112,16 +117,38 @@ def __init__(
# Set user provided inputs
task_inputs(self)

def _add_implicit_inputs(self, inputs):
"""
:param dict[Text,Any] inputs:
:param inputs:
:return:
"""
inputs["__implicit_routing_group"] = self.routing_group
inputs["__implicit_catalog"] = self.catalog
inputs["__implicit_schema"] = self.schema
return inputs

# Override method in order to set the implicit inputs
def __call__(self, *args, **kwargs):
kwargs["__implicit_routing_group"] = self.routing_group
kwargs["__implicit_catalog"] = self.catalog
kwargs["__implicit_schema"] = self.schema
kwargs = self._add_implicit_inputs(kwargs)

return super(SdkPrestoTask, self).__call__(
*args, **kwargs
)

# Override method in order to set the implicit inputs
def _python_std_input_map_to_literal_map(self, inputs):
"""
:param dict[Text,Any] inputs: A dictionary of Python standard inputs that will be type-checked and compiled
to a LiteralMap
:rtype: flytekit.models.literals.LiteralMap
"""
inputs = self._add_implicit_inputs(inputs)
return _type_helpers.pack_python_std_map_to_literal_map(inputs, {
k: _type_helpers.get_sdk_type_from_literal_type(v.type)
for k, v in _six.iteritems(self.interface.inputs)
})

@_exception_scopes.system_entry_point
def add_inputs(self, inputs):
"""
Expand Down
62 changes: 61 additions & 1 deletion flytekit/common/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

import six as _six

from google.protobuf import json_format as _json_format, struct_pb2 as _struct

import hashlib as _hashlib
import json as _json

from flytekit.common import (
interface as _interfaces, nodes as _nodes, sdk_bases as _sdk_bases, workflow_execution as _workflow_execution
)
Expand All @@ -14,7 +19,7 @@
from flytekit.engines import loader as _engine_loader
from flytekit.models import common as _common_model, task as _task_model
from flytekit.models.core import workflow as _workflow_model, identifier as _identifier_model
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.exceptions import user as _user_exceptions, system as _system_exceptions
from flytekit.common.types import helpers as _type_helpers


Expand Down Expand Up @@ -268,6 +273,61 @@ def _python_std_input_map_to_literal_map(self, inputs):
for k, v in _six.iteritems(self.interface.inputs)
})

def _produce_deterministic_version(self, version=None):
"""
:param Text version:
:return Text:
"""

if self.container is not None and self.container.data_config is None:
# Only in the case of raw container tasks (which are the only valid tasks with container definitions that
# can assign a client-side task version) their data config will be None.
raise ValueError("Client-side task versions are not supported for {} task type".format(self.type))
if version is not None:
return version
custom = _json_format.Parse(_json.dumps(self.custom, sort_keys=True), _struct.Struct()) if self.custom else None

# The task body is the entirety of the task template MINUS the identifier. The identifier is omitted because
# 1) this method is used to compute the version portion of the identifier and
# 2 ) the SDK will actually generate a unique name on every task instantiation which is not great for
# the reproducibility this method attempts.
task_body = (self.type, self.metadata.to_flyte_idl().SerializeToString(deterministic=True),
self.interface.to_flyte_idl().SerializeToString(deterministic=True), custom)
return _hashlib.md5(str(task_body).encode('utf-8')).hexdigest()

@_exception_scopes.system_entry_point
def register_and_launch(self, project, domain, name=None, version=None, inputs=None):
"""
:param Text project: The project in which to register and launch this task.
:param Text domain: The domain in which to register and launch this task.
:param Text name: The name to give this task.
:param Text version: The version in which to register this task
:param dict[Text, Any] inputs: A dictionary of Python standard inputs that will be type-checked, then compiled
to a LiteralMap.
:rtype: flytekit.common.workflow_execution.SdkWorkflowExecution
"""
self.validate()
version = self._produce_deterministic_version(version)

if name is None:
try:
self.auto_assign_name()
generated_name = self._platform_valid_name
except _system_exceptions.FlyteSystemException:
# If we're not able to assign a platform valid name, use the deterministically-produced version instead.
generated_name = version
name = name if name else generated_name
id_to_register = _identifier.Identifier(_identifier_model.ResourceType.TASK, project, domain, name, version)
old_id = self.id
try:
self._id = id_to_register
_engine_loader.get_engine().get_task(self).register(id_to_register)
except:
self._id = old_id
raise
return self.launch(project, domain, inputs=inputs)

@_exception_scopes.system_entry_point
def launch_with_literals(self, project, domain, literal_inputs, name=None, notification_overrides=None,
label_overrides=None, annotation_overrides=None):
Expand Down
2 changes: 1 addition & 1 deletion flytekit/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __str__(self):
return self.verbose_string()

def __hash__(self):
return hash(self.to_flyte_idl().SerializeToString())
return hash(self.to_flyte_idl().SerializeToString(deterministic=True))

def short_string(self):
"""
Expand Down
34 changes: 34 additions & 0 deletions tests/flytekit/unit/common_tests/tasks/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from flytekit.models.core import identifier as _identifier
from flytekit.sdk.tasks import python_task, inputs, outputs
from flyteidl.admin import task_pb2 as _admin_task_pb2
from flytekit.common.tasks.presto_task import SdkPrestoTask
from flytekit.sdk.types import Types


@_patch("flytekit.engines.loader.get_engine")
Expand Down Expand Up @@ -68,3 +70,35 @@ def test_task_serialization():
assert isinstance(s, _admin_task_pb2.TaskSpec)
assert s.template.id.name == 'tests.flytekit.unit.common_tests.tasks.test_task.my_task'
assert s.template.container.image == 'myflyteimage:v123'


schema = Types.Schema([("a", Types.String), ("b", Types.Integer)])


def test_task_produce_deterministic_version():
containerless_task = SdkPrestoTask(
task_inputs=inputs(ds=Types.String, rg=Types.String),
statement="SELECT * FROM flyte.widgets WHERE ds = '{{ .Inputs.ds}}' LIMIT 10",
output_schema=schema,
routing_group="{{ .Inputs.rg }}",
)
identical_containerless_task = SdkPrestoTask(
task_inputs=inputs(ds=Types.String, rg=Types.String),
statement="SELECT * FROM flyte.widgets WHERE ds = '{{ .Inputs.ds}}' LIMIT 10",
output_schema=schema,
routing_group="{{ .Inputs.rg }}",
)
different_containerless_task = SdkPrestoTask(
task_inputs=inputs(ds=Types.String, rg=Types.String),
statement="SELECT * FROM flyte.widgets WHERE ds = '{{ .Inputs.ds}}' LIMIT 100000",
output_schema=schema,
routing_group="{{ .Inputs.rg }}",
)
assert containerless_task._produce_deterministic_version() ==\
identical_containerless_task._produce_deterministic_version()

assert containerless_task._produce_deterministic_version() !=\
different_containerless_task._produce_deterministic_version()

with _pytest.raises(Exception):
get_sample_task()._produce_deterministic_version()

0 comments on commit 58fd9cd

Please sign in to comment.