diff --git a/model/fn-execution/src/main/proto/beam_task_worker.proto b/model/fn-execution/src/main/proto/beam_task_worker.proto new file mode 100644 index 000000000000..c3ab5c7a6fc4 --- /dev/null +++ b/model/fn-execution/src/main/proto/beam_task_worker.proto @@ -0,0 +1,137 @@ +/* + * Protocol Buffers describing a custom bundle processer used by Task Worker. + */ + +syntax = "proto3"; + +package org.apache.beam.model.fn_execution.v1; + +option go_package = "fnexecution_v1"; +option java_package = "org.apache.beam.model.fnexecution.v1"; +//option java_outer_classname = "JobApi"; + +import "beam_fn_api.proto"; +import "endpoints.proto"; +import "google/protobuf/struct.proto"; +import "metrics.proto"; + + +// +// Control Plane +// + +// An API that describes the work that a bundle processor task worker is meant to do. +service TaskControl { + + // Instructions sent by the SDK to the task worker requesting different types + // of work. + rpc Control ( + // A stream of responses to instructions the task worker was asked to be + // performed. + stream TaskInstructionResponse) + returns ( + // A stream of instructions requested of the task worker to be performed. + stream TaskInstructionRequest); +} + +// A request sent by SDK which the task worker is asked to fulfill. +message TaskInstructionRequest { + // (Required) An unique identifier provided by the SDK which represents + // this requests execution. The InstructionResponse MUST have the matching id. + string instruction_id = 1; + + // (Required) A request that the task worker needs to interpret. + oneof request { + CreateRequest create = 1000; + ProcessorProcessBundleRequest process_bundle = 1001; + ShutdownRequest shutdown = 1002; + } +} + +// The response for an associated request the task worker had been asked to fulfill. +message TaskInstructionResponse { + + // (Required) A reference provided by the SDK which represents a requests + // execution. The InstructionResponse MUST have the matching id when + // responding to the SDK. + string instruction_id = 1; + + // An equivalent response type depending on the request this matches. + oneof response { + CreateResponse create = 1000; + ProcessorProcessBundleResponse process_bundle = 1001; + ShutdownResponse shutdown = 1002; + } + + // (Optional) If there's error processing request + string error = 2; +} + +message ChannelCredentials { + string _credentials = 1; +} + +message GrpcClientDataChannelFactory { + ChannelCredentials credentials = 1; + string worker_id = 2; + string transmitter_url = 3; +} + +message CreateRequest { + org.apache.beam.model.fn_execution.v1.ProcessBundleDescriptor process_bundle_descriptor = 1; // (required) + org.apache.beam.model.pipeline.v1.ApiServiceDescriptor state_handler_endpoint = 2; // (required) + GrpcClientDataChannelFactory data_factory = 3; // (required) +} + +message CreateResponse { +} + +message ProcessorProcessBundleRequest { + // (Optional) The cache token that can be used by an SDK to reuse + // cached data returned by the State API across multiple bundles. + string cache_token = 1; +} + +message ProcessorProcessBundleResponse { + repeated org.apache.beam.model.fn_execution.v1.DelayedBundleApplication delayed_applications = 1; + bool require_finalization = 2; +} + +message ShutdownRequest { +} + +message ShutdownResponse { +} + + +// +// Data Plane +// + +service TaskFnData { + // Handles data transferring between TaskWorkerHandler and Task Worker. + rpc Receive( + ReceiveRequest) + returns ( + // A stream of data representing output. + stream Elements.Data); + + + // Used to send data from proxy bundle processor to sdk harness + rpc Send (SendRequest) returns (SendResponse); +} + +message ReceiveRequest { + string instruction_id = 1; + string client_data_endpoint = 2; +} + +message SendRequest { + string instruction_id = 1; + string client_data_endpoint = 2; + Elements.Data data = 3; +} + +message SendResponse { + string error = 1; +} diff --git a/sdks/python/apache_beam/examples/taskworker/k8s.py b/sdks/python/apache_beam/examples/taskworker/k8s.py new file mode 100644 index 000000000000..749c1915a7e2 --- /dev/null +++ b/sdks/python/apache_beam/examples/taskworker/k8s.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Sample pipeline showcasing how to use the kubernetes task worker. +""" + +from __future__ import absolute_import + +import argparse +import logging +from typing import TYPE_CHECKING + +import apache_beam as beam +from apache_beam.options.pipeline_options import DirectOptions +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.task.task_worker.kubejob.transforms import KubeTask + +if TYPE_CHECKING: + from typing import Any + from apache_beam.task.task_worker.kubejob.handler import KubePayload + + +def process(element): + # type: (int) -> int + import time + start = time.time() + print('processing...') + # represents an expensive process of some sort + time.sleep(element) + print('processing took {:0.6f} s'.format(time.time() - start)) + return element + + +def _per_element_wrapper(element, payload): + # type: (Any, KubePayload) -> KubePayload + """ + Callback to modify the kubernetes job properties per-element. + """ + payload.job_name += '-{}'.format(element) + return payload + + +def run(argv=None, save_main_session=True): + """ + Run a pipeline that submits two kubernetes jobs, each that simply sleep + """ + parser = argparse.ArgumentParser() + known_args, pipeline_args = parser.parse_known_args(argv) + + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + pipeline_options.view_as(DirectOptions).direct_running_mode = 'multi_processing' + + with beam.Pipeline(options=pipeline_options) as pipe: + ( + pipe + | beam.Create([20, 42]) + | 'kubetask' >> KubeTask( + beam.Map(process), + wrapper=_per_element_wrapper) + ) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/sdks/python/apache_beam/examples/taskworker/local.py b/sdks/python/apache_beam/examples/taskworker/local.py new file mode 100644 index 000000000000..4c33c852d7ca --- /dev/null +++ b/sdks/python/apache_beam/examples/taskworker/local.py @@ -0,0 +1,79 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Basic pipeline to test using TaskWorkers. +""" + +from __future__ import absolute_import + +import argparse +import logging + +import apache_beam as beam +from apache_beam.options.pipeline_options import DirectOptions +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions + +from apache_beam.runners.worker.task_worker.core import BeamTask + + +class TestFn(beam.DoFn): + + def process(self, element, side): + from apache_beam.runners.worker.task_worker.handlers import TaskableValue + + for s in side: + if isinstance(element, TaskableValue): + value = element.value + else: + value = element + print(value + s) + yield value + s + + +def run(argv=None, save_main_session=True): + """ + Run a pipeline that executes each element using the local task worker. + """ + parser = argparse.ArgumentParser() + known_args, pipeline_args = parser.parse_known_args(argv) + + # We use the save_main_session option because one or more DoFn's in this + # workflow rely on global context (e.g., a module imported at module level). + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + pipeline_options.view_as(DirectOptions).direct_running_mode = 'multi_processing' + + with beam.Pipeline(options=pipeline_options) as pipe: + + A = ( + pipe + | 'A' >> beam.Create(range(3)) + ) + + B = ( + pipe + | beam.Create(range(2)) + | BeamTask(beam.ParDo(TestFn(), beam.pvalue.AsList(A)), + wrapper=lambda x, _: x) + ) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 3e600e2d5b95..6c321642f3ab 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -823,7 +823,8 @@ class BundleProcessor(object): def __init__(self, process_bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor state_handler, # type: sdk_worker.CachingStateHandler - data_channel_factory # type: data_plane.DataChannelFactory + data_channel_factory, # type: data_plane.DataChannelFactory + use_task_worker=True # type: bool ): # type: (...) -> None @@ -834,10 +835,12 @@ def __init__(self, a description of the stage that this ``BundleProcessor``is to execute. state_handler (CachingStateHandler). data_channel_factory (``data_plane.DataChannelFactory``). + use_task_worker : whether to engage the task worker system when processing """ self.process_bundle_descriptor = process_bundle_descriptor self.state_handler = state_handler self.data_channel_factory = data_channel_factory + self.use_task_worker = use_task_worker # There is no guarantee that the runner only set # timer_api_service_descriptor when having timers. So this field cannot be @@ -980,20 +983,9 @@ def process_bundle(self, instruction_id): self.ops[transform_id].add_timer_info(timer_family_id, timer_info) # Process data and timer inputs - for data_channel, expected_inputs in data_channels.items(): - for element in data_channel.input_elements(instruction_id, - expected_inputs): - if isinstance(element, beam_fn_api_pb2.Elements.Timers): - timer_coder_impl = ( - self.timers_info[( - element.transform_id, - element.timer_family_id)].timer_coder_impl) - for timer_data in timer_coder_impl.decode_all(element.timers): - self.ops[element.transform_id].process_timer( - element.timer_family_id, timer_data) - elif isinstance(element, beam_fn_api_pb2.Elements.Data): - input_op_by_transform_id[element.transform_id].process_encoded( - element.data) + delayed_applications, requires_finalization = \ + self.maybe_process_remotely(data_channels, instruction_id, + input_op_by_transform_id) # Finish all operations. for op in self.ops.values(): @@ -1005,11 +997,15 @@ def process_bundle(self, instruction_id): assert timer_info.output_stream is not None timer_info.output_stream.close() - return ([ - self.delayed_bundle_application(op, residual) for op, - residual in execution_context.delayed_applications - ], - self.requires_finalization()) + if requires_finalization is None: + requires_finalization = self.requires_finalization() + + if delayed_applications is None: + delayed_applications = [self.delayed_bundle_application(op, residual) + for op, residual in + execution_context.delayed_applications] + + return delayed_applications, requires_finalization finally: # Ensure any in-flight split attempts complete. @@ -1017,6 +1013,80 @@ def process_bundle(self, instruction_id): pass self.state_sampler.stop_if_still_running() + def maybe_process_remotely(self, + data_channels, # type: DefaultDict[DataChannel, list[str]] + instruction_id, # type: str + input_op_by_transform_id, # type: Dict[str, DataInputOperation] + ): + # type: (...) -> Union[Tuple[None, None], Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool]] + """Process the current bundle remotely with task workers, if applicable. + + Processes remotely if there are TaskableValues detected from the input of this bundle + and task workers are allowed to be used. + """ + from apache_beam.runners.worker.task_worker.handlers import BundleProcessorTaskHelper + from apache_beam.runners.worker.task_worker.handlers import get_taskable_value + + wrapped_values = collections.defaultdict(list) # type: DefaultDict[str, List[Tuple[Any, bytes]]] + + for data_channel, expected_transforms in data_channels.items(): + for data in data_channel.input_elements( + instruction_id, expected_transforms): + if isinstance(data, beam_fn_api_pb2.Elements.Timers): + timer_coder_impl = ( + self.timers_info[( + data.transform_id, + data.timer_family_id)].timer_coder_impl) + for timer_data in timer_coder_impl.decode_all(data.timers): + self.ops[data.transform_id].process_timer( + data.timer_family_id, timer_data) + elif isinstance(data, beam_fn_api_pb2.Elements.Data): + input_op = input_op_by_transform_id[data.transform_id] + + # process normally if not using task worker + if self.use_task_worker is False: + input_op.process_encoded(data.data) + continue + + # decode inputs to inspect if it is wrapped + input_stream = coder_impl.create_InputStream(data.data) + # TODO: Come up with a better solution here? + # here we maintain two separate input stream, because `.pos` is not + # accessible on the cython version of InputStream object, so we can't + # re-use the same input stream object and edit the pos to move the + # read handle back + raw_input_stream = coder_impl.create_InputStream(data.data) + + while input_stream.size() > 0: + starting_size = input_stream.size() + decoded_value = input_op.windowed_coder_impl.decode_from_stream( + input_stream, True) + cur_size = input_stream.size() + raw_bytes = raw_input_stream.read(starting_size - cur_size) + # make sure these two stream stays in sync in size + assert raw_input_stream.size() == input_stream.size() + if decoded_value.value and get_taskable_value(decoded_value.value): + # save this to process later in ``process_bundle_with_task_workers`` + wrapped_values[data.transform_id].append( + (get_taskable_value(decoded_value.value), raw_bytes)) + else: + # fallback to process it as normal, trigger receivers to process + with input_op.splitting_lock: + if input_op.index == input_op.stop - 1: + return None, None + input_op.index += 1 + input_op.output(decoded_value) + + if wrapped_values: + task_helper = BundleProcessorTaskHelper(instruction_id, wrapped_values) + return task_helper.process_bundle_with_task_workers( + self.state_handler, + self.data_channel_factory, + self.process_bundle_descriptor + ) + else: + return None, None + def finalize_bundle(self): # type: () -> beam_fn_api_pb2.FinalizeBundleResponse for op in self.ops.values(): diff --git a/sdks/python/apache_beam/runners/worker/task_worker/__init__.py b/sdks/python/apache_beam/runners/worker/task_worker/__init__.py new file mode 100644 index 000000000000..6569e3fe5de4 --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/task_worker/__init__.py @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import absolute_import diff --git a/sdks/python/apache_beam/runners/worker/task_worker/core.py b/sdks/python/apache_beam/runners/worker/task_worker/core.py new file mode 100644 index 000000000000..f5346e23ff8c --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/task_worker/core.py @@ -0,0 +1,170 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +TaskWorker core user facing PTransforms that allows packaging elements as +TaskableValue for beam task workers. +""" + +import copy +from typing import TYPE_CHECKING + +import apache_beam as beam +from apache_beam.runners.worker.task_worker.handlers import TaskableValue + +if TYPE_CHECKING: + from typing import Any, Optional, Callable, Iterator + from apache_beam.options.pipeline_options import PipelineOptions + from apache_beam.transforms.environments import Environment + + +TASK_WORKER_ENV_TYPES = {'Docker', 'Process'} +TASK_WORKER_SDK_ENTRYPOINT = 'apache_beam.runners.worker.task_worker.task_worker_main' + + +class WrapFn(beam.DoFn): + """ + Wraps the given element into a TaskableValue if there's non-empty task + payload. User can pass in wrapper callable to modify payload per element. + """ + + def process( + self, + element, # type: Any + urn='local', # type: str + wrapper=None, # type: Optional[Callable[[Any, Any], Any]] + env=None, # type: Optional[Environment] + payload=None # type: Optional[Any] + ): + # type: (...) -> Iterator[Any] + """ + Args: + element: any ptransform element + urn: id of the task worker handler + wrapper: optional callable which can be used to modify a payload + per-element. + env : Environment for the task to be run in + payload : Payload containing settings for the task worker handler + """ + # override payload if given a wrapper function, which will vary per + # element + if wrapper: + payload = wrapper(element, copy.deepcopy(payload)) + + if payload: + result = TaskableValue(element, urn, env=env, payload=payload) + else: + result = element + yield result + + +class UnWrapFn(beam.DoFn): + """ + Unwraps the TaskableValue into its original value, so that when constructing + transforms user doesn't need to worry about the element type if it is + taskable or not. + """ + + def process(self, element): + + if isinstance(element, TaskableValue): + yield element.value + else: + yield element + + +class BeamTask(beam.PTransform): + """ + Utility transform that wraps a group of transforms, and makes it a Beam + "Task" that can be delegated to a task worker to run remotely. + + The main structure is like this: + + ( pipe + | Wrap + | Reshuffle + | UnWrap + | User Transform1 + | ... + | User TransformN + | Reshuffle + ) + + The use of reshuffle is to make sure stage fusing doesn't try to fuse the + section we want to run with the inputs of this xform; reason being we need + the start of a stage to get data inputs that are *TaskableValue*, so that + the bundle processor will recognize that and will engage Task Workers. + + We end with a Reshuffle for similar reason, so that the next section of the + pipeline doesn't gets fused with the transforms provided, which would end up + being executed remotely in a remote task worker. + + By default, we use the local task worker, but subclass could specify the + type of task worker to use by specifying the ``urn``, and override the + ``get_payload`` method to return meaningful payloads to that type of task + worker. + """ + + # the urn for the registered task worker handler, default to use local task + # worker + urn = 'local' # type: str + + # the sdk harness entry point + SDK_HARNESS_ENTRY_POINT = TASK_WORKER_SDK_ENTRYPOINT + + def __init__(self, transform, wrapper=None, env=None): + # type: (beam.PTransform, Optional[Callable[[Any, Any], Any]], Optional[beam.transforms.environments.Environment]) -> None + self._wrapper = wrapper + self._env = env + self._transform = transform + + def get_payload(self, options): + # type: (PipelineOptions) -> Optional[Any] + """ + Subclass should implement this to generate payload for TaskableValue. + Default to None. + """ + return None + + @staticmethod + def _has_tagged_outputs(xform): + # type: (beam.PTransform) -> bool + """Checks to see if we have tagged output for the given PTransform.""" + if isinstance(xform, beam.core._MultiParDo): + return True + elif isinstance(xform, beam.ptransform._ChainedPTransform) \ + and isinstance(xform._parts[-1], beam.core._MultiParDo): + return True + return False + + def expand(self, pcoll): + # type: (beam.pvalue.PCollection) -> beam.pvalue.PCollection + payload = self.get_payload(pcoll.pipeline.options) + result = ( + pcoll + | 'Wrap' >> beam.ParDo(WrapFn(), urn=self.urn, wrapper=self._wrapper, + env=self._env, payload=payload) + | 'StartStage' >> beam.Reshuffle() + | 'UnWrap' >> beam.ParDo(UnWrapFn()) + | self._transform + ) + if self._has_tagged_outputs(self._transform): + # for xforms that ended up with tagged outputs, we don't want to + # add reshuffle, because it will be a stage split point already, + # also adding reshuffle would error since we now have a tuple of + # pcollections. + return result + return result | 'EndStage' >> beam.Reshuffle() diff --git a/sdks/python/apache_beam/runners/worker/task_worker/handlers.py b/sdks/python/apache_beam/runners/worker/task_worker/handlers.py new file mode 100644 index 000000000000..57a62ea542a6 --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/task_worker/handlers.py @@ -0,0 +1,1255 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import absolute_import +from __future__ import division + +import collections +import copy +import logging +import queue +import os +import sys +import threading +from builtins import object +from concurrent import futures +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import DefaultDict +from typing import Dict +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import Union + +import grpc +from future.utils import raise_ + +from apache_beam.portability.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_fn_api_pb2_grpc +from apache_beam.portability.api import beam_provision_api_pb2 +from apache_beam.portability.api import beam_provision_api_pb2_grpc +from apache_beam.portability.api import beam_task_worker_pb2 +from apache_beam.portability.api import beam_task_worker_pb2_grpc +from apache_beam.portability.api import endpoints_pb2 +from apache_beam.runners.portability.fn_api_runner.worker_handlers import BasicProvisionService +from apache_beam.runners.portability.fn_api_runner.worker_handlers import ControlConnection +from apache_beam.runners.portability.fn_api_runner.worker_handlers import ControlFuture +from apache_beam.runners.portability.fn_api_runner.worker_handlers import GrpcWorkerHandler +from apache_beam.runners.portability.fn_api_runner.worker_handlers import GrpcStateServicer +from apache_beam.runners.portability.fn_api_runner.worker_handlers import WorkerHandler +from apache_beam.runners.worker.bundle_processor import BundleProcessor +from apache_beam.runners.worker.channel_factory import GRPCChannelFactory +from apache_beam.runners.worker.data_plane import _GrpcDataChannel +from apache_beam.runners.worker.data_plane import SizeBasedBufferingClosableOutputStream +from apache_beam.runners.worker.data_plane import DataChannelFactory +from apache_beam.runners.worker.statecache import StateCache +from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor + +if TYPE_CHECKING: + from apache_beam.runners.portability.fn_api_runner.fn_runner import ExtendedProvisionInfo + from apache_beam.runners.worker.data_plane import DataChannelFactory + from apache_beam.runners.worker.sdk_worker import CachingStateHandler + from apache_beam.transforms.environments import Environment + +ENTRY_POINT_NAME = 'apache_beam_task_workers_plugins' +MAX_TASK_WORKERS = 300 +MAX_TASK_WORKER_RETRY = 10 + +Taskable = Union[ + 'TaskableValue', + List['TaskableValue'], + Tuple['TaskableValue', ...], + Set['TaskableValue']] + + +class TaskableValue(object): + """ + Value that can be distributed to TaskWorkers as tasks. + + Has the original value, urn, env and payload that specifies how the task will + be generated. urn is used for getting the correct type of TaskWorkerHandler to + handle processing the current value. + """ + + def __init__( + self, + value, # type: Any + urn, # type: str + env=None, # type: Optional[Environment] + payload=None # type: Optional[Any] + ): + # type: (...) -> None + """ + Args: + value: The wrapped element + urn: id of the task worker handler + env : Environment for the task to be run in + payload : Payload containing settings for the task worker handler + """ + self.value = value + self.urn = urn + self.env = env + self.payload = payload + + +class TaskWorkerProcessBundleError(Exception): + """ + Error thrown when TaskWorker fails to process_bundle. + + Errors encountered when task worker is processing bundle can be retried, up + till max retries defined by ``MAX_TASK_WORKER_RETRY``. + """ + + +class TaskWorkerTerminatedError(Exception): + """ + Error thrown when TaskWorker terminated before it finished working. + + Custom TaskWorkerHandlers can choose to terminate a task but not + affect the whole bundle by setting ``TaskWorkerHandler.alive`` to False, + which will cause this error to be thrown. + """ + + +class TaskWorkerHandler(GrpcWorkerHandler): + """ + Abstract base class for TaskWorkerHandler for a task worker, + + A TaskWorkerHandler is created for each TaskableValue. + + Subclasses must override ``start_remote`` to modify how remote task worker is + started, and register task properties type when defining a subclass. + """ + + _known_urns = {} # type: Dict[str, Type[TaskWorkerHandler]] + + def __init__( + self, + state, # type: CachingStateHandler + provision_info, # type: Union[beam_provision_api_pb2.ProvisionInfo, ExtendedProvisionInfo] + grpc_server, # type: TaskGrpcServer + environment, # type: Environment + task_payload, # type: Any + credentials=None, # type: Optional[str] + worker_id=None # type: Optional[str] + ): + # type: (...) -> None + self._grpc_server = grpc_server + + # we are manually doing init instead of calling GrpcWorkerHandler's init + # because we want to override worker_id and we don't want extra + # ControlConnection to be established + WorkerHandler.__init__(self, grpc_server.control_handler, + grpc_server.data_plane_handler, state, + provision_info) + # override worker_id if provided + if worker_id: + self.worker_id = worker_id + + self.control_address = self._grpc_server.control_address + self.provision_address = self.control_address + self.logging_address = self.port_from_worker( + self.get_port_from_env_var('LOGGING_API_SERVICE_DESCRIPTOR')) + # if we are running from subprocess environment, the aritfact staging + # endpoint is the same as control end point, and the env var won't be + # recorded so use control end point and record that to artifact + try: + self.artifact_address = self.port_from_worker( + self.get_port_from_env_var('ARTIFACT_API_SERVICE_DESCRIPTOR')) + except KeyError: + self.artifact_address = self.port_from_worker( + self.get_port_from_env_var('CONTROL_API_SERVICE_DESCRIPTOR')) + os.environ['ARTIFACT_API_SERVICE_DESCRIPTOR'] = os.environ[ + 'CONTROL_API_SERVICE_DESCRIPTOR'] + + # modify provision info + modified_provision = copy.copy(provision_info) + modified_provision.control_endpoint.url = self.control_address + modified_provision.logging_endpoint.url = self.logging_address + modified_provision.artifact_endpoint.url = self.artifact_address + with TaskProvisionServicer._lock: + self._grpc_server.provision_handler.provision_by_worker_id[ + self.worker_id] = modified_provision + + self.control_conn = self._grpc_server.control_handler.get_conn_by_worker_id( + self.worker_id) + + self.environment = environment + self.task_payload = task_payload + self.credentials = credentials + self.alive = True + + @staticmethod + def host_from_worker(): + # type: () -> str + import socket + return socket.getfqdn() + + @staticmethod + def get_port_from_env_var(env_var): + # type: (str) -> str + """Extract the service port for a given environment variable.""" + from google.protobuf import text_format + endpoint = endpoints_pb2.ApiServiceDescriptor() + text_format.Merge(os.environ[env_var], endpoint) + return endpoint.url.split(':')[-1] + + @staticmethod + def load_plugins(): + # type: () -> None + import entrypoints + for name, entry_point in entrypoints.get_group_named( + ENTRY_POINT_NAME).items(): + logging.info('Loading entry point: {}'.format(name)) + entry_point.load() + + @classmethod + def register_urn(cls, urn, constructor=None): + def register(constructor): + cls._known_urns[urn] = constructor + return constructor + if constructor: + return register(constructor) + else: + return register + + @classmethod + def create(cls, + state, # type: CachingStateHandler + provision_info, # type: Union[beam_provision_api_pb2.ProvisionInfo, ExtendedProvisionInfo] + grpc_server, # type: TaskGrpcServer + taskable_value, # type: TaskableValue + credentials=None, # type: Optional[str] + worker_id=None # type: Optional[str] + ): + # type: (...) -> TaskWorkerHandler + constructor = cls._known_urns[taskable_value.urn] + return constructor(state, provision_info, grpc_server, + taskable_value.env, taskable_value.payload, + credentials=credentials, worker_id=worker_id) + + def start_worker(self): + # type: () -> None + self.start_remote() + + def start_remote(self): + # type: () -> None + """Start up a remote TaskWorker to process the current element. + + Subclass should implement this.""" + raise NotImplementedError + + def stop_worker(self): + # type: () -> None + # send shutdown request + future = self.control_conn.push( + beam_task_worker_pb2.TaskInstructionRequest( + instruction_id=self.worker_id, + shutdown=beam_task_worker_pb2.ShutdownRequest())) + response = future.get() + if response.error: + logging.warning('Error stopping worker: {}'.format(self.worker_id)) + + # close control conn after stop worker + self.control_conn.close() + + def _get_future(self, future, interval=0.5): + # type: (ControlFuture, float) -> beam_task_worker_pb2.TaskInstructionResponse + result = None + while self.alive: + result = future.get(timeout=interval) + if result: + break + + # if the handler is not alive, meaning task worker is stopped before + # finishing processing, raise ``TaskWorkerTerminatedError`` + if result is None: + raise TaskWorkerTerminatedError() + + return result + + def execute(self, + data_channel_factory, # type: ProxyGrpcClientDataChannelFactory + process_bundle_descriptor # type: beam_fn_api_pb2.ProcessBundleDescriptor + ): + # type: (...) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] + """Main entry point of the task execution cycle of a ``TaskWorkerHandler``. + + It will first issue a create request, and wait for the remote bundle + processor to be created; Then it will issue the process bundle request, and + wait for the result. If there's error occurred when processing bundle, + ``TaskWorkerProcessBundleError`` will be raised. + """ + # wait for remote bundle processor to be created + create_future = self.control_conn.push( + beam_task_worker_pb2.TaskInstructionRequest( + instruction_id=self.worker_id, + create=beam_task_worker_pb2.CreateRequest( + process_bundle_descriptor=process_bundle_descriptor, + state_handler_endpoint=endpoints_pb2.ApiServiceDescriptor( + url=self._grpc_server.state_address), + data_factory=beam_task_worker_pb2.GrpcClientDataChannelFactory( + credentials=self.credentials, + worker_id=data_channel_factory.worker_id, + transmitter_url=data_channel_factory.transmitter_url)))) + self._get_future(create_future) + + # process bundle + process_future = self.control_conn.push( + beam_task_worker_pb2.TaskInstructionRequest( + instruction_id=self.worker_id, + process_bundle=beam_task_worker_pb2.ProcessorProcessBundleRequest()) + ) + response = self._get_future(process_future) + if response.error: + # raise here so this task can be retried + raise TaskWorkerProcessBundleError() + else: + delayed_applications = response.process_bundle.delayed_applications + require_finalization = response.process_bundle.require_finalization + + self.stop_worker() + return delayed_applications, require_finalization + + def reset(self): + # type: () -> None + """This is used to retry a failed task.""" + self.control_conn.reset() + + +class TaskGrpcServer(object): + """A collection of grpc servicers that handle communication between a + ``TaskWorker`` and ``TaskWorkerHandler``. + + Contains three servers: + - a control server hosting ``TaskControlService`` amd `TaskProvisionService` + - a data server hosting ``TaskFnDataService`` + - a state server hosting ``TaskStateService`` + + This is shared by all TaskWorkerHandlers generated by one bundle. + """ + + _DEFAULT_SHUTDOWN_TIMEOUT_SECS = 5 + + def __init__( + self, + state_handler, # type: CachingStateHandler + max_workers, # type: int + data_store, # type: Mapping[str, List[beam_fn_api_pb2.Elements.Data]] + data_channel_factory, # type: DataChannelFactory + instruction_id, # type: str + provision_info, # type: Union[beam_provision_api_pb2.ProvisionInfo, ExtendedProvisionInfo] + ): + # type: (...) -> None + self.state_handler = state_handler + self.max_workers = max_workers + self.control_server = grpc.server( + futures.ThreadPoolExecutor(max_workers=self.max_workers)) + self.control_port = self.control_server.add_insecure_port('[::]:0') + self.control_address = '%s:%s' % (self.get_host_name(), self.control_port) + + # Options to have no limits (-1) on the size of the messages + # received or sent over the data plane. The actual buffer size + # is controlled in a layer above. + no_max_message_sizes = [("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1)] + self.data_server = grpc.server( + futures.ThreadPoolExecutor(max_workers=self.max_workers), + options=no_max_message_sizes) + self.data_port = self.data_server.add_insecure_port('[::]:0') + self.data_address = '%s:%s' % (self.get_host_name(), self.data_port) + + self.state_server = grpc.server( + futures.ThreadPoolExecutor(max_workers=self.max_workers), + options=no_max_message_sizes) + self.state_port = self.state_server.add_insecure_port('[::]:0') + self.state_address = '%s:%s' % (self.get_host_name(), self.state_port) + + self.control_handler = TaskControlServicer() + beam_task_worker_pb2_grpc.add_TaskControlServicer_to_server( + self.control_handler, self.control_server) + self.provision_handler = TaskProvisionServicer(provision_info=provision_info) + beam_provision_api_pb2_grpc.add_ProvisionServiceServicer_to_server( + self.provision_handler, self.control_server) + + self.data_plane_handler = TaskFnDataServicer(data_store, + data_channel_factory, + instruction_id) + beam_task_worker_pb2_grpc.add_TaskFnDataServicer_to_server( + self.data_plane_handler, self.data_server) + + beam_fn_api_pb2_grpc.add_BeamFnStateServicer_to_server( + TaskStateServicer(self.state_handler, instruction_id, + state_handler._context.bundle_cache_token), + self.state_server) + + logging.info('starting control server on port %s', self.control_port) + logging.info('starting data server on port %s', self.data_port) + logging.info('starting state server on port %s', self.state_port) + self.state_server.start() + self.data_server.start() + self.control_server.start() + + @staticmethod + def get_host_name(): + # type: () -> str + import socket + return socket.getfqdn() + + def close(self): + # type: () -> None + self.control_handler.done() + to_wait = [ + self.control_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS), + self.data_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS), + self.state_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS), + ] + for w in to_wait: + w.wait() + + +# ============= +# Control Plane +# ============= +class TaskWorkerConnection(ControlConnection): + """The control connection between a TaskWorker and a TaskWorkerHandler. + + TaskWorkerHandler push InstructionRequests to _push_queue, and receives + InstructionResponses from TaskControlServicer. + """ + + _lock = threading.Lock() + + def __init__(self): + self._push_queue = queue.Queue() + self._input = None + self._futures_by_id = {} # type: Dict[Any, ControlFuture] + self._read_thread = threading.Thread( + name='bundle_processor_control_read', target=self._read) + self._state = TaskControlServicer.UNSTARTED_STATE + # marks current TaskConnection as in a state of retrying after failure + self._retrying = False + + def _read(self): + # type: () -> None + for data in self._input: + self._futures_by_id.pop(data.WhichOneof('response')).set(data) + + def push(self, + req # type: Union[TaskControlServicer._DONE_MARKER, beam_task_worker_pb2.TaskInstructionRequest] + ): + # type: (...) -> Optional[ControlFuture] + if req == TaskControlServicer._DONE_MARKER: + self._push_queue.put(req) + return None + if not req.instruction_id: + raise RuntimeError( + 'TaskInstructionRequest has to have instruction id!') + future = ControlFuture(req.instruction_id) + self._futures_by_id[req.WhichOneof('request')] = future + self._push_queue.put(req) + return future + + def set_inputs(self, input): + with TaskWorkerConnection._lock: + if self._input and not self._retrying: + raise RuntimeError('input is already set.') + self._input = input + self._read_thread.start() + self._state = TaskControlServicer.STARTED_STATE + self._retrying = False + + def close(self): + # type: () -> None + with TaskWorkerConnection._lock: + if self._state == TaskControlServicer.STARTED_STATE: + self.push(TaskControlServicer._DONE_MARKER) + self._read_thread.join() + self._state = TaskControlServicer.DONE_STATE + + def reset(self): + # type: () -> None + self.close() + self.__init__() + self._retrying = True + + +class TaskControlServicer(beam_task_worker_pb2_grpc.TaskControlServicer): + + _lock = threading.Lock() + + UNSTARTED_STATE = 'unstarted' + STARTED_STATE = 'started' + DONE_STATE = 'done' + + _DONE_MARKER = object() + + def __init__(self): + # type: () -> None + self._state = self.UNSTARTED_STATE + self._connections_by_worker_id = collections.defaultdict( + TaskWorkerConnection) + + def get_conn_by_worker_id(self, worker_id): + # type: (str) -> TaskWorkerConnection + with self._lock: + result = self._connections_by_worker_id[worker_id] + return result + + def Control(self, request_iterator, context): + with self._lock: + if self._state == self.DONE_STATE: + return + else: + self._state = self.STARTED_STATE + worker_id = dict(context.invocation_metadata()).get('worker_id') + if not worker_id: + raise RuntimeError('Connection does not have worker id.') + conn = self.get_conn_by_worker_id(worker_id) + conn.set_inputs(request_iterator) + + while True: + to_push = conn.get_req() + if to_push is self._DONE_MARKER: + return + yield to_push + + def done(self): + # type: () -> None + self._state = self.DONE_STATE + + +# ========== +# Data Plane +# ========== +class ProxyGrpcClientDataChannelFactory(DataChannelFactory): + """A factory for ``ProxyGrpcClientDataChannel``. + + No caching behavior here because we are starting each data channel on + different location. + """ + + def __init__(self, transmitter_url, credentials=None, worker_id=None): + # type: (str, Optional[str], Optional[str]) -> None + # These two are not private attributes because it was used in + # ``TaskWorkerHandler.execute`` when issuing TaskInstructionRequest + self.transmitter_url = transmitter_url + self.worker_id = worker_id + + self._credentials = credentials + + def create_data_channel(self, remote_grpc_port): + # type: (beam_fn_api_pb2.RemoteGrpcPort) -> ProxyGrpcClientDataChannel + url = remote_grpc_port.api_service_descriptor.url + return self.create_data_channel_from_url(url) + + def create_data_channel_from_url(self, url): + # type: (str) -> ProxyGrpcClientDataChannel + channel_options = [("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1)] + if self._credentials is None: + grpc_channel = GRPCChannelFactory.insecure_channel( + self.transmitter_url, options=channel_options) + else: + grpc_channel = GRPCChannelFactory.secure_channel( + self.transmitter_url, self._credentials, options=channel_options) + return ProxyGrpcClientDataChannel( + url, + beam_task_worker_pb2_grpc.TaskFnDataStub(grpc_channel)) + + def close(self): + # type: () -> None + pass + + +class ProxyGrpcClientDataChannel(_GrpcDataChannel): + """DataChannel wrapping the client side of a TaskFnDataService connection.""" + + def __init__(self, client_url, proxy_stub): + # type: (str, beam_task_worker_pb2_grpc.TaskFnDataStub) -> None + super(ProxyGrpcClientDataChannel, self).__init__() + self.client_url = client_url + self.proxy_stub = proxy_stub + + def input_elements(self, + instruction_id, # type: str + expected_transforms, # type: List[str] + abort_callback=None # type: Optional[Callable[[], bool]] + ): + # type: (...) -> Iterator[beam_fn_api_pb2.Elements.Data] + if not expected_transforms: + return + req = beam_task_worker_pb2.ReceiveRequest( + instruction_id=instruction_id, + client_data_endpoint=self.client_url) + done_transforms = [] + abort_callback = abort_callback or (lambda: False) + + for data in self.proxy_stub.Receive(req): + if self._closed: + raise RuntimeError('Channel closed prematurely.') + if abort_callback(): + return + if self._exc_info: + t, v, tb = self._exc_info + raise_(t, v, tb) + if not data.data and data.transform_id in expected_transforms: + done_transforms.append(data.transform_id) + else: + assert data.transform_id not in done_transforms + yield data + if len(done_transforms) >= len(expected_transforms): + return + + def output_stream(self, instruction_id, transform_id): + # type: (str, str) -> SizeBasedBufferingClosableOutputStream + + def _add_to_send_queue(data): + if data: + self.proxy_stub.Send(beam_task_worker_pb2.SendRequest( + instruction_id=instruction_id, + data=beam_fn_api_pb2.Elements.Data( + instruction_id=instruction_id, + transform_id=transform_id, + data=data), + client_data_endpoint=self.client_url + )) + + def close_callback(data): + _add_to_send_queue(data) + # no need to send empty bytes to signal end of processing here, because + # when the whole bundle finishes, the bundle processor original output + # stream will send that to runner + + return SizeBasedBufferingClosableOutputStream( + close_callback, flush_callback=_add_to_send_queue) + + +class TaskFnDataServicer(beam_task_worker_pb2_grpc.TaskFnDataServicer): + """Implementation of BeamFnDataTransmitServicer for any number of clients.""" + + def __init__(self, + data_store, # type: Mapping[str, List[beam_fn_api_pb2.Elements.Data]] + orig_data_channel_factory, # type: DataChannelFactory + instruction_id # type: str + ): + # type: (...) -> None + self.data_store = data_store + self.orig_data_channel_factory = orig_data_channel_factory + self.orig_instruction_id = instruction_id + self._orig_data_channel = None # type: Optional[ProxyGrpcClientDataChannel] + + def _get_orig_data_channel(self, url): + # type: (str) -> ProxyGrpcClientDataChannel + remote_grpc_port = beam_fn_api_pb2.RemoteGrpcPort( + api_service_descriptor=endpoints_pb2.ApiServiceDescriptor(url=url)) + # the data channel is cached by url + return self.orig_data_channel_factory.create_data_channel(remote_grpc_port) + + def Receive(self, request, context=None): + # type: (...) -> Iterator[beam_fn_api_pb2.Elements.Data]] + data = self.data_store[request.instruction_id] + for elem in data: + yield elem + + def Send(self, request, context=None): + # type: (...) -> beam_task_worker_pb2.SendResponse + if self._orig_data_channel is None: + self._orig_data_channel = self._get_orig_data_channel( + request.client_data_endpoint) + # We need to replace the instruction_id here with the original instruction + # id, not the current one (which is the task worker id) + request.data.instruction_id = self.orig_instruction_id + if request.data.data: + # only send when there's data, because it is signaling the runner side + # worker handler that element of this has ended if it is empty, and we + # want to send that when every task worker handler is finished + self._orig_data_channel._to_send.put(request.data) + return beam_task_worker_pb2.SendResponse() + + +# ===== +# State +# ===== +class TaskStateServicer(GrpcStateServicer): + + def __init__(self, state, instruction_id, cache_token): + # type: (CachingStateHandler, str, Optional[str]) -> None + self.instruction_id = instruction_id + self.cache_token = cache_token + super(TaskStateServicer, self).__init__(state) + + def State(self, request_stream, context=None): + # type: (...) -> Iterator[beam_fn_api_pb2.StateResponse] + # CachingStateHandler and GrpcStateHandler context is thread local, so we + # need to set it here for each TaskWorker + self._state._context.process_instruction_id = self.instruction_id + self._state._context.cache_token = self.cache_token + self._state._underlying._context.process_instruction_id = self.instruction_id + + # FIXME: This is not currently properly supporting state caching (currently + # state caching behavior only happens within python SDK, so runners like + # the FlinkRunner won't create the state cache anyways for now) + for request in request_stream: + request_type = request.WhichOneof('request') + + if request_type == 'get': + data, continuation_token = self._state._underlying.get_raw( + request.state_key, request.get.continuation_token) + yield beam_fn_api_pb2.StateResponse( + id=request.id, + get=beam_fn_api_pb2.StateGetResponse( + data=data, continuation_token=continuation_token)) + elif request_type == 'append': + self._state._underlying.append_raw(request.state_key, + request.append.data) + yield beam_fn_api_pb2.StateResponse( + id=request.id, + append=beam_fn_api_pb2.StateAppendResponse()) + elif request_type == 'clear': + self._state._underlying.clear(request.state_key) + yield beam_fn_api_pb2.StateResponse( + id=request.id, + clear=beam_fn_api_pb2.StateClearResponse()) + else: + raise NotImplementedError('Unknown state request: %s' % request_type) + + +# ===== +# Provision +# ===== +class TaskProvisionServicer(BasicProvisionService): + """ + Provide provision info for remote task workers, provision is static because + for each bundle the provision info is static. + """ + + _lock = threading.Lock() + + def __init__(self, provision_info=None): + # type: (Optional[beam_provision_api_pb2.ProvisionInfo]) -> None + self._provision_info = provision_info + self.provision_by_worker_id = dict() + + def GetProvisionInfo(self, request, context=None): + # type: (...) -> beam_provision_api_pb2.GetProvisionInfoResponse + # if request comes from task worker that can be found, return the modified + # provision info + if context: + worker_id = dict(context.invocation_metadata())['worker_id'] + provision_info = self.provision_by_worker_id.get(worker_id, + self._provision_info) + else: + # fallback to the generic sdk worker version of provision info if not + # found from a cached task worker + provision_info = self._provision_info + + return beam_provision_api_pb2.GetProvisionInfoResponse(info=provision_info) + + +class BundleProcessorTaskWorker(object): + """ + The remote task worker that communicates with the SDK worker to do the + actual work of processing bundles. + + The BundleProcessor will detect inputs and see if there is TaskableValue, and + if there is and the BundleProcessor is set to "use task worker", then + BundleProcessorTaskHelper will create a TaskWorkerHandler that this class + communicates with. + + This class creates a BundleProcessor and receives TaskInstructionRequests and + sends back respective responses via the grpc channels connected to the control + endpoint; + """ + + REQUEST_PREFIX = '_request_' + _lock = threading.Lock() + + def __init__(self, worker_id, server_url, credentials=None): + # type: (str, str, Optional[str]) -> None + """Initialize a BundleProcessorTaskWorker. Lives remotely. + + It will create a BundleProcessor with the provide information and process + the requests using the BundleProcessor created. + + Args: + worker_id: the worker id of current task worker + server_url: control service url for the TaskGrpcServer + credentials: credentials to use when creating client + """ + self.worker_id = worker_id + self._credentials = credentials + self._responses = queue.Queue() + self._alive = None # type: Optional[bool] + self._bundle_processor = None # type: Optional[BundleProcessor] + self._exc_info = None + self.stub = self._create_stub(server_url) + + @classmethod + def execute(cls, worker_id, server_url, credentials=None): + # type: (str, str, Optional[str]) -> None + """Instantiate a BundleProcessorTaskWorker and start running. + + If there's error, it will be raised here so it can be reflected to user. + + Args: + worker_id: worker id for the BundleProcessorTaskWorker + server_url: control service url for the TaskGrpcServer + credentials: credentials to use when creating client + """ + self = cls(worker_id, server_url, credentials=credentials) + self.run() + + # raise the error here, so user knows there's a failure and could retry + if self._exc_info: + t, v, tb = self._exc_info + raise_(t, v, tb) + + def _create_stub(self, server_url): + # type: (str) -> beam_task_worker_pb2_grpc.TaskControlStub + """Create the TaskControl client.""" + channel_options = [("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1)] + if self._credentials is None: + channel = GRPCChannelFactory.insecure_channel( + server_url, + options=channel_options) + else: + channel = GRPCChannelFactory.secure_channel(server_url, + self._credentials, + options=channel_options) + + # add instruction_id to grpc channel + channel = grpc.intercept_channel( + channel, + WorkerIdInterceptor(self.worker_id)) + + return beam_task_worker_pb2_grpc.TaskControlStub(channel) + + def do_instruction(self, request): + # type: (beam_task_worker_pb2.TaskInstructionRequest) -> beam_task_worker_pb2.TaskInstructionResponse + """Process the requests with the corresponding method.""" + request_type = request.WhichOneof('request') + if request_type: + return getattr(self, self.REQUEST_PREFIX + request_type)( + getattr(request, request_type)) + else: + raise NotImplementedError + + def run(self): + # type: () -> None + """Start the full running life cycle for a task worker. + + It send TaskWorkerInstructionResponse to TaskWorkerHandler, and wait for + TaskWorkerInstructionRequest. This service is bidirectional. + """ + no_more_work = object() + self._alive = True + + def get_responses(): + while True: + response = self._responses.get() + if response is no_more_work: + return + if response: + yield response + + try: + for request in self.stub.Control(get_responses()): + self._responses.put(self.do_instruction(request)) + finally: + self._alive = False + + self._responses.put(no_more_work) + logging.info('Done consuming work.') + + def _request_create(self, request): + # type: (beam_task_worker_pb2.TaskInstructionRequest) -> beam_task_worker_pb2.TaskInstructionResponse + """Create a BundleProcessor based on the request. + + Should be the first request received from the handler. + """ + from apache_beam.runners.worker.sdk_worker import \ + GrpcStateHandlerFactory + + credentials = None + if request.data_factory.credentials._credentials: + credentials = grpc.ChannelCredentials( + request.data_factory.credentials._credentials) + logging.debug('Credentials: {!r}'.format(credentials)) + + worker_id = request.data_factory.worker_id + transmitter_url = request.data_factory.transmitter_url + state_handler_endpoint = request.state_handler_endpoint + # FIXME: Add support for Caching later + state_factory = GrpcStateHandlerFactory(StateCache(0), credentials) + state_handler = state_factory.create_state_handler( + state_handler_endpoint) + data_channel_factory = ProxyGrpcClientDataChannelFactory( + transmitter_url, credentials, worker_id + ) + + self._bundle_processor = BundleProcessor( + request.process_bundle_descriptor, + state_handler, + data_channel_factory, + use_task_worker=False + ) + return beam_task_worker_pb2.TaskInstructionResponse( + create=beam_task_worker_pb2.CreateResponse(), + instruction_id=self.worker_id + ) + + def _request_process_bundle(self, + request # type: beam_task_worker_pb2.TaskInstructionRequest + ): + # type: (...) -> beam_task_worker_pb2.TaskInstructionResponse + """Process bundle using the bundle processor based on the request.""" + error = None + + try: + # FIXME: Update this to use the cache_tokens properly + with self._bundle_processor.state_handler._underlying.process_instruction_id( + self.worker_id): + delayed_applications, require_finalization = \ + self._bundle_processor.process_bundle(self.worker_id) + except Exception as e: + # we want to propagate the error back to the TaskWorkerHandler, so that + # it will raise `TaskWorkerProcessBundleError` which allows for requeue + # behavior (up until MAX_TASK_WORKER_RETRY number of retries) + error = e.message + self._exc_info = sys.exc_info() + delayed_applications = [] + require_finalization = False + + return beam_task_worker_pb2.TaskInstructionResponse( + process_bundle=beam_task_worker_pb2.ProcessorProcessBundleResponse( + delayed_applications=delayed_applications, + require_finalization=require_finalization), + instruction_id=self.worker_id, + error=error + ) + + def _request_shutdown(self, request): + # type: (beam_task_worker_pb2.TaskInstructionRequest) -> beam_task_worker_pb2.TaskInstructionResponse + """Shutdown the bundleprocessor.""" + error = None + try: + # shut down state handler here because it is not created by the state + # handler factory thus won't be closed automatically + self._bundle_processor.state_handler.done() + self._bundle_processor.shutdown() + except Exception as e: + error = e.message + finally: + return beam_task_worker_pb2.TaskInstructionResponse( + shutdown=beam_task_worker_pb2.ShutdownResponse(), + instruction_id=self.worker_id, + error=error + ) + + +class BundleProcessorTaskHelper(object): + """ + A helper object that is used by a BundleProcessor while processing bundle. + + Delegates TaskableValues to TaskWorkers, if enabled. + + It can process TaskableValue using TaskWorkers if inspected, and kept the + default behavior if specified to not use TaskWorker, or there's no + TaskableValue found in this bundle. + + To utilize TaskWorkers, BundleProcessorTaskHelper will split up the input + bundle into tasks based on the wrapped TaskableValue's payload, and create a + TaskWorkerHandler for each task. + """ + + def __init__(self, instruction_id, wrapped_values): + # type: (str, DefaultDict[str, List[Tuple[Any, bytes]]]) -> None + """Initialize a BundleProcessorTaskHelper object. + + Args: + instruction_id: the instruction_id of the bundle that the + BundleProcessor is processing + wrapped_values: The mapping of transform id to raw and encoded data. + """ + self.instruction_id = instruction_id + self.wrapped_values = wrapped_values + + def split_taskable_values(self): + # type: () -> Tuple[DefaultDict[str, List[Any]], DefaultDict[str, List[beam_fn_api_pb2.Elements.Data]]] + """Split TaskableValues into tasks and pair it with worker. + + Also put the raw bytes along with worker id for data dispatching by data + plane handler. + """ + # TODO: Come up with solution on how this can be dynamically changed + # could use window + splitted = collections.defaultdict(list) + data_store = collections.defaultdict(list) + worker_count = 0 + for ptransform_id, values in self.wrapped_values.items(): + for decoded, raw in values: + worker_id = 'worker_{}'.format(worker_count) + splitted[worker_id].append(decoded) + data_store[worker_id].append(beam_fn_api_pb2.Elements.Data( + transform_id=ptransform_id, + data=raw, + instruction_id=worker_id + )) + worker_count += 1 + + return splitted, data_store + + def _start_task_grpc_server( + self, + max_workers, # type: int + data_store, # type: Mapping[str, List[beam_fn_api_pb2.Elements.Data]] + state_handler, # type: CachingStateHandler + data_channel_factory, # type: DataChannelFactory + provision_info, # type: beam_provision_api_pb2.ProvisionInfo + ): + # type:(...) -> TaskGrpcServer + """Start up TaskGrpcServer. + + Args: + max_workers: number of max worker + data_store: stored data of worker id and the raw and decoded values for + the worker to process as inputs + state_handler: state handler of current BundleProcessor + data_channel_factory: data channel factory of current BundleProcessor + """ + return TaskGrpcServer(state_handler, max_workers, data_store, + data_channel_factory, self.instruction_id, + provision_info) + + @staticmethod + def get_default_task_env(process_bundle_descriptor): + # type:(beam_fn_api_pb2.ProcessBundleDescriptor) -> Optional[Environment] + """Get the current running beam Environment class. + + Used as the default for the task worker. + + Args: + process_bundle_descriptor: the ProcessBundleDescriptor proto + """ + from apache_beam.runners.portability.fn_api_runner.translations import \ + PAR_DO_URNS + from apache_beam.transforms.environments import Environment + + # find a ParDo xform in this stage + pardo = None + for _, xform in process_bundle_descriptor.transforms.items(): + if xform.spec.urn in PAR_DO_URNS: + pardo = xform + break + + if pardo is None: + # don't set the default task env if no ParDo is found + # FIXME: Use the pipeline default env here? + return None + + env_proto = process_bundle_descriptor.environments.get( + pardo.environment_id) + + return Environment.from_runner_api(env_proto, None) + + def get_sdk_worker_provision_info(self, server_url): + # type:(str) -> beam_provision_api_pb2.ProvisionInfo + channel_options = [("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1)] + channel = GRPCChannelFactory.insecure_channel(server_url, + options=channel_options) + + worker_id = os.environ['WORKER_ID'] + # add sdk worker id to grpc channel + channel = grpc.intercept_channel(channel, WorkerIdInterceptor(worker_id)) + + provision_stub = beam_provision_api_pb2_grpc.ProvisionServiceStub(channel) + response = provision_stub.GetProvisionInfo( + beam_provision_api_pb2.GetProvisionInfoRequest()) + channel.close() + return response.info + + def process_bundle_with_task_workers( + self, + state_handler, # type: CachingStateHandler + data_channel_factory, # type: DataChannelFactory + process_bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor + ): + # type: (...) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] + """Main entry point for task worker system. + + Starts up a group of TaskWorkerHandlers, dispatches tasks and waits for them + to finish. + + Fails if any TaskWorker exceeds maximum retries. + + Args: + state_handler: state handler of current BundleProcessor + data_channel_factory: data channel factory of current BundleProcessor + process_bundle_descriptor: a description of the stage that this + ``BundleProcessor``is to execute. + """ + from google.protobuf import text_format + + default_env = self.get_default_task_env(process_bundle_descriptor) + # start up grpc server + splitted_elements, data_store = self.split_taskable_values() + num_task_workers = len(splitted_elements.items()) + if num_task_workers > MAX_TASK_WORKERS: + logging.warning( + 'Number of element exceeded MAX_TASK_WORKERS ({})'.format( + MAX_TASK_WORKERS)) + num_task_workers = MAX_TASK_WORKERS + + # get sdk worker provision info first + try: + provision_port = TaskWorkerHandler.get_port_from_env_var( + 'PROVISION_API_SERVICE_DESCRIPTOR') + except KeyError: + # if we are in subprocess environment then there won't be any provision + # service so use default provison info + provision_info = beam_provision_api_pb2.ProvisionInfo() + + else: + provision_info = self.get_sdk_worker_provision_info('{}:{}'.format( + TaskWorkerHandler.host_from_worker(), provision_port)) + + server = self._start_task_grpc_server(num_task_workers, data_store, + state_handler, data_channel_factory, + provision_info) + # modify provision api service descriptor to use the new address that we are + # gonna be using (the control address) + os.environ['PROVISION_API_SERVICE_DESCRIPTOR'] = text_format.MessageToString( + endpoints_pb2.ApiServiceDescriptor(url=server.control_address)) + + # create TaskWorkerHandlers + task_worker_handlers = [] + for worker_id, elem in splitted_elements.items(): + taskable_value = get_taskable_value(elem) + # set the env to default env if there is + if taskable_value.env is None and default_env: + taskable_value.env = default_env + + task_worker_handler = TaskWorkerHandler.create( + state_handler, provision_info, server, taskable_value, + credentials=data_channel_factory._credentials, worker_id=worker_id) + task_worker_handlers.append(task_worker_handler) + task_worker_handler.start_worker() + + def _execute(handler): + """ + This is the method that runs in the thread pool representing a working + TaskHandler. + """ + worker_data_channel_factory = ProxyGrpcClientDataChannelFactory( + server.data_address, + credentials=data_channel_factory._credentials, + worker_id=task_worker_handler.worker_id) + counter = 0 + while True: + try: + counter += 1 + return handler.execute(worker_data_channel_factory, + process_bundle_descriptor) + except TaskWorkerProcessBundleError as e: + if counter >= MAX_TASK_WORKER_RETRY: + logging.error('Task Worker has exceeded max retries!') + handler.stop_worker() + raise + # retry if task worker failed to process bundle + handler.reset() + continue + except TaskWorkerTerminatedError as e: + # This error is thrown only when TaskWorkerHandler is terminated + # before it finished processing + logging.warning('TaskWorker terminated prematurely.') + raise + + # start actual processing of splitted bundle + merged_delayed_applications = [] + bundle_require_finalization = False + with futures.ThreadPoolExecutor(max_workers=num_task_workers) as executor: + try: + for delayed_applications, require_finalization in executor.map( + _execute, task_worker_handlers): + + if delayed_applications is None: + raise RuntimeError('Task Worker failed to process task.') + merged_delayed_applications.extend(delayed_applications) + # if any elem requires finalization, set it to True + if not bundle_require_finalization and require_finalization: + bundle_require_finalization = True + except TaskWorkerProcessBundleError: + raise RuntimeError('Task Worker failed to process task.') + except TaskWorkerTerminatedError: + # This error is thrown only when TaskWorkerHandler is terminated before + # it finished processing, this is only possible if + # `TaskWorkerHandler.alive` is manually set to False by custom user + # defined monitoring function, which user would trigger if they want + # to terminate the task; + # In that case, we want to continue on and not hold the whole bundle + # by the tasks that user manually terminated. + pass + + return merged_delayed_applications, bundle_require_finalization + + +@TaskWorkerHandler.register_urn('local') +class LocalTaskWorkerHandler(TaskWorkerHandler): + """TaskWorkerHandler that starts up task worker locally.""" + + # FIXME: create a class-level thread pool to restrict the number of threads + + def start_remote(self): + # type: () -> None + """start a task worker local to the task worker handler.""" + obj = BundleProcessorTaskWorker(self.worker_id, self.control_address, + self.credentials) + run_thread = threading.Thread(target=obj.run) + run_thread.daemon = True + run_thread.start() + + +TaskWorkerHandler.load_plugins() + + +def get_taskable_value(decoded_value): + # type: (Any) -> Optional[TaskableValue] + """Check whether the given value contains taskable value. + + If taskable, return the TaskableValue + + Args: + decoded_value: decoded value from raw input stream + """ + # FIXME: Come up with a solution that's not so specific + if isinstance(decoded_value, (list, tuple, set)): + for val in decoded_value: + result = get_taskable_value(val) + if result: + return result + elif isinstance(decoded_value, TaskableValue): + return decoded_value + return None diff --git a/sdks/python/apache_beam/runners/worker/task_worker/task_worker_main.py b/sdks/python/apache_beam/runners/worker/task_worker/task_worker_main.py new file mode 100644 index 000000000000..5d8604c9a66f --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/task_worker/task_worker_main.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""TaskWorker entry point.""" + +from __future__ import absolute_import + +import os +import logging +import sys +import traceback + +from apache_beam.runners.worker.sdk_worker_main import _load_main_session +from apache_beam.runners.worker.task_worker.handlers import BundleProcessorTaskWorker +from apache_beam.runners.worker.task_worker.handlers import TaskWorkerHandler + +# This module is experimental. No backwards-compatibility guarantees. + + +def main(unused_argv): + """Main entry point for Starting up TaskWorker.""" + + # TODO: Do we want to set up the logging service here? + + # below section is the same as in sdk_worker_main, for loading pickled main + # session if there's any + if 'SEMI_PERSISTENT_DIRECTORY' in os.environ: + semi_persistent_directory = os.environ['SEMI_PERSISTENT_DIRECTORY'] + else: + semi_persistent_directory = None + + logging.info('semi_persistent_directory: %s', semi_persistent_directory) + + try: + _load_main_session(semi_persistent_directory) + except Exception: # pylint: disable=broad-except + exception_details = traceback.format_exc() + logging.error( + 'Could not load main session: %s', exception_details, exc_info=True) + + worker_id = os.environ['TASK_WORKER_ID'] + control_address = os.environ['TASK_WORKER_CONTROL_ADDRESS'] + if 'TASK_WORKER_CREDENTIALS' in os.environ: + credentials = os.environ['TASK_WORKER_CREDENTIALS'] + else: + credentials = None + + TaskWorkerHandler.load_plugins() + + # exception should be handled already by task workers + BundleProcessorTaskWorker.execute(worker_id, control_address, + credentials=credentials) + + +if __name__ == '__main__': + main(sys.argv) diff --git a/sdks/python/apache_beam/runners/worker/task_worker/task_worker_test.py b/sdks/python/apache_beam/runners/worker/task_worker/task_worker_test.py new file mode 100644 index 000000000000..8df68a6171e8 --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/task_worker/task_worker_test.py @@ -0,0 +1,297 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for apache_beam.runners.worker.task_worker.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import logging +import threading +import unittest + +from google.protobuf import text_format + +from apache_beam.portability.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_provision_api_pb2 +from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.portability.api import beam_task_worker_pb2 +from apache_beam.portability.api import endpoints_pb2 +from apache_beam.runners.worker.task_worker import handlers +from apache_beam.runners.worker.data_plane import GrpcClientDataChannelFactory +from apache_beam.runners.worker.sdk_worker import CachingStateHandler + + +# -- utilities for testing, mocking up test objects +class _MockBundleProcessorTaskWorker(handlers.BundleProcessorTaskWorker): + """ + A mocked version of BundleProcessorTaskWorker, responsible for recording the + requests it received, and provide response to each request type by a passed + in dictionary as user desired. + """ + def __init__(self, worker_id, server_url, credentials=None, + requestRecorder=None, responsesByRequestType=None): + super(_MockBundleProcessorTaskWorker, self).__init__( + worker_id, + server_url, + credentials=credentials + ) + self.requestRecorder = requestRecorder or [] + self.responsesByRequestType = responsesByRequestType or {} + + def do_instruction(self, request): + # type: (beam_task_worker_pb2.TaskInstructionRequest) -> beam_task_worker_pb2.TaskInstructionResponse + request_type = request.WhichOneof('request') + self.requestRecorder.append(request) + return self.responsesByRequestType.get(request_type) + + +@handlers.TaskWorkerHandler.register_urn('unittest') +class _MockTaskWorkerHandler(handlers.TaskWorkerHandler): + """ + Register a mocked version of task handler only used for "unittest"; will start + a ``_MockBundleProcessorTaskWorker`` for each discovered TaskableValue. + + Main difference is that it returns the started task worker object, for easier + testing. + """ + + def __init__(self, + state, # type: CachingStateHandler + provision_info, + # type: Union[beam_provision_api_pb2.ProvisionInfo, ExtendedProvisionInfo] + grpc_server, # type: TaskGrpcServer + environment, # type: Environment + task_payload, # type: Any + credentials=None, # type: Optional[str] + worker_id=None, # type: Optional[str] + responseByRequestType=None + ): + provision_info = beam_provision_api_pb2.ProvisionInfo() + super(_MockTaskWorkerHandler, self).__init__(state, provision_info, + grpc_server, environment, + task_payload, + credentials=credentials, + worker_id=worker_id) + self.responseByRequestType = responseByRequestType + + def start_remote(self): + # type: () -> _MockBundleProcessorTaskWorker + """starts a task worker local to the task worker handler.""" + obj = _MockBundleProcessorTaskWorker( + self.worker_id, + self.control_address, + self.credentials, + responsesByRequestType=self.responseByRequestType) + run_thread = threading.Thread(target=obj.run) + run_thread.daemon = True + run_thread.start() + + return obj + + def start_worker(self): + # type: () -> _MockBundleProcessorTaskWorker + return self.start_remote() + + +class _MockTaskGrpcServer(handlers.TaskGrpcServer): + """ + Mocked version of TaskGrpcServer, using mocked version of data channel factory + and cache handler. + """ + + def __init__(self, instruction_id, max_workers=1, data_store=None): + dummy_state_handler = _MockCachingStateHandler(None, None) + dummy_data_channel_factory = GrpcClientDataChannelFactory() + provision_info = beam_provision_api_pb2.ProvisionInfo() + super(_MockTaskGrpcServer, self).__init__(dummy_state_handler, max_workers, + data_store or {}, + dummy_data_channel_factory, + instruction_id, provision_info) + self.control_address = 'localhost:{}'.format(self.control_port) + control_descriptor = text_format.MessageToString( + endpoints_pb2.ApiServiceDescriptor(url=self.control_address)) + print(control_descriptor) + os.environ['CONTROL_API_SERVICE_DESCRIPTOR'] = control_descriptor + + +class _MockCachingStateHandler(CachingStateHandler): + """ + Mocked CachingStateHandler, mainly for patching the thread local variable + ``_context`` and create the ``cache_token`` attribute on it. + """ + + def __init__(self, underlying_state, global_state_cache): + self._underlying = underlying_state + self._state_cache = global_state_cache + self._context = threading.local() + self._context.bundle_cache_token = '' + + +class _MockDataInputOperation(object): + """ + A mocked version of DataInputOperation, responsible for recording and decoding + data for testing. + """ + + def __init__(self, coder): + self.coder = coder + self.decoded = [] + self.splitting_lock = threading.Lock() + self.windowed_coder_impl = self.coder.get_impl() + + with self.splitting_lock: + self.index = -1 + self.stop = float('inf') + + def output(self, decoded_value): + self.decoded.append(decoded_value) + + +def prep_responses_by_request_type(worker_id, delayed_applications=(), + require_finalization=False, process_error=None, + shutdown_error=None): + return { + 'create': beam_task_worker_pb2.TaskInstructionResponse( + create=beam_task_worker_pb2.CreateResponse(), + instruction_id=worker_id + ), + 'process_bundle': beam_task_worker_pb2.TaskInstructionResponse( + process_bundle=beam_task_worker_pb2.ProcessorProcessBundleResponse( + delayed_applications=list(delayed_applications), + require_finalization=require_finalization), + instruction_id=worker_id, + error=process_error + ), + 'shutdown': beam_task_worker_pb2.TaskInstructionResponse( + shutdown=beam_task_worker_pb2.ShutdownResponse(), + instruction_id=worker_id, + error=shutdown_error + ) + } + + +def prep_bundle_processor_descriptor(bundle_id): + return beam_fn_api_pb2.ProcessBundleDescriptor( + id='test_bundle_{}'.format(bundle_id), + transforms={ + str(bundle_id): beam_runner_api_pb2.PTransform(unique_name=str(bundle_id)) + }) + + +class TaskWorkerHandlerTest(unittest.TestCase): + + def setUp(self): + # put endpoints environment variable in for testing + logging_descriptor = text_format.MessageToString( + endpoints_pb2.ApiServiceDescriptor(url='localhost:10000')) + os.environ['LOGGING_API_SERVICE_DESCRIPTOR'] = logging_descriptor + + @staticmethod + def _get_task_worker_handler(worker_id, resp_by_type, instruction_id, + max_workers=1, data_store=None): + server = _MockTaskGrpcServer(instruction_id, max_workers=max_workers, + data_store=data_store) + return _MockTaskWorkerHandler(server.state_handler, None, server, None, + None, + worker_id=worker_id, + responseByRequestType=resp_by_type) + + def test_execute_success(self): + """ + Test when a TaskWorkerHandler successfully executed one life cycle. + """ + dummy_process_bundle_descriptor = prep_bundle_processor_descriptor(1) + + worker_id = 'test_task_worker_1' + instruction_id = 'test_instruction_1' + resp_by_type = prep_responses_by_request_type(worker_id) + + test_handler = self._get_task_worker_handler(worker_id, resp_by_type, + instruction_id) + + proxy_data_channel_factory = handlers.ProxyGrpcClientDataChannelFactory( + test_handler._grpc_server.data_address + ) + + test_worker = test_handler.start_worker() + + try: + delayed, requests = test_handler.execute(proxy_data_channel_factory, + dummy_process_bundle_descriptor) + self.assertEquals(len(delayed), 0) + self.assertEquals(requests, False) + finally: + test_handler._grpc_server.close() + + # check that the requests we received are as expected + expected = [ + beam_task_worker_pb2.TaskInstructionRequest( + instruction_id=worker_id, + create=beam_task_worker_pb2.CreateRequest( + process_bundle_descriptor=dummy_process_bundle_descriptor, + state_handler_endpoint=endpoints_pb2.ApiServiceDescriptor( + url=test_handler._grpc_server.state_address), + data_factory=beam_task_worker_pb2.GrpcClientDataChannelFactory( + transmitter_url=proxy_data_channel_factory.transmitter_url, + worker_id=proxy_data_channel_factory.worker_id, + credentials=proxy_data_channel_factory._credentials + ))), + beam_task_worker_pb2.TaskInstructionRequest( + instruction_id=worker_id, + process_bundle=beam_task_worker_pb2.ProcessorProcessBundleRequest()), + beam_task_worker_pb2.TaskInstructionRequest( + instruction_id=worker_id, + shutdown=beam_task_worker_pb2.ShutdownRequest()) + ] + + self.assertEquals(test_worker.requestRecorder, expected) + + def test_execute_failure(self): + """ + Test when a TaskWorkerHandler fails to process a bundle. + """ + dummy_process_bundle_descriptor = prep_bundle_processor_descriptor(1) + + worker_id = 'test_task_worker_1' + instruction_id = 'test_instruction_1' + resp_by_type = prep_responses_by_request_type(worker_id, process_error='error') + + test_handler = self._get_task_worker_handler(worker_id, resp_by_type, + instruction_id) + proxy_data_channel_factory = handlers.ProxyGrpcClientDataChannelFactory( + test_handler._grpc_server.data_address + ) + + test_handler.start_worker() + + try: + with self.assertRaises(handlers.TaskWorkerProcessBundleError): + print(test_handler.execute) + test_handler.execute( + proxy_data_channel_factory, + dummy_process_bundle_descriptor) + + test_handler.stop_worker() + finally: + test_handler._grpc_server.close() + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/apache_beam/task/__init__.py b/sdks/python/apache_beam/task/__init__.py new file mode 100644 index 000000000000..6569e3fe5de4 --- /dev/null +++ b/sdks/python/apache_beam/task/__init__.py @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import absolute_import diff --git a/sdks/python/apache_beam/task/task_worker/__init__.py b/sdks/python/apache_beam/task/task_worker/__init__.py new file mode 100644 index 000000000000..6569e3fe5de4 --- /dev/null +++ b/sdks/python/apache_beam/task/task_worker/__init__.py @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import absolute_import diff --git a/sdks/python/apache_beam/task/task_worker/kubejob/__init__.py b/sdks/python/apache_beam/task/task_worker/kubejob/__init__.py new file mode 100644 index 000000000000..6569e3fe5de4 --- /dev/null +++ b/sdks/python/apache_beam/task/task_worker/kubejob/__init__.py @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import absolute_import diff --git a/sdks/python/apache_beam/task/task_worker/kubejob/handler.py b/sdks/python/apache_beam/task/task_worker/kubejob/handler.py new file mode 100644 index 000000000000..a14716b2908e --- /dev/null +++ b/sdks/python/apache_beam/task/task_worker/kubejob/handler.py @@ -0,0 +1,195 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Kubernetes task worker implementation. +""" + +from __future__ import absolute_import + +import copy +import threading +import time +from typing import TYPE_CHECKING +from typing import NamedTuple + +from apache_beam.runners.worker.task_worker.handlers import TaskWorkerHandler + +# This module will be imported by the task worker handler plugin system +# regardless of whether the kubernetes API is installed. It must be safe to +# import whether it will be used or not. +try: + import kubernetes.client as client + import kubernetes.config as config + from kubernetes.client.rest import ApiException +except ImportError: + client = None + config = None + ApiException = None + +if TYPE_CHECKING: + from typing import List + +__all__ = [ + 'KubeTaskWorkerHandler', +] + + +class KubePayload(object): + """ + Object for holding attributes for a kubernetes job. + """ + + def __init__(self, job, namespace='default'): + # type: (client.V1Job, str) -> None + self.job = job + self.namespace = namespace + + @property + def job_name(self): + return self.job.metadata.name + + @job_name.setter + def job_name(self, value): + self.job.metadata.name = value + self.job.spec.template.metadata.labels['app'] = value + for container in self.job.spec.template.spec.containers: + container.name = value + + +class KubeJobManager(object): + """ + Monitors jobs submitted by the KubeTaskWorkerHandler. Responsible for + notifying `watch`ed handlers if their Kubernetes job is deleted. + """ + + _thread = None # type: threading.Thread + _lock = threading.Lock() + + def __init__(self): + self._handlers = [] # type: List[KubeTaskWorkerHandler] + + def is_started(self): + """ + Return if the manager is currently running or not. + """ + return self._thread is not None + + def _start(self): + self._thread = threading.Thread(target=self._run) + self._thread.daemon = True + self._thread.start() + + def _run(self, interval=5.0): + while True: + for handler in self._handlers: + # If the handler thinks it's alive but it's not actually, change its + # alive state. + if handler.alive and not handler.job_exists(): + handler.alive = False + time.sleep(interval) + + def watch(self, handler): + # type: (KubeTaskWorkerHandler) -> None + """ + Monitor the passed handler checking periodically that the job is still + running. + """ + if not self.is_started(): + self._start() + self._handlers.append(handler) + + +@TaskWorkerHandler.register_urn('k8s') +class KubeTaskWorkerHandler(TaskWorkerHandler): + """ + The Kubernetes task handler. + """ + + _lock = threading.Lock() + _monitor = None # type: KubeJobManager + + api = None # type: client.BatchV1Api + + @property + def monitor(self): + # type: () -> KubeJobManager + if KubeTaskWorkerHandler._monitor is None: + with KubeJobManager._lock: + KubeTaskWorkerHandler._monitor = KubeJobManager() + return KubeTaskWorkerHandler._monitor + + def job_exists(self): + # type: () -> bool + """ + Return whether or not the Kubernetes job exists. + """ + try: + self.api.read_namespaced_job_status( + self.task_payload.job.metadata.name, self.task_payload.namespace) + except ApiException: + return False + return True + + def submit_job(self, payload): + # type: (KubePayload) -> client.V1Job + """ + Submit a Kubernetes job. + """ + # Patch some handler specific env variables into the job + job = copy.deepcopy(payload.job) # type: client.V1Job + + env = [ + client.V1EnvVar(name='TASK_WORKER_ID', value=self.worker_id), + client.V1EnvVar(name='TASK_WORKER_CONTROL_ADDRESS', + value=self.control_address), + ] + if self.credentials: + env.extend([ + client.V1EnvVar(name='TASK_WORKER_CREDENTIALS', + value=self.credentials), + ]) + + for container in job.spec.template.spec.containers: + if container.env is None: + container.env = [] + container.env.extend(env) + + return self.api.create_namespaced_job( + body=job, + namespace=payload.namespace) + + def delete_job(self): + # type: () -> client.V1Status + """ + Delete the kubernetes job. + """ + return self.api.delete_namespaced_job( + name=self.task_payload.job.metadata.name, + namespace=self.task_payload.namespace, + body=client.V1DeleteOptions( + propagation_policy='Foreground', + grace_period_seconds=5)) + + def start_remote(self): + # type: () -> None + with KubeTaskWorkerHandler._lock: + config.load_kube_config() + self.api = client.BatchV1Api() + + self.submit_job(self.task_payload) + self.monitor.watch(self) diff --git a/sdks/python/apache_beam/task/task_worker/kubejob/transforms.py b/sdks/python/apache_beam/task/task_worker/kubejob/transforms.py new file mode 100644 index 000000000000..17a59a5c4566 --- /dev/null +++ b/sdks/python/apache_beam/task/task_worker/kubejob/transforms.py @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Transforms for kubernetes task workers. +""" + +from __future__ import absolute_import + +import re +from typing import TYPE_CHECKING + +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.transforms.environments import DockerEnvironment +from apache_beam.runners.worker.task_worker.core import BeamTask + +try: + import kubernetes.client +except ImportError: + raise ImportError( + 'Kubernetes task worker is not supported by this environment ' + '(could not import kubernetes API).') + +if TYPE_CHECKING: + from apache_beam.task.task_worker.kubejob.handler import KubePayload + +__all__ = [ + 'KubeJobOptions', + 'KubeTask', +] + + +class KubeJobOptions(PipelineOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_argument( + '--namespace', + type=str, + default='default', + help='The namespace to submit kubernetes task worker jobs to.') + + +class KubeTask(BeamTask): + """ + Kubernetes task worker transform. + """ + + urn = 'k8s' + + K8S_JOB_NAME_RE = re.compile(r'[a-z][a-zA-Z0-9-.]*') + + def get_payload(self, options): + # type: (PipelineOptions) -> KubePayload + """ + Get a task payload for configuring a kubernetes job. + """ + from apache_beam.task.task_worker.kubejob.handler import KubePayload + + name = self.label + name = name[0].lower() + name[1:] + if not self.K8S_JOB_NAME_RE.match(name): + raise ValueError( + 'Kubernetes job name must start with a lowercase letter and use ' + 'only - or . special characters') + + image = DockerEnvironment.get_container_image_from_options(options) + + container = kubernetes.client.V1Container( + name=name, + image=image, + command=['python', '-m', self.SDK_HARNESS_ENTRY_POINT], + env=[]) + + template = kubernetes.client.V1PodTemplateSpec( + metadata=kubernetes.client.V1ObjectMeta( + labels={'app': name}), + spec=kubernetes.client.V1PodSpec( + restart_policy='Never', + containers=[container])) + + spec = kubernetes.client.V1JobSpec( + template=template, + backoff_limit=4) + + job = kubernetes.client.V1Job( + api_version='batch/v1', + kind='Job', + metadata=kubernetes.client.V1ObjectMeta(name=name), + spec=spec) + + return KubePayload( + job, + namespace=options.view_as(KubeJobOptions).namespace) diff --git a/sdks/python/apache_beam/transforms/environments.py b/sdks/python/apache_beam/transforms/environments.py index 9d48e892371d..5cf522bd9a99 100644 --- a/sdks/python/apache_beam/transforms/environments.py +++ b/sdks/python/apache_beam/transforms/environments.py @@ -45,6 +45,7 @@ from google.protobuf import message from apache_beam import coders +from apache_beam.options.pipeline_options import PortableOptions from apache_beam.options.pipeline_options import SetupOptions from apache_beam.portability import common_urns from apache_beam.portability import python_urns @@ -55,7 +56,7 @@ from apache_beam.utils import proto_utils if TYPE_CHECKING: - from apache_beam.options.pipeline_options import PortableOptions + from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.runners.pipeline_context import PipelineContext __all__ = [ @@ -276,17 +277,25 @@ def from_runner_api_parameter(payload, # type: beam_runner_api_pb2.DockerPayloa artifacts=artifacts) @classmethod - def from_options(cls, options): - # type: (PortableOptions) -> DockerEnvironment + def get_container_image_from_options(cls, options): + # type: (PipelineOptions) -> str if options.view_as(SetupOptions).prebuild_sdk_container_engine: - prebuilt_container_image = SdkContainerImageBuilder.build_container_image( - options) - return cls.from_container_image( - container_image=prebuilt_container_image, - artifacts=python_sdk_dependencies(options)) + return SdkContainerImageBuilder.build_container_image(options) + portable_options = options.view_as(PortableOptions) + result = portable_options.lookup_environment_option( + 'docker_container_image') + if not result: + result = portable_options.environment_config + if not result: + result = cls.default_docker_image() + return result + + @classmethod + def from_options(cls, options): + # type: (PipelineOptions) -> DockerEnvironment + container_image = cls.get_container_image_from_options(options) return cls.from_container_image( - container_image=options.lookup_environment_option( - 'docker_container_image') or options.environment_config, + container_image=container_image, artifacts=python_sdk_dependencies(options)) @classmethod diff --git a/sdks/python/container/boot.go b/sdks/python/container/boot.go index ace661cd3bba..2d64e68ae240 100644 --- a/sdks/python/container/boot.go +++ b/sdks/python/container/boot.go @@ -50,17 +50,21 @@ var ( // Contract: https://s.apache.org/beam-fn-api-container-contract. - workerPool = flag.Bool("worker_pool", false, "Run as worker pool (optional).") - id = flag.String("id", "", "Local identifier (required).") - loggingEndpoint = flag.String("logging_endpoint", "", "Logging endpoint (required).") - artifactEndpoint = flag.String("artifact_endpoint", "", "Artifact endpoint (required).") - provisionEndpoint = flag.String("provision_endpoint", "", "Provision endpoint (required).") - controlEndpoint = flag.String("control_endpoint", "", "Control endpoint (required).") - semiPersistDir = flag.String("semi_persist_dir", "/tmp", "Local semi-persistent directory (optional).") + workerPool = flag.Bool("worker_pool", false, "Run as worker pool (optional).") + id = flag.String("id", "", "Local identifier (required).") + loggingEndpoint = flag.String("logging_endpoint", "", "Logging endpoint (required).") + artifactEndpoint = flag.String("artifact_endpoint", "", "Artifact endpoint (required).") + provisionEndpoint = flag.String("provision_endpoint", "", "Provision endpoint (required).") + controlEndpoint = flag.String("control_endpoint", "", "Control endpoint (required).") + semiPersistDir = flag.String("semi_persist_dir", "/tmp", "Local semi-persistent directory (optional).") + // this allows us to override entry point so it can be used by task worker + sdkHarnessEntrypoint = flag.String("sdk_harness_entry_point", "apache_beam.runners.worker.sdk_worker_main", + "Entry point for the python process (optional).") + // this allows us to switch the python executable if needed (for example for use in DCC application) + pythonExec = flag.String("python_executable", "python", "Python executable to use (optional).") ) const ( - sdkHarnessEntrypoint = "apache_beam.runners.worker.sdk_worker_main" // Please keep these names in sync with stager.py workflowFile = "workflow.tar.gz" requirementsFile = "requirements.txt" @@ -179,6 +183,10 @@ func main() { os.Setenv("SEMI_PERSISTENT_DIRECTORY", *semiPersistDir) os.Setenv("LOGGING_API_SERVICE_DESCRIPTOR", proto.MarshalTextString(&pipepb.ApiServiceDescriptor{Url: *loggingEndpoint})) os.Setenv("CONTROL_API_SERVICE_DESCRIPTOR", proto.MarshalTextString(&pipepb.ApiServiceDescriptor{Url: *controlEndpoint})) + // we need to record the other endpoints here because task worker need to use boot to setup the environment for + // the sdk harness as well + os.Setenv("ARTIFACT_API_SERVICE_DESCRIPTOR", proto.MarshalTextString(&pipepb.ApiServiceDescriptor{Url: *artifactEndpoint})) + os.Setenv("PROVISION_API_SERVICE_DESCRIPTOR", proto.MarshalTextString(&pipepb.ApiServiceDescriptor{Url: *provisionEndpoint})) if info.GetStatusEndpoint() != nil { os.Setenv("STATUS_API_SERVICE_DESCRIPTOR", proto.MarshalTextString(info.GetStatusEndpoint())) @@ -186,11 +194,11 @@ func main() { args := []string{ "-m", - sdkHarnessEntrypoint, + *sdkHarnessEntrypoint, } - log.Printf("Executing: python %v", strings.Join(args, " ")) + log.Printf("Executing: %s %v", *pythonExec, strings.Join(args, " ")) - log.Fatalf("Python exited: %v", execx.Execute("python", args...)) + log.Fatalf("Python exited: %v", execx.Execute(*pythonExec, args...)) } // setup wheel specs according to installed python version diff --git a/sdks/python/setup.py b/sdks/python/setup.py index b792ddff28ca..c00f1603b7d7 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -141,6 +141,7 @@ def get_version(): # server, therefore list of allowed versions is very narrow. # See: https://github.com/uqfoundation/dill/issues/341. 'dill>=0.3.1.1,<0.3.2', + 'entrypoints>=0.3', 'fastavro>=0.21.4,<2', 'funcsigs>=1.0.2,<2; python_version < "3.0"', 'future>=0.18.2,<1.0.0', @@ -312,7 +313,8 @@ def run(self): 'interactive': INTERACTIVE_BEAM, 'interactive_test': INTERACTIVE_BEAM_TEST, 'aws': AWS_REQUIREMENTS, - 'azure': AZURE_REQUIREMENTS + 'azure': AZURE_REQUIREMENTS, + 'taskworker_k8s': ['kubernetes>=11.0.0'], }, zip_safe=False, # PyPI package information. @@ -333,6 +335,9 @@ def run(self): entry_points={ 'nose.plugins.0.10': [ 'beam_test_plugin = test_config:BeamTestPlugin', + ], + 'apache_beam_task_workers_plugins': [ + 'k8s_task_worker = apache_beam.task.task_worker.kubejob.handler:KubeTaskWorkerHandler', ]}, cmdclass={ 'build_py': generate_protos_first(build_py),