From c1774b1420ac1b71954da5451ee51aa6089ca2ee Mon Sep 17 00:00:00 2001 From: Viola Lyu Date: Wed, 19 Aug 2020 12:55:26 -0700 Subject: [PATCH 01/15] Update `boot.go` to record artifact and provision endpoints in env var --- sdks/python/container/boot.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sdks/python/container/boot.go b/sdks/python/container/boot.go index ace661cd3bba..a4b86563c497 100644 --- a/sdks/python/container/boot.go +++ b/sdks/python/container/boot.go @@ -179,6 +179,10 @@ func main() { os.Setenv("SEMI_PERSISTENT_DIRECTORY", *semiPersistDir) os.Setenv("LOGGING_API_SERVICE_DESCRIPTOR", proto.MarshalTextString(&pipepb.ApiServiceDescriptor{Url: *loggingEndpoint})) os.Setenv("CONTROL_API_SERVICE_DESCRIPTOR", proto.MarshalTextString(&pipepb.ApiServiceDescriptor{Url: *controlEndpoint})) + // we need to record the other endpoints here because task worker need to use boot to setup the environment for + // the sdk harness as well + os.Setenv("ARTIFACT_API_SERVICE_DESCRIPTOR", proto.MarshalTextString(&pipepb.ApiServiceDescriptor{Url: *artifactEndpoint})) + os.Setenv("PROVISION_API_SERVICE_DESCRIPTOR", proto.MarshalTextString(&pipepb.ApiServiceDescriptor{Url: *provisionEndpoint})) if info.GetStatusEndpoint() != nil { os.Setenv("STATUS_API_SERVICE_DESCRIPTOR", proto.MarshalTextString(info.GetStatusEndpoint())) From 21a58fbf7bd9a4893913c0f769f16b38a848be5c Mon Sep 17 00:00:00 2001 From: Viola Lyu Date: Fri, 14 Feb 2020 17:47:04 -0800 Subject: [PATCH 02/15] Update `boot.go` to take custom python executable --- sdks/python/container/boot.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sdks/python/container/boot.go b/sdks/python/container/boot.go index a4b86563c497..1962d2bec3b5 100644 --- a/sdks/python/container/boot.go +++ b/sdks/python/container/boot.go @@ -57,6 +57,8 @@ var ( provisionEndpoint = flag.String("provision_endpoint", "", "Provision endpoint (required).") controlEndpoint = flag.String("control_endpoint", "", "Control endpoint (required).") semiPersistDir = flag.String("semi_persist_dir", "/tmp", "Local semi-persistent directory (optional).") + // this allows us to switch the python executable if needed (for example for use in DCC application) + pythonExec = flag.String("python_executable", "python", "Python executable to use (optional).") ) const ( @@ -192,9 +194,9 @@ func main() { "-m", sdkHarnessEntrypoint, } - log.Printf("Executing: python %v", strings.Join(args, " ")) + log.Printf("Executing: %s %v", *pythonExec, strings.Join(args, " ")) - log.Fatalf("Python exited: %v", execx.Execute("python", args...)) + log.Fatalf("Python exited: %v", execx.Execute(*pythonExec, args...)) } // setup wheel specs according to installed python version From 5b97563bb8cc6fd1e947badff63d78bdd36da9b6 Mon Sep 17 00:00:00 2001 From: Viola Lyu Date: Fri, 14 Feb 2020 17:48:52 -0800 Subject: [PATCH 03/15] Update `boot.go` to take custom entry point --- sdks/python/container/boot.go | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/sdks/python/container/boot.go b/sdks/python/container/boot.go index 1962d2bec3b5..2d64e68ae240 100644 --- a/sdks/python/container/boot.go +++ b/sdks/python/container/boot.go @@ -50,19 +50,21 @@ var ( // Contract: https://s.apache.org/beam-fn-api-container-contract. - workerPool = flag.Bool("worker_pool", false, "Run as worker pool (optional).") - id = flag.String("id", "", "Local identifier (required).") - loggingEndpoint = flag.String("logging_endpoint", "", "Logging endpoint (required).") - artifactEndpoint = flag.String("artifact_endpoint", "", "Artifact endpoint (required).") - provisionEndpoint = flag.String("provision_endpoint", "", "Provision endpoint (required).") - controlEndpoint = flag.String("control_endpoint", "", "Control endpoint (required).") - semiPersistDir = flag.String("semi_persist_dir", "/tmp", "Local semi-persistent directory (optional).") + workerPool = flag.Bool("worker_pool", false, "Run as worker pool (optional).") + id = flag.String("id", "", "Local identifier (required).") + loggingEndpoint = flag.String("logging_endpoint", "", "Logging endpoint (required).") + artifactEndpoint = flag.String("artifact_endpoint", "", "Artifact endpoint (required).") + provisionEndpoint = flag.String("provision_endpoint", "", "Provision endpoint (required).") + controlEndpoint = flag.String("control_endpoint", "", "Control endpoint (required).") + semiPersistDir = flag.String("semi_persist_dir", "/tmp", "Local semi-persistent directory (optional).") + // this allows us to override entry point so it can be used by task worker + sdkHarnessEntrypoint = flag.String("sdk_harness_entry_point", "apache_beam.runners.worker.sdk_worker_main", + "Entry point for the python process (optional).") // this allows us to switch the python executable if needed (for example for use in DCC application) - pythonExec = flag.String("python_executable", "python", "Python executable to use (optional).") + pythonExec = flag.String("python_executable", "python", "Python executable to use (optional).") ) const ( - sdkHarnessEntrypoint = "apache_beam.runners.worker.sdk_worker_main" // Please keep these names in sync with stager.py workflowFile = "workflow.tar.gz" requirementsFile = "requirements.txt" @@ -192,7 +194,7 @@ func main() { args := []string{ "-m", - sdkHarnessEntrypoint, + *sdkHarnessEntrypoint, } log.Printf("Executing: %s %v", *pythonExec, strings.Join(args, " ")) From b611cdb7fce43b99e82e379551a0733fe7e75d9d Mon Sep 17 00:00:00 2001 From: Viola Lyu Date: Wed, 14 Oct 2020 12:29:19 -0700 Subject: [PATCH 04/15] [Task Worker] Added task worker system to BundleProcessor. Rework the task work plugin --- .../src/main/proto/beam_task_worker.proto | 137 ++ .../runners/worker/bundle_processor.py | 97 +- .../apache_beam/runners/worker/task_worker.py | 1154 +++++++++++++++++ .../runners/worker/task_worker_main.py | 69 + .../runners/worker/task_worker_test.py | 409 ++++++ 5 files changed, 1857 insertions(+), 9 deletions(-) create mode 100644 model/fn-execution/src/main/proto/beam_task_worker.proto create mode 100644 sdks/python/apache_beam/runners/worker/task_worker.py create mode 100644 sdks/python/apache_beam/runners/worker/task_worker_main.py create mode 100644 sdks/python/apache_beam/runners/worker/task_worker_test.py diff --git a/model/fn-execution/src/main/proto/beam_task_worker.proto b/model/fn-execution/src/main/proto/beam_task_worker.proto new file mode 100644 index 000000000000..c3ab5c7a6fc4 --- /dev/null +++ b/model/fn-execution/src/main/proto/beam_task_worker.proto @@ -0,0 +1,137 @@ +/* + * Protocol Buffers describing a custom bundle processer used by Task Worker. + */ + +syntax = "proto3"; + +package org.apache.beam.model.fn_execution.v1; + +option go_package = "fnexecution_v1"; +option java_package = "org.apache.beam.model.fnexecution.v1"; +//option java_outer_classname = "JobApi"; + +import "beam_fn_api.proto"; +import "endpoints.proto"; +import "google/protobuf/struct.proto"; +import "metrics.proto"; + + +// +// Control Plane +// + +// An API that describes the work that a bundle processor task worker is meant to do. +service TaskControl { + + // Instructions sent by the SDK to the task worker requesting different types + // of work. + rpc Control ( + // A stream of responses to instructions the task worker was asked to be + // performed. + stream TaskInstructionResponse) + returns ( + // A stream of instructions requested of the task worker to be performed. + stream TaskInstructionRequest); +} + +// A request sent by SDK which the task worker is asked to fulfill. +message TaskInstructionRequest { + // (Required) An unique identifier provided by the SDK which represents + // this requests execution. The InstructionResponse MUST have the matching id. + string instruction_id = 1; + + // (Required) A request that the task worker needs to interpret. + oneof request { + CreateRequest create = 1000; + ProcessorProcessBundleRequest process_bundle = 1001; + ShutdownRequest shutdown = 1002; + } +} + +// The response for an associated request the task worker had been asked to fulfill. +message TaskInstructionResponse { + + // (Required) A reference provided by the SDK which represents a requests + // execution. The InstructionResponse MUST have the matching id when + // responding to the SDK. + string instruction_id = 1; + + // An equivalent response type depending on the request this matches. + oneof response { + CreateResponse create = 1000; + ProcessorProcessBundleResponse process_bundle = 1001; + ShutdownResponse shutdown = 1002; + } + + // (Optional) If there's error processing request + string error = 2; +} + +message ChannelCredentials { + string _credentials = 1; +} + +message GrpcClientDataChannelFactory { + ChannelCredentials credentials = 1; + string worker_id = 2; + string transmitter_url = 3; +} + +message CreateRequest { + org.apache.beam.model.fn_execution.v1.ProcessBundleDescriptor process_bundle_descriptor = 1; // (required) + org.apache.beam.model.pipeline.v1.ApiServiceDescriptor state_handler_endpoint = 2; // (required) + GrpcClientDataChannelFactory data_factory = 3; // (required) +} + +message CreateResponse { +} + +message ProcessorProcessBundleRequest { + // (Optional) The cache token that can be used by an SDK to reuse + // cached data returned by the State API across multiple bundles. + string cache_token = 1; +} + +message ProcessorProcessBundleResponse { + repeated org.apache.beam.model.fn_execution.v1.DelayedBundleApplication delayed_applications = 1; + bool require_finalization = 2; +} + +message ShutdownRequest { +} + +message ShutdownResponse { +} + + +// +// Data Plane +// + +service TaskFnData { + // Handles data transferring between TaskWorkerHandler and Task Worker. + rpc Receive( + ReceiveRequest) + returns ( + // A stream of data representing output. + stream Elements.Data); + + + // Used to send data from proxy bundle processor to sdk harness + rpc Send (SendRequest) returns (SendResponse); +} + +message ReceiveRequest { + string instruction_id = 1; + string client_data_endpoint = 2; +} + +message SendRequest { + string instruction_id = 1; + string client_data_endpoint = 2; + Elements.Data data = 3; +} + +message SendResponse { + string error = 1; +} diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 3e600e2d5b95..2671d8b3f4f7 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -928,8 +928,8 @@ def reset(self): for op in self.ops.values(): op.reset() - def process_bundle(self, instruction_id): - # type: (str) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] + def process_bundle(self, instruction_id, use_task_worker=True): + # type: (str, bool) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] expected_input_ops = [] # type: List[DataInputOperation] @@ -979,6 +979,8 @@ def process_bundle(self, instruction_id): timer_info.output_stream = output_stream self.ops[transform_id].add_timer_info(timer_family_id, timer_info) + task_helper = BundleProcessorTaskHelper(instruction_id, + use_task_worker=use_task_worker) # Process data and timer inputs for data_channel, expected_inputs in data_channels.items(): for element in data_channel.input_elements(instruction_id, @@ -992,8 +994,14 @@ def process_bundle(self, instruction_id): self.ops[element.transform_id].process_timer( element.timer_family_id, timer_data) elif isinstance(element, beam_fn_api_pb2.Elements.Data): - input_op_by_transform_id[element.transform_id].process_encoded( - element.data) + input_op = input_op_by_transform_id[element.transform_id] + # decode inputs to inspect if it is wrapped + task_helper.process_encoded(input_op, element) + + delayed_applications, requires_finalization = \ + self.maybe_process_remotely(data_channels, instruction_id, + input_op_by_transform_id, + use_task_worker=use_task_worker) # Finish all operations. for op in self.ops.values(): @@ -1005,11 +1013,15 @@ def process_bundle(self, instruction_id): assert timer_info.output_stream is not None timer_info.output_stream.close() - return ([ - self.delayed_bundle_application(op, residual) for op, - residual in execution_context.delayed_applications - ], - self.requires_finalization()) + if requires_finalization is None: + requires_finalization = self.requires_finalization() + + if delayed_applications is None: + delayed_applications = [self.delayed_bundle_application(op, residual) + for op, residual in + execution_context.delayed_applications] + + return delayed_applications, requires_finalization finally: # Ensure any in-flight split attempts complete. @@ -1017,6 +1029,73 @@ def process_bundle(self, instruction_id): pass self.state_sampler.stop_if_still_running() + def maybe_process_remotely(self, + data_channels, # type: DefaultDict[DataChannel, list[str]] + instruction_id, # type: str + input_op_by_transform_id, # type: Dict[str, DataInputOperation] + use_task_worker=True # type: bool + ): + # type: (...) -> Union[Tuple[None, None], Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool]] + """Process the current bundle remotely with task workers, if applicable. + + Processes remotely if ``wrapped_values`` is not None (meaning there are + TaskableValue detected from input of this bundle) and task worker is + allowed to be used. + """ + from apache_beam.runners.worker.task_worker import BundleProcessorTaskHelper + from apache_beam.runners.worker.task_worker import get_taskable_value + + wrapped_values = collections.defaultdict(list) # type: DefaultDict[str, List[Tuple[Any, bytes]]] + + for data_channel, expected_transforms in data_channels.items(): + for data in data_channel.input_elements( + instruction_id, expected_transforms): + input_op = input_op_by_transform_id[data.transform_id] + + # process normally if not using task worker + if use_task_worker is False: + input_op.process_encoded(data.data) + continue + + # decode inputs to inspect if it is wrapped + input_stream = coder_impl.create_InputStream(data.data) + # TODO: Come up with a better solution here? + # here we maintain two separate input stream, because `.pos` is not + # accessible on the cython version of InputStream object, so we can't re + # -use the same input stream object and edit the pos to move the read + # handle back + raw_input_stream = coder_impl.create_InputStream(data.data) + + while input_stream.size() > 0: + starting_size = input_stream.size() + decoded_value = input_op.windowed_coder_impl.decode_from_stream( + input_stream, True) + cur_size = input_stream.size() + raw_bytes = raw_input_stream.read(starting_size - cur_size) + # make sure these two stream stays in sync in size + assert raw_input_stream.size() == input_stream.size() + if decoded_value.value and get_taskable_value(decoded_value.value): + # save this to process later in ``process_bundle_with_task_workers`` + wrapped_values[data.transform_id].append( + (get_taskable_value(decoded_value.value), raw_bytes)) + else: + # fallback to process it as normal, trigger receivers to process + with input_op.splitting_lock: + if input_op.index == input_op.stop - 1: + return + input_op.index += 1 + input_op.output(decoded_value) + + if wrapped_values: + task_helper = BundleProcessorTaskHelper(instruction_id, wrapped_values) + return task_helper.process_bundle_with_task_workers( + self.state_handler, + self.data_channel_factory, + self.process_bundle_descriptor + ) + else: + return None, None + def finalize_bundle(self): # type: () -> beam_fn_api_pb2.FinalizeBundleResponse for op in self.ops.values(): diff --git a/sdks/python/apache_beam/runners/worker/task_worker.py b/sdks/python/apache_beam/runners/worker/task_worker.py new file mode 100644 index 000000000000..5136e6b851b2 --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/task_worker.py @@ -0,0 +1,1154 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import absolute_import +from __future__ import division + +import collections +import logging +import queue +import os +import sys +import threading +from builtins import object +from concurrent import futures +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import DefaultDict +from typing import Dict +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import Union + +import grpc +from future.utils import raise_ + +from apache_beam.portability.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_fn_api_pb2_grpc +from apache_beam.portability.api import beam_task_worker_pb2 +from apache_beam.portability.api import beam_task_worker_pb2_grpc +from apache_beam.portability.api import endpoints_pb2 +from apache_beam.runners.portability.fn_api_runner import ControlConnection +from apache_beam.runners.portability.fn_api_runner import ControlFuture +from apache_beam.runners.portability.fn_api_runner import FnApiRunner +from apache_beam.runners.portability.fn_api_runner import GrpcWorkerHandler +from apache_beam.runners.portability.fn_api_runner import WorkerHandler +from apache_beam.runners.worker.bundle_processor import BundleProcessor +from apache_beam.runners.worker.channel_factory import GRPCChannelFactory +from apache_beam.runners.worker.data_plane import _GrpcDataChannel +from apache_beam.runners.worker.data_plane import ClosableOutputStream +from apache_beam.runners.worker.data_plane import DataChannelFactory +from apache_beam.runners.worker.statecache import StateCache +from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor + +if TYPE_CHECKING: + from apache_beam.portability.api import beam_provision_api_pb2 + from apache_beam.runners.portability.fn_api_runner import ExtendedProvisionInfo + from apache_beam.runners.worker.data_plane import DataChannelFactory + from apache_beam.runners.worker.sdk_worker import CachingStateHandler + from apache_beam.transforms.environments import Environment + +ENTRY_POINT_NAME = 'apache_beam_task_workers_plugins' +MAX_TASK_WORKERS = 300 +MAX_TASK_WORKER_RETRY = 10 + +Taskable = Union[ + 'TaskableValue', + List['TaskableValue'], + Tuple['TaskableValue', ...], + Set['TaskableValue']] + + +class TaskableValue(object): + """ + Value that can be distributed to TaskWorkers as tasks. + + Has the original value, and TaskProperties that specifies how the task will be + generated.""" + + def __init__(self, + value, # type: Any + urn, # type: str + env=None, # type: Optional[Environment] + payload=None # type: Optional[Any] + ): + # type: (...) -> None + """ + Args: + value: The wrapped element + urn: id of the task worker handler + env : Environment for the task to be run in + payload : Payload containing settings for the task worker handler + """ + self.value = value + self.urn = urn + self.env = env + self.payload = payload + + +class TaskWorkerProcessBundleError(Exception): + """ + Error thrown when TaskWorker fails to process_bundle. + + Errors encountered when task worker is processing bundle can be retried, up + till max retries defined by ``MAX_TASK_WORKER_RETRY``. + """ + + +class TaskWorkerTerminatedError(Exception): + """ + Error thrown when TaskWorker terminated before it finished working. + + Custom TaskWorkerHandlers can choose to terminate a task but not + affect the whole bundle by setting ``TaskWorkerHandler.alive`` to False, + which will cause this error to be thrown. + """ + + +class TaskWorkerHandler(GrpcWorkerHandler): + """ + Abstract base class for TaskWorkerHandler for a task worker, + + A TaskWorkerHandler is created for each TaskableValue. + + Subclasses must override ``start_remote`` to modify how remote task worker is + started, and register task properties type when defining a subclass. + """ + + _known_urns = {} # type: Dict[str, Type[TaskWorkerHandler]] + + def __init__(self, + state, # type: CachingStateHandler + provision_info, # type: Union[beam_provision_api_pb2.ProvisionInfo, ExtendedProvisionInfo] + grpc_server, # type: TaskGrpcServer + environment, # type: Environment + task_payload, # type: Any + credentials=None, # type: Optional[str] + worker_id=None # type: Optional[str] + ): + # type: (...) -> None + self._grpc_server = grpc_server + + # we are manually doing init instead of calling GrpcWorkerHandler's init + # because we want to override worker_id and we don't want extra + # ControlConnection to be established + WorkerHandler.__init__(self, grpc_server.control_handler, + grpc_server.data_plane_handler, state, + provision_info) + # override worker_id if provided + if worker_id: + self.worker_id = worker_id + + self.control_address = self.port_from_worker(self._grpc_server.control_port) + self.logging_address = self.port_from_worker( + self.get_port_from_env_var('LOGGING_API_SERVICE_DESCRIPTOR')) + self.artifact_address = self.port_from_worker( + self.get_port_from_env_var('ARTIFACT_API_SERVICE_DESCRIPTOR')) + self.provision_address = self.port_from_worker( + self.get_port_from_env_var('PROVISION_API_SERVICE_DESCRIPTOR')) + + self.control_conn = self._grpc_server.control_handler.get_conn_by_worker_id( + self.worker_id) + + self.environment = environment + self.task_payload = task_payload + self.credentials = credentials + self.alive = True + + def host_from_worker(self): + # type: () -> str + import socket + return socket.getfqdn() + + def get_port_from_env_var(self, env_var): + # type: (str) -> str + """Extract the service port for a given environment variable.""" + from google.protobuf import text_format + endpoint = endpoints_pb2.ApiServiceDescriptor() + text_format.Merge(os.environ[env_var], endpoint) + return endpoint.url.split(':')[-1] + + @staticmethod + def load_plugins(): + # type: () -> None + import entrypoints + for name, entry_point in entrypoints.get_group_named( + 'apache_beam_task_workers_plugins').iteritems(): + logging.info('Loading entry point: {}'.format(name)) + entry_point.load() + + @classmethod + def register_urn(cls, urn, constructor=None): + def register(constructor): + cls._known_urns[urn] = constructor + return constructor + if constructor: + return register(constructor) + else: + return register + + @classmethod + def create(cls, + state, # type: CachingStateHandler + provision_info, # type: Union[beam_provision_api_pb2.ProvisionInfo, ExtendedProvisionInfo] + grpc_server, # type: TaskGrpcServer + taskable_value, # type: TaskableValue + credentials=None, # type: Optional[str] + worker_id=None # type: Optional[str] + ): + # type: (...) -> TaskWorkerHandler + constructor = cls._known_urns[taskable_value.urn] + return constructor(state, provision_info, grpc_server, + taskable_value.env, taskable_value.payload, + credentials=credentials, worker_id=worker_id) + + def start_worker(self): + # type: () -> None + self.start_remote() + + def start_remote(self): + # type: () -> None + """Start up a remote TaskWorker to process the current element. + + Subclass should implement this.""" + raise NotImplementedError + + def stop_worker(self): + # type: () -> None + # send shutdown request + future = self.control_conn.push( + beam_task_worker_pb2.TaskInstructionRequest( + instruction_id=self.worker_id, + shutdown=beam_task_worker_pb2.ShutdownRequest())) + response = future.get() + if response.error: + logging.warning('Error stopping worker: {}'.format(self.worker_id)) + + # close control conn after stop worker + self.control_conn.close() + + def _get_future(self, future, interval=0.5): + # type: (ControlFuture, float) -> beam_task_worker_pb2.TaskInstructionResponse + result = None + while self.alive: + result = future.get(timeout=interval) + if result: + break + + # if the handler is not alive, meaning task worker is stopped before + # finishing processing, raise ``TaskWorkerTerminatedError`` + if result is None: + raise TaskWorkerTerminatedError() + + return result + + def execute(self, + data_channel_factory, # type: ProxyGrpcClientDataChannelFactory + process_bundle_descriptor # type: beam_fn_api_pb2.ProcessBundleDescriptor + ): + # type: (...) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] + """Main entry point of the task execution cycle of a ``TaskWorkerHandler``. + + It will first issue a create request, and wait for the remote bundle + processor to be created; Then it will issue the process bundle request, and + wait for the result. If there's error occurred when processing bundle, + ``TaskWorkerProcessBundleError`` will be raised. + """ + # wait for remote bundle processor to be created + create_future = self.control_conn.push( + beam_task_worker_pb2.TaskInstructionRequest( + instruction_id=self.worker_id, + create=beam_task_worker_pb2.CreateRequest( + process_bundle_descriptor=process_bundle_descriptor, + state_handler_endpoint=endpoints_pb2.ApiServiceDescriptor( + url=self._grpc_server.state_address), + data_factory=beam_task_worker_pb2.GrpcClientDataChannelFactory( + credentials=self.credentials, + worker_id=data_channel_factory.worker_id, + transmitter_url=data_channel_factory.transmitter_url)))) + self._get_future(create_future) + + # process bundle + process_future = self.control_conn.push( + beam_task_worker_pb2.TaskInstructionRequest( + instruction_id=self.worker_id, + process_bundle=beam_task_worker_pb2.ProcessorProcessBundleRequest()) + ) + response = self._get_future(process_future) + if response.error: + # raise here so this task can be retried + raise TaskWorkerProcessBundleError() + else: + delayed_applications = response.process_bundle.delayed_applications + require_finalization = response.process_bundle.require_finalization + + self.stop_worker() + return delayed_applications, require_finalization + + def reset(self): + # type: () -> None + """This is used to retry a failed task.""" + self.control_conn.reset() + + +class TaskGrpcServer(object): + """ + A collection of grpc servicers that handle communication between a + ``TaskWorker`` and ``TaskWorkerHandler``. + + Contains three servers: + - a control server hosting ``TaskControlService`` + - a data server hosting ``TaskFnDataService`` + - a state server hosting ``TaskStateService`` + + This is shared by all TaskWorkerHandlers generated by one bundle. + """ + + _DEFAULT_SHUTDOWN_TIMEOUT_SECS = 5 + + def __init__(self, + state_handler, # type: CachingStateHandler + max_workers, # type: int + data_store, # type: Mapping[str, List[beam_fn_api_pb2.Elements.Data]] + data_channel_factory, # type: DataChannelFactory + instruction_id # type: str + ): + # type: (...) -> None + self.state_handler = state_handler + self.max_workers = max_workers + self.control_server = grpc.server( + futures.ThreadPoolExecutor(max_workers=self.max_workers)) + self.control_port = self.control_server.add_insecure_port('[::]:0') + self.control_address = '%s:%s' % (self.get_host_name(), self.control_port) + + # Options to have no limits (-1) on the size of the messages + # received or sent over the data plane. The actual buffer size + # is controlled in a layer above. + no_max_message_sizes = [("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1)] + self.data_server = grpc.server( + futures.ThreadPoolExecutor(max_workers=self.max_workers), + options=no_max_message_sizes) + self.data_port = self.data_server.add_insecure_port('[::]:0') + self.data_address = '%s:%s' % (self.get_host_name(), self.data_port) + + self.state_server = grpc.server( + futures.ThreadPoolExecutor(max_workers=self.max_workers), + options=no_max_message_sizes) + self.state_port = self.state_server.add_insecure_port('[::]:0') + self.state_address = '%s:%s' % (self.get_host_name(), self.state_port) + + self.control_handler = TaskControlServicer() + beam_task_worker_pb2_grpc.add_TaskControlServicer_to_server( + self.control_handler, self.control_server) + # TODO: When we add provision / staging service, it needs to be added to the + # control server too + + self.data_plane_handler = TaskFnDataServicer(data_store, + data_channel_factory, + instruction_id) + beam_task_worker_pb2_grpc.add_TaskFnDataServicer_to_server( + self.data_plane_handler, self.data_server) + + beam_fn_api_pb2_grpc.add_BeamFnStateServicer_to_server( + TaskStateServicer(self.state_handler, instruction_id, + state_handler._context.cache_token), + self.state_server) + + logging.info('starting control server on port %s', self.control_port) + logging.info('starting data server on port %s', self.data_port) + logging.info('starting state server on port %s', self.state_port) + self.state_server.start() + self.data_server.start() + self.control_server.start() + + @staticmethod + def get_host_name(): + # type: () -> str + import socket + return socket.getfqdn() + + def close(self): + # type: () -> None + self.control_handler.done() + to_wait = [ + self.control_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS), + self.data_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS), + self.state_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS), + ] + for w in to_wait: + w.wait() + + +# ============= +# Control Plane +# ============= +class TaskWorkerConnection(ControlConnection): + """The control connection between a TaskWorker and a TaskWorkerHandler. + + TaskWorkerHandler push InstructionRequests to _push_queue, and receives + InstructionResponses from TaskControlServicer. + """ + + _lock = threading.Lock() + + def __init__(self): + self._push_queue = queue.Queue() + self._input = None + self._futures_by_id = {} # type: Dict[Any, ControlFuture] + self._read_thread = threading.Thread( + name='bundle_processor_control_read', target=self._read) + self._state = TaskControlServicer.UNSTARTED_STATE + # marks current TaskConnection as in a state of retrying after failure + self._retrying = False + + def _read(self): + # type: () -> None + for data in self._input: + self._futures_by_id.pop(data.WhichOneof('response')).set(data) + + def push(self, + req # type: Union[TaskControlServicer._DONE_MARKER, beam_task_worker_pb2.TaskInstructionRequest] + ): + # type: (...) -> Optional[ControlFuture] + if req == TaskControlServicer._DONE_MARKER: + self._push_queue.put(req) + return None + if not req.instruction_id: + raise RuntimeError( + 'TaskInstructionRequest has to have instruction id!') + future = ControlFuture(req.instruction_id) + self._futures_by_id[req.WhichOneof('request')] = future + self._push_queue.put(req) + return future + + def set_inputs(self, input): + with TaskWorkerConnection._lock: + if self._input and not self._retrying: + raise RuntimeError('input is already set.') + self._input = input + self._read_thread.start() + self._state = TaskControlServicer.STARTED_STATE + self._retrying = False + + def close(self): + # type: () -> None + with TaskWorkerConnection._lock: + if self._state == TaskControlServicer.STARTED_STATE: + self.push(TaskControlServicer._DONE_MARKER) + self._read_thread.join() + self._state = TaskControlServicer.DONE_STATE + + def reset(self): + # type: () -> None + self.close() + self.__init__() + self._retrying = True + + +class TaskControlServicer(beam_task_worker_pb2_grpc.TaskControlServicer): + + _lock = threading.Lock() + + UNSTARTED_STATE = 'unstarted' + STARTED_STATE = 'started' + DONE_STATE = 'done' + + _DONE_MARKER = object() + + def __init__(self): + # type: () -> None + self._state = self.UNSTARTED_STATE + self._connections_by_worker_id = collections.defaultdict( + TaskWorkerConnection) + + def get_conn_by_worker_id(self, worker_id): + # type: (str) -> TaskWorkerConnection + with self._lock: + result = self._connections_by_worker_id[worker_id] + return result + + def Control(self, request_iterator, context): + with self._lock: + if self._state == self.DONE_STATE: + return + else: + self._state = self.STARTED_STATE + worker_id = dict(context.invocation_metadata()).get('worker_id') + if not worker_id: + raise RuntimeError('Connection does not have worker id.') + conn = self.get_conn_by_worker_id(worker_id) + conn.set_inputs(request_iterator) + + while True: + to_push = conn.get_req() + if to_push is self._DONE_MARKER: + return + yield to_push + + def done(self): + # type: () -> None + self._state = self.DONE_STATE + + +# ========== +# Data Plane +# ========== +class ProxyGrpcClientDataChannelFactory(DataChannelFactory): + """A factory for ``ProxyGrpcClientDataChannel``. + + No caching behavior here because we are starting each data channel on + different location.""" + + def __init__(self, transmitter_url, credentials=None, worker_id=None): + # type: (str, Optional[str], Optional[str]) -> None + # These two are not private attributes because it was used in + # ``TaskWorkerHandler.execute`` when issuing TaskInstructionRequest + self.transmitter_url = transmitter_url + self.worker_id = worker_id + + self._credentials = credentials + + def create_data_channel(self, remote_grpc_port): + # type: (beam_fn_api_pb2.RemoteGrpcPort) -> ProxyGrpcClientDataChannel + channel_options = [("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1)] + if self._credentials is None: + grpc_channel = GRPCChannelFactory.insecure_channel( + self.transmitter_url, options=channel_options) + else: + grpc_channel = GRPCChannelFactory.secure_channel( + self.transmitter_url, self._credentials, options=channel_options) + return ProxyGrpcClientDataChannel( + remote_grpc_port.api_service_descriptor.url, + beam_task_worker_pb2_grpc.TaskFnDataStub(grpc_channel)) + + def close(self): + # type: () -> None + pass + + +class ProxyGrpcClientDataChannel(_GrpcDataChannel): + """DataChannel wrapping the client side of a TaskFnDataService connection.""" + + def __init__(self, client_url, proxy_stub): + # type: (str, beam_task_worker_pb2_grpc.TaskFnDataStub) -> None + super(ProxyGrpcClientDataChannel, self).__init__() + self.client_url = client_url + self.proxy_stub = proxy_stub + + def input_elements(self, + instruction_id, # type: str + expected_transforms, # type: List[str] + abort_callback=None # type: Optional[Callable[[], bool]] + ): + # type: (...) -> Iterator[beam_fn_api_pb2.Elements.Data] + req = beam_task_worker_pb2.ReceiveRequest( + instruction_id=instruction_id, + client_data_endpoint=self.client_url) + done_transforms = [] + abort_callback = abort_callback or (lambda: False) + + for data in self.proxy_stub.Receive(req): + if self._closed: + raise RuntimeError('Channel closed prematurely.') + if abort_callback(): + return + if self._exc_info: + t, v, tb = self._exc_info + raise_(t, v, tb) + if not data.data and data.transform_id in expected_transforms: + done_transforms.append(data.transform_id) + else: + assert data.transform_id not in done_transforms + yield data + if len(done_transforms) >= len(expected_transforms): + return + + def output_stream(self, instruction_id, transform_id): + # type: (str, str) -> ClosableOutputStream + + def _add_to_send_queue(data): + if data: + self.proxy_stub.Send(beam_task_worker_pb2.SendRequest( + instruction_id=instruction_id, + data=beam_fn_api_pb2.Elements.Data( + instruction_id=instruction_id, + transform_id=transform_id, + data=data), + client_data_endpoint=self.client_url + )) + + def close_callback(data): + _add_to_send_queue(data) + # no need to send empty bytes to signal end of processing here, because + # when the whole bundle finishes, the bundle processor original output + # stream will send that to runner + + return ClosableOutputStream( + close_callback, flush_callback=_add_to_send_queue) + + +class TaskFnDataServicer(beam_task_worker_pb2_grpc.TaskFnDataServicer): + """Implementation of BeamFnDataTransmitServicer for any number of clients.""" + + def __init__(self, + data_store, # type: Mapping[str, List[beam_fn_api_pb2.Elements.Data]] + orig_data_channel_factory, # type: DataChannelFactory + instruction_id # type: str + ): + # type: (...) -> None + self.data_store = data_store + self.orig_data_channel_factory = orig_data_channel_factory + self.orig_instruction_id = instruction_id + self._orig_data_channel = None # type: Optional[ProxyGrpcClientDataChannel] + + def _get_orig_data_channel(self, url): + # type: (str) -> ProxyGrpcClientDataChannel + remote_grpc_port = beam_fn_api_pb2.RemoteGrpcPort( + api_service_descriptor=endpoints_pb2.ApiServiceDescriptor(url=url)) + # the data channel is cached by url + return self.orig_data_channel_factory.create_data_channel(remote_grpc_port) + + def Receive(self, request, context=None): + # type: (...) -> Iterator[beam_fn_api_pb2.Elements.Data]] + data = self.data_store[request.instruction_id] + for elem in data: + yield elem + + def Send(self, request, context=None): + # type: (...) -> beam_task_worker_pb2.SendResponse + if self._orig_data_channel is None: + self._orig_data_channel = self._get_orig_data_channel( + request.client_data_endpoint) + # We need to replace the instruction_id here with the original instruction + # id, not the current one (which is the task worker id) + request.data.instruction_id = self.orig_instruction_id + if request.data.data: + # only send when there's data, because it is signaling the runner side + # worker handler that element of this has ended if it is empty, and we + # want to send that when every task worker handler is finished + self._orig_data_channel._to_send.put(request.data) + return beam_task_worker_pb2.SendResponse() + + +# ===== +# State +# ===== +class TaskStateServicer(FnApiRunner.GrpcStateServicer): + + def __init__(self, state, instruction_id, cache_token): + # type: (CachingStateHandler, str, Optional[str]) -> None + self.instruction_id = instruction_id + self.cache_token = cache_token + super(TaskStateServicer, self).__init__(state) + + def State(self, request_stream, context=None): + # type: (...) -> Iterator[beam_fn_api_pb2.StateResponse] + # CachingStateHandler and GrpcStateHandler context is thread local, so we + # need to set it here for each TaskWorker + self._state._context.process_instruction_id = self.instruction_id + self._state._context.cache_token = self.cache_token + self._state._underlying._context.process_instruction_id = self.instruction_id + + # FIXME: This is not currently properly supporting state caching (currently + # state caching behavior only happens within python SDK, so runners like + # the FlinkRunner won't create the state cache anyways for now) + for request in request_stream: + request_type = request.WhichOneof('request') + + if request_type == 'get': + data, continuation_token = self._state._underlying.get_raw( + request.state_key, request.get.continuation_token) + yield beam_fn_api_pb2.StateResponse( + id=request.id, + get=beam_fn_api_pb2.StateGetResponse( + data=data, continuation_token=continuation_token)) + elif request_type == 'append': + self._state._underlying.append_raw(request.state_key, + request.append.data) + yield beam_fn_api_pb2.StateResponse( + id=request.id, + append=beam_fn_api_pb2.StateAppendResponse()) + elif request_type == 'clear': + self._state._underlying.clear(request.state_key) + yield beam_fn_api_pb2.StateResponse( + id=request.id, + clear=beam_fn_api_pb2.StateClearResponse()) + else: + raise NotImplementedError('Unknown state request: %s' % request_type) + + +class BundleProcessorTaskWorker(object): + """ + The remote task worker that communicates with the SDK worker to do the + actual work of processing bundles. + + The BundleProcessor will detect inputs and see if there is TaskableValue, and + if there is and the BundleProcessor is set to "use task worker", then + BundleProcessorTaskHelper will create a TaskWorkerHandler that this class + communicates with. + + This class creates a BundleProcessor and receives TaskInstructionRequests and + sends back respective responses via the grpc channels connected to the control + endpoint; + """ + + REQUEST_PREFIX = '_request_' + _lock = threading.Lock() + + def __init__(self, worker_id, server_url, credentials=None): + # type: (str, str, Optional[str]) -> None + """Initialize a BundleProcessorTaskWorker. Lives remotely. + + It will create a BundleProcessor with the provide information and process + the requests using the BundleProcessor created. + + Args: + worker_id: the worker id of current task worker + server_url: control service url for the TaskGrpcServer + credentials: credentials to use when creating client + """ + self.worker_id = worker_id + self._credentials = credentials + self._responses = queue.Queue() + self._alive = None # type: Optional[bool] + self._bundle_processor = None # type: Optional[BundleProcessor] + self._exc_info = None + self.stub = self._create_stub(server_url) + + @classmethod + def execute(cls, worker_id, server_url, credentials=None): + # type: (str, str, Optional[str]) -> None + """Instantiate a BundleProcessorTaskWorker and start running. + + If there's error, it will be raised here so it can be reflected to user. + + Args: + worker_id: worker id for the BundleProcessorTaskWorker + server_url: control service url for the TaskGrpcServer + credentials: credentials to use when creating client + """ + self = cls(worker_id, server_url, credentials=credentials) + self.run() + + # raise the error here, so user knows there's a failure and could retry + if self._exc_info: + t, v, tb = self._exc_info + raise_(t, v, tb) + + def _create_stub(self, server_url): + # type: (str) -> beam_task_worker_pb2_grpc.TaskControlStub + """Create the TaskControl client.""" + channel_options = [("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1)] + if self._credentials is None: + channel = GRPCChannelFactory.insecure_channel( + server_url, + options=channel_options) + else: + channel = GRPCChannelFactory.secure_channel(server_url, + self._credentials, + options=channel_options) + + # add instruction_id to grpc channel + channel = grpc.intercept_channel( + channel, + WorkerIdInterceptor(self.worker_id)) + + return beam_task_worker_pb2_grpc.TaskControlStub(channel) + + def do_instruction(self, request): + # type: (beam_task_worker_pb2.TaskInstructionRequest) -> beam_task_worker_pb2.TaskInstructionResponse + """Process the requests with the corresponding method.""" + request_type = request.WhichOneof('request') + if request_type: + return getattr(self, self.REQUEST_PREFIX + request_type)( + getattr(request, request_type)) + else: + raise NotImplementedError + + def run(self): + # type: () -> None + """Start the full running life cycle for a task worker. + + It send TaskWorkerInstructionResponse to TaskWorkerHandler, and wait for + TaskWorkerInstructionRequest. This service is bidirectional. + """ + no_more_work = object() + self._alive = True + + def get_responses(): + while True: + response = self._responses.get() + if response is no_more_work: + return + if response: + yield response + + try: + for request in self.stub.Control(get_responses()): + self._responses.put(self.do_instruction(request)) + finally: + self._alive = False + + self._responses.put(no_more_work) + logging.info('Done consuming work.') + + def _request_create(self, request): + # type: (beam_task_worker_pb2.TaskInstructionRequest) -> beam_task_worker_pb2.TaskInstructionResponse + """Create a BundleProcessor based on the request. + + Should be the first request received from the handler. + """ + from apache_beam.runners.worker.sdk_worker import \ + GrpcStateHandlerFactory + + credentials = None + if request.data_factory.credentials._credentials: + credentials = grpc.ChannelCredentials( + request.data_factory.credentials._credentials) + logging.debug('Credentials: {!r}'.format(credentials)) + + worker_id = request.data_factory.worker_id + transmitter_url = request.data_factory.transmitter_url + state_handler_endpoint = request.state_handler_endpoint + # FIXME: Add support for Caching later + state_factory = GrpcStateHandlerFactory(StateCache(0), credentials) + state_handler = state_factory.create_state_handler( + state_handler_endpoint) + data_channel_factory = ProxyGrpcClientDataChannelFactory( + transmitter_url, credentials, worker_id + ) + + self._bundle_processor = BundleProcessor( + request.process_bundle_descriptor, + state_handler, + data_channel_factory + ) + return beam_task_worker_pb2.TaskInstructionResponse( + create=beam_task_worker_pb2.CreateResponse(), + instruction_id=self.worker_id + ) + + def _request_process_bundle(self, + request # type: beam_task_worker_pb2.TaskInstructionRequest + ): + # type: (...) -> beam_task_worker_pb2.TaskInstructionResponse + """Process bundle using the bundle processor based on the request.""" + error = None + + try: + # FIXME: Update this to use the cache_tokens properly + with self._bundle_processor.state_handler._underlying.process_instruction_id( + self.worker_id): + delayed_applications, require_finalization = \ + self._bundle_processor.process_bundle(self.worker_id, + use_task_worker=False) + except Exception as e: + # we want to propagate the error back to the TaskWorkerHandler, so that + # it will raise `TaskWorkerProcessBundleError` which allows for requeue + # behavior (up until MAX_TASK_WORKER_RETRY number of retries) + error = e.message + self._exc_info = sys.exc_info() + delayed_applications = [] + require_finalization = False + + return beam_task_worker_pb2.TaskInstructionResponse( + process_bundle=beam_task_worker_pb2.ProcessorProcessBundleResponse( + delayed_applications=delayed_applications, + require_finalization=require_finalization), + instruction_id=self.worker_id, + error=error + ) + + def _request_shutdown(self, request): + # type: (beam_task_worker_pb2.TaskInstructionRequest) -> beam_task_worker_pb2.TaskInstructionResponse + """Shutdown the bundleprocessor.""" + error = None + try: + # shut down state handler here because it is not created by the state + # handler factory thus won't be closed automatically + self._bundle_processor.state_handler.done() + self._bundle_processor.shutdown() + except Exception as e: + error = e.message + finally: + return beam_task_worker_pb2.TaskInstructionResponse( + shutdown=beam_task_worker_pb2.ShutdownResponse(), + instruction_id=self.worker_id, + error=error + ) + + +class BundleProcessorTaskHelper(object): + """ + A helper object that is used by a BundleProcessor while processing bundle. + + Delegates TaskableValues to TaskWorkers, if enabled. + + It can process TaskableValue using TaskWorkers if inspected, and kept the + default behavior if specified to not use TaskWorker, or there's no + TaskableValue found in this bundle. + + To utilize TaskWorkers, BundleProcessorTaskHelper will split up the input + bundle into tasks based on the wrapped TaskableValue's payload, and create a + TaskWorkerHandler for each task. + """ + + def __init__(self, instruction_id, wrapped_values): + # type: (str, DefaultDict[str, List[Tuple[Any, bytes]]]) -> None + """Initialize a BundleProcessorTaskHelper object. + + Args: + instruction_id: the instruction_id of the bundle that the + BundleProcessor is processing + wrapped_values: The mapping of transform id to raw and encoded data. + + """ + self.instruction_id = instruction_id + self.wrapped_values = wrapped_values + + def split_taskable_values(self): + # type: () -> Tuple[DefaultDict[str, List[Any]], DefaultDict[str, List[beam_fn_api_pb2.Elements.Data]]] + """Split TaskableValues into tasks and pair it with worker. + + Also put the raw bytes along with worker id for data dispatching by data + plane handler. + """ + # TODO: Come up with solution on how this can be dynamically changed + # could use window + splitted = collections.defaultdict(list) + data_store = collections.defaultdict(list) + worker_count = 0 + for ptransform_id, values in self.wrapped_values.iteritems(): + for decoded, raw in values: + worker_id = 'worker_{}'.format(worker_count) + splitted[worker_id].append(decoded) + data_store[worker_id].append(beam_fn_api_pb2.Elements.Data( + transform_id=ptransform_id, + data=raw, + instruction_id=worker_id + )) + worker_count += 1 + + return splitted, data_store + + def _start_task_grpc_server(self, + max_workers, # type: int + data_store, # type: Mapping[str, List[beam_fn_api_pb2.Elements.Data]] + state_handler, # type: CachingStateHandler + data_channel_factory # type: DataChannelFactory + ): + # type:(...) -> TaskGrpcServer + """Start up TaskGrpcServer. + + Args: + max_workers: number of max worker + data_store: stored data of worker id and the raw and decoded values for + the worker to process as inputs + state_handler: state handler of current BundleProcessor + data_channel_factory: data channel factory of current BundleProcessor + """ + return TaskGrpcServer(state_handler, max_workers, data_store, + data_channel_factory, self.instruction_id) + + @staticmethod + def get_default_task_env(process_bundle_descriptor): + # type:(beam_fn_api_pb2.ProcessBundleDescriptor) -> Optional[Environment] + """Get the current running beam Environment class. + + Used as the default for the task worker. + + Args: + process_bundle_descriptor: the ProcessBundleDescriptor proto + """ + from apache_beam.portability.api import beam_runner_api_pb2 + from apache_beam.runners.portability.fn_api_runner_transforms import \ + PAR_DO_URNS + from apache_beam.transforms.environments import Environment + from apache_beam.utils import proto_utils + + # find a ParDo xform in this stage + pardo = None + for _, xform in process_bundle_descriptor.transforms.iteritems(): + if xform.spec.urn in PAR_DO_URNS: + pardo = xform + break + + if pardo is None: + # don't set the default task env if no ParDo is found + # FIXME: Use the pipeline default env here? + return None + + pardo_payload = proto_utils.parse_Bytes( + pardo.spec.payload, + beam_runner_api_pb2.ParDoPayload) + env_proto = process_bundle_descriptor.environments.get( + pardo_payload.do_fn.environment_id) + + return Environment.from_runner_api(env_proto, None) + + def process_bundle_with_task_workers(self, + state_handler, # type: CachingStateHandler + data_channel_factory, # type: DataChannelFactory + process_bundle_descriptor # type: beam_fn_api_pb2.ProcessBundleDescriptor + ): + # type: (...) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] + """Main entry point for task worker system. + + Starts up a group of TaskWorkerHandlers, dispatches tasks and waits for them + to finish. + + Fails if any TaskWorker exceeds maximum retries. + + Args: + state_handler: state handler of current BundleProcessor + data_channel_factory: data channel factory of current BundleProcessor + process_bundle_descriptor: a description of the stage that this + ``BundleProcessor``is to execute. + """ + default_env = self.get_default_task_env(process_bundle_descriptor) + # start up grpc server + splitted_elements, data_store = self.split_taskable_values() + num_task_workers = len(splitted_elements.items()) + if num_task_workers > MAX_TASK_WORKERS: + logging.warning( + 'Number of element exceeded MAX_TASK_WORKERS ({})'.format( + MAX_TASK_WORKERS)) + num_task_workers = MAX_TASK_WORKERS + server = self._start_task_grpc_server(num_task_workers, data_store, + state_handler, data_channel_factory) + + # create TaskWorkerHandlers + task_worker_handlers = [] + # FIXME: leaving out provision info for now, it should come from + # Environment + provision_info = None + for worker_id, elem in splitted_elements.iteritems(): + taskable_value = get_taskable_value(elem) + # set the env to default env if there is + if taskable_value.env is None and default_env: + taskable_value.env = default_env + + task_worker_handler = TaskWorkerHandler.create( + state_handler, provision_info, server, taskable_value, + credentials=data_channel_factory._credentials, worker_id=worker_id) + task_worker_handlers.append(task_worker_handler) + task_worker_handler.start_worker() + + def _execute(handler): + """ + This is the method that runs in the thread pool representing a working + TaskHandler. + """ + worker_data_channel_factory = ProxyGrpcClientDataChannelFactory( + server.data_address, + credentials=data_channel_factory._credentials, + worker_id=task_worker_handler.worker_id) + counter = 0 + while True: + try: + counter += 1 + return handler.execute(worker_data_channel_factory, + process_bundle_descriptor) + except TaskWorkerProcessBundleError as e: + if counter >= MAX_TASK_WORKER_RETRY: + logging.error('Task Worker has exceeded max retries!') + handler.stop_worker() + raise + # retry if task worker failed to process bundle + handler.reset() + continue + except TaskWorkerTerminatedError as e: + # This error is thrown only when TaskWorkerHandler is terminated + # before it finished processing + logging.warning('TaskWorker terminated prematurely.') + raise + + # start actual processing of splitted bundle + merged_delayed_applications = [] + bundle_require_finalization = False + with futures.ThreadPoolExecutor(max_workers=num_task_workers) as executor: + try: + for delayed_applications, require_finalization in executor.map( + _execute, task_worker_handlers): + + if delayed_applications is None: + raise RuntimeError('Task Worker failed to process task.') + merged_delayed_applications.extend(delayed_applications) + # if any elem requires finalization, set it to True + if not bundle_require_finalization and require_finalization: + bundle_require_finalization = True + except TaskWorkerProcessBundleError: + raise RuntimeError('Task Worker failed to process task.') + except TaskWorkerTerminatedError: + # This error is thrown only when TaskWorkerHandler is terminated before + # it finished processing, this is only possible if + # `TaskWorkerHandler.alive` is manually set to False by custom user + # defined monitoring function, which user would trigger if they want + # to terminate the task; + # In that case, we want to continue on and not hold the whole bundle + # by the tasks that user manually terminated. + pass + + return merged_delayed_applications, bundle_require_finalization + + +@TaskWorkerHandler.register_urn('local') +class LocalTaskWorkerHandler(TaskWorkerHandler): + """TaskWorkerHandler that starts up task worker locally.""" + + # FIXME: create a class-level thread pool to restrict the number of threads + + def start_remote(self): + # type: () -> None + """start a task worker local to the task worker handler.""" + obj = BundleProcessorTaskWorker(self.worker_id, self.control_address, + self.credentials) + run_thread = threading.Thread(target=obj.run) + run_thread.daemon = True + run_thread.start() + + +TaskWorkerHandler.load_plugins() + + +def get_taskable_value(decoded_value): + # type: (Any) -> Optional[TaskableValue] + """Check whether the given value contains taskable value. + + If taskable, return the TaskableValue + + Args: + decoded_value: decoded value from raw input stream + """ + # FIXME: Come up with a solution that's not so specific + if isinstance(decoded_value, (list, tuple, set)): + for val in decoded_value: + result = get_taskable_value(val) + if result: + return result + elif isinstance(decoded_value, TaskableValue): + return decoded_value + return None diff --git a/sdks/python/apache_beam/runners/worker/task_worker_main.py b/sdks/python/apache_beam/runners/worker/task_worker_main.py new file mode 100644 index 000000000000..6cbb2ea405a1 --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/task_worker_main.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""TaskWorker entry point.""" + +from __future__ import absolute_import + +import os +import logging +import sys +import traceback + +from apache_beam.runners.worker.sdk_worker_main import _load_main_session +from apache_beam.runners.worker.task_worker import BundleProcessorTaskWorker +from apache_beam.runners.worker.task_worker import TaskWorkerHandler + +# This module is experimental. No backwards-compatibility guarantees. + + +def main(unused_argv): + """Main entry point for Starting up TaskWorker.""" + + # TODO: Do we want to set up the logging service here? + + # below section is the same as in sdk_worker_main, for loading pickled main + # session if there's any + if 'SEMI_PERSISTENT_DIRECTORY' in os.environ: + semi_persistent_directory = os.environ['SEMI_PERSISTENT_DIRECTORY'] + else: + semi_persistent_directory = None + + logging.info('semi_persistent_directory: %s', semi_persistent_directory) + + try: + _load_main_session(semi_persistent_directory) + except Exception: # pylint: disable=broad-except + exception_details = traceback.format_exc() + logging.error( + 'Could not load main session: %s', exception_details, exc_info=True) + + worker_id = os.environ['TASK_WORKER_ID'] + control_address = os.environ['TASK_WORKER_CONTROL_ADDRESS'] + if 'TASK_WORKER_CREDENTIALS' in os.environ: + credentials = os.environ['TASK_WORKER_CREDENTIALS'] + else: + credentials = None + + TaskWorkerHandler.load_plugins() + + # exception should be handled already by task workers + BundleProcessorTaskWorker.execute(worker_id, control_address, + credentials=credentials) + + +if __name__ == '__main__': + main(sys.argv) diff --git a/sdks/python/apache_beam/runners/worker/task_worker_test.py b/sdks/python/apache_beam/runners/worker/task_worker_test.py new file mode 100644 index 000000000000..0b61aeb11e1d --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/task_worker_test.py @@ -0,0 +1,409 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for apache_beam.runners.worker.task_worker.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import mock +import threading +import unittest +from builtins import range +from collections import defaultdict + +import grpc +from future.utils import raise_ + +from apache_beam.coders import coders, coder_impl +from apache_beam.portability.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.portability.api import beam_task_worker_pb2 +from apache_beam.portability.api import endpoints_pb2 +from apache_beam.runners.worker import task_worker +from apache_beam.runners.worker.data_plane import GrpcClientDataChannelFactory +from apache_beam.runners.worker.sdk_worker import CachingStateHandler +from apache_beam.transforms import window + + +# -- utilities for testing, mocking up test objects +class _MockBundleProcessorTaskWorker(task_worker.BundleProcessorTaskWorker): + """ + A mocked version of BundleProcessorTaskWorker, responsible for recording the + requests it received, and provide response to each request type by a passed + in dictionary as user desired. + """ + def __init__(self, worker_id, server_url, credentials=None, + requestRecorder=None, responsesByRequestType=None): + super(_MockBundleProcessorTaskWorker, self).__init__( + worker_id, + server_url, + credentials=credentials + ) + self.requestRecorder = requestRecorder or [] + self.responsesByRequestType = responsesByRequestType or {} + + def do_instruction(self, request): + # type: (beam_task_worker_pb2.TaskInstructionRequest) -> beam_task_worker_pb2.TaskInstructionResponse + request_type = request.WhichOneof('request') + self.requestRecorder.append(request) + return self.responsesByRequestType.get(request_type) + + +@task_worker.TaskWorkerHandler.register_urn('unittest') +class _MockTaskWorkerHandler(task_worker.TaskWorkerHandler): + """ + Register a mocked version of task handler only used for "unittest"; will start + a ``_MockBundleProcessorTaskWorker`` for each discovered TaskableValue. + + Main difference is that it returns the started task worker object, for easier + testing. + """ + + def __init__(self, + state, # type: CachingStateHandler + provision_info, + # type: Union[beam_provision_api_pb2.ProvisionInfo, ExtendedProvisionInfo] + grpc_server, # type: TaskGrpcServer + environment, # type: Environment + task_payload, # type: Any + credentials=None, # type: Optional[str] + worker_id=None, # type: Optional[str] + responseByRequestType=None + ): + super(_MockTaskWorkerHandler, self).__init__(state, provision_info, + grpc_server, environment, + task_payload, + credentials=credentials, + worker_id=worker_id) + self.responseByRequestType = responseByRequestType + + def start_remote(self): + # type: () -> _MockBundleProcessorTaskWorker + """starts a task worker local to the task worker handler.""" + obj = _MockBundleProcessorTaskWorker( + self.worker_id, + self.control_address, + self.credentials, + responsesByRequestType=self.responseByRequestType) + run_thread = threading.Thread(target=obj.run) + run_thread.daemon = True + run_thread.start() + + return obj + + def start_worker(self): + # type: () -> _MockBundleProcessorTaskWorker + return self.start_remote() + + +class _MockTaskGrpcServer(task_worker.TaskGrpcServer): + """ + Mocked version of TaskGrpcServer, using mocked version of data channel factory + and cache handler. + """ + + def __init__(self, instruction_id, max_workers=1, data_store=None): + dummy_state_handler = _MockCachingStateHandler(None, None) + dummy_data_channel_factory = GrpcClientDataChannelFactory() + + super(_MockTaskGrpcServer, self).__init__(dummy_state_handler, max_workers, + data_store or {}, + dummy_data_channel_factory, + instruction_id) + + +class _MockCachingStateHandler(CachingStateHandler): + """ + Mocked CachingStateHandler, mainly for patching the thread local variable + ``_context`` and create the ``cache_token`` attribute on it. + """ + + def __init__(self, underlying_state, global_state_cache): + self._underlying = underlying_state + self._state_cache = global_state_cache + self._context = threading.local() + + self._context.cache_token = '' + + +class _MockDataInputOperation(object): + """ + A mocked version of DataInputOperation, responsible for recording and decoding + data for testing. + """ + + def __init__(self, coder): + self.coder = coder + self.decoded = [] + self.splitting_lock = threading.Lock() + self.windowed_coder_impl = self.coder.get_impl() + + with self.splitting_lock: + self.index = -1 + self.stop = float('inf') + + def output(self, decoded_value): + self.decoded.append(decoded_value) + + +def prep_responses_by_request_type(worker_id, delayed_applications=(), + require_finalization=False, process_error=None, + shutdown_error=None): + return { + 'create': beam_task_worker_pb2.TaskInstructionResponse( + create=beam_task_worker_pb2.CreateResponse(), + instruction_id=worker_id + ), + 'process_bundle': beam_task_worker_pb2.TaskInstructionResponse( + process_bundle=beam_task_worker_pb2.ProcessorProcessBundleResponse( + delayed_applications=list(delayed_applications), + require_finalization=require_finalization), + instruction_id=worker_id, + error=process_error + ), + 'shutdown': beam_task_worker_pb2.TaskInstructionResponse( + shutdown=beam_task_worker_pb2.ShutdownResponse(), + instruction_id=worker_id, + error=shutdown_error + ) + } + + +def prep_bundle_processor_descriptor(bundle_id): + return beam_fn_api_pb2.ProcessBundleDescriptor( + id='test_bundle_{}'.format(bundle_id), + transforms={ + str(bundle_id): beam_runner_api_pb2.PTransform(unique_name=str(bundle_id)) + }) + + +class TaskWorkerHandlerTest(unittest.TestCase): + + @staticmethod + def _get_task_worker_handler(worker_id, resp_by_type, instruction_id, + max_workers=1, data_store=None): + server = _MockTaskGrpcServer(instruction_id, max_workers=max_workers, + data_store=data_store) + return _MockTaskWorkerHandler(server.state_handler, None, server, None, + None, + worker_id=worker_id, + responseByRequestType=resp_by_type) + + def test_execute_success(self): + """ + Test when a TaskWorkerHandler successfully executed one life cycle. + """ + dummy_process_bundle_descriptor = prep_bundle_processor_descriptor(1) + + worker_id = 'test_task_worker_1' + instruction_id = 'test_instruction_1' + resp_by_type = prep_responses_by_request_type(worker_id) + + test_handler = self._get_task_worker_handler(worker_id, resp_by_type, + instruction_id) + + proxy_data_channel_factory = task_worker.ProxyGrpcClientDataChannelFactory( + test_handler._grpc_server.data_address + ) + + test_worker = test_handler.start_worker() + + try: + delayed, requests = test_handler.execute(proxy_data_channel_factory, + dummy_process_bundle_descriptor) + self.assertEquals(len(delayed), 0) + self.assertEquals(requests, False) + finally: + test_handler._grpc_server.close() + + # check that the requests we received are as expected + expected = [ + beam_task_worker_pb2.TaskInstructionRequest( + instruction_id=worker_id, + create=beam_task_worker_pb2.CreateRequest( + process_bundle_descriptor=dummy_process_bundle_descriptor, + state_handler_endpoint=endpoints_pb2.ApiServiceDescriptor( + url=test_handler._grpc_server.state_address), + data_factory=beam_task_worker_pb2.GrpcClientDataChannelFactory( + transmitter_url=proxy_data_channel_factory.transmitter_url, + worker_id=proxy_data_channel_factory.worker_id, + credentials=proxy_data_channel_factory._credentials + ))), + beam_task_worker_pb2.TaskInstructionRequest( + instruction_id=worker_id, + process_bundle=beam_task_worker_pb2.ProcessorProcessBundleRequest()), + beam_task_worker_pb2.TaskInstructionRequest( + instruction_id=worker_id, + shutdown=beam_task_worker_pb2.ShutdownRequest()) + ] + + self.assertEquals(test_worker.requestRecorder, expected) + + def test_execute_failure(self): + """ + Test when a TaskWorkerHandler fails to process a bundle. + """ + dummy_process_bundle_descriptor = prep_bundle_processor_descriptor(1) + + worker_id = 'test_task_worker_1' + instruction_id = 'test_instruction_1' + resp_by_type = prep_responses_by_request_type(worker_id, process_error='error') + + test_handler = self._get_task_worker_handler(worker_id, resp_by_type, + instruction_id) + proxy_data_channel_factory = task_worker.ProxyGrpcClientDataChannelFactory( + test_handler._grpc_server.data_address + ) + + test_handler.start_worker() + + try: + with self.assertRaises(task_worker.TaskWorkerProcessBundleError): + print(test_handler.execute) + test_handler.execute( + proxy_data_channel_factory, + dummy_process_bundle_descriptor) + + test_handler.stop_worker() + finally: + test_handler._grpc_server.close() + + +class BundleProcessorTaskHelperTest(unittest.TestCase): + + @staticmethod + def _get_test_int_coder(): + return coders.WindowedValueCoder(coders.VarIntCoder(), + coders.GlobalWindowCoder()) + + @staticmethod + def _get_test_pickle_coder(): + return coders.WindowedValueCoder(coders.FastPrimitivesCoder(), + coders.GlobalWindowCoder()) + + @staticmethod + def _prep_elements(elements): + return [window.GlobalWindows.windowed_value(elem) for elem in elements] + + @staticmethod + def _prep_encoded_data(coder, elements, instruction_id, transform_id): + temp_out = coder_impl.create_OutputStream() + raw_bytes = [] + + for elem in elements: + encoded = coder.encode(elem) + raw_bytes.append(encoded) + coder.get_impl().encode_to_stream(elem, temp_out, True) + + data = beam_fn_api_pb2.Elements.Data( + instruction_id=instruction_id, + transform_id=transform_id, + data=temp_out.get() + ) + return raw_bytes, data + + def test_data_split_with_task_worker(self): + """ + Test that input data is split correctly by BundleProcessorTaskHelper. + """ + test_coder = self._get_test_pickle_coder() + mocked_op = _MockDataInputOperation(test_coder) + + test_elems = self._prep_elements( + [task_worker.TaskableValue(i, 'unittest') for i in range(5)]) + instruction_id = 'test_instruction_1' + transform_id = 'test_transform_1' + + test_task_helper = task_worker.BundleProcessorTaskHelper(instruction_id) + raw_bytes, data = self._prep_encoded_data(test_coder, test_elems, + instruction_id, transform_id) + + test_task_helper.process_encoded(mocked_op, data) + self.assertEquals(mocked_op.decoded, []) + expected_wrapped_values = defaultdict(list) + for decode, raw in zip(test_elems, raw_bytes): + expected_wrapped_values[transform_id].append((decode.value, raw)) + self.assertItemsEqual(test_task_helper.wrapped_values, expected_wrapped_values) + + def test_process_normally_without_task_worker(self): + """ + Test that when input data doesn't consists of TaskableValue, it is processed + not using task worker but normally via DataInputOperation's process. + """ + test_coder = self._get_test_int_coder() + mocked_op = _MockDataInputOperation(test_coder) + + test_elems = self._prep_elements(range(5)) + instruction_id = 'test_instruction_2' + transform_id = 'test_transform_1' + + test_task_helper = task_worker.BundleProcessorTaskHelper(instruction_id) + _, data = self._prep_encoded_data(test_coder, test_elems, instruction_id, + transform_id) + + test_task_helper.process_encoded(mocked_op, data) + + # when processed normally, it will use the DataInputOperation to process, + # so it will be recorded in the `decoded` list + self.assertEquals(mocked_op.decoded, test_elems) + + @mock.patch.object(task_worker, 'MAX_TASK_WORKER_RETRY', 2) + @mock.patch('__main__._MockTaskWorkerHandler.execute', + side_effect=task_worker.TaskWorkerProcessBundleError('test')) + @mock.patch('__main__._MockTaskWorkerHandler.start_worker') + @mock.patch('__main__._MockTaskWorkerHandler.stop_worker') + def test_exceed_max_retries(self, unused_mock_stop, unused_mock_start, + mock_execute): + """ + Test the scenario when task worker fails exceed max retries. + """ + test_coder = self._get_test_pickle_coder() + mocked_op = _MockDataInputOperation(test_coder) + + test_coder = self._get_test_pickle_coder() + + test_elems = self._prep_elements( + [task_worker.TaskableValue(i, 'unittest') for i in range(2)] + ) + instruction_id = 'test_instruction_3' + transform_id = 'test_transform_1' + + test_task_helper = task_worker.BundleProcessorTaskHelper(instruction_id) + _, data = self._prep_encoded_data(test_coder, test_elems, instruction_id, + transform_id) + test_task_helper.process_encoded(mocked_op, data) + + dummy_process_bundle_descriptor = prep_bundle_processor_descriptor(1) + dummy_data_channel_factory = GrpcClientDataChannelFactory() + dummy_state_handler = _MockCachingStateHandler(None, None) + + with self.assertRaises(RuntimeError): + test_task_helper.process_bundle_with_task_workers( + dummy_state_handler, + dummy_data_channel_factory, + dummy_process_bundle_descriptor + ) + + # num(elems) * MAX_TASK_WORKER_RETRY + self.assertEquals(mock_execute.call_count, 4) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + unittest.main() From bdb3d13556b330ffb74bec41112d22521664e7e9 Mon Sep 17 00:00:00 2001 From: Viola Lyu Date: Wed, 19 Aug 2020 15:38:55 -0700 Subject: [PATCH 05/15] Rebase modification for bundle processor --- .../runners/worker/bundle_processor.py | 97 +++++++++---------- .../apache_beam/runners/worker/task_worker.py | 24 ++--- 2 files changed, 54 insertions(+), 67 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 2671d8b3f4f7..aa3ae016004c 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -979,25 +979,7 @@ def process_bundle(self, instruction_id, use_task_worker=True): timer_info.output_stream = output_stream self.ops[transform_id].add_timer_info(timer_family_id, timer_info) - task_helper = BundleProcessorTaskHelper(instruction_id, - use_task_worker=use_task_worker) # Process data and timer inputs - for data_channel, expected_inputs in data_channels.items(): - for element in data_channel.input_elements(instruction_id, - expected_inputs): - if isinstance(element, beam_fn_api_pb2.Elements.Timers): - timer_coder_impl = ( - self.timers_info[( - element.transform_id, - element.timer_family_id)].timer_coder_impl) - for timer_data in timer_coder_impl.decode_all(element.timers): - self.ops[element.transform_id].process_timer( - element.timer_family_id, timer_data) - elif isinstance(element, beam_fn_api_pb2.Elements.Data): - input_op = input_op_by_transform_id[element.transform_id] - # decode inputs to inspect if it is wrapped - task_helper.process_encoded(input_op, element) - delayed_applications, requires_finalization = \ self.maybe_process_remotely(data_channels, instruction_id, input_op_by_transform_id, @@ -1050,41 +1032,50 @@ def maybe_process_remotely(self, for data_channel, expected_transforms in data_channels.items(): for data in data_channel.input_elements( instruction_id, expected_transforms): - input_op = input_op_by_transform_id[data.transform_id] - - # process normally if not using task worker - if use_task_worker is False: - input_op.process_encoded(data.data) - continue - - # decode inputs to inspect if it is wrapped - input_stream = coder_impl.create_InputStream(data.data) - # TODO: Come up with a better solution here? - # here we maintain two separate input stream, because `.pos` is not - # accessible on the cython version of InputStream object, so we can't re - # -use the same input stream object and edit the pos to move the read - # handle back - raw_input_stream = coder_impl.create_InputStream(data.data) - - while input_stream.size() > 0: - starting_size = input_stream.size() - decoded_value = input_op.windowed_coder_impl.decode_from_stream( - input_stream, True) - cur_size = input_stream.size() - raw_bytes = raw_input_stream.read(starting_size - cur_size) - # make sure these two stream stays in sync in size - assert raw_input_stream.size() == input_stream.size() - if decoded_value.value and get_taskable_value(decoded_value.value): - # save this to process later in ``process_bundle_with_task_workers`` - wrapped_values[data.transform_id].append( - (get_taskable_value(decoded_value.value), raw_bytes)) - else: - # fallback to process it as normal, trigger receivers to process - with input_op.splitting_lock: - if input_op.index == input_op.stop - 1: - return - input_op.index += 1 - input_op.output(decoded_value) + if isinstance(data, beam_fn_api_pb2.Elements.Timers): + timer_coder_impl = ( + self.timers_info[( + data.transform_id, + data.timer_family_id)].timer_coder_impl) + for timer_data in timer_coder_impl.decode_all(data.timers): + self.ops[data.transform_id].process_timer( + data.timer_family_id, timer_data) + elif isinstance(data, beam_fn_api_pb2.Elements.Data): + input_op = input_op_by_transform_id[data.transform_id] + + # process normally if not using task worker + if use_task_worker is False: + input_op.process_encoded(data.data) + continue + + # decode inputs to inspect if it is wrapped + input_stream = coder_impl.create_InputStream(data.data) + # TODO: Come up with a better solution here? + # here we maintain two separate input stream, because `.pos` is not + # accessible on the cython version of InputStream object, so we can't + # re-use the same input stream object and edit the pos to move the + # read handle back + raw_input_stream = coder_impl.create_InputStream(data.data) + + while input_stream.size() > 0: + starting_size = input_stream.size() + decoded_value = input_op.windowed_coder_impl.decode_from_stream( + input_stream, True) + cur_size = input_stream.size() + raw_bytes = raw_input_stream.read(starting_size - cur_size) + # make sure these two stream stays in sync in size + assert raw_input_stream.size() == input_stream.size() + if decoded_value.value and get_taskable_value(decoded_value.value): + # save this to process later in ``process_bundle_with_task_workers`` + wrapped_values[data.transform_id].append( + (get_taskable_value(decoded_value.value), raw_bytes)) + else: + # fallback to process it as normal, trigger receivers to process + with input_op.splitting_lock: + if input_op.index == input_op.stop - 1: + return + input_op.index += 1 + input_op.output(decoded_value) if wrapped_values: task_helper = BundleProcessorTaskHelper(instruction_id, wrapped_values) diff --git a/sdks/python/apache_beam/runners/worker/task_worker.py b/sdks/python/apache_beam/runners/worker/task_worker.py index 5136e6b851b2..2453e8179e88 100644 --- a/sdks/python/apache_beam/runners/worker/task_worker.py +++ b/sdks/python/apache_beam/runners/worker/task_worker.py @@ -48,11 +48,12 @@ from apache_beam.portability.api import beam_task_worker_pb2 from apache_beam.portability.api import beam_task_worker_pb2_grpc from apache_beam.portability.api import endpoints_pb2 -from apache_beam.runners.portability.fn_api_runner import ControlConnection -from apache_beam.runners.portability.fn_api_runner import ControlFuture +from apache_beam.runners.portability.fn_api_runner.worker_handlers import ControlConnection +from apache_beam.runners.portability.fn_api_runner.worker_handlers import ControlFuture from apache_beam.runners.portability.fn_api_runner import FnApiRunner -from apache_beam.runners.portability.fn_api_runner import GrpcWorkerHandler -from apache_beam.runners.portability.fn_api_runner import WorkerHandler +from apache_beam.runners.portability.fn_api_runner.worker_handlers import GrpcWorkerHandler +from apache_beam.runners.portability.fn_api_runner.worker_handlers import GrpcStateServicer +from apache_beam.runners.portability.fn_api_runner.worker_handlers import WorkerHandler from apache_beam.runners.worker.bundle_processor import BundleProcessor from apache_beam.runners.worker.channel_factory import GRPCChannelFactory from apache_beam.runners.worker.data_plane import _GrpcDataChannel @@ -193,7 +194,7 @@ def load_plugins(): # type: () -> None import entrypoints for name, entry_point in entrypoints.get_group_named( - 'apache_beam_task_workers_plugins').iteritems(): + ENTRY_POINT_NAME).iteritems(): logging.info('Loading entry point: {}'.format(name)) entry_point.load() @@ -372,7 +373,7 @@ def __init__(self, beam_fn_api_pb2_grpc.add_BeamFnStateServicer_to_server( TaskStateServicer(self.state_handler, instruction_id, - state_handler._context.cache_token), + state_handler._context.bundle_cache_token), self.state_server) logging.info('starting control server on port %s', self.control_port) @@ -655,7 +656,7 @@ def Send(self, request, context=None): # ===== # State # ===== -class TaskStateServicer(FnApiRunner.GrpcStateServicer): +class TaskStateServicer(GrpcStateServicer): def __init__(self, state, instruction_id, cache_token): # type: (CachingStateHandler, str, Optional[str]) -> None @@ -983,11 +984,9 @@ def get_default_task_env(process_bundle_descriptor): Args: process_bundle_descriptor: the ProcessBundleDescriptor proto """ - from apache_beam.portability.api import beam_runner_api_pb2 - from apache_beam.runners.portability.fn_api_runner_transforms import \ + from apache_beam.runners.portability.fn_api_runner.translations import \ PAR_DO_URNS from apache_beam.transforms.environments import Environment - from apache_beam.utils import proto_utils # find a ParDo xform in this stage pardo = None @@ -1001,11 +1000,8 @@ def get_default_task_env(process_bundle_descriptor): # FIXME: Use the pipeline default env here? return None - pardo_payload = proto_utils.parse_Bytes( - pardo.spec.payload, - beam_runner_api_pb2.ParDoPayload) env_proto = process_bundle_descriptor.environments.get( - pardo_payload.do_fn.environment_id) + pardo.environment_id) return Environment.from_runner_api(env_proto, None) From 13713b5de236ae0aa99691b35d7dbcf3806058d2 Mon Sep 17 00:00:00 2001 From: Viola Lyu Date: Tue, 1 Sep 2020 18:21:48 -0700 Subject: [PATCH 06/15] Add provision service to task worker, micellaneous updates Update provision descritpor, data channel input_elements --- .../apache_beam/runners/worker/task_worker.py | 122 +++++++++++++++--- 1 file changed, 103 insertions(+), 19 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/task_worker.py b/sdks/python/apache_beam/runners/worker/task_worker.py index 2453e8179e88..e84a6c9c1ddc 100644 --- a/sdks/python/apache_beam/runners/worker/task_worker.py +++ b/sdks/python/apache_beam/runners/worker/task_worker.py @@ -19,6 +19,7 @@ from __future__ import division import collections +import copy import logging import queue import os @@ -45,9 +46,12 @@ from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.portability.api import beam_fn_api_pb2_grpc +from apache_beam.portability.api import beam_provision_api_pb2 +from apache_beam.portability.api import beam_provision_api_pb2_grpc from apache_beam.portability.api import beam_task_worker_pb2 from apache_beam.portability.api import beam_task_worker_pb2_grpc from apache_beam.portability.api import endpoints_pb2 +from apache_beam.runners.portability.fn_api_runner.worker_handlers import BasicProvisionService from apache_beam.runners.portability.fn_api_runner.worker_handlers import ControlConnection from apache_beam.runners.portability.fn_api_runner.worker_handlers import ControlFuture from apache_beam.runners.portability.fn_api_runner import FnApiRunner @@ -57,13 +61,12 @@ from apache_beam.runners.worker.bundle_processor import BundleProcessor from apache_beam.runners.worker.channel_factory import GRPCChannelFactory from apache_beam.runners.worker.data_plane import _GrpcDataChannel -from apache_beam.runners.worker.data_plane import ClosableOutputStream +from apache_beam.runners.worker.data_plane import SizeBasedBufferingClosableOutputStream from apache_beam.runners.worker.data_plane import DataChannelFactory from apache_beam.runners.worker.statecache import StateCache from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor if TYPE_CHECKING: - from apache_beam.portability.api import beam_provision_api_pb2 from apache_beam.runners.portability.fn_api_runner import ExtendedProvisionInfo from apache_beam.runners.worker.data_plane import DataChannelFactory from apache_beam.runners.worker.sdk_worker import CachingStateHandler @@ -161,12 +164,20 @@ def __init__(self, self.worker_id = worker_id self.control_address = self.port_from_worker(self._grpc_server.control_port) + self.provision_address = self.control_address self.logging_address = self.port_from_worker( self.get_port_from_env_var('LOGGING_API_SERVICE_DESCRIPTOR')) self.artifact_address = self.port_from_worker( self.get_port_from_env_var('ARTIFACT_API_SERVICE_DESCRIPTOR')) - self.provision_address = self.port_from_worker( - self.get_port_from_env_var('PROVISION_API_SERVICE_DESCRIPTOR')) + + # modify provision info + modified_provision = copy.copy(provision_info) + modified_provision.control_endpoint.url = self.control_address + modified_provision.logging_endpoint.url = self.logging_address + modified_provision.artifact_endpoint.url = self.artifact_address + with TaskProvisionServicer._lock: + self._grpc_server.provision_handler.provision_by_worker_id[ + self.worker_id] = modified_provision self.control_conn = self._grpc_server.control_handler.get_conn_by_worker_id( self.worker_id) @@ -176,12 +187,14 @@ def __init__(self, self.credentials = credentials self.alive = True - def host_from_worker(self): + @staticmethod + def host_from_worker(): # type: () -> str import socket return socket.getfqdn() - def get_port_from_env_var(self, env_var): + @staticmethod + def get_port_from_env_var(env_var): # type: (str) -> str """Extract the service port for a given environment variable.""" from google.protobuf import text_format @@ -318,7 +331,7 @@ class TaskGrpcServer(object): ``TaskWorker`` and ``TaskWorkerHandler``. Contains three servers: - - a control server hosting ``TaskControlService`` + - a control server hosting ``TaskControlService`` amd `TaskProvisionService` - a data server hosting ``TaskFnDataService`` - a state server hosting ``TaskStateService`` @@ -332,7 +345,8 @@ def __init__(self, max_workers, # type: int data_store, # type: Mapping[str, List[beam_fn_api_pb2.Elements.Data]] data_channel_factory, # type: DataChannelFactory - instruction_id # type: str + instruction_id, # type: str + provision_info, # type: Union[beam_provision_api_pb2.ProvisionInfo, ExtendedProvisionInfo] ): # type: (...) -> None self.state_handler = state_handler @@ -362,8 +376,9 @@ def __init__(self, self.control_handler = TaskControlServicer() beam_task_worker_pb2_grpc.add_TaskControlServicer_to_server( self.control_handler, self.control_server) - # TODO: When we add provision / staging service, it needs to be added to the - # control server too + self.provision_handler = TaskProvisionServicer(provision_info=provision_info) + beam_provision_api_pb2_grpc.add_ProvisionServiceServicer_to_server( + self.provision_handler, self.control_server) self.data_plane_handler = TaskFnDataServicer(data_store, data_channel_factory, @@ -532,6 +547,11 @@ def __init__(self, transmitter_url, credentials=None, worker_id=None): def create_data_channel(self, remote_grpc_port): # type: (beam_fn_api_pb2.RemoteGrpcPort) -> ProxyGrpcClientDataChannel + url = remote_grpc_port.api_service_descriptor.url + return self.create_data_channel_from_url(url) + + def create_data_channel_from_url(self, url): + # type: (str) -> ProxyGrpcClientDataChannel channel_options = [("grpc.max_receive_message_length", -1), ("grpc.max_send_message_length", -1)] if self._credentials is None: @@ -541,7 +561,7 @@ def create_data_channel(self, remote_grpc_port): grpc_channel = GRPCChannelFactory.secure_channel( self.transmitter_url, self._credentials, options=channel_options) return ProxyGrpcClientDataChannel( - remote_grpc_port.api_service_descriptor.url, + url, beam_task_worker_pb2_grpc.TaskFnDataStub(grpc_channel)) def close(self): @@ -564,6 +584,8 @@ def input_elements(self, abort_callback=None # type: Optional[Callable[[], bool]] ): # type: (...) -> Iterator[beam_fn_api_pb2.Elements.Data] + if not expected_transforms: + return req = beam_task_worker_pb2.ReceiveRequest( instruction_id=instruction_id, client_data_endpoint=self.client_url) @@ -587,7 +609,7 @@ def input_elements(self, return def output_stream(self, instruction_id, transform_id): - # type: (str, str) -> ClosableOutputStream + # type: (str, str) -> SizeBasedBufferingClosableOutputStream def _add_to_send_queue(data): if data: @@ -606,7 +628,7 @@ def close_callback(data): # when the whole bundle finishes, the bundle processor original output # stream will send that to runner - return ClosableOutputStream( + return SizeBasedBufferingClosableOutputStream( close_callback, flush_callback=_add_to_send_queue) @@ -700,6 +722,38 @@ def State(self, request_stream, context=None): raise NotImplementedError('Unknown state request: %s' % request_type) +# ===== +# Provision +# ===== +class TaskProvisionServicer(BasicProvisionService): + """ + Provide provision info for remote task workers, provision is static because + for each bundle the provision info is static. + """ + + _lock = threading.Lock() + + def __init__(self, provision_info=None): + # type: (Optional[beam_provision_api_pb2.ProvisionInfo]) -> None + self._provision_info = provision_info + self.provision_by_worker_id = dict() + + def GetProvisionInfo(self, request, context=None): + # type: (...) -> beam_provision_api_pb2.GetProvisionInfoResponse + # if request comes from task worker that can be found, return the modified + # provision info + if context: + worker_id = dict(context.invocation_metadata())['worker_id'] + provision_info = self.provision_by_worker_id.get(worker_id, + self._provision_info) + else: + # fallback to the generic sdk worker version of provision info if not + # found from a cached task worker + provision_info = self._provision_info + + return beam_provision_api_pb2.GetProvisionInfoResponse(info=provision_info) + + class BundleProcessorTaskWorker(object): """ The remote task worker that communicates with the SDK worker to do the @@ -959,7 +1013,8 @@ def _start_task_grpc_server(self, max_workers, # type: int data_store, # type: Mapping[str, List[beam_fn_api_pb2.Elements.Data]] state_handler, # type: CachingStateHandler - data_channel_factory # type: DataChannelFactory + data_channel_factory, # type: DataChannelFactory + provision_info, # type: beam_provision_api_pb2.ProvisionInfo ): # type:(...) -> TaskGrpcServer """Start up TaskGrpcServer. @@ -972,7 +1027,8 @@ def _start_task_grpc_server(self, data_channel_factory: data channel factory of current BundleProcessor """ return TaskGrpcServer(state_handler, max_workers, data_store, - data_channel_factory, self.instruction_id) + data_channel_factory, self.instruction_id, + provision_info) @staticmethod def get_default_task_env(process_bundle_descriptor): @@ -1005,6 +1061,23 @@ def get_default_task_env(process_bundle_descriptor): return Environment.from_runner_api(env_proto, None) + def get_sdk_worker_provision_info(self, server_url): + # type:(str) -> beam_provision_api_pb2.ProvisionInfo + channel_options = [("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1)] + channel = GRPCChannelFactory.insecure_channel(server_url, + options=channel_options) + + worker_id = os.environ['WORKER_ID'] + # add sdk worker id to grpc channel + channel = grpc.intercept_channel(channel, WorkerIdInterceptor(worker_id)) + + provision_stub = beam_provision_api_pb2_grpc.ProvisionServiceStub(channel) + response = provision_stub.GetProvisionInfo( + beam_provision_api_pb2.GetProvisionInfoRequest()) + channel.close() + return response.info + def process_bundle_with_task_workers(self, state_handler, # type: CachingStateHandler data_channel_factory, # type: DataChannelFactory @@ -1024,6 +1097,8 @@ def process_bundle_with_task_workers(self, process_bundle_descriptor: a description of the stage that this ``BundleProcessor``is to execute. """ + from google.protobuf import text_format + default_env = self.get_default_task_env(process_bundle_descriptor) # start up grpc server splitted_elements, data_store = self.split_taskable_values() @@ -1033,14 +1108,23 @@ def process_bundle_with_task_workers(self, 'Number of element exceeded MAX_TASK_WORKERS ({})'.format( MAX_TASK_WORKERS)) num_task_workers = MAX_TASK_WORKERS + + # get sdk worker provision info first + provision_port = TaskWorkerHandler.get_port_from_env_var( + 'PROVISION_API_SERVICE_DESCRIPTOR') + provision_info = self.get_sdk_worker_provision_info('{}:{}'.format( + TaskWorkerHandler.host_from_worker(), provision_port)) + server = self._start_task_grpc_server(num_task_workers, data_store, - state_handler, data_channel_factory) + state_handler, data_channel_factory, + provision_info) + # modify provision api service descriptor to use the new address that we are + # gonna be using (the control address) + os.environ['PROVISION_API_SERVICE_DESCRIPTOR'] = text_format.MessageToString( + endpoints_pb2.ApiServiceDescriptor(url=server.control_address)) # create TaskWorkerHandlers task_worker_handlers = [] - # FIXME: leaving out provision info for now, it should come from - # Environment - provision_info = None for worker_id, elem in splitted_elements.iteritems(): taskable_value = get_taskable_value(elem) # set the env to default env if there is From 141d9acef4e361d6df5e64124eab4ed742175069 Mon Sep 17 00:00:00 2001 From: Viola Lyu Date: Wed, 14 Oct 2020 12:53:53 -0700 Subject: [PATCH 07/15] [Task Worker] update endpoints gathering from sdk env --- .../apache_beam/runners/worker/task_worker.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/task_worker.py b/sdks/python/apache_beam/runners/worker/task_worker.py index e84a6c9c1ddc..cffe5a991cd4 100644 --- a/sdks/python/apache_beam/runners/worker/task_worker.py +++ b/sdks/python/apache_beam/runners/worker/task_worker.py @@ -167,8 +167,17 @@ def __init__(self, self.provision_address = self.control_address self.logging_address = self.port_from_worker( self.get_port_from_env_var('LOGGING_API_SERVICE_DESCRIPTOR')) - self.artifact_address = self.port_from_worker( - self.get_port_from_env_var('ARTIFACT_API_SERVICE_DESCRIPTOR')) + # if we are running from subprocess environment, the aritfact staging + # endpoint is the same as control end point, and the env var won't be + # recorded so use control end point and record that to artifact + try: + self.artifact_address = self.port_from_worker( + self.get_port_from_env_var('ARTIFACT_API_SERVICE_DESCRIPTOR')) + except KeyError: + self.artifact_address = self.port_from_worker( + self.get_port_from_env_var('CONTROL_API_SERVICE_DESCRIPTOR')) + os.environ['ARTIFACT_API_SERVICE_DESCRIPTOR'] = os.environ[ + 'CONTROL_API_SERVICE_DESCRIPTOR'] # modify provision info modified_provision = copy.copy(provision_info) @@ -1110,10 +1119,17 @@ def process_bundle_with_task_workers(self, num_task_workers = MAX_TASK_WORKERS # get sdk worker provision info first - provision_port = TaskWorkerHandler.get_port_from_env_var( - 'PROVISION_API_SERVICE_DESCRIPTOR') - provision_info = self.get_sdk_worker_provision_info('{}:{}'.format( - TaskWorkerHandler.host_from_worker(), provision_port)) + try: + provision_port = TaskWorkerHandler.get_port_from_env_var( + 'PROVISION_API_SERVICE_DESCRIPTOR') + except KeyError: + # if we are in subprocess environment then there won't be any provision + # service so use default provison info + provision_info = beam_provision_api_pb2.ProvisionInfo() + + else: + provision_info = self.get_sdk_worker_provision_info('{}:{}'.format( + TaskWorkerHandler.host_from_worker(), provision_port)) server = self._start_task_grpc_server(num_task_workers, data_store, state_handler, data_channel_factory, From 2beb3e487ef90ff40cd7ea9052a06b32b994d3f1 Mon Sep 17 00:00:00 2001 From: Viola Lyu Date: Wed, 14 Oct 2020 14:51:52 -0700 Subject: [PATCH 08/15] Change task worker into package, add example, add BeamTask xform --- .../__init__.py} | 10 +- .../runners/worker/task_worker/core.py | 169 ++++++++++++++++++ .../runners/worker/task_worker/example.py | 69 +++++++ .../{ => task_worker}/task_worker_main.py | 0 .../{ => task_worker}/task_worker_test.py | 0 sdks/python/build-requirements.txt | 1 + 6 files changed, 244 insertions(+), 5 deletions(-) rename sdks/python/apache_beam/runners/worker/{task_worker.py => task_worker/__init__.py} (99%) create mode 100644 sdks/python/apache_beam/runners/worker/task_worker/core.py create mode 100644 sdks/python/apache_beam/runners/worker/task_worker/example.py rename sdks/python/apache_beam/runners/worker/{ => task_worker}/task_worker_main.py (100%) rename sdks/python/apache_beam/runners/worker/{ => task_worker}/task_worker_test.py (100%) diff --git a/sdks/python/apache_beam/runners/worker/task_worker.py b/sdks/python/apache_beam/runners/worker/task_worker/__init__.py similarity index 99% rename from sdks/python/apache_beam/runners/worker/task_worker.py rename to sdks/python/apache_beam/runners/worker/task_worker/__init__.py index cffe5a991cd4..ac21e4f0ca6d 100644 --- a/sdks/python/apache_beam/runners/worker/task_worker.py +++ b/sdks/python/apache_beam/runners/worker/task_worker/__init__.py @@ -67,7 +67,7 @@ from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor if TYPE_CHECKING: - from apache_beam.runners.portability.fn_api_runner import ExtendedProvisionInfo + from apache_beam.runners.portability.fn_api_runner.fn_runner import ExtendedProvisionInfo from apache_beam.runners.worker.data_plane import DataChannelFactory from apache_beam.runners.worker.sdk_worker import CachingStateHandler from apache_beam.transforms.environments import Environment @@ -216,7 +216,7 @@ def load_plugins(): # type: () -> None import entrypoints for name, entry_point in entrypoints.get_group_named( - ENTRY_POINT_NAME).iteritems(): + ENTRY_POINT_NAME).items(): logging.info('Loading entry point: {}'.format(name)) entry_point.load() @@ -1005,7 +1005,7 @@ def split_taskable_values(self): splitted = collections.defaultdict(list) data_store = collections.defaultdict(list) worker_count = 0 - for ptransform_id, values in self.wrapped_values.iteritems(): + for ptransform_id, values in self.wrapped_values.items(): for decoded, raw in values: worker_id = 'worker_{}'.format(worker_count) splitted[worker_id].append(decoded) @@ -1055,7 +1055,7 @@ def get_default_task_env(process_bundle_descriptor): # find a ParDo xform in this stage pardo = None - for _, xform in process_bundle_descriptor.transforms.iteritems(): + for _, xform in process_bundle_descriptor.transforms.items(): if xform.spec.urn in PAR_DO_URNS: pardo = xform break @@ -1141,7 +1141,7 @@ def process_bundle_with_task_workers(self, # create TaskWorkerHandlers task_worker_handlers = [] - for worker_id, elem in splitted_elements.iteritems(): + for worker_id, elem in splitted_elements.items(): taskable_value = get_taskable_value(elem) # set the env to default env if there is if taskable_value.env is None and default_env: diff --git a/sdks/python/apache_beam/runners/worker/task_worker/core.py b/sdks/python/apache_beam/runners/worker/task_worker/core.py new file mode 100644 index 000000000000..ee8797c2f006 --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/task_worker/core.py @@ -0,0 +1,169 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +TaskWorker core user facing PTransforms that allows packaging elements as +TaskableValue for beam task workers. +""" + +from typing import TYPE_CHECKING + +import apache_beam as beam +from apache_beam.runners.worker.task_worker import TaskableValue + +TASK_WORKER_ENV_TYPES = {'Docker', 'Process'} +TASK_WORKER_SDK_ENTRYPOINT = 'apache_beam.runners.worker.task_worker.task_worker_main' + +if TYPE_CHECKING: + from typing import Any, Optional, Callable, Iterable, Iterator + + +class WrapFn(beam.DoFn): + """ + Wraps the given element into a TaskableValue if there's non-empty task + payload. User can pass in wrapper callable to modify payload per element. + """ + + def process(self, element, urn='local', wrapper=None, env=None, payload=None): + # type: (Any, str, Optional[Callable[[Any, Any], Any]], Optional[beam.transforms.environments.Environment], Optional[Any]) -> Iterator[Any] + """ + Parameters + ---------- + element : Any + urn : str + wrapper : Optional[Callable[[Any, Any], Any]] + env : Optional[beam.transforms.environments.Environment] + payload : Optional[Any] + + Yields + ------ + Any + """ + from apache_beam.runners.worker.task_worker import TaskableValue + + # override payload if given a wrapper function, which will vary per + # element + if wrapper: + payload = wrapper(element, payload) + + if payload: + result = TaskableValue(element, urn, env=env, payload=payload) + else: + result = element + yield result + + +class UnWrapFn(beam.DoFn): + """ + Unwraps the TaskableValue into its original value, so that when constructing + transforms user doesn't need to worry about the element type if it is + taskable or not. + """ + + def process(self, element): + + if isinstance(element, TaskableValue): + yield element.value + else: + yield element + + +class BeamTask(beam.PTransform): + """ + Utility transform that wraps a group of transforms, and makes it a Beam + "Task" that can be delegated to a task worker to run remotely. + + The main structure is like this: + + ( pipe + | Wrap + | Reshuffle + | UnWrap + | User Transform1 + | ... + | User TransformN + | Reshuffle + ) + + The use of reshuffle is to make sure stage fusing doesn't try to fuse the + section we want to run with the inputs of this xform; reason being we need + the start of a stage to get data inputs that are *TaskableValue*, so that + the bundle processor will recognize that and will engage Task Workers. + + We end with a Reshuffle for similar reason, so that the next section of the + pipeline doesn't gets fused with the transforms provided, which would end up + being executed remotely in a remote task worker. + + By default, we use the local task worker, but subclass could specify the + type of task worker to use by specifying the ``urn``, and override the + ``getPayload`` method to return meaningful payloads to that type of task + worker. + """ + + # the urn for the registered task worker handler, default to use local task + # worker + urn = 'local' # type: str + + # the sdk harness entry point + SDK_HARNESS_ENTRY_POINT = TASK_WORKER_SDK_ENTRYPOINT + + def __init__(self, fusedXform, wrapper=None, env=None): + # type: (beam.PTransform, Optional[Callable[[Any, Any], Any]], Optional[beam.transforms.environments.Environment]) -> None + self._wrapper = wrapper + self._env = env + self._fusedXform = fusedXform + + def getPayload(self): + # type: () -> Optional[Any] + """ + Subclass should implement this to generate payload for TaskableValue. + Default to None. + + Returns + ------- + Optional[Any] + """ + return None + + @staticmethod + def _hasTaggedOutputs(xform): + # type: (beam.PTransform) -> bool + """Checks to see if we have tagged output for the given PTransform.""" + if isinstance(xform, beam.core._MultiParDo): + return True + elif isinstance(xform, beam.ptransform._ChainedPTransform) \ + and isinstance(xform._parts[-1], beam.core._MultiParDo): + return True + return False + + def expand(self, pcoll): + # type: (beam.pvalue.PCollection) -> beam.pvalue.PCollection + payload = self.getPayload() + result = ( + pcoll + | 'Wrap' >> beam.ParDo(WrapFn(), urn=self.urn, wrapper=self._wrapper, + env=self._env, payload=payload) + | 'StartStage' >> beam.Reshuffle() + | 'UnWrap' >> beam.ParDo(UnWrapFn()) + | self._fusedXform + ) + if self._hasTaggedOutputs(self._fusedXform): + # for xforms that ended up with tagged outputs, we don't want to + # add reshuffle, because it will be a stage split point already, + # also adding reshuffle would error since we now have a tuple of + # pcollections. + return result + return result | 'EndStage' >> beam.Reshuffle() diff --git a/sdks/python/apache_beam/runners/worker/task_worker/example.py b/sdks/python/apache_beam/runners/worker/task_worker/example.py new file mode 100644 index 000000000000..9e855ecbddbf --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/task_worker/example.py @@ -0,0 +1,69 @@ +""" +Basic graph to test using TaskWorker. +""" +# pytype: skip-file + +from __future__ import absolute_import + +import argparse +import logging +import re + +from past.builtins import unicode + +import apache_beam as beam +from apache_beam.io import ReadFromText +from apache_beam.io import WriteToText +from apache_beam.options.pipeline_options import DirectOptions +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions + +import apache_beam as beam + +from apache_beam.runners.worker.task_worker.core import BeamTask + + +class TestFn(beam.DoFn): + + def process(self, element, side): + from apache_beam.runners.worker.task_worker import TaskableValue + + for s in side: + if isinstance(element, TaskableValue): + value = element.value + else: + value = element + print(value + s) + yield value + s + + +def run(argv=None, save_main_session=True): + """Main entry point; defines and runs the test pipeline.""" + parser = argparse.ArgumentParser() + known_args, pipeline_args = parser.parse_known_args(argv) + + # We use the save_main_session option because one or more DoFn's in this + # workflow rely on global context (e.g., a module imported at module level). + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + pipeline_options.view_as(DirectOptions).direct_running_mode = 'multi_processing' + + # The pipeline will be run on exiting the with block. + with beam.Pipeline(options=pipeline_options) as pipe: + + A = ( + pipe + | 'A' >> beam.Create(range(3)) + ) + + B = ( + pipe + | beam.Create(range(2)) + | BeamTask(beam.ParDo(TestFn(), beam.pvalue.AsList(A)), + wrapper=lambda x, _: x) + ) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/sdks/python/apache_beam/runners/worker/task_worker_main.py b/sdks/python/apache_beam/runners/worker/task_worker/task_worker_main.py similarity index 100% rename from sdks/python/apache_beam/runners/worker/task_worker_main.py rename to sdks/python/apache_beam/runners/worker/task_worker/task_worker_main.py diff --git a/sdks/python/apache_beam/runners/worker/task_worker_test.py b/sdks/python/apache_beam/runners/worker/task_worker/task_worker_test.py similarity index 100% rename from sdks/python/apache_beam/runners/worker/task_worker_test.py rename to sdks/python/apache_beam/runners/worker/task_worker/task_worker_test.py diff --git a/sdks/python/build-requirements.txt b/sdks/python/build-requirements.txt index 1ecd6ec10a50..90d766abcdbe 100644 --- a/sdks/python/build-requirements.txt +++ b/sdks/python/build-requirements.txt @@ -17,3 +17,4 @@ grpcio-tools==1.30.0 future==0.18.2 mypy-protobuf==1.18 +entrypoints==0.3 From 1bd26c9f7e52b1cefa4b370563a10606fa7d6537 Mon Sep 17 00:00:00 2001 From: Sam Bourne Date: Mon, 19 Oct 2020 17:05:05 -0700 Subject: [PATCH 09/15] Reorganize taskworker a bit --- .../apache_beam/examples/taskworker/local.py | 79 ++ .../runners/worker/bundle_processor.py | 4 +- .../runners/worker/task_worker/__init__.py | 1232 ---------------- .../runners/worker/task_worker/core.py | 253 ++-- .../runners/worker/task_worker/example.py | 69 +- .../runners/worker/task_worker/handlers.py | 1255 +++++++++++++++++ .../worker/task_worker/task_worker_main.py | 4 +- sdks/python/build-requirements.txt | 1 - sdks/python/setup.py | 4 + 9 files changed, 1499 insertions(+), 1402 deletions(-) create mode 100644 sdks/python/apache_beam/examples/taskworker/local.py create mode 100644 sdks/python/apache_beam/runners/worker/task_worker/handlers.py diff --git a/sdks/python/apache_beam/examples/taskworker/local.py b/sdks/python/apache_beam/examples/taskworker/local.py new file mode 100644 index 000000000000..4c33c852d7ca --- /dev/null +++ b/sdks/python/apache_beam/examples/taskworker/local.py @@ -0,0 +1,79 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Basic pipeline to test using TaskWorkers. +""" + +from __future__ import absolute_import + +import argparse +import logging + +import apache_beam as beam +from apache_beam.options.pipeline_options import DirectOptions +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions + +from apache_beam.runners.worker.task_worker.core import BeamTask + + +class TestFn(beam.DoFn): + + def process(self, element, side): + from apache_beam.runners.worker.task_worker.handlers import TaskableValue + + for s in side: + if isinstance(element, TaskableValue): + value = element.value + else: + value = element + print(value + s) + yield value + s + + +def run(argv=None, save_main_session=True): + """ + Run a pipeline that executes each element using the local task worker. + """ + parser = argparse.ArgumentParser() + known_args, pipeline_args = parser.parse_known_args(argv) + + # We use the save_main_session option because one or more DoFn's in this + # workflow rely on global context (e.g., a module imported at module level). + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + pipeline_options.view_as(DirectOptions).direct_running_mode = 'multi_processing' + + with beam.Pipeline(options=pipeline_options) as pipe: + + A = ( + pipe + | 'A' >> beam.Create(range(3)) + ) + + B = ( + pipe + | beam.Create(range(2)) + | BeamTask(beam.ParDo(TestFn(), beam.pvalue.AsList(A)), + wrapper=lambda x, _: x) + ) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index aa3ae016004c..568307c0bc41 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -1024,8 +1024,8 @@ def maybe_process_remotely(self, TaskableValue detected from input of this bundle) and task worker is allowed to be used. """ - from apache_beam.runners.worker.task_worker import BundleProcessorTaskHelper - from apache_beam.runners.worker.task_worker import get_taskable_value + from apache_beam.runners.worker.task_worker.handlers import BundleProcessorTaskHelper + from apache_beam.runners.worker.task_worker.handlers import get_taskable_value wrapped_values = collections.defaultdict(list) # type: DefaultDict[str, List[Tuple[Any, bytes]]] diff --git a/sdks/python/apache_beam/runners/worker/task_worker/__init__.py b/sdks/python/apache_beam/runners/worker/task_worker/__init__.py index ac21e4f0ca6d..6569e3fe5de4 100644 --- a/sdks/python/apache_beam/runners/worker/task_worker/__init__.py +++ b/sdks/python/apache_beam/runners/worker/task_worker/__init__.py @@ -16,1235 +16,3 @@ # from __future__ import absolute_import -from __future__ import division - -import collections -import copy -import logging -import queue -import os -import sys -import threading -from builtins import object -from concurrent import futures -from typing import TYPE_CHECKING -from typing import Any -from typing import Callable -from typing import DefaultDict -from typing import Dict -from typing import Iterator -from typing import List -from typing import Mapping -from typing import Optional -from typing import Set -from typing import Tuple -from typing import Type -from typing import Union - -import grpc -from future.utils import raise_ - -from apache_beam.portability.api import beam_fn_api_pb2 -from apache_beam.portability.api import beam_fn_api_pb2_grpc -from apache_beam.portability.api import beam_provision_api_pb2 -from apache_beam.portability.api import beam_provision_api_pb2_grpc -from apache_beam.portability.api import beam_task_worker_pb2 -from apache_beam.portability.api import beam_task_worker_pb2_grpc -from apache_beam.portability.api import endpoints_pb2 -from apache_beam.runners.portability.fn_api_runner.worker_handlers import BasicProvisionService -from apache_beam.runners.portability.fn_api_runner.worker_handlers import ControlConnection -from apache_beam.runners.portability.fn_api_runner.worker_handlers import ControlFuture -from apache_beam.runners.portability.fn_api_runner import FnApiRunner -from apache_beam.runners.portability.fn_api_runner.worker_handlers import GrpcWorkerHandler -from apache_beam.runners.portability.fn_api_runner.worker_handlers import GrpcStateServicer -from apache_beam.runners.portability.fn_api_runner.worker_handlers import WorkerHandler -from apache_beam.runners.worker.bundle_processor import BundleProcessor -from apache_beam.runners.worker.channel_factory import GRPCChannelFactory -from apache_beam.runners.worker.data_plane import _GrpcDataChannel -from apache_beam.runners.worker.data_plane import SizeBasedBufferingClosableOutputStream -from apache_beam.runners.worker.data_plane import DataChannelFactory -from apache_beam.runners.worker.statecache import StateCache -from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor - -if TYPE_CHECKING: - from apache_beam.runners.portability.fn_api_runner.fn_runner import ExtendedProvisionInfo - from apache_beam.runners.worker.data_plane import DataChannelFactory - from apache_beam.runners.worker.sdk_worker import CachingStateHandler - from apache_beam.transforms.environments import Environment - -ENTRY_POINT_NAME = 'apache_beam_task_workers_plugins' -MAX_TASK_WORKERS = 300 -MAX_TASK_WORKER_RETRY = 10 - -Taskable = Union[ - 'TaskableValue', - List['TaskableValue'], - Tuple['TaskableValue', ...], - Set['TaskableValue']] - - -class TaskableValue(object): - """ - Value that can be distributed to TaskWorkers as tasks. - - Has the original value, and TaskProperties that specifies how the task will be - generated.""" - - def __init__(self, - value, # type: Any - urn, # type: str - env=None, # type: Optional[Environment] - payload=None # type: Optional[Any] - ): - # type: (...) -> None - """ - Args: - value: The wrapped element - urn: id of the task worker handler - env : Environment for the task to be run in - payload : Payload containing settings for the task worker handler - """ - self.value = value - self.urn = urn - self.env = env - self.payload = payload - - -class TaskWorkerProcessBundleError(Exception): - """ - Error thrown when TaskWorker fails to process_bundle. - - Errors encountered when task worker is processing bundle can be retried, up - till max retries defined by ``MAX_TASK_WORKER_RETRY``. - """ - - -class TaskWorkerTerminatedError(Exception): - """ - Error thrown when TaskWorker terminated before it finished working. - - Custom TaskWorkerHandlers can choose to terminate a task but not - affect the whole bundle by setting ``TaskWorkerHandler.alive`` to False, - which will cause this error to be thrown. - """ - - -class TaskWorkerHandler(GrpcWorkerHandler): - """ - Abstract base class for TaskWorkerHandler for a task worker, - - A TaskWorkerHandler is created for each TaskableValue. - - Subclasses must override ``start_remote`` to modify how remote task worker is - started, and register task properties type when defining a subclass. - """ - - _known_urns = {} # type: Dict[str, Type[TaskWorkerHandler]] - - def __init__(self, - state, # type: CachingStateHandler - provision_info, # type: Union[beam_provision_api_pb2.ProvisionInfo, ExtendedProvisionInfo] - grpc_server, # type: TaskGrpcServer - environment, # type: Environment - task_payload, # type: Any - credentials=None, # type: Optional[str] - worker_id=None # type: Optional[str] - ): - # type: (...) -> None - self._grpc_server = grpc_server - - # we are manually doing init instead of calling GrpcWorkerHandler's init - # because we want to override worker_id and we don't want extra - # ControlConnection to be established - WorkerHandler.__init__(self, grpc_server.control_handler, - grpc_server.data_plane_handler, state, - provision_info) - # override worker_id if provided - if worker_id: - self.worker_id = worker_id - - self.control_address = self.port_from_worker(self._grpc_server.control_port) - self.provision_address = self.control_address - self.logging_address = self.port_from_worker( - self.get_port_from_env_var('LOGGING_API_SERVICE_DESCRIPTOR')) - # if we are running from subprocess environment, the aritfact staging - # endpoint is the same as control end point, and the env var won't be - # recorded so use control end point and record that to artifact - try: - self.artifact_address = self.port_from_worker( - self.get_port_from_env_var('ARTIFACT_API_SERVICE_DESCRIPTOR')) - except KeyError: - self.artifact_address = self.port_from_worker( - self.get_port_from_env_var('CONTROL_API_SERVICE_DESCRIPTOR')) - os.environ['ARTIFACT_API_SERVICE_DESCRIPTOR'] = os.environ[ - 'CONTROL_API_SERVICE_DESCRIPTOR'] - - # modify provision info - modified_provision = copy.copy(provision_info) - modified_provision.control_endpoint.url = self.control_address - modified_provision.logging_endpoint.url = self.logging_address - modified_provision.artifact_endpoint.url = self.artifact_address - with TaskProvisionServicer._lock: - self._grpc_server.provision_handler.provision_by_worker_id[ - self.worker_id] = modified_provision - - self.control_conn = self._grpc_server.control_handler.get_conn_by_worker_id( - self.worker_id) - - self.environment = environment - self.task_payload = task_payload - self.credentials = credentials - self.alive = True - - @staticmethod - def host_from_worker(): - # type: () -> str - import socket - return socket.getfqdn() - - @staticmethod - def get_port_from_env_var(env_var): - # type: (str) -> str - """Extract the service port for a given environment variable.""" - from google.protobuf import text_format - endpoint = endpoints_pb2.ApiServiceDescriptor() - text_format.Merge(os.environ[env_var], endpoint) - return endpoint.url.split(':')[-1] - - @staticmethod - def load_plugins(): - # type: () -> None - import entrypoints - for name, entry_point in entrypoints.get_group_named( - ENTRY_POINT_NAME).items(): - logging.info('Loading entry point: {}'.format(name)) - entry_point.load() - - @classmethod - def register_urn(cls, urn, constructor=None): - def register(constructor): - cls._known_urns[urn] = constructor - return constructor - if constructor: - return register(constructor) - else: - return register - - @classmethod - def create(cls, - state, # type: CachingStateHandler - provision_info, # type: Union[beam_provision_api_pb2.ProvisionInfo, ExtendedProvisionInfo] - grpc_server, # type: TaskGrpcServer - taskable_value, # type: TaskableValue - credentials=None, # type: Optional[str] - worker_id=None # type: Optional[str] - ): - # type: (...) -> TaskWorkerHandler - constructor = cls._known_urns[taskable_value.urn] - return constructor(state, provision_info, grpc_server, - taskable_value.env, taskable_value.payload, - credentials=credentials, worker_id=worker_id) - - def start_worker(self): - # type: () -> None - self.start_remote() - - def start_remote(self): - # type: () -> None - """Start up a remote TaskWorker to process the current element. - - Subclass should implement this.""" - raise NotImplementedError - - def stop_worker(self): - # type: () -> None - # send shutdown request - future = self.control_conn.push( - beam_task_worker_pb2.TaskInstructionRequest( - instruction_id=self.worker_id, - shutdown=beam_task_worker_pb2.ShutdownRequest())) - response = future.get() - if response.error: - logging.warning('Error stopping worker: {}'.format(self.worker_id)) - - # close control conn after stop worker - self.control_conn.close() - - def _get_future(self, future, interval=0.5): - # type: (ControlFuture, float) -> beam_task_worker_pb2.TaskInstructionResponse - result = None - while self.alive: - result = future.get(timeout=interval) - if result: - break - - # if the handler is not alive, meaning task worker is stopped before - # finishing processing, raise ``TaskWorkerTerminatedError`` - if result is None: - raise TaskWorkerTerminatedError() - - return result - - def execute(self, - data_channel_factory, # type: ProxyGrpcClientDataChannelFactory - process_bundle_descriptor # type: beam_fn_api_pb2.ProcessBundleDescriptor - ): - # type: (...) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] - """Main entry point of the task execution cycle of a ``TaskWorkerHandler``. - - It will first issue a create request, and wait for the remote bundle - processor to be created; Then it will issue the process bundle request, and - wait for the result. If there's error occurred when processing bundle, - ``TaskWorkerProcessBundleError`` will be raised. - """ - # wait for remote bundle processor to be created - create_future = self.control_conn.push( - beam_task_worker_pb2.TaskInstructionRequest( - instruction_id=self.worker_id, - create=beam_task_worker_pb2.CreateRequest( - process_bundle_descriptor=process_bundle_descriptor, - state_handler_endpoint=endpoints_pb2.ApiServiceDescriptor( - url=self._grpc_server.state_address), - data_factory=beam_task_worker_pb2.GrpcClientDataChannelFactory( - credentials=self.credentials, - worker_id=data_channel_factory.worker_id, - transmitter_url=data_channel_factory.transmitter_url)))) - self._get_future(create_future) - - # process bundle - process_future = self.control_conn.push( - beam_task_worker_pb2.TaskInstructionRequest( - instruction_id=self.worker_id, - process_bundle=beam_task_worker_pb2.ProcessorProcessBundleRequest()) - ) - response = self._get_future(process_future) - if response.error: - # raise here so this task can be retried - raise TaskWorkerProcessBundleError() - else: - delayed_applications = response.process_bundle.delayed_applications - require_finalization = response.process_bundle.require_finalization - - self.stop_worker() - return delayed_applications, require_finalization - - def reset(self): - # type: () -> None - """This is used to retry a failed task.""" - self.control_conn.reset() - - -class TaskGrpcServer(object): - """ - A collection of grpc servicers that handle communication between a - ``TaskWorker`` and ``TaskWorkerHandler``. - - Contains three servers: - - a control server hosting ``TaskControlService`` amd `TaskProvisionService` - - a data server hosting ``TaskFnDataService`` - - a state server hosting ``TaskStateService`` - - This is shared by all TaskWorkerHandlers generated by one bundle. - """ - - _DEFAULT_SHUTDOWN_TIMEOUT_SECS = 5 - - def __init__(self, - state_handler, # type: CachingStateHandler - max_workers, # type: int - data_store, # type: Mapping[str, List[beam_fn_api_pb2.Elements.Data]] - data_channel_factory, # type: DataChannelFactory - instruction_id, # type: str - provision_info, # type: Union[beam_provision_api_pb2.ProvisionInfo, ExtendedProvisionInfo] - ): - # type: (...) -> None - self.state_handler = state_handler - self.max_workers = max_workers - self.control_server = grpc.server( - futures.ThreadPoolExecutor(max_workers=self.max_workers)) - self.control_port = self.control_server.add_insecure_port('[::]:0') - self.control_address = '%s:%s' % (self.get_host_name(), self.control_port) - - # Options to have no limits (-1) on the size of the messages - # received or sent over the data plane. The actual buffer size - # is controlled in a layer above. - no_max_message_sizes = [("grpc.max_receive_message_length", -1), - ("grpc.max_send_message_length", -1)] - self.data_server = grpc.server( - futures.ThreadPoolExecutor(max_workers=self.max_workers), - options=no_max_message_sizes) - self.data_port = self.data_server.add_insecure_port('[::]:0') - self.data_address = '%s:%s' % (self.get_host_name(), self.data_port) - - self.state_server = grpc.server( - futures.ThreadPoolExecutor(max_workers=self.max_workers), - options=no_max_message_sizes) - self.state_port = self.state_server.add_insecure_port('[::]:0') - self.state_address = '%s:%s' % (self.get_host_name(), self.state_port) - - self.control_handler = TaskControlServicer() - beam_task_worker_pb2_grpc.add_TaskControlServicer_to_server( - self.control_handler, self.control_server) - self.provision_handler = TaskProvisionServicer(provision_info=provision_info) - beam_provision_api_pb2_grpc.add_ProvisionServiceServicer_to_server( - self.provision_handler, self.control_server) - - self.data_plane_handler = TaskFnDataServicer(data_store, - data_channel_factory, - instruction_id) - beam_task_worker_pb2_grpc.add_TaskFnDataServicer_to_server( - self.data_plane_handler, self.data_server) - - beam_fn_api_pb2_grpc.add_BeamFnStateServicer_to_server( - TaskStateServicer(self.state_handler, instruction_id, - state_handler._context.bundle_cache_token), - self.state_server) - - logging.info('starting control server on port %s', self.control_port) - logging.info('starting data server on port %s', self.data_port) - logging.info('starting state server on port %s', self.state_port) - self.state_server.start() - self.data_server.start() - self.control_server.start() - - @staticmethod - def get_host_name(): - # type: () -> str - import socket - return socket.getfqdn() - - def close(self): - # type: () -> None - self.control_handler.done() - to_wait = [ - self.control_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS), - self.data_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS), - self.state_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS), - ] - for w in to_wait: - w.wait() - - -# ============= -# Control Plane -# ============= -class TaskWorkerConnection(ControlConnection): - """The control connection between a TaskWorker and a TaskWorkerHandler. - - TaskWorkerHandler push InstructionRequests to _push_queue, and receives - InstructionResponses from TaskControlServicer. - """ - - _lock = threading.Lock() - - def __init__(self): - self._push_queue = queue.Queue() - self._input = None - self._futures_by_id = {} # type: Dict[Any, ControlFuture] - self._read_thread = threading.Thread( - name='bundle_processor_control_read', target=self._read) - self._state = TaskControlServicer.UNSTARTED_STATE - # marks current TaskConnection as in a state of retrying after failure - self._retrying = False - - def _read(self): - # type: () -> None - for data in self._input: - self._futures_by_id.pop(data.WhichOneof('response')).set(data) - - def push(self, - req # type: Union[TaskControlServicer._DONE_MARKER, beam_task_worker_pb2.TaskInstructionRequest] - ): - # type: (...) -> Optional[ControlFuture] - if req == TaskControlServicer._DONE_MARKER: - self._push_queue.put(req) - return None - if not req.instruction_id: - raise RuntimeError( - 'TaskInstructionRequest has to have instruction id!') - future = ControlFuture(req.instruction_id) - self._futures_by_id[req.WhichOneof('request')] = future - self._push_queue.put(req) - return future - - def set_inputs(self, input): - with TaskWorkerConnection._lock: - if self._input and not self._retrying: - raise RuntimeError('input is already set.') - self._input = input - self._read_thread.start() - self._state = TaskControlServicer.STARTED_STATE - self._retrying = False - - def close(self): - # type: () -> None - with TaskWorkerConnection._lock: - if self._state == TaskControlServicer.STARTED_STATE: - self.push(TaskControlServicer._DONE_MARKER) - self._read_thread.join() - self._state = TaskControlServicer.DONE_STATE - - def reset(self): - # type: () -> None - self.close() - self.__init__() - self._retrying = True - - -class TaskControlServicer(beam_task_worker_pb2_grpc.TaskControlServicer): - - _lock = threading.Lock() - - UNSTARTED_STATE = 'unstarted' - STARTED_STATE = 'started' - DONE_STATE = 'done' - - _DONE_MARKER = object() - - def __init__(self): - # type: () -> None - self._state = self.UNSTARTED_STATE - self._connections_by_worker_id = collections.defaultdict( - TaskWorkerConnection) - - def get_conn_by_worker_id(self, worker_id): - # type: (str) -> TaskWorkerConnection - with self._lock: - result = self._connections_by_worker_id[worker_id] - return result - - def Control(self, request_iterator, context): - with self._lock: - if self._state == self.DONE_STATE: - return - else: - self._state = self.STARTED_STATE - worker_id = dict(context.invocation_metadata()).get('worker_id') - if not worker_id: - raise RuntimeError('Connection does not have worker id.') - conn = self.get_conn_by_worker_id(worker_id) - conn.set_inputs(request_iterator) - - while True: - to_push = conn.get_req() - if to_push is self._DONE_MARKER: - return - yield to_push - - def done(self): - # type: () -> None - self._state = self.DONE_STATE - - -# ========== -# Data Plane -# ========== -class ProxyGrpcClientDataChannelFactory(DataChannelFactory): - """A factory for ``ProxyGrpcClientDataChannel``. - - No caching behavior here because we are starting each data channel on - different location.""" - - def __init__(self, transmitter_url, credentials=None, worker_id=None): - # type: (str, Optional[str], Optional[str]) -> None - # These two are not private attributes because it was used in - # ``TaskWorkerHandler.execute`` when issuing TaskInstructionRequest - self.transmitter_url = transmitter_url - self.worker_id = worker_id - - self._credentials = credentials - - def create_data_channel(self, remote_grpc_port): - # type: (beam_fn_api_pb2.RemoteGrpcPort) -> ProxyGrpcClientDataChannel - url = remote_grpc_port.api_service_descriptor.url - return self.create_data_channel_from_url(url) - - def create_data_channel_from_url(self, url): - # type: (str) -> ProxyGrpcClientDataChannel - channel_options = [("grpc.max_receive_message_length", -1), - ("grpc.max_send_message_length", -1)] - if self._credentials is None: - grpc_channel = GRPCChannelFactory.insecure_channel( - self.transmitter_url, options=channel_options) - else: - grpc_channel = GRPCChannelFactory.secure_channel( - self.transmitter_url, self._credentials, options=channel_options) - return ProxyGrpcClientDataChannel( - url, - beam_task_worker_pb2_grpc.TaskFnDataStub(grpc_channel)) - - def close(self): - # type: () -> None - pass - - -class ProxyGrpcClientDataChannel(_GrpcDataChannel): - """DataChannel wrapping the client side of a TaskFnDataService connection.""" - - def __init__(self, client_url, proxy_stub): - # type: (str, beam_task_worker_pb2_grpc.TaskFnDataStub) -> None - super(ProxyGrpcClientDataChannel, self).__init__() - self.client_url = client_url - self.proxy_stub = proxy_stub - - def input_elements(self, - instruction_id, # type: str - expected_transforms, # type: List[str] - abort_callback=None # type: Optional[Callable[[], bool]] - ): - # type: (...) -> Iterator[beam_fn_api_pb2.Elements.Data] - if not expected_transforms: - return - req = beam_task_worker_pb2.ReceiveRequest( - instruction_id=instruction_id, - client_data_endpoint=self.client_url) - done_transforms = [] - abort_callback = abort_callback or (lambda: False) - - for data in self.proxy_stub.Receive(req): - if self._closed: - raise RuntimeError('Channel closed prematurely.') - if abort_callback(): - return - if self._exc_info: - t, v, tb = self._exc_info - raise_(t, v, tb) - if not data.data and data.transform_id in expected_transforms: - done_transforms.append(data.transform_id) - else: - assert data.transform_id not in done_transforms - yield data - if len(done_transforms) >= len(expected_transforms): - return - - def output_stream(self, instruction_id, transform_id): - # type: (str, str) -> SizeBasedBufferingClosableOutputStream - - def _add_to_send_queue(data): - if data: - self.proxy_stub.Send(beam_task_worker_pb2.SendRequest( - instruction_id=instruction_id, - data=beam_fn_api_pb2.Elements.Data( - instruction_id=instruction_id, - transform_id=transform_id, - data=data), - client_data_endpoint=self.client_url - )) - - def close_callback(data): - _add_to_send_queue(data) - # no need to send empty bytes to signal end of processing here, because - # when the whole bundle finishes, the bundle processor original output - # stream will send that to runner - - return SizeBasedBufferingClosableOutputStream( - close_callback, flush_callback=_add_to_send_queue) - - -class TaskFnDataServicer(beam_task_worker_pb2_grpc.TaskFnDataServicer): - """Implementation of BeamFnDataTransmitServicer for any number of clients.""" - - def __init__(self, - data_store, # type: Mapping[str, List[beam_fn_api_pb2.Elements.Data]] - orig_data_channel_factory, # type: DataChannelFactory - instruction_id # type: str - ): - # type: (...) -> None - self.data_store = data_store - self.orig_data_channel_factory = orig_data_channel_factory - self.orig_instruction_id = instruction_id - self._orig_data_channel = None # type: Optional[ProxyGrpcClientDataChannel] - - def _get_orig_data_channel(self, url): - # type: (str) -> ProxyGrpcClientDataChannel - remote_grpc_port = beam_fn_api_pb2.RemoteGrpcPort( - api_service_descriptor=endpoints_pb2.ApiServiceDescriptor(url=url)) - # the data channel is cached by url - return self.orig_data_channel_factory.create_data_channel(remote_grpc_port) - - def Receive(self, request, context=None): - # type: (...) -> Iterator[beam_fn_api_pb2.Elements.Data]] - data = self.data_store[request.instruction_id] - for elem in data: - yield elem - - def Send(self, request, context=None): - # type: (...) -> beam_task_worker_pb2.SendResponse - if self._orig_data_channel is None: - self._orig_data_channel = self._get_orig_data_channel( - request.client_data_endpoint) - # We need to replace the instruction_id here with the original instruction - # id, not the current one (which is the task worker id) - request.data.instruction_id = self.orig_instruction_id - if request.data.data: - # only send when there's data, because it is signaling the runner side - # worker handler that element of this has ended if it is empty, and we - # want to send that when every task worker handler is finished - self._orig_data_channel._to_send.put(request.data) - return beam_task_worker_pb2.SendResponse() - - -# ===== -# State -# ===== -class TaskStateServicer(GrpcStateServicer): - - def __init__(self, state, instruction_id, cache_token): - # type: (CachingStateHandler, str, Optional[str]) -> None - self.instruction_id = instruction_id - self.cache_token = cache_token - super(TaskStateServicer, self).__init__(state) - - def State(self, request_stream, context=None): - # type: (...) -> Iterator[beam_fn_api_pb2.StateResponse] - # CachingStateHandler and GrpcStateHandler context is thread local, so we - # need to set it here for each TaskWorker - self._state._context.process_instruction_id = self.instruction_id - self._state._context.cache_token = self.cache_token - self._state._underlying._context.process_instruction_id = self.instruction_id - - # FIXME: This is not currently properly supporting state caching (currently - # state caching behavior only happens within python SDK, so runners like - # the FlinkRunner won't create the state cache anyways for now) - for request in request_stream: - request_type = request.WhichOneof('request') - - if request_type == 'get': - data, continuation_token = self._state._underlying.get_raw( - request.state_key, request.get.continuation_token) - yield beam_fn_api_pb2.StateResponse( - id=request.id, - get=beam_fn_api_pb2.StateGetResponse( - data=data, continuation_token=continuation_token)) - elif request_type == 'append': - self._state._underlying.append_raw(request.state_key, - request.append.data) - yield beam_fn_api_pb2.StateResponse( - id=request.id, - append=beam_fn_api_pb2.StateAppendResponse()) - elif request_type == 'clear': - self._state._underlying.clear(request.state_key) - yield beam_fn_api_pb2.StateResponse( - id=request.id, - clear=beam_fn_api_pb2.StateClearResponse()) - else: - raise NotImplementedError('Unknown state request: %s' % request_type) - - -# ===== -# Provision -# ===== -class TaskProvisionServicer(BasicProvisionService): - """ - Provide provision info for remote task workers, provision is static because - for each bundle the provision info is static. - """ - - _lock = threading.Lock() - - def __init__(self, provision_info=None): - # type: (Optional[beam_provision_api_pb2.ProvisionInfo]) -> None - self._provision_info = provision_info - self.provision_by_worker_id = dict() - - def GetProvisionInfo(self, request, context=None): - # type: (...) -> beam_provision_api_pb2.GetProvisionInfoResponse - # if request comes from task worker that can be found, return the modified - # provision info - if context: - worker_id = dict(context.invocation_metadata())['worker_id'] - provision_info = self.provision_by_worker_id.get(worker_id, - self._provision_info) - else: - # fallback to the generic sdk worker version of provision info if not - # found from a cached task worker - provision_info = self._provision_info - - return beam_provision_api_pb2.GetProvisionInfoResponse(info=provision_info) - - -class BundleProcessorTaskWorker(object): - """ - The remote task worker that communicates with the SDK worker to do the - actual work of processing bundles. - - The BundleProcessor will detect inputs and see if there is TaskableValue, and - if there is and the BundleProcessor is set to "use task worker", then - BundleProcessorTaskHelper will create a TaskWorkerHandler that this class - communicates with. - - This class creates a BundleProcessor and receives TaskInstructionRequests and - sends back respective responses via the grpc channels connected to the control - endpoint; - """ - - REQUEST_PREFIX = '_request_' - _lock = threading.Lock() - - def __init__(self, worker_id, server_url, credentials=None): - # type: (str, str, Optional[str]) -> None - """Initialize a BundleProcessorTaskWorker. Lives remotely. - - It will create a BundleProcessor with the provide information and process - the requests using the BundleProcessor created. - - Args: - worker_id: the worker id of current task worker - server_url: control service url for the TaskGrpcServer - credentials: credentials to use when creating client - """ - self.worker_id = worker_id - self._credentials = credentials - self._responses = queue.Queue() - self._alive = None # type: Optional[bool] - self._bundle_processor = None # type: Optional[BundleProcessor] - self._exc_info = None - self.stub = self._create_stub(server_url) - - @classmethod - def execute(cls, worker_id, server_url, credentials=None): - # type: (str, str, Optional[str]) -> None - """Instantiate a BundleProcessorTaskWorker and start running. - - If there's error, it will be raised here so it can be reflected to user. - - Args: - worker_id: worker id for the BundleProcessorTaskWorker - server_url: control service url for the TaskGrpcServer - credentials: credentials to use when creating client - """ - self = cls(worker_id, server_url, credentials=credentials) - self.run() - - # raise the error here, so user knows there's a failure and could retry - if self._exc_info: - t, v, tb = self._exc_info - raise_(t, v, tb) - - def _create_stub(self, server_url): - # type: (str) -> beam_task_worker_pb2_grpc.TaskControlStub - """Create the TaskControl client.""" - channel_options = [("grpc.max_receive_message_length", -1), - ("grpc.max_send_message_length", -1)] - if self._credentials is None: - channel = GRPCChannelFactory.insecure_channel( - server_url, - options=channel_options) - else: - channel = GRPCChannelFactory.secure_channel(server_url, - self._credentials, - options=channel_options) - - # add instruction_id to grpc channel - channel = grpc.intercept_channel( - channel, - WorkerIdInterceptor(self.worker_id)) - - return beam_task_worker_pb2_grpc.TaskControlStub(channel) - - def do_instruction(self, request): - # type: (beam_task_worker_pb2.TaskInstructionRequest) -> beam_task_worker_pb2.TaskInstructionResponse - """Process the requests with the corresponding method.""" - request_type = request.WhichOneof('request') - if request_type: - return getattr(self, self.REQUEST_PREFIX + request_type)( - getattr(request, request_type)) - else: - raise NotImplementedError - - def run(self): - # type: () -> None - """Start the full running life cycle for a task worker. - - It send TaskWorkerInstructionResponse to TaskWorkerHandler, and wait for - TaskWorkerInstructionRequest. This service is bidirectional. - """ - no_more_work = object() - self._alive = True - - def get_responses(): - while True: - response = self._responses.get() - if response is no_more_work: - return - if response: - yield response - - try: - for request in self.stub.Control(get_responses()): - self._responses.put(self.do_instruction(request)) - finally: - self._alive = False - - self._responses.put(no_more_work) - logging.info('Done consuming work.') - - def _request_create(self, request): - # type: (beam_task_worker_pb2.TaskInstructionRequest) -> beam_task_worker_pb2.TaskInstructionResponse - """Create a BundleProcessor based on the request. - - Should be the first request received from the handler. - """ - from apache_beam.runners.worker.sdk_worker import \ - GrpcStateHandlerFactory - - credentials = None - if request.data_factory.credentials._credentials: - credentials = grpc.ChannelCredentials( - request.data_factory.credentials._credentials) - logging.debug('Credentials: {!r}'.format(credentials)) - - worker_id = request.data_factory.worker_id - transmitter_url = request.data_factory.transmitter_url - state_handler_endpoint = request.state_handler_endpoint - # FIXME: Add support for Caching later - state_factory = GrpcStateHandlerFactory(StateCache(0), credentials) - state_handler = state_factory.create_state_handler( - state_handler_endpoint) - data_channel_factory = ProxyGrpcClientDataChannelFactory( - transmitter_url, credentials, worker_id - ) - - self._bundle_processor = BundleProcessor( - request.process_bundle_descriptor, - state_handler, - data_channel_factory - ) - return beam_task_worker_pb2.TaskInstructionResponse( - create=beam_task_worker_pb2.CreateResponse(), - instruction_id=self.worker_id - ) - - def _request_process_bundle(self, - request # type: beam_task_worker_pb2.TaskInstructionRequest - ): - # type: (...) -> beam_task_worker_pb2.TaskInstructionResponse - """Process bundle using the bundle processor based on the request.""" - error = None - - try: - # FIXME: Update this to use the cache_tokens properly - with self._bundle_processor.state_handler._underlying.process_instruction_id( - self.worker_id): - delayed_applications, require_finalization = \ - self._bundle_processor.process_bundle(self.worker_id, - use_task_worker=False) - except Exception as e: - # we want to propagate the error back to the TaskWorkerHandler, so that - # it will raise `TaskWorkerProcessBundleError` which allows for requeue - # behavior (up until MAX_TASK_WORKER_RETRY number of retries) - error = e.message - self._exc_info = sys.exc_info() - delayed_applications = [] - require_finalization = False - - return beam_task_worker_pb2.TaskInstructionResponse( - process_bundle=beam_task_worker_pb2.ProcessorProcessBundleResponse( - delayed_applications=delayed_applications, - require_finalization=require_finalization), - instruction_id=self.worker_id, - error=error - ) - - def _request_shutdown(self, request): - # type: (beam_task_worker_pb2.TaskInstructionRequest) -> beam_task_worker_pb2.TaskInstructionResponse - """Shutdown the bundleprocessor.""" - error = None - try: - # shut down state handler here because it is not created by the state - # handler factory thus won't be closed automatically - self._bundle_processor.state_handler.done() - self._bundle_processor.shutdown() - except Exception as e: - error = e.message - finally: - return beam_task_worker_pb2.TaskInstructionResponse( - shutdown=beam_task_worker_pb2.ShutdownResponse(), - instruction_id=self.worker_id, - error=error - ) - - -class BundleProcessorTaskHelper(object): - """ - A helper object that is used by a BundleProcessor while processing bundle. - - Delegates TaskableValues to TaskWorkers, if enabled. - - It can process TaskableValue using TaskWorkers if inspected, and kept the - default behavior if specified to not use TaskWorker, or there's no - TaskableValue found in this bundle. - - To utilize TaskWorkers, BundleProcessorTaskHelper will split up the input - bundle into tasks based on the wrapped TaskableValue's payload, and create a - TaskWorkerHandler for each task. - """ - - def __init__(self, instruction_id, wrapped_values): - # type: (str, DefaultDict[str, List[Tuple[Any, bytes]]]) -> None - """Initialize a BundleProcessorTaskHelper object. - - Args: - instruction_id: the instruction_id of the bundle that the - BundleProcessor is processing - wrapped_values: The mapping of transform id to raw and encoded data. - - """ - self.instruction_id = instruction_id - self.wrapped_values = wrapped_values - - def split_taskable_values(self): - # type: () -> Tuple[DefaultDict[str, List[Any]], DefaultDict[str, List[beam_fn_api_pb2.Elements.Data]]] - """Split TaskableValues into tasks and pair it with worker. - - Also put the raw bytes along with worker id for data dispatching by data - plane handler. - """ - # TODO: Come up with solution on how this can be dynamically changed - # could use window - splitted = collections.defaultdict(list) - data_store = collections.defaultdict(list) - worker_count = 0 - for ptransform_id, values in self.wrapped_values.items(): - for decoded, raw in values: - worker_id = 'worker_{}'.format(worker_count) - splitted[worker_id].append(decoded) - data_store[worker_id].append(beam_fn_api_pb2.Elements.Data( - transform_id=ptransform_id, - data=raw, - instruction_id=worker_id - )) - worker_count += 1 - - return splitted, data_store - - def _start_task_grpc_server(self, - max_workers, # type: int - data_store, # type: Mapping[str, List[beam_fn_api_pb2.Elements.Data]] - state_handler, # type: CachingStateHandler - data_channel_factory, # type: DataChannelFactory - provision_info, # type: beam_provision_api_pb2.ProvisionInfo - ): - # type:(...) -> TaskGrpcServer - """Start up TaskGrpcServer. - - Args: - max_workers: number of max worker - data_store: stored data of worker id and the raw and decoded values for - the worker to process as inputs - state_handler: state handler of current BundleProcessor - data_channel_factory: data channel factory of current BundleProcessor - """ - return TaskGrpcServer(state_handler, max_workers, data_store, - data_channel_factory, self.instruction_id, - provision_info) - - @staticmethod - def get_default_task_env(process_bundle_descriptor): - # type:(beam_fn_api_pb2.ProcessBundleDescriptor) -> Optional[Environment] - """Get the current running beam Environment class. - - Used as the default for the task worker. - - Args: - process_bundle_descriptor: the ProcessBundleDescriptor proto - """ - from apache_beam.runners.portability.fn_api_runner.translations import \ - PAR_DO_URNS - from apache_beam.transforms.environments import Environment - - # find a ParDo xform in this stage - pardo = None - for _, xform in process_bundle_descriptor.transforms.items(): - if xform.spec.urn in PAR_DO_URNS: - pardo = xform - break - - if pardo is None: - # don't set the default task env if no ParDo is found - # FIXME: Use the pipeline default env here? - return None - - env_proto = process_bundle_descriptor.environments.get( - pardo.environment_id) - - return Environment.from_runner_api(env_proto, None) - - def get_sdk_worker_provision_info(self, server_url): - # type:(str) -> beam_provision_api_pb2.ProvisionInfo - channel_options = [("grpc.max_receive_message_length", -1), - ("grpc.max_send_message_length", -1)] - channel = GRPCChannelFactory.insecure_channel(server_url, - options=channel_options) - - worker_id = os.environ['WORKER_ID'] - # add sdk worker id to grpc channel - channel = grpc.intercept_channel(channel, WorkerIdInterceptor(worker_id)) - - provision_stub = beam_provision_api_pb2_grpc.ProvisionServiceStub(channel) - response = provision_stub.GetProvisionInfo( - beam_provision_api_pb2.GetProvisionInfoRequest()) - channel.close() - return response.info - - def process_bundle_with_task_workers(self, - state_handler, # type: CachingStateHandler - data_channel_factory, # type: DataChannelFactory - process_bundle_descriptor # type: beam_fn_api_pb2.ProcessBundleDescriptor - ): - # type: (...) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] - """Main entry point for task worker system. - - Starts up a group of TaskWorkerHandlers, dispatches tasks and waits for them - to finish. - - Fails if any TaskWorker exceeds maximum retries. - - Args: - state_handler: state handler of current BundleProcessor - data_channel_factory: data channel factory of current BundleProcessor - process_bundle_descriptor: a description of the stage that this - ``BundleProcessor``is to execute. - """ - from google.protobuf import text_format - - default_env = self.get_default_task_env(process_bundle_descriptor) - # start up grpc server - splitted_elements, data_store = self.split_taskable_values() - num_task_workers = len(splitted_elements.items()) - if num_task_workers > MAX_TASK_WORKERS: - logging.warning( - 'Number of element exceeded MAX_TASK_WORKERS ({})'.format( - MAX_TASK_WORKERS)) - num_task_workers = MAX_TASK_WORKERS - - # get sdk worker provision info first - try: - provision_port = TaskWorkerHandler.get_port_from_env_var( - 'PROVISION_API_SERVICE_DESCRIPTOR') - except KeyError: - # if we are in subprocess environment then there won't be any provision - # service so use default provison info - provision_info = beam_provision_api_pb2.ProvisionInfo() - - else: - provision_info = self.get_sdk_worker_provision_info('{}:{}'.format( - TaskWorkerHandler.host_from_worker(), provision_port)) - - server = self._start_task_grpc_server(num_task_workers, data_store, - state_handler, data_channel_factory, - provision_info) - # modify provision api service descriptor to use the new address that we are - # gonna be using (the control address) - os.environ['PROVISION_API_SERVICE_DESCRIPTOR'] = text_format.MessageToString( - endpoints_pb2.ApiServiceDescriptor(url=server.control_address)) - - # create TaskWorkerHandlers - task_worker_handlers = [] - for worker_id, elem in splitted_elements.items(): - taskable_value = get_taskable_value(elem) - # set the env to default env if there is - if taskable_value.env is None and default_env: - taskable_value.env = default_env - - task_worker_handler = TaskWorkerHandler.create( - state_handler, provision_info, server, taskable_value, - credentials=data_channel_factory._credentials, worker_id=worker_id) - task_worker_handlers.append(task_worker_handler) - task_worker_handler.start_worker() - - def _execute(handler): - """ - This is the method that runs in the thread pool representing a working - TaskHandler. - """ - worker_data_channel_factory = ProxyGrpcClientDataChannelFactory( - server.data_address, - credentials=data_channel_factory._credentials, - worker_id=task_worker_handler.worker_id) - counter = 0 - while True: - try: - counter += 1 - return handler.execute(worker_data_channel_factory, - process_bundle_descriptor) - except TaskWorkerProcessBundleError as e: - if counter >= MAX_TASK_WORKER_RETRY: - logging.error('Task Worker has exceeded max retries!') - handler.stop_worker() - raise - # retry if task worker failed to process bundle - handler.reset() - continue - except TaskWorkerTerminatedError as e: - # This error is thrown only when TaskWorkerHandler is terminated - # before it finished processing - logging.warning('TaskWorker terminated prematurely.') - raise - - # start actual processing of splitted bundle - merged_delayed_applications = [] - bundle_require_finalization = False - with futures.ThreadPoolExecutor(max_workers=num_task_workers) as executor: - try: - for delayed_applications, require_finalization in executor.map( - _execute, task_worker_handlers): - - if delayed_applications is None: - raise RuntimeError('Task Worker failed to process task.') - merged_delayed_applications.extend(delayed_applications) - # if any elem requires finalization, set it to True - if not bundle_require_finalization and require_finalization: - bundle_require_finalization = True - except TaskWorkerProcessBundleError: - raise RuntimeError('Task Worker failed to process task.') - except TaskWorkerTerminatedError: - # This error is thrown only when TaskWorkerHandler is terminated before - # it finished processing, this is only possible if - # `TaskWorkerHandler.alive` is manually set to False by custom user - # defined monitoring function, which user would trigger if they want - # to terminate the task; - # In that case, we want to continue on and not hold the whole bundle - # by the tasks that user manually terminated. - pass - - return merged_delayed_applications, bundle_require_finalization - - -@TaskWorkerHandler.register_urn('local') -class LocalTaskWorkerHandler(TaskWorkerHandler): - """TaskWorkerHandler that starts up task worker locally.""" - - # FIXME: create a class-level thread pool to restrict the number of threads - - def start_remote(self): - # type: () -> None - """start a task worker local to the task worker handler.""" - obj = BundleProcessorTaskWorker(self.worker_id, self.control_address, - self.credentials) - run_thread = threading.Thread(target=obj.run) - run_thread.daemon = True - run_thread.start() - - -TaskWorkerHandler.load_plugins() - - -def get_taskable_value(decoded_value): - # type: (Any) -> Optional[TaskableValue] - """Check whether the given value contains taskable value. - - If taskable, return the TaskableValue - - Args: - decoded_value: decoded value from raw input stream - """ - # FIXME: Come up with a solution that's not so specific - if isinstance(decoded_value, (list, tuple, set)): - for val in decoded_value: - result = get_taskable_value(val) - if result: - return result - elif isinstance(decoded_value, TaskableValue): - return decoded_value - return None diff --git a/sdks/python/apache_beam/runners/worker/task_worker/core.py b/sdks/python/apache_beam/runners/worker/task_worker/core.py index ee8797c2f006..2b23d9996679 100644 --- a/sdks/python/apache_beam/runners/worker/task_worker/core.py +++ b/sdks/python/apache_beam/runners/worker/task_worker/core.py @@ -6,7 +6,7 @@ # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -22,148 +22,147 @@ from typing import TYPE_CHECKING import apache_beam as beam -from apache_beam.runners.worker.task_worker import TaskableValue +from apache_beam.runners.worker.task_worker.handlers import TaskableValue + +if TYPE_CHECKING: + from typing import Any, Optional, Callable, Iterator + from apache_beam.transforms.environments import Environment + TASK_WORKER_ENV_TYPES = {'Docker', 'Process'} TASK_WORKER_SDK_ENTRYPOINT = 'apache_beam.runners.worker.task_worker.task_worker_main' -if TYPE_CHECKING: - from typing import Any, Optional, Callable, Iterable, Iterator - class WrapFn(beam.DoFn): + """ + Wraps the given element into a TaskableValue if there's non-empty task + payload. User can pass in wrapper callable to modify payload per element. + """ + + def process( + self, + element, # type: Any + urn='local', # type: str + wrapper=None, # type: Optional[Callable[[Any, Any], Any]] + env=None, # type: Optional[Environment] + payload=None # type: Optional[Any] + ): + # type: (...) -> Iterator[Any] """ - Wraps the given element into a TaskableValue if there's non-empty task - payload. User can pass in wrapper callable to modify payload per element. + Args: + element: any ptransform element + urn: id of the task worker handler + wrapper: optional callable which can be used to modify a payload + per-element. + env : Environment for the task to be run in + payload : Payload containing settings for the task worker handler """ + # override payload if given a wrapper function, which will vary per + # element + if wrapper: + payload = wrapper(element, payload) - def process(self, element, urn='local', wrapper=None, env=None, payload=None): - # type: (Any, str, Optional[Callable[[Any, Any], Any]], Optional[beam.transforms.environments.Environment], Optional[Any]) -> Iterator[Any] - """ - Parameters - ---------- - element : Any - urn : str - wrapper : Optional[Callable[[Any, Any], Any]] - env : Optional[beam.transforms.environments.Environment] - payload : Optional[Any] - - Yields - ------ - Any - """ - from apache_beam.runners.worker.task_worker import TaskableValue - - # override payload if given a wrapper function, which will vary per - # element - if wrapper: - payload = wrapper(element, payload) - - if payload: - result = TaskableValue(element, urn, env=env, payload=payload) - else: - result = element - yield result + if payload: + result = TaskableValue(element, urn, env=env, payload=payload) + else: + result = element + yield result class UnWrapFn(beam.DoFn): - """ - Unwraps the TaskableValue into its original value, so that when constructing - transforms user doesn't need to worry about the element type if it is - taskable or not. - """ + """ + Unwraps the TaskableValue into its original value, so that when constructing + transforms user doesn't need to worry about the element type if it is + taskable or not. + """ - def process(self, element): + def process(self, element): - if isinstance(element, TaskableValue): - yield element.value - else: - yield element + if isinstance(element, TaskableValue): + yield element.value + else: + yield element class BeamTask(beam.PTransform): + """ + Utility transform that wraps a group of transforms, and makes it a Beam + "Task" that can be delegated to a task worker to run remotely. + + The main structure is like this: + + ( pipe + | Wrap + | Reshuffle + | UnWrap + | User Transform1 + | ... + | User TransformN + | Reshuffle + ) + + The use of reshuffle is to make sure stage fusing doesn't try to fuse the + section we want to run with the inputs of this xform; reason being we need + the start of a stage to get data inputs that are *TaskableValue*, so that + the bundle processor will recognize that and will engage Task Workers. + + We end with a Reshuffle for similar reason, so that the next section of the + pipeline doesn't gets fused with the transforms provided, which would end up + being executed remotely in a remote task worker. + + By default, we use the local task worker, but subclass could specify the + type of task worker to use by specifying the ``urn``, and override the + ``getPayload`` method to return meaningful payloads to that type of task + worker. + """ + + # the urn for the registered task worker handler, default to use local task + # worker + urn = 'local' # type: str + + # the sdk harness entry point + SDK_HARNESS_ENTRY_POINT = TASK_WORKER_SDK_ENTRYPOINT + + def __init__(self, fusedXform, wrapper=None, env=None): + # type: (beam.PTransform, Optional[Callable[[Any, Any], Any]], Optional[beam.transforms.environments.Environment]) -> None + self._wrapper = wrapper + self._env = env + self._fusedXform = fusedXform + + def getPayload(self): + # type: () -> Optional[Any] """ - Utility transform that wraps a group of transforms, and makes it a Beam - "Task" that can be delegated to a task worker to run remotely. - - The main structure is like this: - - ( pipe - | Wrap - | Reshuffle - | UnWrap - | User Transform1 - | ... - | User TransformN - | Reshuffle - ) - - The use of reshuffle is to make sure stage fusing doesn't try to fuse the - section we want to run with the inputs of this xform; reason being we need - the start of a stage to get data inputs that are *TaskableValue*, so that - the bundle processor will recognize that and will engage Task Workers. - - We end with a Reshuffle for similar reason, so that the next section of the - pipeline doesn't gets fused with the transforms provided, which would end up - being executed remotely in a remote task worker. - - By default, we use the local task worker, but subclass could specify the - type of task worker to use by specifying the ``urn``, and override the - ``getPayload`` method to return meaningful payloads to that type of task - worker. + Subclass should implement this to generate payload for TaskableValue. + Default to None. """ - - # the urn for the registered task worker handler, default to use local task - # worker - urn = 'local' # type: str - - # the sdk harness entry point - SDK_HARNESS_ENTRY_POINT = TASK_WORKER_SDK_ENTRYPOINT - - def __init__(self, fusedXform, wrapper=None, env=None): - # type: (beam.PTransform, Optional[Callable[[Any, Any], Any]], Optional[beam.transforms.environments.Environment]) -> None - self._wrapper = wrapper - self._env = env - self._fusedXform = fusedXform - - def getPayload(self): - # type: () -> Optional[Any] - """ - Subclass should implement this to generate payload for TaskableValue. - Default to None. - - Returns - ------- - Optional[Any] - """ - return None - - @staticmethod - def _hasTaggedOutputs(xform): - # type: (beam.PTransform) -> bool - """Checks to see if we have tagged output for the given PTransform.""" - if isinstance(xform, beam.core._MultiParDo): - return True - elif isinstance(xform, beam.ptransform._ChainedPTransform) \ - and isinstance(xform._parts[-1], beam.core._MultiParDo): - return True - return False - - def expand(self, pcoll): - # type: (beam.pvalue.PCollection) -> beam.pvalue.PCollection - payload = self.getPayload() - result = ( - pcoll - | 'Wrap' >> beam.ParDo(WrapFn(), urn=self.urn, wrapper=self._wrapper, - env=self._env, payload=payload) - | 'StartStage' >> beam.Reshuffle() - | 'UnWrap' >> beam.ParDo(UnWrapFn()) - | self._fusedXform - ) - if self._hasTaggedOutputs(self._fusedXform): - # for xforms that ended up with tagged outputs, we don't want to - # add reshuffle, because it will be a stage split point already, - # also adding reshuffle would error since we now have a tuple of - # pcollections. - return result - return result | 'EndStage' >> beam.Reshuffle() + return None + + @staticmethod + def _hasTaggedOutputs(xform): + # type: (beam.PTransform) -> bool + """Checks to see if we have tagged output for the given PTransform.""" + if isinstance(xform, beam.core._MultiParDo): + return True + elif isinstance(xform, beam.ptransform._ChainedPTransform) \ + and isinstance(xform._parts[-1], beam.core._MultiParDo): + return True + return False + + def expand(self, pcoll): + # type: (beam.pvalue.PCollection) -> beam.pvalue.PCollection + payload = self.getPayload() + result = ( + pcoll + | 'Wrap' >> beam.ParDo(WrapFn(), urn=self.urn, wrapper=self._wrapper, + env=self._env, payload=payload) + | 'StartStage' >> beam.Reshuffle() + | 'UnWrap' >> beam.ParDo(UnWrapFn()) + | self._fusedXform + ) + if self._hasTaggedOutputs(self._fusedXform): + # for xforms that ended up with tagged outputs, we don't want to + # add reshuffle, because it will be a stage split point already, + # also adding reshuffle would error since we now have a tuple of + # pcollections. + return result + return result | 'EndStage' >> beam.Reshuffle() diff --git a/sdks/python/apache_beam/runners/worker/task_worker/example.py b/sdks/python/apache_beam/runners/worker/task_worker/example.py index 9e855ecbddbf..99a9c2986b8e 100644 --- a/sdks/python/apache_beam/runners/worker/task_worker/example.py +++ b/sdks/python/apache_beam/runners/worker/task_worker/example.py @@ -7,63 +7,56 @@ import argparse import logging -import re - -from past.builtins import unicode import apache_beam as beam -from apache_beam.io import ReadFromText -from apache_beam.io import WriteToText from apache_beam.options.pipeline_options import DirectOptions from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import SetupOptions -import apache_beam as beam - from apache_beam.runners.worker.task_worker.core import BeamTask class TestFn(beam.DoFn): - def process(self, element, side): - from apache_beam.runners.worker.task_worker import TaskableValue + def process(self, element, side): + from apache_beam.runners.worker.task_worker.handlers import TaskableValue - for s in side: - if isinstance(element, TaskableValue): - value = element.value - else: - value = element - print(value + s) - yield value + s + for s in side: + if isinstance(element, TaskableValue): + value = element.value + else: + value = element + print(value + s) + yield value + s def run(argv=None, save_main_session=True): - """Main entry point; defines and runs the test pipeline.""" - parser = argparse.ArgumentParser() - known_args, pipeline_args = parser.parse_known_args(argv) + """Main entry point; defines and runs the test pipeline.""" + parser = argparse.ArgumentParser() + known_args, pipeline_args = parser.parse_known_args(argv) - # We use the save_main_session option because one or more DoFn's in this - # workflow rely on global context (e.g., a module imported at module level). - pipeline_options = PipelineOptions(pipeline_args) - pipeline_options.view_as(SetupOptions).save_main_session = save_main_session - pipeline_options.view_as(DirectOptions).direct_running_mode = 'multi_processing' + # We use the save_main_session option because one or more DoFn's in this + # workflow rely on global context (e.g., a module imported at module level). + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + pipeline_options.view_as(DirectOptions).direct_running_mode = 'multi_processing' - # The pipeline will be run on exiting the with block. - with beam.Pipeline(options=pipeline_options) as pipe: + # The pipeline will be run on exiting the with block. + with beam.Pipeline(options=pipeline_options) as pipe: - A = ( - pipe - | 'A' >> beam.Create(range(3)) - ) + A = ( + pipe + | 'A' >> beam.Create(range(3)) + ) - B = ( - pipe - | beam.Create(range(2)) - | BeamTask(beam.ParDo(TestFn(), beam.pvalue.AsList(A)), - wrapper=lambda x, _: x) - ) + B = ( + pipe + | beam.Create(range(2)) + | BeamTask(beam.ParDo(TestFn(), beam.pvalue.AsList(A)), + wrapper=lambda x, _: x) + ) if __name__ == '__main__': - logging.getLogger().setLevel(logging.INFO) - run() + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/sdks/python/apache_beam/runners/worker/task_worker/handlers.py b/sdks/python/apache_beam/runners/worker/task_worker/handlers.py new file mode 100644 index 000000000000..d6c7918547d9 --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/task_worker/handlers.py @@ -0,0 +1,1255 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import absolute_import +from __future__ import division + +import collections +import copy +import logging +import queue +import os +import sys +import threading +from builtins import object +from concurrent import futures +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import DefaultDict +from typing import Dict +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import Union + +import grpc +from future.utils import raise_ + +from apache_beam.portability.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_fn_api_pb2_grpc +from apache_beam.portability.api import beam_provision_api_pb2 +from apache_beam.portability.api import beam_provision_api_pb2_grpc +from apache_beam.portability.api import beam_task_worker_pb2 +from apache_beam.portability.api import beam_task_worker_pb2_grpc +from apache_beam.portability.api import endpoints_pb2 +from apache_beam.runners.portability.fn_api_runner.worker_handlers import BasicProvisionService +from apache_beam.runners.portability.fn_api_runner.worker_handlers import ControlConnection +from apache_beam.runners.portability.fn_api_runner.worker_handlers import ControlFuture +from apache_beam.runners.portability.fn_api_runner import FnApiRunner +from apache_beam.runners.portability.fn_api_runner.worker_handlers import GrpcWorkerHandler +from apache_beam.runners.portability.fn_api_runner.worker_handlers import GrpcStateServicer +from apache_beam.runners.portability.fn_api_runner.worker_handlers import WorkerHandler +from apache_beam.runners.worker.bundle_processor import BundleProcessor +from apache_beam.runners.worker.channel_factory import GRPCChannelFactory +from apache_beam.runners.worker.data_plane import _GrpcDataChannel +from apache_beam.runners.worker.data_plane import SizeBasedBufferingClosableOutputStream +from apache_beam.runners.worker.data_plane import DataChannelFactory +from apache_beam.runners.worker.statecache import StateCache +from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor + +if TYPE_CHECKING: + from apache_beam.runners.portability.fn_api_runner.fn_runner import ExtendedProvisionInfo + from apache_beam.runners.worker.data_plane import DataChannelFactory + from apache_beam.runners.worker.sdk_worker import CachingStateHandler + from apache_beam.transforms.environments import Environment + +ENTRY_POINT_NAME = 'apache_beam_task_workers_plugins' +MAX_TASK_WORKERS = 300 +MAX_TASK_WORKER_RETRY = 10 + +Taskable = Union[ + 'TaskableValue', + List['TaskableValue'], + Tuple['TaskableValue', ...], + Set['TaskableValue']] + + +class TaskableValue(object): + """ + Value that can be distributed to TaskWorkers as tasks. + + Has the original value, and TaskProperties that specifies how the task will be + generated. + """ + + def __init__( + self, + value, # type: Any + urn, # type: str + env=None, # type: Optional[Environment] + payload=None # type: Optional[Any] + ): + # type: (...) -> None + """ + Args: + value: The wrapped element + urn: id of the task worker handler + env : Environment for the task to be run in + payload : Payload containing settings for the task worker handler + """ + self.value = value + self.urn = urn + self.env = env + self.payload = payload + + +class TaskWorkerProcessBundleError(Exception): + """ + Error thrown when TaskWorker fails to process_bundle. + + Errors encountered when task worker is processing bundle can be retried, up + till max retries defined by ``MAX_TASK_WORKER_RETRY``. + """ + + +class TaskWorkerTerminatedError(Exception): + """ + Error thrown when TaskWorker terminated before it finished working. + + Custom TaskWorkerHandlers can choose to terminate a task but not + affect the whole bundle by setting ``TaskWorkerHandler.alive`` to False, + which will cause this error to be thrown. + """ + + +class TaskWorkerHandler(GrpcWorkerHandler): + """ + Abstract base class for TaskWorkerHandler for a task worker, + + A TaskWorkerHandler is created for each TaskableValue. + + Subclasses must override ``start_remote`` to modify how remote task worker is + started, and register task properties type when defining a subclass. + """ + + _known_urns = {} # type: Dict[str, Type[TaskWorkerHandler]] + + def __init__( + self, + state, # type: CachingStateHandler + provision_info, # type: Union[beam_provision_api_pb2.ProvisionInfo, ExtendedProvisionInfo] + grpc_server, # type: TaskGrpcServer + environment, # type: Environment + task_payload, # type: Any + credentials=None, # type: Optional[str] + worker_id=None # type: Optional[str] + ): + # type: (...) -> None + self._grpc_server = grpc_server + + # we are manually doing init instead of calling GrpcWorkerHandler's init + # because we want to override worker_id and we don't want extra + # ControlConnection to be established + WorkerHandler.__init__(self, grpc_server.control_handler, + grpc_server.data_plane_handler, state, + provision_info) + # override worker_id if provided + if worker_id: + self.worker_id = worker_id + + self.control_address = self.port_from_worker(self._grpc_server.control_port) + self.provision_address = self.control_address + self.logging_address = self.port_from_worker( + self.get_port_from_env_var('LOGGING_API_SERVICE_DESCRIPTOR')) + # if we are running from subprocess environment, the aritfact staging + # endpoint is the same as control end point, and the env var won't be + # recorded so use control end point and record that to artifact + try: + self.artifact_address = self.port_from_worker( + self.get_port_from_env_var('ARTIFACT_API_SERVICE_DESCRIPTOR')) + except KeyError: + self.artifact_address = self.port_from_worker( + self.get_port_from_env_var('CONTROL_API_SERVICE_DESCRIPTOR')) + os.environ['ARTIFACT_API_SERVICE_DESCRIPTOR'] = os.environ[ + 'CONTROL_API_SERVICE_DESCRIPTOR'] + + # modify provision info + modified_provision = copy.copy(provision_info) + modified_provision.control_endpoint.url = self.control_address + modified_provision.logging_endpoint.url = self.logging_address + modified_provision.artifact_endpoint.url = self.artifact_address + with TaskProvisionServicer._lock: + self._grpc_server.provision_handler.provision_by_worker_id[ + self.worker_id] = modified_provision + + self.control_conn = self._grpc_server.control_handler.get_conn_by_worker_id( + self.worker_id) + + self.environment = environment + self.task_payload = task_payload + self.credentials = credentials + self.alive = True + + @staticmethod + def host_from_worker(): + # type: () -> str + import socket + return socket.getfqdn() + + @staticmethod + def get_port_from_env_var(env_var): + # type: (str) -> str + """Extract the service port for a given environment variable.""" + from google.protobuf import text_format + endpoint = endpoints_pb2.ApiServiceDescriptor() + text_format.Merge(os.environ[env_var], endpoint) + return endpoint.url.split(':')[-1] + + @staticmethod + def load_plugins(): + # type: () -> None + import entrypoints + for name, entry_point in entrypoints.get_group_named( + ENTRY_POINT_NAME).items(): + logging.info('Loading entry point: {}'.format(name)) + entry_point.load() + + @classmethod + def register_urn(cls, urn, constructor=None): + def register(constructor): + cls._known_urns[urn] = constructor + return constructor + if constructor: + return register(constructor) + else: + return register + + @classmethod + def create(cls, + state, # type: CachingStateHandler + provision_info, # type: Union[beam_provision_api_pb2.ProvisionInfo, ExtendedProvisionInfo] + grpc_server, # type: TaskGrpcServer + taskable_value, # type: TaskableValue + credentials=None, # type: Optional[str] + worker_id=None # type: Optional[str] + ): + # type: (...) -> TaskWorkerHandler + constructor = cls._known_urns[taskable_value.urn] + return constructor(state, provision_info, grpc_server, + taskable_value.env, taskable_value.payload, + credentials=credentials, worker_id=worker_id) + + def start_worker(self): + # type: () -> None + self.start_remote() + + def start_remote(self): + # type: () -> None + """Start up a remote TaskWorker to process the current element. + + Subclass should implement this.""" + raise NotImplementedError + + def stop_worker(self): + # type: () -> None + # send shutdown request + future = self.control_conn.push( + beam_task_worker_pb2.TaskInstructionRequest( + instruction_id=self.worker_id, + shutdown=beam_task_worker_pb2.ShutdownRequest())) + response = future.get() + if response.error: + logging.warning('Error stopping worker: {}'.format(self.worker_id)) + + # close control conn after stop worker + self.control_conn.close() + + def _get_future(self, future, interval=0.5): + # type: (ControlFuture, float) -> beam_task_worker_pb2.TaskInstructionResponse + result = None + while self.alive: + result = future.get(timeout=interval) + if result: + break + + # if the handler is not alive, meaning task worker is stopped before + # finishing processing, raise ``TaskWorkerTerminatedError`` + if result is None: + raise TaskWorkerTerminatedError() + + return result + + def execute(self, + data_channel_factory, # type: ProxyGrpcClientDataChannelFactory + process_bundle_descriptor # type: beam_fn_api_pb2.ProcessBundleDescriptor + ): + # type: (...) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] + """Main entry point of the task execution cycle of a ``TaskWorkerHandler``. + + It will first issue a create request, and wait for the remote bundle + processor to be created; Then it will issue the process bundle request, and + wait for the result. If there's error occurred when processing bundle, + ``TaskWorkerProcessBundleError`` will be raised. + """ + # wait for remote bundle processor to be created + create_future = self.control_conn.push( + beam_task_worker_pb2.TaskInstructionRequest( + instruction_id=self.worker_id, + create=beam_task_worker_pb2.CreateRequest( + process_bundle_descriptor=process_bundle_descriptor, + state_handler_endpoint=endpoints_pb2.ApiServiceDescriptor( + url=self._grpc_server.state_address), + data_factory=beam_task_worker_pb2.GrpcClientDataChannelFactory( + credentials=self.credentials, + worker_id=data_channel_factory.worker_id, + transmitter_url=data_channel_factory.transmitter_url)))) + self._get_future(create_future) + + # process bundle + process_future = self.control_conn.push( + beam_task_worker_pb2.TaskInstructionRequest( + instruction_id=self.worker_id, + process_bundle=beam_task_worker_pb2.ProcessorProcessBundleRequest()) + ) + response = self._get_future(process_future) + if response.error: + # raise here so this task can be retried + raise TaskWorkerProcessBundleError() + else: + delayed_applications = response.process_bundle.delayed_applications + require_finalization = response.process_bundle.require_finalization + + self.stop_worker() + return delayed_applications, require_finalization + + def reset(self): + # type: () -> None + """This is used to retry a failed task.""" + self.control_conn.reset() + + +class TaskGrpcServer(object): + """ + A collection of grpc servicers that handle communication between a + ``TaskWorker`` and ``TaskWorkerHandler``. + + Contains three servers: + - a control server hosting ``TaskControlService`` amd `TaskProvisionService` + - a data server hosting ``TaskFnDataService`` + - a state server hosting ``TaskStateService`` + + This is shared by all TaskWorkerHandlers generated by one bundle. + """ + + _DEFAULT_SHUTDOWN_TIMEOUT_SECS = 5 + + def __init__( + self, + state_handler, # type: CachingStateHandler + max_workers, # type: int + data_store, # type: Mapping[str, List[beam_fn_api_pb2.Elements.Data]] + data_channel_factory, # type: DataChannelFactory + instruction_id, # type: str + provision_info, # type: Union[beam_provision_api_pb2.ProvisionInfo, ExtendedProvisionInfo] + ): + # type: (...) -> None + self.state_handler = state_handler + self.max_workers = max_workers + self.control_server = grpc.server( + futures.ThreadPoolExecutor(max_workers=self.max_workers)) + self.control_port = self.control_server.add_insecure_port('[::]:0') + self.control_address = '%s:%s' % (self.get_host_name(), self.control_port) + + # Options to have no limits (-1) on the size of the messages + # received or sent over the data plane. The actual buffer size + # is controlled in a layer above. + no_max_message_sizes = [("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1)] + self.data_server = grpc.server( + futures.ThreadPoolExecutor(max_workers=self.max_workers), + options=no_max_message_sizes) + self.data_port = self.data_server.add_insecure_port('[::]:0') + self.data_address = '%s:%s' % (self.get_host_name(), self.data_port) + + self.state_server = grpc.server( + futures.ThreadPoolExecutor(max_workers=self.max_workers), + options=no_max_message_sizes) + self.state_port = self.state_server.add_insecure_port('[::]:0') + self.state_address = '%s:%s' % (self.get_host_name(), self.state_port) + + self.control_handler = TaskControlServicer() + beam_task_worker_pb2_grpc.add_TaskControlServicer_to_server( + self.control_handler, self.control_server) + self.provision_handler = TaskProvisionServicer(provision_info=provision_info) + beam_provision_api_pb2_grpc.add_ProvisionServiceServicer_to_server( + self.provision_handler, self.control_server) + + self.data_plane_handler = TaskFnDataServicer(data_store, + data_channel_factory, + instruction_id) + beam_task_worker_pb2_grpc.add_TaskFnDataServicer_to_server( + self.data_plane_handler, self.data_server) + + beam_fn_api_pb2_grpc.add_BeamFnStateServicer_to_server( + TaskStateServicer(self.state_handler, instruction_id, + state_handler._context.bundle_cache_token), + self.state_server) + + logging.info('starting control server on port %s', self.control_port) + logging.info('starting data server on port %s', self.data_port) + logging.info('starting state server on port %s', self.state_port) + self.state_server.start() + self.data_server.start() + self.control_server.start() + + @staticmethod + def get_host_name(): + # type: () -> str + import socket + return socket.getfqdn() + + def close(self): + # type: () -> None + self.control_handler.done() + to_wait = [ + self.control_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS), + self.data_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS), + self.state_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS), + ] + for w in to_wait: + w.wait() + + +# ============= +# Control Plane +# ============= +class TaskWorkerConnection(ControlConnection): + """The control connection between a TaskWorker and a TaskWorkerHandler. + + TaskWorkerHandler push InstructionRequests to _push_queue, and receives + InstructionResponses from TaskControlServicer. + """ + + _lock = threading.Lock() + + def __init__(self): + self._push_queue = queue.Queue() + self._input = None + self._futures_by_id = {} # type: Dict[Any, ControlFuture] + self._read_thread = threading.Thread( + name='bundle_processor_control_read', target=self._read) + self._state = TaskControlServicer.UNSTARTED_STATE + # marks current TaskConnection as in a state of retrying after failure + self._retrying = False + + def _read(self): + # type: () -> None + for data in self._input: + self._futures_by_id.pop(data.WhichOneof('response')).set(data) + + def push(self, + req # type: Union[TaskControlServicer._DONE_MARKER, beam_task_worker_pb2.TaskInstructionRequest] + ): + # type: (...) -> Optional[ControlFuture] + if req == TaskControlServicer._DONE_MARKER: + self._push_queue.put(req) + return None + if not req.instruction_id: + raise RuntimeError( + 'TaskInstructionRequest has to have instruction id!') + future = ControlFuture(req.instruction_id) + self._futures_by_id[req.WhichOneof('request')] = future + self._push_queue.put(req) + return future + + def set_inputs(self, input): + with TaskWorkerConnection._lock: + if self._input and not self._retrying: + raise RuntimeError('input is already set.') + self._input = input + self._read_thread.start() + self._state = TaskControlServicer.STARTED_STATE + self._retrying = False + + def close(self): + # type: () -> None + with TaskWorkerConnection._lock: + if self._state == TaskControlServicer.STARTED_STATE: + self.push(TaskControlServicer._DONE_MARKER) + self._read_thread.join() + self._state = TaskControlServicer.DONE_STATE + + def reset(self): + # type: () -> None + self.close() + self.__init__() + self._retrying = True + + +class TaskControlServicer(beam_task_worker_pb2_grpc.TaskControlServicer): + + _lock = threading.Lock() + + UNSTARTED_STATE = 'unstarted' + STARTED_STATE = 'started' + DONE_STATE = 'done' + + _DONE_MARKER = object() + + def __init__(self): + # type: () -> None + self._state = self.UNSTARTED_STATE + self._connections_by_worker_id = collections.defaultdict( + TaskWorkerConnection) + + def get_conn_by_worker_id(self, worker_id): + # type: (str) -> TaskWorkerConnection + with self._lock: + result = self._connections_by_worker_id[worker_id] + return result + + def Control(self, request_iterator, context): + with self._lock: + if self._state == self.DONE_STATE: + return + else: + self._state = self.STARTED_STATE + worker_id = dict(context.invocation_metadata()).get('worker_id') + if not worker_id: + raise RuntimeError('Connection does not have worker id.') + conn = self.get_conn_by_worker_id(worker_id) + conn.set_inputs(request_iterator) + + while True: + to_push = conn.get_req() + if to_push is self._DONE_MARKER: + return + yield to_push + + def done(self): + # type: () -> None + self._state = self.DONE_STATE + + +# ========== +# Data Plane +# ========== +class ProxyGrpcClientDataChannelFactory(DataChannelFactory): + """A factory for ``ProxyGrpcClientDataChannel``. + + No caching behavior here because we are starting each data channel on + different location.""" + + def __init__(self, transmitter_url, credentials=None, worker_id=None): + # type: (str, Optional[str], Optional[str]) -> None + # These two are not private attributes because it was used in + # ``TaskWorkerHandler.execute`` when issuing TaskInstructionRequest + self.transmitter_url = transmitter_url + self.worker_id = worker_id + + self._credentials = credentials + + def create_data_channel(self, remote_grpc_port): + # type: (beam_fn_api_pb2.RemoteGrpcPort) -> ProxyGrpcClientDataChannel + url = remote_grpc_port.api_service_descriptor.url + return self.create_data_channel_from_url(url) + + def create_data_channel_from_url(self, url): + # type: (str) -> ProxyGrpcClientDataChannel + channel_options = [("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1)] + if self._credentials is None: + grpc_channel = GRPCChannelFactory.insecure_channel( + self.transmitter_url, options=channel_options) + else: + grpc_channel = GRPCChannelFactory.secure_channel( + self.transmitter_url, self._credentials, options=channel_options) + return ProxyGrpcClientDataChannel( + url, + beam_task_worker_pb2_grpc.TaskFnDataStub(grpc_channel)) + + def close(self): + # type: () -> None + pass + + +class ProxyGrpcClientDataChannel(_GrpcDataChannel): + """DataChannel wrapping the client side of a TaskFnDataService connection.""" + + def __init__(self, client_url, proxy_stub): + # type: (str, beam_task_worker_pb2_grpc.TaskFnDataStub) -> None + super(ProxyGrpcClientDataChannel, self).__init__() + self.client_url = client_url + self.proxy_stub = proxy_stub + + def input_elements(self, + instruction_id, # type: str + expected_transforms, # type: List[str] + abort_callback=None # type: Optional[Callable[[], bool]] + ): + # type: (...) -> Iterator[beam_fn_api_pb2.Elements.Data] + if not expected_transforms: + return + req = beam_task_worker_pb2.ReceiveRequest( + instruction_id=instruction_id, + client_data_endpoint=self.client_url) + done_transforms = [] + abort_callback = abort_callback or (lambda: False) + + for data in self.proxy_stub.Receive(req): + if self._closed: + raise RuntimeError('Channel closed prematurely.') + if abort_callback(): + return + if self._exc_info: + t, v, tb = self._exc_info + raise_(t, v, tb) + if not data.data and data.transform_id in expected_transforms: + done_transforms.append(data.transform_id) + else: + assert data.transform_id not in done_transforms + yield data + if len(done_transforms) >= len(expected_transforms): + return + + def output_stream(self, instruction_id, transform_id): + # type: (str, str) -> SizeBasedBufferingClosableOutputStream + + def _add_to_send_queue(data): + if data: + self.proxy_stub.Send(beam_task_worker_pb2.SendRequest( + instruction_id=instruction_id, + data=beam_fn_api_pb2.Elements.Data( + instruction_id=instruction_id, + transform_id=transform_id, + data=data), + client_data_endpoint=self.client_url + )) + + def close_callback(data): + _add_to_send_queue(data) + # no need to send empty bytes to signal end of processing here, because + # when the whole bundle finishes, the bundle processor original output + # stream will send that to runner + + return SizeBasedBufferingClosableOutputStream( + close_callback, flush_callback=_add_to_send_queue) + + +class TaskFnDataServicer(beam_task_worker_pb2_grpc.TaskFnDataServicer): + """Implementation of BeamFnDataTransmitServicer for any number of clients.""" + + def __init__(self, + data_store, # type: Mapping[str, List[beam_fn_api_pb2.Elements.Data]] + orig_data_channel_factory, # type: DataChannelFactory + instruction_id # type: str + ): + # type: (...) -> None + self.data_store = data_store + self.orig_data_channel_factory = orig_data_channel_factory + self.orig_instruction_id = instruction_id + self._orig_data_channel = None # type: Optional[ProxyGrpcClientDataChannel] + + def _get_orig_data_channel(self, url): + # type: (str) -> ProxyGrpcClientDataChannel + remote_grpc_port = beam_fn_api_pb2.RemoteGrpcPort( + api_service_descriptor=endpoints_pb2.ApiServiceDescriptor(url=url)) + # the data channel is cached by url + return self.orig_data_channel_factory.create_data_channel(remote_grpc_port) + + def Receive(self, request, context=None): + # type: (...) -> Iterator[beam_fn_api_pb2.Elements.Data]] + data = self.data_store[request.instruction_id] + for elem in data: + yield elem + + def Send(self, request, context=None): + # type: (...) -> beam_task_worker_pb2.SendResponse + if self._orig_data_channel is None: + self._orig_data_channel = self._get_orig_data_channel( + request.client_data_endpoint) + # We need to replace the instruction_id here with the original instruction + # id, not the current one (which is the task worker id) + request.data.instruction_id = self.orig_instruction_id + if request.data.data: + # only send when there's data, because it is signaling the runner side + # worker handler that element of this has ended if it is empty, and we + # want to send that when every task worker handler is finished + self._orig_data_channel._to_send.put(request.data) + return beam_task_worker_pb2.SendResponse() + + +# ===== +# State +# ===== +class TaskStateServicer(GrpcStateServicer): + + def __init__(self, state, instruction_id, cache_token): + # type: (CachingStateHandler, str, Optional[str]) -> None + self.instruction_id = instruction_id + self.cache_token = cache_token + super(TaskStateServicer, self).__init__(state) + + def State(self, request_stream, context=None): + # type: (...) -> Iterator[beam_fn_api_pb2.StateResponse] + # CachingStateHandler and GrpcStateHandler context is thread local, so we + # need to set it here for each TaskWorker + self._state._context.process_instruction_id = self.instruction_id + self._state._context.cache_token = self.cache_token + self._state._underlying._context.process_instruction_id = self.instruction_id + + # FIXME: This is not currently properly supporting state caching (currently + # state caching behavior only happens within python SDK, so runners like + # the FlinkRunner won't create the state cache anyways for now) + for request in request_stream: + request_type = request.WhichOneof('request') + + if request_type == 'get': + data, continuation_token = self._state._underlying.get_raw( + request.state_key, request.get.continuation_token) + yield beam_fn_api_pb2.StateResponse( + id=request.id, + get=beam_fn_api_pb2.StateGetResponse( + data=data, continuation_token=continuation_token)) + elif request_type == 'append': + self._state._underlying.append_raw(request.state_key, + request.append.data) + yield beam_fn_api_pb2.StateResponse( + id=request.id, + append=beam_fn_api_pb2.StateAppendResponse()) + elif request_type == 'clear': + self._state._underlying.clear(request.state_key) + yield beam_fn_api_pb2.StateResponse( + id=request.id, + clear=beam_fn_api_pb2.StateClearResponse()) + else: + raise NotImplementedError('Unknown state request: %s' % request_type) + + +# ===== +# Provision +# ===== +class TaskProvisionServicer(BasicProvisionService): + """ + Provide provision info for remote task workers, provision is static because + for each bundle the provision info is static. + """ + + _lock = threading.Lock() + + def __init__(self, provision_info=None): + # type: (Optional[beam_provision_api_pb2.ProvisionInfo]) -> None + self._provision_info = provision_info + self.provision_by_worker_id = dict() + + def GetProvisionInfo(self, request, context=None): + # type: (...) -> beam_provision_api_pb2.GetProvisionInfoResponse + # if request comes from task worker that can be found, return the modified + # provision info + if context: + worker_id = dict(context.invocation_metadata())['worker_id'] + provision_info = self.provision_by_worker_id.get(worker_id, + self._provision_info) + else: + # fallback to the generic sdk worker version of provision info if not + # found from a cached task worker + provision_info = self._provision_info + + return beam_provision_api_pb2.GetProvisionInfoResponse(info=provision_info) + + +class BundleProcessorTaskWorker(object): + """ + The remote task worker that communicates with the SDK worker to do the + actual work of processing bundles. + + The BundleProcessor will detect inputs and see if there is TaskableValue, and + if there is and the BundleProcessor is set to "use task worker", then + BundleProcessorTaskHelper will create a TaskWorkerHandler that this class + communicates with. + + This class creates a BundleProcessor and receives TaskInstructionRequests and + sends back respective responses via the grpc channels connected to the control + endpoint; + """ + + REQUEST_PREFIX = '_request_' + _lock = threading.Lock() + + def __init__(self, worker_id, server_url, credentials=None): + # type: (str, str, Optional[str]) -> None + """Initialize a BundleProcessorTaskWorker. Lives remotely. + + It will create a BundleProcessor with the provide information and process + the requests using the BundleProcessor created. + + Args: + worker_id: the worker id of current task worker + server_url: control service url for the TaskGrpcServer + credentials: credentials to use when creating client + """ + self.worker_id = worker_id + self._credentials = credentials + self._responses = queue.Queue() + self._alive = None # type: Optional[bool] + self._bundle_processor = None # type: Optional[BundleProcessor] + self._exc_info = None + self.stub = self._create_stub(server_url) + + @classmethod + def execute(cls, worker_id, server_url, credentials=None): + # type: (str, str, Optional[str]) -> None + """Instantiate a BundleProcessorTaskWorker and start running. + + If there's error, it will be raised here so it can be reflected to user. + + Args: + worker_id: worker id for the BundleProcessorTaskWorker + server_url: control service url for the TaskGrpcServer + credentials: credentials to use when creating client + """ + self = cls(worker_id, server_url, credentials=credentials) + self.run() + + # raise the error here, so user knows there's a failure and could retry + if self._exc_info: + t, v, tb = self._exc_info + raise_(t, v, tb) + + def _create_stub(self, server_url): + # type: (str) -> beam_task_worker_pb2_grpc.TaskControlStub + """Create the TaskControl client.""" + channel_options = [("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1)] + if self._credentials is None: + channel = GRPCChannelFactory.insecure_channel( + server_url, + options=channel_options) + else: + channel = GRPCChannelFactory.secure_channel(server_url, + self._credentials, + options=channel_options) + + # add instruction_id to grpc channel + channel = grpc.intercept_channel( + channel, + WorkerIdInterceptor(self.worker_id)) + + return beam_task_worker_pb2_grpc.TaskControlStub(channel) + + def do_instruction(self, request): + # type: (beam_task_worker_pb2.TaskInstructionRequest) -> beam_task_worker_pb2.TaskInstructionResponse + """Process the requests with the corresponding method.""" + request_type = request.WhichOneof('request') + if request_type: + return getattr(self, self.REQUEST_PREFIX + request_type)( + getattr(request, request_type)) + else: + raise NotImplementedError + + def run(self): + # type: () -> None + """Start the full running life cycle for a task worker. + + It send TaskWorkerInstructionResponse to TaskWorkerHandler, and wait for + TaskWorkerInstructionRequest. This service is bidirectional. + """ + no_more_work = object() + self._alive = True + + def get_responses(): + while True: + response = self._responses.get() + if response is no_more_work: + return + if response: + yield response + + try: + for request in self.stub.Control(get_responses()): + self._responses.put(self.do_instruction(request)) + finally: + self._alive = False + + self._responses.put(no_more_work) + logging.info('Done consuming work.') + + def _request_create(self, request): + # type: (beam_task_worker_pb2.TaskInstructionRequest) -> beam_task_worker_pb2.TaskInstructionResponse + """Create a BundleProcessor based on the request. + + Should be the first request received from the handler. + """ + from apache_beam.runners.worker.sdk_worker import \ + GrpcStateHandlerFactory + + credentials = None + if request.data_factory.credentials._credentials: + credentials = grpc.ChannelCredentials( + request.data_factory.credentials._credentials) + logging.debug('Credentials: {!r}'.format(credentials)) + + worker_id = request.data_factory.worker_id + transmitter_url = request.data_factory.transmitter_url + state_handler_endpoint = request.state_handler_endpoint + # FIXME: Add support for Caching later + state_factory = GrpcStateHandlerFactory(StateCache(0), credentials) + state_handler = state_factory.create_state_handler( + state_handler_endpoint) + data_channel_factory = ProxyGrpcClientDataChannelFactory( + transmitter_url, credentials, worker_id + ) + + self._bundle_processor = BundleProcessor( + request.process_bundle_descriptor, + state_handler, + data_channel_factory + ) + return beam_task_worker_pb2.TaskInstructionResponse( + create=beam_task_worker_pb2.CreateResponse(), + instruction_id=self.worker_id + ) + + def _request_process_bundle(self, + request # type: beam_task_worker_pb2.TaskInstructionRequest + ): + # type: (...) -> beam_task_worker_pb2.TaskInstructionResponse + """Process bundle using the bundle processor based on the request.""" + error = None + + try: + # FIXME: Update this to use the cache_tokens properly + with self._bundle_processor.state_handler._underlying.process_instruction_id( + self.worker_id): + delayed_applications, require_finalization = \ + self._bundle_processor.process_bundle(self.worker_id, + use_task_worker=False) + except Exception as e: + # we want to propagate the error back to the TaskWorkerHandler, so that + # it will raise `TaskWorkerProcessBundleError` which allows for requeue + # behavior (up until MAX_TASK_WORKER_RETRY number of retries) + error = e.message + self._exc_info = sys.exc_info() + delayed_applications = [] + require_finalization = False + + return beam_task_worker_pb2.TaskInstructionResponse( + process_bundle=beam_task_worker_pb2.ProcessorProcessBundleResponse( + delayed_applications=delayed_applications, + require_finalization=require_finalization), + instruction_id=self.worker_id, + error=error + ) + + def _request_shutdown(self, request): + # type: (beam_task_worker_pb2.TaskInstructionRequest) -> beam_task_worker_pb2.TaskInstructionResponse + """Shutdown the bundleprocessor.""" + error = None + try: + # shut down state handler here because it is not created by the state + # handler factory thus won't be closed automatically + self._bundle_processor.state_handler.done() + self._bundle_processor.shutdown() + except Exception as e: + error = e.message + finally: + return beam_task_worker_pb2.TaskInstructionResponse( + shutdown=beam_task_worker_pb2.ShutdownResponse(), + instruction_id=self.worker_id, + error=error + ) + + +class BundleProcessorTaskHelper(object): + """ + A helper object that is used by a BundleProcessor while processing bundle. + + Delegates TaskableValues to TaskWorkers, if enabled. + + It can process TaskableValue using TaskWorkers if inspected, and kept the + default behavior if specified to not use TaskWorker, or there's no + TaskableValue found in this bundle. + + To utilize TaskWorkers, BundleProcessorTaskHelper will split up the input + bundle into tasks based on the wrapped TaskableValue's payload, and create a + TaskWorkerHandler for each task. + """ + + def __init__(self, instruction_id, wrapped_values): + # type: (str, DefaultDict[str, List[Tuple[Any, bytes]]]) -> None + """Initialize a BundleProcessorTaskHelper object. + + Args: + instruction_id: the instruction_id of the bundle that the + BundleProcessor is processing + wrapped_values: The mapping of transform id to raw and encoded data. + """ + self.instruction_id = instruction_id + self.wrapped_values = wrapped_values + + def split_taskable_values(self): + # type: () -> Tuple[DefaultDict[str, List[Any]], DefaultDict[str, List[beam_fn_api_pb2.Elements.Data]]] + """Split TaskableValues into tasks and pair it with worker. + + Also put the raw bytes along with worker id for data dispatching by data + plane handler. + """ + # TODO: Come up with solution on how this can be dynamically changed + # could use window + splitted = collections.defaultdict(list) + data_store = collections.defaultdict(list) + worker_count = 0 + for ptransform_id, values in self.wrapped_values.items(): + for decoded, raw in values: + worker_id = 'worker_{}'.format(worker_count) + splitted[worker_id].append(decoded) + data_store[worker_id].append(beam_fn_api_pb2.Elements.Data( + transform_id=ptransform_id, + data=raw, + instruction_id=worker_id + )) + worker_count += 1 + + return splitted, data_store + + def _start_task_grpc_server( + self, + max_workers, # type: int + data_store, # type: Mapping[str, List[beam_fn_api_pb2.Elements.Data]] + state_handler, # type: CachingStateHandler + data_channel_factory, # type: DataChannelFactory + provision_info, # type: beam_provision_api_pb2.ProvisionInfo + ): + # type:(...) -> TaskGrpcServer + """Start up TaskGrpcServer. + + Args: + max_workers: number of max worker + data_store: stored data of worker id and the raw and decoded values for + the worker to process as inputs + state_handler: state handler of current BundleProcessor + data_channel_factory: data channel factory of current BundleProcessor + """ + return TaskGrpcServer(state_handler, max_workers, data_store, + data_channel_factory, self.instruction_id, + provision_info) + + @staticmethod + def get_default_task_env(process_bundle_descriptor): + # type:(beam_fn_api_pb2.ProcessBundleDescriptor) -> Optional[Environment] + """Get the current running beam Environment class. + + Used as the default for the task worker. + + Args: + process_bundle_descriptor: the ProcessBundleDescriptor proto + """ + from apache_beam.runners.portability.fn_api_runner.translations import \ + PAR_DO_URNS + from apache_beam.transforms.environments import Environment + + # find a ParDo xform in this stage + pardo = None + for _, xform in process_bundle_descriptor.transforms.items(): + if xform.spec.urn in PAR_DO_URNS: + pardo = xform + break + + if pardo is None: + # don't set the default task env if no ParDo is found + # FIXME: Use the pipeline default env here? + return None + + env_proto = process_bundle_descriptor.environments.get( + pardo.environment_id) + + return Environment.from_runner_api(env_proto, None) + + def get_sdk_worker_provision_info(self, server_url): + # type:(str) -> beam_provision_api_pb2.ProvisionInfo + channel_options = [("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1)] + channel = GRPCChannelFactory.insecure_channel(server_url, + options=channel_options) + + worker_id = os.environ['WORKER_ID'] + # add sdk worker id to grpc channel + channel = grpc.intercept_channel(channel, WorkerIdInterceptor(worker_id)) + + provision_stub = beam_provision_api_pb2_grpc.ProvisionServiceStub(channel) + response = provision_stub.GetProvisionInfo( + beam_provision_api_pb2.GetProvisionInfoRequest()) + channel.close() + return response.info + + def process_bundle_with_task_workers( + self, + state_handler, # type: CachingStateHandler + data_channel_factory, # type: DataChannelFactory + process_bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor + ): + # type: (...) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] + """Main entry point for task worker system. + + Starts up a group of TaskWorkerHandlers, dispatches tasks and waits for them + to finish. + + Fails if any TaskWorker exceeds maximum retries. + + Args: + state_handler: state handler of current BundleProcessor + data_channel_factory: data channel factory of current BundleProcessor + process_bundle_descriptor: a description of the stage that this + ``BundleProcessor``is to execute. + """ + from google.protobuf import text_format + + default_env = self.get_default_task_env(process_bundle_descriptor) + # start up grpc server + splitted_elements, data_store = self.split_taskable_values() + num_task_workers = len(splitted_elements.items()) + if num_task_workers > MAX_TASK_WORKERS: + logging.warning( + 'Number of element exceeded MAX_TASK_WORKERS ({})'.format( + MAX_TASK_WORKERS)) + num_task_workers = MAX_TASK_WORKERS + + # get sdk worker provision info first + try: + provision_port = TaskWorkerHandler.get_port_from_env_var( + 'PROVISION_API_SERVICE_DESCRIPTOR') + except KeyError: + # if we are in subprocess environment then there won't be any provision + # service so use default provison info + provision_info = beam_provision_api_pb2.ProvisionInfo() + + else: + provision_info = self.get_sdk_worker_provision_info('{}:{}'.format( + TaskWorkerHandler.host_from_worker(), provision_port)) + + server = self._start_task_grpc_server(num_task_workers, data_store, + state_handler, data_channel_factory, + provision_info) + # modify provision api service descriptor to use the new address that we are + # gonna be using (the control address) + os.environ['PROVISION_API_SERVICE_DESCRIPTOR'] = text_format.MessageToString( + endpoints_pb2.ApiServiceDescriptor(url=server.control_address)) + + # create TaskWorkerHandlers + task_worker_handlers = [] + for worker_id, elem in splitted_elements.items(): + taskable_value = get_taskable_value(elem) + # set the env to default env if there is + if taskable_value.env is None and default_env: + taskable_value.env = default_env + + task_worker_handler = TaskWorkerHandler.create( + state_handler, provision_info, server, taskable_value, + credentials=data_channel_factory._credentials, worker_id=worker_id) + task_worker_handlers.append(task_worker_handler) + task_worker_handler.start_worker() + + def _execute(handler): + """ + This is the method that runs in the thread pool representing a working + TaskHandler. + """ + worker_data_channel_factory = ProxyGrpcClientDataChannelFactory( + server.data_address, + credentials=data_channel_factory._credentials, + worker_id=task_worker_handler.worker_id) + counter = 0 + while True: + try: + counter += 1 + return handler.execute(worker_data_channel_factory, + process_bundle_descriptor) + except TaskWorkerProcessBundleError as e: + if counter >= MAX_TASK_WORKER_RETRY: + logging.error('Task Worker has exceeded max retries!') + handler.stop_worker() + raise + # retry if task worker failed to process bundle + handler.reset() + continue + except TaskWorkerTerminatedError as e: + # This error is thrown only when TaskWorkerHandler is terminated + # before it finished processing + logging.warning('TaskWorker terminated prematurely.') + raise + + # start actual processing of splitted bundle + merged_delayed_applications = [] + bundle_require_finalization = False + with futures.ThreadPoolExecutor(max_workers=num_task_workers) as executor: + try: + for delayed_applications, require_finalization in executor.map( + _execute, task_worker_handlers): + + if delayed_applications is None: + raise RuntimeError('Task Worker failed to process task.') + merged_delayed_applications.extend(delayed_applications) + # if any elem requires finalization, set it to True + if not bundle_require_finalization and require_finalization: + bundle_require_finalization = True + except TaskWorkerProcessBundleError: + raise RuntimeError('Task Worker failed to process task.') + except TaskWorkerTerminatedError: + # This error is thrown only when TaskWorkerHandler is terminated before + # it finished processing, this is only possible if + # `TaskWorkerHandler.alive` is manually set to False by custom user + # defined monitoring function, which user would trigger if they want + # to terminate the task; + # In that case, we want to continue on and not hold the whole bundle + # by the tasks that user manually terminated. + pass + + return merged_delayed_applications, bundle_require_finalization + + +@TaskWorkerHandler.register_urn('local') +class LocalTaskWorkerHandler(TaskWorkerHandler): + """TaskWorkerHandler that starts up task worker locally.""" + + # FIXME: create a class-level thread pool to restrict the number of threads + + def start_remote(self): + # type: () -> None + """start a task worker local to the task worker handler.""" + obj = BundleProcessorTaskWorker(self.worker_id, self.control_address, + self.credentials) + run_thread = threading.Thread(target=obj.run) + run_thread.daemon = True + run_thread.start() + + +TaskWorkerHandler.load_plugins() + + +def get_taskable_value(decoded_value): + # type: (Any) -> Optional[TaskableValue] + """Check whether the given value contains taskable value. + + If taskable, return the TaskableValue + + Args: + decoded_value: decoded value from raw input stream + """ + # FIXME: Come up with a solution that's not so specific + if isinstance(decoded_value, (list, tuple, set)): + for val in decoded_value: + result = get_taskable_value(val) + if result: + return result + elif isinstance(decoded_value, TaskableValue): + return decoded_value + return None diff --git a/sdks/python/apache_beam/runners/worker/task_worker/task_worker_main.py b/sdks/python/apache_beam/runners/worker/task_worker/task_worker_main.py index 6cbb2ea405a1..5d8604c9a66f 100644 --- a/sdks/python/apache_beam/runners/worker/task_worker/task_worker_main.py +++ b/sdks/python/apache_beam/runners/worker/task_worker/task_worker_main.py @@ -24,8 +24,8 @@ import traceback from apache_beam.runners.worker.sdk_worker_main import _load_main_session -from apache_beam.runners.worker.task_worker import BundleProcessorTaskWorker -from apache_beam.runners.worker.task_worker import TaskWorkerHandler +from apache_beam.runners.worker.task_worker.handlers import BundleProcessorTaskWorker +from apache_beam.runners.worker.task_worker.handlers import TaskWorkerHandler # This module is experimental. No backwards-compatibility guarantees. diff --git a/sdks/python/build-requirements.txt b/sdks/python/build-requirements.txt index 90d766abcdbe..1ecd6ec10a50 100644 --- a/sdks/python/build-requirements.txt +++ b/sdks/python/build-requirements.txt @@ -17,4 +17,3 @@ grpcio-tools==1.30.0 future==0.18.2 mypy-protobuf==1.18 -entrypoints==0.3 diff --git a/sdks/python/setup.py b/sdks/python/setup.py index b792ddff28ca..e48dcaefc812 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -141,6 +141,7 @@ def get_version(): # server, therefore list of allowed versions is very narrow. # See: https://github.com/uqfoundation/dill/issues/341. 'dill>=0.3.1.1,<0.3.2', + 'entrypoints>=0.3', 'fastavro>=0.21.4,<2', 'funcsigs>=1.0.2,<2; python_version < "3.0"', 'future>=0.18.2,<1.0.0', @@ -333,6 +334,9 @@ def run(self): entry_points={ 'nose.plugins.0.10': [ 'beam_test_plugin = test_config:BeamTestPlugin', + ], + 'apache_beam_task_workers_plugins': [ + 'k8s_task_worker = apache_beam.task.task_worker.k8s.handler:KubeTaskWorkerHandler', ]}, cmdclass={ 'build_py': generate_protos_first(build_py), From 1cbbd798ca0b9b4724a777469656331bc1f8a36b Mon Sep 17 00:00:00 2001 From: Sam Bourne Date: Mon, 19 Oct 2020 17:05:36 -0700 Subject: [PATCH 10/15] Add k8s KubeTask taskworker --- .../apache_beam/examples/taskworker/k8s.py | 79 +++++++ sdks/python/apache_beam/task/__init__.py | 18 ++ .../apache_beam/task/task_worker/__init__.py | 18 ++ .../task/task_worker/k8s/__init__.py | 21 ++ .../task/task_worker/k8s/handler.py | 210 ++++++++++++++++++ .../task/task_worker/k8s/transforms.py | 51 +++++ 6 files changed, 397 insertions(+) create mode 100644 sdks/python/apache_beam/examples/taskworker/k8s.py create mode 100644 sdks/python/apache_beam/task/__init__.py create mode 100644 sdks/python/apache_beam/task/task_worker/__init__.py create mode 100644 sdks/python/apache_beam/task/task_worker/k8s/__init__.py create mode 100644 sdks/python/apache_beam/task/task_worker/k8s/handler.py create mode 100644 sdks/python/apache_beam/task/task_worker/k8s/transforms.py diff --git a/sdks/python/apache_beam/examples/taskworker/k8s.py b/sdks/python/apache_beam/examples/taskworker/k8s.py new file mode 100644 index 000000000000..bdf9082595be --- /dev/null +++ b/sdks/python/apache_beam/examples/taskworker/k8s.py @@ -0,0 +1,79 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import logging +from typing import TYPE_CHECKING + +import apache_beam as beam +from apache_beam.options.pipeline_options import DirectOptions +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions + +from apache_beam.task.task_worker.k8s.transforms import KubeTask + +if TYPE_CHECKING: + from typing import Any + from apache_beam.task.task_worker.k8s.handler import KubeTaskProperties + + +def process(element): + # type: (int) -> int + import time + start = time.time() + print('processing...') + # represents an expensive process of some sort + time.sleep(element) + print('processing took {:0.6f} s'.format(time.time() - start)) + return element + + +def _per_element_wrapper(element, payload): + # type: (Any, KubeTaskProperties) -> KubeTaskProperties + """ + Callback to modify the kubernetes job properties per-element. + """ + import copy + result = copy.copy(payload) + result.name += '-{}'.format(element) + return result + + +def run(argv=None, save_main_session=True): + """ + Run a pipeline that submits two kubernetes jobs, each that simply sleep + """ + parser = argparse.ArgumentParser() + known_args, pipeline_args = parser.parse_known_args(argv) + + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + pipeline_options.view_as(DirectOptions).direct_running_mode = 'multi_processing' + + with beam.Pipeline(options=pipeline_options) as pipe: + ( + pipe + | beam.Create([20, 42]) + | 'kubetask' >> KubeTask( + beam.Map(process), + wrapper=_per_element_wrapper) + ) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/sdks/python/apache_beam/task/__init__.py b/sdks/python/apache_beam/task/__init__.py new file mode 100644 index 000000000000..6569e3fe5de4 --- /dev/null +++ b/sdks/python/apache_beam/task/__init__.py @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import absolute_import diff --git a/sdks/python/apache_beam/task/task_worker/__init__.py b/sdks/python/apache_beam/task/task_worker/__init__.py new file mode 100644 index 000000000000..6569e3fe5de4 --- /dev/null +++ b/sdks/python/apache_beam/task/task_worker/__init__.py @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import absolute_import diff --git a/sdks/python/apache_beam/task/task_worker/k8s/__init__.py b/sdks/python/apache_beam/task/task_worker/k8s/__init__.py new file mode 100644 index 000000000000..49168a7b1fe4 --- /dev/null +++ b/sdks/python/apache_beam/task/task_worker/k8s/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +""" + +from __future__ import absolute_import diff --git a/sdks/python/apache_beam/task/task_worker/k8s/handler.py b/sdks/python/apache_beam/task/task_worker/k8s/handler.py new file mode 100644 index 000000000000..6f0887818763 --- /dev/null +++ b/sdks/python/apache_beam/task/task_worker/k8s/handler.py @@ -0,0 +1,210 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Kubernetes TaskWorker. +""" +from __future__ import absolute_import + +import threading +import time +from typing import TYPE_CHECKING + +from apache_beam.runners.worker.task_worker.handlers import TaskWorkerHandler + +try: + from kubernetes.client import BatchV1Api + from kubernetes.client import V1EnvVar + from kubernetes.client import V1Container + from kubernetes.client import V1PodTemplateSpec + from kubernetes.client import V1PodSpec + from kubernetes.client import V1ObjectMeta + from kubernetes.client import V1Job + from kubernetes.client import V1JobSpec + from kubernetes.client import V1DeleteOptions + from kubernetes.client.rest import ApiException + from kubernetes.config import load_kube_config +except ImportError: + BatchV1Api = None + V1EnvVar = None + V1Container = None + V1PodTemplateSpec = None + V1PodSpec = None + V1ObjectMeta = None + V1Job = None + V1JobSpec = None + V1DeleteOptions = None + ApiException = None + load_kube_config = None + +if TYPE_CHECKING: + from typing import List + +__all__ = [ + 'KubeTaskProperties', + 'KubeTaskWorkerHandler', +] + + +class KubeTaskProperties(object): + """ + Object for describing kubernetes job properties for a task worker. + """ + + def __init__( + self, + name, # type: str + namespace='default', + # FIXME: Use PortableOptions.environment_config? + container='apache/beam_python3.8_sdk:2.26.0.dev', + command=('python', '-m', + 'apache_beam.runners.worker.task_worker.task_worker_main') + ): + # type: (...) -> None + self.name = name + self.namespace = namespace + self.container = container + self.command = command + + +class KubeJobManager(object): + """ + Monitors jobs submitted by the KubeTaskWorkerHandler. + """ + + _thread = None # type: threading.Thread + _lock = threading.Lock() + + def __init__(self): + self._handlers = [] # type: List[KubeTaskWorkerHandler] + + def is_started(self): + """ + Return if the manager is currently running or not. + """ + return self._thread is not None + + def _start(self): + self._thread = threading.Thread(target=self._run) + self._thread.daemon = True + self._thread.start() + + def _run(self, interval=5.0): + while True: + for handler in self._handlers: + # If the handler thinks it's alive but it's not actually, change its + # alive state. + if handler.alive and not handler.is_alive(): + handler.alive = False + time.sleep(interval) + + def watch(self, handler): + # type: (KubeTaskWorkerHandler) -> None + """ + Monitor the passed handler checking periodically that the job is still + running. + """ + if not self.is_started(): + self._start() + self._handlers.append(handler) + + +@TaskWorkerHandler.register_urn('k8s') +class KubeTaskWorkerHandler(TaskWorkerHandler): + """ + The kubernetes task handler. + """ + + _lock = threading.Lock() + _monitor = None # type: KubeJobManager + + api = None # type: BatchV1Api + + @property + def monitor(self): + # type: () -> KubeJobManager + if KubeTaskWorkerHandler._monitor is None: + with KubeJobManager._lock: + KubeTaskWorkerHandler._monitor = KubeJobManager() + return KubeTaskWorkerHandler._monitor + + def is_alive(self): + try: + self.api.read_namespaced_job_status( + self.task_payload.name, self.task_payload.namespace) + except ApiException: + return False + return True + + def create_job(self): + + env = [ + V1EnvVar(name='TASK_WORKER_ID', value=self.worker_id), + V1EnvVar(name='TASK_WORKER_CONTROL_ADDRESS', value=self.control_address), + ] + if self.credentials: + env.extend([ + V1EnvVar(name='TASK_WORKER_CREDENTIALS', value=self.credentials), + ]) + + # Configure Pod template container + container = V1Container( + name=self.task_payload.name, + image=self.task_payload.container, + command=self.task_payload.command, + env=env) + # Create and configure a spec section + template = V1PodTemplateSpec( + metadata=V1ObjectMeta( + labels={'app': self.task_payload.name}), + spec=V1PodSpec(restart_policy='Never', containers=[container])) + # Create the specification of deployment + spec = V1JobSpec( + template=template, + backoff_limit=4) + # Instantiate the job object + job = V1Job( + api_version='batch/v1', + kind='Job', + metadata=V1ObjectMeta(name=self.task_payload.name), + spec=spec) + + return job + + def submit_job(self, job): + # type: (V1Job) -> str + api_response = self.api.create_namespaced_job( + body=job, + namespace=self.task_payload.namespace) + return api_response.metadata.uid + + def delete_job(self): + return self.api.delete_namespaced_job( + name=self.task_payload.name, + namespace=self.task_payload.namespace, + body=V1DeleteOptions( + propagation_policy='Foreground', + grace_period_seconds=5)) + + def start_remote(self): + # type: () -> None + with KubeTaskWorkerHandler._lock: + load_kube_config() + self.api = BatchV1Api() + job = self.create_job() + self.submit_job(job) + self.monitor.watch(self) diff --git a/sdks/python/apache_beam/task/task_worker/k8s/transforms.py b/sdks/python/apache_beam/task/task_worker/k8s/transforms.py new file mode 100644 index 000000000000..b753912330e3 --- /dev/null +++ b/sdks/python/apache_beam/task/task_worker/k8s/transforms.py @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Transforms for kubernetes task workers. +""" + +from __future__ import absolute_import + +import re + +from apache_beam.runners.worker.task_worker.core import BeamTask + + +class KubeTask(BeamTask): + """ + Kubernetes task worker transform. + """ + + urn = 'k8s' + + K8S_NAME_RE = re.compile(r'[a-z][a-zA-Z0-9-.]*') + + def getPayload(self): + """ + Get a task payload for configuring a kubernetes job. + """ + from apache_beam.task.task_worker.k8s.handler import KubeTaskProperties + + name = self.label + name = name[0].lower() + name[1:] + # NOTE: k8s jobs have very restricted names + assert self.K8S_NAME_RE.match(name), 'Kubernetes job name must start ' \ + 'with a lowercase letter and use ' \ + 'only - or . special characters' + # FIXME: How to make this name auto-unique? Default wrapper? + return KubeTaskProperties(name=name) From 64a7285d97bea24bfd505d13ac7189a92fd3d8f5 Mon Sep 17 00:00:00 2001 From: Sam Bourne Date: Mon, 19 Oct 2020 18:30:22 -0700 Subject: [PATCH 11/15] Minor fixes and cleanup --- .../runners/worker/task_worker/core.py | 16 ++--- .../runners/worker/task_worker/example.py | 62 ------------------- .../runners/worker/task_worker/handlers.py | 11 ++-- .../task/task_worker/k8s/__init__.py | 3 - .../task/task_worker/k8s/handler.py | 13 +++- .../task/task_worker/k8s/transforms.py | 2 +- 6 files changed, 26 insertions(+), 81 deletions(-) delete mode 100644 sdks/python/apache_beam/runners/worker/task_worker/example.py diff --git a/sdks/python/apache_beam/runners/worker/task_worker/core.py b/sdks/python/apache_beam/runners/worker/task_worker/core.py index 2b23d9996679..d9129c452c88 100644 --- a/sdks/python/apache_beam/runners/worker/task_worker/core.py +++ b/sdks/python/apache_beam/runners/worker/task_worker/core.py @@ -112,7 +112,7 @@ class BeamTask(beam.PTransform): By default, we use the local task worker, but subclass could specify the type of task worker to use by specifying the ``urn``, and override the - ``getPayload`` method to return meaningful payloads to that type of task + ``get_payload`` method to return meaningful payloads to that type of task worker. """ @@ -123,13 +123,13 @@ class BeamTask(beam.PTransform): # the sdk harness entry point SDK_HARNESS_ENTRY_POINT = TASK_WORKER_SDK_ENTRYPOINT - def __init__(self, fusedXform, wrapper=None, env=None): + def __init__(self, transform, wrapper=None, env=None): # type: (beam.PTransform, Optional[Callable[[Any, Any], Any]], Optional[beam.transforms.environments.Environment]) -> None self._wrapper = wrapper self._env = env - self._fusedXform = fusedXform + self._transform = transform - def getPayload(self): + def get_payload(self): # type: () -> Optional[Any] """ Subclass should implement this to generate payload for TaskableValue. @@ -138,7 +138,7 @@ def getPayload(self): return None @staticmethod - def _hasTaggedOutputs(xform): + def _has_tagged_outputs(xform): # type: (beam.PTransform) -> bool """Checks to see if we have tagged output for the given PTransform.""" if isinstance(xform, beam.core._MultiParDo): @@ -150,16 +150,16 @@ def _hasTaggedOutputs(xform): def expand(self, pcoll): # type: (beam.pvalue.PCollection) -> beam.pvalue.PCollection - payload = self.getPayload() + payload = self.get_payload() result = ( pcoll | 'Wrap' >> beam.ParDo(WrapFn(), urn=self.urn, wrapper=self._wrapper, env=self._env, payload=payload) | 'StartStage' >> beam.Reshuffle() | 'UnWrap' >> beam.ParDo(UnWrapFn()) - | self._fusedXform + | self._transform ) - if self._hasTaggedOutputs(self._fusedXform): + if self._has_tagged_outputs(self._transform): # for xforms that ended up with tagged outputs, we don't want to # add reshuffle, because it will be a stage split point already, # also adding reshuffle would error since we now have a tuple of diff --git a/sdks/python/apache_beam/runners/worker/task_worker/example.py b/sdks/python/apache_beam/runners/worker/task_worker/example.py deleted file mode 100644 index 99a9c2986b8e..000000000000 --- a/sdks/python/apache_beam/runners/worker/task_worker/example.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -Basic graph to test using TaskWorker. -""" -# pytype: skip-file - -from __future__ import absolute_import - -import argparse -import logging - -import apache_beam as beam -from apache_beam.options.pipeline_options import DirectOptions -from apache_beam.options.pipeline_options import PipelineOptions -from apache_beam.options.pipeline_options import SetupOptions - -from apache_beam.runners.worker.task_worker.core import BeamTask - - -class TestFn(beam.DoFn): - - def process(self, element, side): - from apache_beam.runners.worker.task_worker.handlers import TaskableValue - - for s in side: - if isinstance(element, TaskableValue): - value = element.value - else: - value = element - print(value + s) - yield value + s - - -def run(argv=None, save_main_session=True): - """Main entry point; defines and runs the test pipeline.""" - parser = argparse.ArgumentParser() - known_args, pipeline_args = parser.parse_known_args(argv) - - # We use the save_main_session option because one or more DoFn's in this - # workflow rely on global context (e.g., a module imported at module level). - pipeline_options = PipelineOptions(pipeline_args) - pipeline_options.view_as(SetupOptions).save_main_session = save_main_session - pipeline_options.view_as(DirectOptions).direct_running_mode = 'multi_processing' - - # The pipeline will be run on exiting the with block. - with beam.Pipeline(options=pipeline_options) as pipe: - - A = ( - pipe - | 'A' >> beam.Create(range(3)) - ) - - B = ( - pipe - | beam.Create(range(2)) - | BeamTask(beam.ParDo(TestFn(), beam.pvalue.AsList(A)), - wrapper=lambda x, _: x) - ) - - -if __name__ == '__main__': - logging.getLogger().setLevel(logging.INFO) - run() diff --git a/sdks/python/apache_beam/runners/worker/task_worker/handlers.py b/sdks/python/apache_beam/runners/worker/task_worker/handlers.py index d6c7918547d9..c0cc5f677fcc 100644 --- a/sdks/python/apache_beam/runners/worker/task_worker/handlers.py +++ b/sdks/python/apache_beam/runners/worker/task_worker/handlers.py @@ -54,7 +54,6 @@ from apache_beam.runners.portability.fn_api_runner.worker_handlers import BasicProvisionService from apache_beam.runners.portability.fn_api_runner.worker_handlers import ControlConnection from apache_beam.runners.portability.fn_api_runner.worker_handlers import ControlFuture -from apache_beam.runners.portability.fn_api_runner import FnApiRunner from apache_beam.runners.portability.fn_api_runner.worker_handlers import GrpcWorkerHandler from apache_beam.runners.portability.fn_api_runner.worker_handlers import GrpcStateServicer from apache_beam.runners.portability.fn_api_runner.worker_handlers import WorkerHandler @@ -87,8 +86,8 @@ class TaskableValue(object): """ Value that can be distributed to TaskWorkers as tasks. - Has the original value, and TaskProperties that specifies how the task will be - generated. + Has the original value, and TaskProperties that specifies how the task will + be generated. """ def __init__( @@ -338,8 +337,7 @@ def reset(self): class TaskGrpcServer(object): - """ - A collection of grpc servicers that handle communication between a + """A collection of grpc servicers that handle communication between a ``TaskWorker`` and ``TaskWorkerHandler``. Contains three servers: @@ -547,7 +545,8 @@ class ProxyGrpcClientDataChannelFactory(DataChannelFactory): """A factory for ``ProxyGrpcClientDataChannel``. No caching behavior here because we are starting each data channel on - different location.""" + different location. + """ def __init__(self, transmitter_url, credentials=None, worker_id=None): # type: (str, Optional[str], Optional[str]) -> None diff --git a/sdks/python/apache_beam/task/task_worker/k8s/__init__.py b/sdks/python/apache_beam/task/task_worker/k8s/__init__.py index 49168a7b1fe4..6569e3fe5de4 100644 --- a/sdks/python/apache_beam/task/task_worker/k8s/__init__.py +++ b/sdks/python/apache_beam/task/task_worker/k8s/__init__.py @@ -15,7 +15,4 @@ # limitations under the License. # -""" -""" - from __future__ import absolute_import diff --git a/sdks/python/apache_beam/task/task_worker/k8s/handler.py b/sdks/python/apache_beam/task/task_worker/k8s/handler.py index 6f0887818763..55307c488176 100644 --- a/sdks/python/apache_beam/task/task_worker/k8s/handler.py +++ b/sdks/python/apache_beam/task/task_worker/k8s/handler.py @@ -16,8 +16,9 @@ # """ -Kubernetes TaskWorker. +Kubernetes task worker implementation. """ + from __future__ import absolute_import import threading @@ -151,6 +152,10 @@ def is_alive(self): return True def create_job(self): + # type: () -> V1Job + """ + Create a kubernetes job object. + """ env = [ V1EnvVar(name='TASK_WORKER_ID', value=self.worker_id), @@ -187,12 +192,18 @@ def create_job(self): def submit_job(self, job): # type: (V1Job) -> str + """ + Submit a kubernetes job. + """ api_response = self.api.create_namespaced_job( body=job, namespace=self.task_payload.namespace) return api_response.metadata.uid def delete_job(self): + """ + Delete the kubernetes job. + """ return self.api.delete_namespaced_job( name=self.task_payload.name, namespace=self.task_payload.namespace, diff --git a/sdks/python/apache_beam/task/task_worker/k8s/transforms.py b/sdks/python/apache_beam/task/task_worker/k8s/transforms.py index b753912330e3..873783bae2ed 100644 --- a/sdks/python/apache_beam/task/task_worker/k8s/transforms.py +++ b/sdks/python/apache_beam/task/task_worker/k8s/transforms.py @@ -35,7 +35,7 @@ class KubeTask(BeamTask): K8S_NAME_RE = re.compile(r'[a-z][a-zA-Z0-9-.]*') - def getPayload(self): + def get_payload(self): """ Get a task payload for configuring a kubernetes job. """ From e7f688dfb991f12fc750922ea91e34f9cd6d5379 Mon Sep 17 00:00:00 2001 From: Viola Lyu Date: Tue, 20 Oct 2020 16:44:57 -0700 Subject: [PATCH 12/15] Fix unittest because of task worker module refactor --- .../worker/task_worker/task_worker_test.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/task_worker/task_worker_test.py b/sdks/python/apache_beam/runners/worker/task_worker/task_worker_test.py index 0b61aeb11e1d..ca13f9e1aba1 100644 --- a/sdks/python/apache_beam/runners/worker/task_worker/task_worker_test.py +++ b/sdks/python/apache_beam/runners/worker/task_worker/task_worker_test.py @@ -35,14 +35,14 @@ from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.portability.api import beam_task_worker_pb2 from apache_beam.portability.api import endpoints_pb2 -from apache_beam.runners.worker import task_worker +from apache_beam.runners.worker.task_worker import handlers from apache_beam.runners.worker.data_plane import GrpcClientDataChannelFactory from apache_beam.runners.worker.sdk_worker import CachingStateHandler from apache_beam.transforms import window # -- utilities for testing, mocking up test objects -class _MockBundleProcessorTaskWorker(task_worker.BundleProcessorTaskWorker): +class _MockBundleProcessorTaskWorker(handlers.BundleProcessorTaskWorker): """ A mocked version of BundleProcessorTaskWorker, responsible for recording the requests it received, and provide response to each request type by a passed @@ -65,8 +65,8 @@ def do_instruction(self, request): return self.responsesByRequestType.get(request_type) -@task_worker.TaskWorkerHandler.register_urn('unittest') -class _MockTaskWorkerHandler(task_worker.TaskWorkerHandler): +@handlers.TaskWorkerHandler.register_urn('unittest') +class _MockTaskWorkerHandler(handlers.TaskWorkerHandler): """ Register a mocked version of task handler only used for "unittest"; will start a ``_MockBundleProcessorTaskWorker`` for each discovered TaskableValue. @@ -112,7 +112,7 @@ def start_worker(self): return self.start_remote() -class _MockTaskGrpcServer(task_worker.TaskGrpcServer): +class _MockTaskGrpcServer(handlers.TaskGrpcServer): """ Mocked version of TaskGrpcServer, using mocked version of data channel factory and cache handler. @@ -363,9 +363,9 @@ def test_process_normally_without_task_worker(self): # so it will be recorded in the `decoded` list self.assertEquals(mocked_op.decoded, test_elems) - @mock.patch.object(task_worker, 'MAX_TASK_WORKER_RETRY', 2) + @mock.patch.object(handlers, 'MAX_TASK_WORKER_RETRY', 2) @mock.patch('__main__._MockTaskWorkerHandler.execute', - side_effect=task_worker.TaskWorkerProcessBundleError('test')) + side_effect=handlers.TaskWorkerProcessBundleError('test')) @mock.patch('__main__._MockTaskWorkerHandler.start_worker') @mock.patch('__main__._MockTaskWorkerHandler.stop_worker') def test_exceed_max_retries(self, unused_mock_stop, unused_mock_start, @@ -379,7 +379,7 @@ def test_exceed_max_retries(self, unused_mock_stop, unused_mock_start, test_coder = self._get_test_pickle_coder() test_elems = self._prep_elements( - [task_worker.TaskableValue(i, 'unittest') for i in range(2)] + [handlers.TaskableValue(i, 'unittest') for i in range(2)] ) instruction_id = 'test_instruction_3' transform_id = 'test_transform_1' From 109eb034dcc05a430ddfcd06c212cd679d301ae0 Mon Sep 17 00:00:00 2001 From: Viola Lyu Date: Wed, 21 Oct 2020 13:57:45 -0700 Subject: [PATCH 13/15] Cleanup, update unittest --- .../runners/worker/bundle_processor.py | 7 +- .../runners/worker/task_worker/handlers.py | 7 +- .../worker/task_worker/task_worker_test.py | 154 +++--------------- 3 files changed, 28 insertions(+), 140 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 568307c0bc41..074ba67046d5 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -1020,9 +1020,8 @@ def maybe_process_remotely(self, # type: (...) -> Union[Tuple[None, None], Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool]] """Process the current bundle remotely with task workers, if applicable. - Processes remotely if ``wrapped_values`` is not None (meaning there are - TaskableValue detected from input of this bundle) and task worker is - allowed to be used. + Processes remotely if there are TaskableValues detected from the input of this bundle + and task workers are allowed to be used. """ from apache_beam.runners.worker.task_worker.handlers import BundleProcessorTaskHelper from apache_beam.runners.worker.task_worker.handlers import get_taskable_value @@ -1073,7 +1072,7 @@ def maybe_process_remotely(self, # fallback to process it as normal, trigger receivers to process with input_op.splitting_lock: if input_op.index == input_op.stop - 1: - return + return None, None input_op.index += 1 input_op.output(decoded_value) diff --git a/sdks/python/apache_beam/runners/worker/task_worker/handlers.py b/sdks/python/apache_beam/runners/worker/task_worker/handlers.py index c0cc5f677fcc..bf320d5652e1 100644 --- a/sdks/python/apache_beam/runners/worker/task_worker/handlers.py +++ b/sdks/python/apache_beam/runners/worker/task_worker/handlers.py @@ -86,8 +86,9 @@ class TaskableValue(object): """ Value that can be distributed to TaskWorkers as tasks. - Has the original value, and TaskProperties that specifies how the task will - be generated. + Has the original value, urn, env and payload that specifies how the task will + be generated. urn is used for getting the correct type of TaskWorkerHandler to + handle processing the current value. """ def __init__( @@ -165,7 +166,7 @@ def __init__( if worker_id: self.worker_id = worker_id - self.control_address = self.port_from_worker(self._grpc_server.control_port) + self.control_address = self._grpc_server.control_address self.provision_address = self.control_address self.logging_address = self.port_from_worker( self.get_port_from_env_var('LOGGING_API_SERVICE_DESCRIPTOR')) diff --git a/sdks/python/apache_beam/runners/worker/task_worker/task_worker_test.py b/sdks/python/apache_beam/runners/worker/task_worker/task_worker_test.py index ca13f9e1aba1..8df68a6171e8 100644 --- a/sdks/python/apache_beam/runners/worker/task_worker/task_worker_test.py +++ b/sdks/python/apache_beam/runners/worker/task_worker/task_worker_test.py @@ -20,25 +20,21 @@ from __future__ import division from __future__ import print_function +import os import logging -import mock import threading import unittest -from builtins import range -from collections import defaultdict -import grpc -from future.utils import raise_ +from google.protobuf import text_format -from apache_beam.coders import coders, coder_impl from apache_beam.portability.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_provision_api_pb2 from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.portability.api import beam_task_worker_pb2 from apache_beam.portability.api import endpoints_pb2 from apache_beam.runners.worker.task_worker import handlers from apache_beam.runners.worker.data_plane import GrpcClientDataChannelFactory from apache_beam.runners.worker.sdk_worker import CachingStateHandler -from apache_beam.transforms import window # -- utilities for testing, mocking up test objects @@ -86,6 +82,7 @@ def __init__(self, worker_id=None, # type: Optional[str] responseByRequestType=None ): + provision_info = beam_provision_api_pb2.ProvisionInfo() super(_MockTaskWorkerHandler, self).__init__(state, provision_info, grpc_server, environment, task_payload, @@ -121,11 +118,16 @@ class _MockTaskGrpcServer(handlers.TaskGrpcServer): def __init__(self, instruction_id, max_workers=1, data_store=None): dummy_state_handler = _MockCachingStateHandler(None, None) dummy_data_channel_factory = GrpcClientDataChannelFactory() - + provision_info = beam_provision_api_pb2.ProvisionInfo() super(_MockTaskGrpcServer, self).__init__(dummy_state_handler, max_workers, data_store or {}, dummy_data_channel_factory, - instruction_id) + instruction_id, provision_info) + self.control_address = 'localhost:{}'.format(self.control_port) + control_descriptor = text_format.MessageToString( + endpoints_pb2.ApiServiceDescriptor(url=self.control_address)) + print(control_descriptor) + os.environ['CONTROL_API_SERVICE_DESCRIPTOR'] = control_descriptor class _MockCachingStateHandler(CachingStateHandler): @@ -138,8 +140,7 @@ def __init__(self, underlying_state, global_state_cache): self._underlying = underlying_state self._state_cache = global_state_cache self._context = threading.local() - - self._context.cache_token = '' + self._context.bundle_cache_token = '' class _MockDataInputOperation(object): @@ -195,6 +196,12 @@ def prep_bundle_processor_descriptor(bundle_id): class TaskWorkerHandlerTest(unittest.TestCase): + def setUp(self): + # put endpoints environment variable in for testing + logging_descriptor = text_format.MessageToString( + endpoints_pb2.ApiServiceDescriptor(url='localhost:10000')) + os.environ['LOGGING_API_SERVICE_DESCRIPTOR'] = logging_descriptor + @staticmethod def _get_task_worker_handler(worker_id, resp_by_type, instruction_id, max_workers=1, data_store=None): @@ -218,7 +225,7 @@ def test_execute_success(self): test_handler = self._get_task_worker_handler(worker_id, resp_by_type, instruction_id) - proxy_data_channel_factory = task_worker.ProxyGrpcClientDataChannelFactory( + proxy_data_channel_factory = handlers.ProxyGrpcClientDataChannelFactory( test_handler._grpc_server.data_address ) @@ -267,14 +274,14 @@ def test_execute_failure(self): test_handler = self._get_task_worker_handler(worker_id, resp_by_type, instruction_id) - proxy_data_channel_factory = task_worker.ProxyGrpcClientDataChannelFactory( + proxy_data_channel_factory = handlers.ProxyGrpcClientDataChannelFactory( test_handler._grpc_server.data_address ) test_handler.start_worker() try: - with self.assertRaises(task_worker.TaskWorkerProcessBundleError): + with self.assertRaises(handlers.TaskWorkerProcessBundleError): print(test_handler.execute) test_handler.execute( proxy_data_channel_factory, @@ -285,125 +292,6 @@ def test_execute_failure(self): test_handler._grpc_server.close() -class BundleProcessorTaskHelperTest(unittest.TestCase): - - @staticmethod - def _get_test_int_coder(): - return coders.WindowedValueCoder(coders.VarIntCoder(), - coders.GlobalWindowCoder()) - - @staticmethod - def _get_test_pickle_coder(): - return coders.WindowedValueCoder(coders.FastPrimitivesCoder(), - coders.GlobalWindowCoder()) - - @staticmethod - def _prep_elements(elements): - return [window.GlobalWindows.windowed_value(elem) for elem in elements] - - @staticmethod - def _prep_encoded_data(coder, elements, instruction_id, transform_id): - temp_out = coder_impl.create_OutputStream() - raw_bytes = [] - - for elem in elements: - encoded = coder.encode(elem) - raw_bytes.append(encoded) - coder.get_impl().encode_to_stream(elem, temp_out, True) - - data = beam_fn_api_pb2.Elements.Data( - instruction_id=instruction_id, - transform_id=transform_id, - data=temp_out.get() - ) - return raw_bytes, data - - def test_data_split_with_task_worker(self): - """ - Test that input data is split correctly by BundleProcessorTaskHelper. - """ - test_coder = self._get_test_pickle_coder() - mocked_op = _MockDataInputOperation(test_coder) - - test_elems = self._prep_elements( - [task_worker.TaskableValue(i, 'unittest') for i in range(5)]) - instruction_id = 'test_instruction_1' - transform_id = 'test_transform_1' - - test_task_helper = task_worker.BundleProcessorTaskHelper(instruction_id) - raw_bytes, data = self._prep_encoded_data(test_coder, test_elems, - instruction_id, transform_id) - - test_task_helper.process_encoded(mocked_op, data) - self.assertEquals(mocked_op.decoded, []) - expected_wrapped_values = defaultdict(list) - for decode, raw in zip(test_elems, raw_bytes): - expected_wrapped_values[transform_id].append((decode.value, raw)) - self.assertItemsEqual(test_task_helper.wrapped_values, expected_wrapped_values) - - def test_process_normally_without_task_worker(self): - """ - Test that when input data doesn't consists of TaskableValue, it is processed - not using task worker but normally via DataInputOperation's process. - """ - test_coder = self._get_test_int_coder() - mocked_op = _MockDataInputOperation(test_coder) - - test_elems = self._prep_elements(range(5)) - instruction_id = 'test_instruction_2' - transform_id = 'test_transform_1' - - test_task_helper = task_worker.BundleProcessorTaskHelper(instruction_id) - _, data = self._prep_encoded_data(test_coder, test_elems, instruction_id, - transform_id) - - test_task_helper.process_encoded(mocked_op, data) - - # when processed normally, it will use the DataInputOperation to process, - # so it will be recorded in the `decoded` list - self.assertEquals(mocked_op.decoded, test_elems) - - @mock.patch.object(handlers, 'MAX_TASK_WORKER_RETRY', 2) - @mock.patch('__main__._MockTaskWorkerHandler.execute', - side_effect=handlers.TaskWorkerProcessBundleError('test')) - @mock.patch('__main__._MockTaskWorkerHandler.start_worker') - @mock.patch('__main__._MockTaskWorkerHandler.stop_worker') - def test_exceed_max_retries(self, unused_mock_stop, unused_mock_start, - mock_execute): - """ - Test the scenario when task worker fails exceed max retries. - """ - test_coder = self._get_test_pickle_coder() - mocked_op = _MockDataInputOperation(test_coder) - - test_coder = self._get_test_pickle_coder() - - test_elems = self._prep_elements( - [handlers.TaskableValue(i, 'unittest') for i in range(2)] - ) - instruction_id = 'test_instruction_3' - transform_id = 'test_transform_1' - - test_task_helper = task_worker.BundleProcessorTaskHelper(instruction_id) - _, data = self._prep_encoded_data(test_coder, test_elems, instruction_id, - transform_id) - test_task_helper.process_encoded(mocked_op, data) - - dummy_process_bundle_descriptor = prep_bundle_processor_descriptor(1) - dummy_data_channel_factory = GrpcClientDataChannelFactory() - dummy_state_handler = _MockCachingStateHandler(None, None) - - with self.assertRaises(RuntimeError): - test_task_helper.process_bundle_with_task_workers( - dummy_state_handler, - dummy_data_channel_factory, - dummy_process_bundle_descriptor - ) - - # num(elems) * MAX_TASK_WORKER_RETRY - self.assertEquals(mock_execute.call_count, 4) - - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) unittest.main() From 470d4a4525207de5b6bed974113811b1e00d8c05 Mon Sep 17 00:00:00 2001 From: Sam Bourne Date: Wed, 21 Oct 2020 16:39:12 -0700 Subject: [PATCH 14/15] Update kubernetes task worker so the k8s job object is part of the KubeTask payload; misc cleanup and organization --- .../apache_beam/examples/taskworker/k8s.py | 19 ++- .../runners/worker/task_worker/core.py | 12 +- .../task/task_worker/k8s/transforms.py | 51 ------ .../task_worker/{k8s => kubejob}/__init__.py | 0 .../task_worker/{k8s => kubejob}/handler.py | 154 ++++++++---------- .../task/task_worker/kubejob/transforms.py | 107 ++++++++++++ .../apache_beam/transforms/environments.py | 29 ++-- sdks/python/setup.py | 5 +- 8 files changed, 211 insertions(+), 166 deletions(-) delete mode 100644 sdks/python/apache_beam/task/task_worker/k8s/transforms.py rename sdks/python/apache_beam/task/task_worker/{k8s => kubejob}/__init__.py (100%) rename sdks/python/apache_beam/task/task_worker/{k8s => kubejob}/handler.py (50%) create mode 100644 sdks/python/apache_beam/task/task_worker/kubejob/transforms.py diff --git a/sdks/python/apache_beam/examples/taskworker/k8s.py b/sdks/python/apache_beam/examples/taskworker/k8s.py index bdf9082595be..749c1915a7e2 100644 --- a/sdks/python/apache_beam/examples/taskworker/k8s.py +++ b/sdks/python/apache_beam/examples/taskworker/k8s.py @@ -15,6 +15,12 @@ # limitations under the License. # +""" +Sample pipeline showcasing how to use the kubernetes task worker. +""" + +from __future__ import absolute_import + import argparse import logging from typing import TYPE_CHECKING @@ -23,12 +29,11 @@ from apache_beam.options.pipeline_options import DirectOptions from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import SetupOptions - -from apache_beam.task.task_worker.k8s.transforms import KubeTask +from apache_beam.task.task_worker.kubejob.transforms import KubeTask if TYPE_CHECKING: from typing import Any - from apache_beam.task.task_worker.k8s.handler import KubeTaskProperties + from apache_beam.task.task_worker.kubejob.handler import KubePayload def process(element): @@ -43,14 +48,12 @@ def process(element): def _per_element_wrapper(element, payload): - # type: (Any, KubeTaskProperties) -> KubeTaskProperties + # type: (Any, KubePayload) -> KubePayload """ Callback to modify the kubernetes job properties per-element. """ - import copy - result = copy.copy(payload) - result.name += '-{}'.format(element) - return result + payload.job_name += '-{}'.format(element) + return payload def run(argv=None, save_main_session=True): diff --git a/sdks/python/apache_beam/runners/worker/task_worker/core.py b/sdks/python/apache_beam/runners/worker/task_worker/core.py index d9129c452c88..f5346e23ff8c 100644 --- a/sdks/python/apache_beam/runners/worker/task_worker/core.py +++ b/sdks/python/apache_beam/runners/worker/task_worker/core.py @@ -19,6 +19,7 @@ TaskableValue for beam task workers. """ +import copy from typing import TYPE_CHECKING import apache_beam as beam @@ -26,6 +27,7 @@ if TYPE_CHECKING: from typing import Any, Optional, Callable, Iterator + from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.transforms.environments import Environment @@ -60,7 +62,7 @@ def process( # override payload if given a wrapper function, which will vary per # element if wrapper: - payload = wrapper(element, payload) + payload = wrapper(element, copy.deepcopy(payload)) if payload: result = TaskableValue(element, urn, env=env, payload=payload) @@ -129,8 +131,8 @@ def __init__(self, transform, wrapper=None, env=None): self._env = env self._transform = transform - def get_payload(self): - # type: () -> Optional[Any] + def get_payload(self, options): + # type: (PipelineOptions) -> Optional[Any] """ Subclass should implement this to generate payload for TaskableValue. Default to None. @@ -150,11 +152,11 @@ def _has_tagged_outputs(xform): def expand(self, pcoll): # type: (beam.pvalue.PCollection) -> beam.pvalue.PCollection - payload = self.get_payload() + payload = self.get_payload(pcoll.pipeline.options) result = ( pcoll | 'Wrap' >> beam.ParDo(WrapFn(), urn=self.urn, wrapper=self._wrapper, - env=self._env, payload=payload) + env=self._env, payload=payload) | 'StartStage' >> beam.Reshuffle() | 'UnWrap' >> beam.ParDo(UnWrapFn()) | self._transform diff --git a/sdks/python/apache_beam/task/task_worker/k8s/transforms.py b/sdks/python/apache_beam/task/task_worker/k8s/transforms.py deleted file mode 100644 index 873783bae2ed..000000000000 --- a/sdks/python/apache_beam/task/task_worker/k8s/transforms.py +++ /dev/null @@ -1,51 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Transforms for kubernetes task workers. -""" - -from __future__ import absolute_import - -import re - -from apache_beam.runners.worker.task_worker.core import BeamTask - - -class KubeTask(BeamTask): - """ - Kubernetes task worker transform. - """ - - urn = 'k8s' - - K8S_NAME_RE = re.compile(r'[a-z][a-zA-Z0-9-.]*') - - def get_payload(self): - """ - Get a task payload for configuring a kubernetes job. - """ - from apache_beam.task.task_worker.k8s.handler import KubeTaskProperties - - name = self.label - name = name[0].lower() + name[1:] - # NOTE: k8s jobs have very restricted names - assert self.K8S_NAME_RE.match(name), 'Kubernetes job name must start ' \ - 'with a lowercase letter and use ' \ - 'only - or . special characters' - # FIXME: How to make this name auto-unique? Default wrapper? - return KubeTaskProperties(name=name) diff --git a/sdks/python/apache_beam/task/task_worker/k8s/__init__.py b/sdks/python/apache_beam/task/task_worker/kubejob/__init__.py similarity index 100% rename from sdks/python/apache_beam/task/task_worker/k8s/__init__.py rename to sdks/python/apache_beam/task/task_worker/kubejob/__init__.py diff --git a/sdks/python/apache_beam/task/task_worker/k8s/handler.py b/sdks/python/apache_beam/task/task_worker/kubejob/handler.py similarity index 50% rename from sdks/python/apache_beam/task/task_worker/k8s/handler.py rename to sdks/python/apache_beam/task/task_worker/kubejob/handler.py index 55307c488176..a14716b2908e 100644 --- a/sdks/python/apache_beam/task/task_worker/k8s/handler.py +++ b/sdks/python/apache_beam/task/task_worker/kubejob/handler.py @@ -21,70 +21,60 @@ from __future__ import absolute_import +import copy import threading import time from typing import TYPE_CHECKING +from typing import NamedTuple from apache_beam.runners.worker.task_worker.handlers import TaskWorkerHandler +# This module will be imported by the task worker handler plugin system +# regardless of whether the kubernetes API is installed. It must be safe to +# import whether it will be used or not. try: - from kubernetes.client import BatchV1Api - from kubernetes.client import V1EnvVar - from kubernetes.client import V1Container - from kubernetes.client import V1PodTemplateSpec - from kubernetes.client import V1PodSpec - from kubernetes.client import V1ObjectMeta - from kubernetes.client import V1Job - from kubernetes.client import V1JobSpec - from kubernetes.client import V1DeleteOptions + import kubernetes.client as client + import kubernetes.config as config from kubernetes.client.rest import ApiException - from kubernetes.config import load_kube_config except ImportError: - BatchV1Api = None - V1EnvVar = None - V1Container = None - V1PodTemplateSpec = None - V1PodSpec = None - V1ObjectMeta = None - V1Job = None - V1JobSpec = None - V1DeleteOptions = None + client = None + config = None ApiException = None - load_kube_config = None if TYPE_CHECKING: from typing import List __all__ = [ - 'KubeTaskProperties', 'KubeTaskWorkerHandler', ] -class KubeTaskProperties(object): +class KubePayload(object): """ - Object for describing kubernetes job properties for a task worker. + Object for holding attributes for a kubernetes job. """ - def __init__( - self, - name, # type: str - namespace='default', - # FIXME: Use PortableOptions.environment_config? - container='apache/beam_python3.8_sdk:2.26.0.dev', - command=('python', '-m', - 'apache_beam.runners.worker.task_worker.task_worker_main') - ): - # type: (...) -> None - self.name = name + def __init__(self, job, namespace='default'): + # type: (client.V1Job, str) -> None + self.job = job self.namespace = namespace - self.container = container - self.command = command + + @property + def job_name(self): + return self.job.metadata.name + + @job_name.setter + def job_name(self, value): + self.job.metadata.name = value + self.job.spec.template.metadata.labels['app'] = value + for container in self.job.spec.template.spec.containers: + container.name = value class KubeJobManager(object): """ - Monitors jobs submitted by the KubeTaskWorkerHandler. + Monitors jobs submitted by the KubeTaskWorkerHandler. Responsible for + notifying `watch`ed handlers if their Kubernetes job is deleted. """ _thread = None # type: threading.Thread @@ -109,7 +99,7 @@ def _run(self, interval=5.0): for handler in self._handlers: # If the handler thinks it's alive but it's not actually, change its # alive state. - if handler.alive and not handler.is_alive(): + if handler.alive and not handler.job_exists(): handler.alive = False time.sleep(interval) @@ -127,13 +117,13 @@ def watch(self, handler): @TaskWorkerHandler.register_urn('k8s') class KubeTaskWorkerHandler(TaskWorkerHandler): """ - The kubernetes task handler. + The Kubernetes task handler. """ _lock = threading.Lock() _monitor = None # type: KubeJobManager - api = None # type: BatchV1Api + api = None # type: client.BatchV1Api @property def monitor(self): @@ -143,79 +133,63 @@ def monitor(self): KubeTaskWorkerHandler._monitor = KubeJobManager() return KubeTaskWorkerHandler._monitor - def is_alive(self): + def job_exists(self): + # type: () -> bool + """ + Return whether or not the Kubernetes job exists. + """ try: self.api.read_namespaced_job_status( - self.task_payload.name, self.task_payload.namespace) + self.task_payload.job.metadata.name, self.task_payload.namespace) except ApiException: return False return True - def create_job(self): - # type: () -> V1Job + def submit_job(self, payload): + # type: (KubePayload) -> client.V1Job """ - Create a kubernetes job object. + Submit a Kubernetes job. """ + # Patch some handler specific env variables into the job + job = copy.deepcopy(payload.job) # type: client.V1Job env = [ - V1EnvVar(name='TASK_WORKER_ID', value=self.worker_id), - V1EnvVar(name='TASK_WORKER_CONTROL_ADDRESS', value=self.control_address), + client.V1EnvVar(name='TASK_WORKER_ID', value=self.worker_id), + client.V1EnvVar(name='TASK_WORKER_CONTROL_ADDRESS', + value=self.control_address), ] if self.credentials: env.extend([ - V1EnvVar(name='TASK_WORKER_CREDENTIALS', value=self.credentials), + client.V1EnvVar(name='TASK_WORKER_CREDENTIALS', + value=self.credentials), ]) - # Configure Pod template container - container = V1Container( - name=self.task_payload.name, - image=self.task_payload.container, - command=self.task_payload.command, - env=env) - # Create and configure a spec section - template = V1PodTemplateSpec( - metadata=V1ObjectMeta( - labels={'app': self.task_payload.name}), - spec=V1PodSpec(restart_policy='Never', containers=[container])) - # Create the specification of deployment - spec = V1JobSpec( - template=template, - backoff_limit=4) - # Instantiate the job object - job = V1Job( - api_version='batch/v1', - kind='Job', - metadata=V1ObjectMeta(name=self.task_payload.name), - spec=spec) - - return job - - def submit_job(self, job): - # type: (V1Job) -> str - """ - Submit a kubernetes job. - """ - api_response = self.api.create_namespaced_job( - body=job, - namespace=self.task_payload.namespace) - return api_response.metadata.uid + for container in job.spec.template.spec.containers: + if container.env is None: + container.env = [] + container.env.extend(env) + + return self.api.create_namespaced_job( + body=job, + namespace=payload.namespace) def delete_job(self): + # type: () -> client.V1Status """ Delete the kubernetes job. """ return self.api.delete_namespaced_job( - name=self.task_payload.name, - namespace=self.task_payload.namespace, - body=V1DeleteOptions( - propagation_policy='Foreground', - grace_period_seconds=5)) + name=self.task_payload.job.metadata.name, + namespace=self.task_payload.namespace, + body=client.V1DeleteOptions( + propagation_policy='Foreground', + grace_period_seconds=5)) def start_remote(self): # type: () -> None with KubeTaskWorkerHandler._lock: - load_kube_config() - self.api = BatchV1Api() - job = self.create_job() - self.submit_job(job) + config.load_kube_config() + self.api = client.BatchV1Api() + + self.submit_job(self.task_payload) self.monitor.watch(self) diff --git a/sdks/python/apache_beam/task/task_worker/kubejob/transforms.py b/sdks/python/apache_beam/task/task_worker/kubejob/transforms.py new file mode 100644 index 000000000000..17a59a5c4566 --- /dev/null +++ b/sdks/python/apache_beam/task/task_worker/kubejob/transforms.py @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Transforms for kubernetes task workers. +""" + +from __future__ import absolute_import + +import re +from typing import TYPE_CHECKING + +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.transforms.environments import DockerEnvironment +from apache_beam.runners.worker.task_worker.core import BeamTask + +try: + import kubernetes.client +except ImportError: + raise ImportError( + 'Kubernetes task worker is not supported by this environment ' + '(could not import kubernetes API).') + +if TYPE_CHECKING: + from apache_beam.task.task_worker.kubejob.handler import KubePayload + +__all__ = [ + 'KubeJobOptions', + 'KubeTask', +] + + +class KubeJobOptions(PipelineOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_argument( + '--namespace', + type=str, + default='default', + help='The namespace to submit kubernetes task worker jobs to.') + + +class KubeTask(BeamTask): + """ + Kubernetes task worker transform. + """ + + urn = 'k8s' + + K8S_JOB_NAME_RE = re.compile(r'[a-z][a-zA-Z0-9-.]*') + + def get_payload(self, options): + # type: (PipelineOptions) -> KubePayload + """ + Get a task payload for configuring a kubernetes job. + """ + from apache_beam.task.task_worker.kubejob.handler import KubePayload + + name = self.label + name = name[0].lower() + name[1:] + if not self.K8S_JOB_NAME_RE.match(name): + raise ValueError( + 'Kubernetes job name must start with a lowercase letter and use ' + 'only - or . special characters') + + image = DockerEnvironment.get_container_image_from_options(options) + + container = kubernetes.client.V1Container( + name=name, + image=image, + command=['python', '-m', self.SDK_HARNESS_ENTRY_POINT], + env=[]) + + template = kubernetes.client.V1PodTemplateSpec( + metadata=kubernetes.client.V1ObjectMeta( + labels={'app': name}), + spec=kubernetes.client.V1PodSpec( + restart_policy='Never', + containers=[container])) + + spec = kubernetes.client.V1JobSpec( + template=template, + backoff_limit=4) + + job = kubernetes.client.V1Job( + api_version='batch/v1', + kind='Job', + metadata=kubernetes.client.V1ObjectMeta(name=name), + spec=spec) + + return KubePayload( + job, + namespace=options.view_as(KubeJobOptions).namespace) diff --git a/sdks/python/apache_beam/transforms/environments.py b/sdks/python/apache_beam/transforms/environments.py index 9d48e892371d..5cf522bd9a99 100644 --- a/sdks/python/apache_beam/transforms/environments.py +++ b/sdks/python/apache_beam/transforms/environments.py @@ -45,6 +45,7 @@ from google.protobuf import message from apache_beam import coders +from apache_beam.options.pipeline_options import PortableOptions from apache_beam.options.pipeline_options import SetupOptions from apache_beam.portability import common_urns from apache_beam.portability import python_urns @@ -55,7 +56,7 @@ from apache_beam.utils import proto_utils if TYPE_CHECKING: - from apache_beam.options.pipeline_options import PortableOptions + from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.runners.pipeline_context import PipelineContext __all__ = [ @@ -276,17 +277,25 @@ def from_runner_api_parameter(payload, # type: beam_runner_api_pb2.DockerPayloa artifacts=artifacts) @classmethod - def from_options(cls, options): - # type: (PortableOptions) -> DockerEnvironment + def get_container_image_from_options(cls, options): + # type: (PipelineOptions) -> str if options.view_as(SetupOptions).prebuild_sdk_container_engine: - prebuilt_container_image = SdkContainerImageBuilder.build_container_image( - options) - return cls.from_container_image( - container_image=prebuilt_container_image, - artifacts=python_sdk_dependencies(options)) + return SdkContainerImageBuilder.build_container_image(options) + portable_options = options.view_as(PortableOptions) + result = portable_options.lookup_environment_option( + 'docker_container_image') + if not result: + result = portable_options.environment_config + if not result: + result = cls.default_docker_image() + return result + + @classmethod + def from_options(cls, options): + # type: (PipelineOptions) -> DockerEnvironment + container_image = cls.get_container_image_from_options(options) return cls.from_container_image( - container_image=options.lookup_environment_option( - 'docker_container_image') or options.environment_config, + container_image=container_image, artifacts=python_sdk_dependencies(options)) @classmethod diff --git a/sdks/python/setup.py b/sdks/python/setup.py index e48dcaefc812..c00f1603b7d7 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -313,7 +313,8 @@ def run(self): 'interactive': INTERACTIVE_BEAM, 'interactive_test': INTERACTIVE_BEAM_TEST, 'aws': AWS_REQUIREMENTS, - 'azure': AZURE_REQUIREMENTS + 'azure': AZURE_REQUIREMENTS, + 'taskworker_k8s': ['kubernetes>=11.0.0'], }, zip_safe=False, # PyPI package information. @@ -336,7 +337,7 @@ def run(self): 'beam_test_plugin = test_config:BeamTestPlugin', ], 'apache_beam_task_workers_plugins': [ - 'k8s_task_worker = apache_beam.task.task_worker.k8s.handler:KubeTaskWorkerHandler', + 'k8s_task_worker = apache_beam.task.task_worker.kubejob.handler:KubeTaskWorkerHandler', ]}, cmdclass={ 'build_py': generate_protos_first(build_py), From 884ffd332a342760bb50c2b64c135dabd300d4f7 Mon Sep 17 00:00:00 2001 From: Viola Lyu Date: Thu, 3 Dec 2020 00:23:18 +0800 Subject: [PATCH 15/15] Moved `use_task_worker` as init arg for `BundleProcessor` --- .../runners/worker/bundle_processor.py | 15 ++++++++------- .../runners/worker/task_worker/handlers.py | 6 +++--- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 074ba67046d5..6c321642f3ab 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -823,7 +823,8 @@ class BundleProcessor(object): def __init__(self, process_bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor state_handler, # type: sdk_worker.CachingStateHandler - data_channel_factory # type: data_plane.DataChannelFactory + data_channel_factory, # type: data_plane.DataChannelFactory + use_task_worker=True # type: bool ): # type: (...) -> None @@ -834,10 +835,12 @@ def __init__(self, a description of the stage that this ``BundleProcessor``is to execute. state_handler (CachingStateHandler). data_channel_factory (``data_plane.DataChannelFactory``). + use_task_worker : whether to engage the task worker system when processing """ self.process_bundle_descriptor = process_bundle_descriptor self.state_handler = state_handler self.data_channel_factory = data_channel_factory + self.use_task_worker = use_task_worker # There is no guarantee that the runner only set # timer_api_service_descriptor when having timers. So this field cannot be @@ -928,8 +931,8 @@ def reset(self): for op in self.ops.values(): op.reset() - def process_bundle(self, instruction_id, use_task_worker=True): - # type: (str, bool) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] + def process_bundle(self, instruction_id): + # type: (str) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] expected_input_ops = [] # type: List[DataInputOperation] @@ -982,8 +985,7 @@ def process_bundle(self, instruction_id, use_task_worker=True): # Process data and timer inputs delayed_applications, requires_finalization = \ self.maybe_process_remotely(data_channels, instruction_id, - input_op_by_transform_id, - use_task_worker=use_task_worker) + input_op_by_transform_id) # Finish all operations. for op in self.ops.values(): @@ -1015,7 +1017,6 @@ def maybe_process_remotely(self, data_channels, # type: DefaultDict[DataChannel, list[str]] instruction_id, # type: str input_op_by_transform_id, # type: Dict[str, DataInputOperation] - use_task_worker=True # type: bool ): # type: (...) -> Union[Tuple[None, None], Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool]] """Process the current bundle remotely with task workers, if applicable. @@ -1043,7 +1044,7 @@ def maybe_process_remotely(self, input_op = input_op_by_transform_id[data.transform_id] # process normally if not using task worker - if use_task_worker is False: + if self.use_task_worker is False: input_op.process_encoded(data.data) continue diff --git a/sdks/python/apache_beam/runners/worker/task_worker/handlers.py b/sdks/python/apache_beam/runners/worker/task_worker/handlers.py index bf320d5652e1..57a62ea542a6 100644 --- a/sdks/python/apache_beam/runners/worker/task_worker/handlers.py +++ b/sdks/python/apache_beam/runners/worker/task_worker/handlers.py @@ -912,7 +912,8 @@ def _request_create(self, request): self._bundle_processor = BundleProcessor( request.process_bundle_descriptor, state_handler, - data_channel_factory + data_channel_factory, + use_task_worker=False ) return beam_task_worker_pb2.TaskInstructionResponse( create=beam_task_worker_pb2.CreateResponse(), @@ -931,8 +932,7 @@ def _request_process_bundle(self, with self._bundle_processor.state_handler._underlying.process_instruction_id( self.worker_id): delayed_applications, require_finalization = \ - self._bundle_processor.process_bundle(self.worker_id, - use_task_worker=False) + self._bundle_processor.process_bundle(self.worker_id) except Exception as e: # we want to propagate the error back to the TaskWorkerHandler, so that # it will raise `TaskWorkerProcessBundleError` which allows for requeue