Skip to content

Commit

Permalink
Add blocked_handlers to servers (#2556)
Browse files Browse the repository at this point in the history
This lets users opt out of handlers, which is particularly useful for security concerns
  • Loading branch information
cicdw authored and mrocklin committed Mar 11, 2019
1 parent ae18f65 commit 8e843cd
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 5 deletions.
19 changes: 17 additions & 2 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def get_total_physical_memory():
return 2e9


def raise_later(exc):
def _raise(*args, **kwargs):
raise exc
return _raise


MAX_BUFFER_SIZE = get_total_physical_memory()

tick_maximum_delay = parse_timedelta(dask.config.get('distributed.admin.tick.limit'), default='ms')
Expand Down Expand Up @@ -89,13 +95,16 @@ class Server(object):
default_ip = ''
default_port = 0

def __init__(self, handlers, stream_handlers=None, connection_limit=512,
def __init__(self, handlers, blocked_handlers=None, stream_handlers=None, connection_limit=512,
deserialize=True, io_loop=None):
self.handlers = {
'identity': self.identity,
'connection_stream': self.handle_stream,
}
self.handlers.update(handlers)
if blocked_handlers is None:
blocked_handlers = dask.config.get('distributed.%s.blocked-handlers' % type(self).__name__.lower(), [])
self.blocked_handlers = blocked_handlers
self.stream_handlers = {}
self.stream_handlers.update(stream_handlers or {})

Expand Down Expand Up @@ -330,7 +339,13 @@ def handle_comm(self, comm, shutting_down=shutting_down):

result = None
try:
handler = self.handlers[op]
if op in self.blocked_handlers:
_msg = ("The '{op}' handler has been explicitly disallowed "
"in {obj}, possibly due to security concerns.")
exc = ValueError(_msg.format(op=op, obj=type(self).__name__))
handler = raise_later(exc)
else:
handler = self.handlers[op]
except KeyError:
logger.warning("No handler %s found in %s", op,
type(self).__name__, exc_info=True)
Expand Down
8 changes: 6 additions & 2 deletions distributed/deploy/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class LocalCluster(Cluster):
Tornado gen.coroutines. This should remain False for normal use.
kwargs: dict
Extra worker arguments, will be passed to the Worker constructor.
blocked_handlers: List[str]
A list of strings specifying a blacklist of handlers to disallow on the Scheduler,
like ``['feed', 'run_function']``
service_kwargs: Dict[str, Dict]
Extra keywords to hand to the running services
security : Security
Expand Down Expand Up @@ -82,7 +85,7 @@ def __init__(self, n_workers=None, threads_per_worker=None, processes=True,
loop=None, start=None, ip=None, scheduler_port=0,
silence_logs=logging.WARN, diagnostics_port=8787,
services=None, worker_services=None, service_kwargs=None,
asynchronous=False, security=None, **worker_kwargs):
asynchronous=False, security=None, blocked_handlers=None, **worker_kwargs):
if start is not None:
msg = ("The start= parameter is deprecated. "
"LocalCluster always starts. "
Expand Down Expand Up @@ -133,7 +136,8 @@ def __init__(self, n_workers=None, threads_per_worker=None, processes=True,

self.scheduler = Scheduler(loop=self.loop,
services=services,
security=security)
security=security,
blocked_handlers=blocked_handlers)
self.scheduler_port = scheduler_port

self.workers = []
Expand Down
9 changes: 9 additions & 0 deletions distributed/deploy/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ def test_simple(loop):
assert e.loop is c.loop


def test_local_cluster_supports_blocked_handlers(loop):
with LocalCluster(blocked_handlers=['run_function'], loop=loop) as c:
with Client(c) as client:
with pytest.raises(ValueError) as exc:
client.run_on_scheduler(lambda x: x, 42)

assert "'run_function' handler has been explicitly disallowed in Scheduler" in str(exc.value)


@pytest.mark.skipif('sys.version_info[0] == 2', reason='fork issues')
def test_close_twice():
with LocalCluster() as cluster:
Expand Down
2 changes: 2 additions & 0 deletions distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ distributed:
scheduler:
allowed-failures: 3 # number of retries before a task is considered bad
bandwidth: 100000000 # 100 MB/s estimated worker-worker bandwidth
blocked-handlers: []
default-data-size: 1000
# Number of seconds to wait until workers or clients are removed from the events log
# after they have been removed from the scheduler
Expand All @@ -22,6 +23,7 @@ distributed:
preload-argv: []

worker:
blocked-handlers: []
multiprocessing-method: forkserver
use-file-locking: True
connections: # Maximum concurrent connections for data
Expand Down
3 changes: 2 additions & 1 deletion distributed/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ServerNode(Node, Server):
# XXX avoid inheriting from Server? there is some large potential for confusion
# between base and derived attribute namespaces...

def __init__(self, handlers=None, stream_handlers=None,
def __init__(self, handlers=None, blocked_handlers=None, stream_handlers=None,
connection_limit=512, deserialize=True,
connection_args=None, io_loop=None, serializers=None,
deserializers=None):
Expand All @@ -42,6 +42,7 @@ def __init__(self, handlers=None, stream_handlers=None,
serializers=serializers,
deserializers=deserializers)
Server.__init__(self, handlers=handlers,
blocked_handlers=blocked_handlers,
stream_handlers=stream_handlers,
connection_limit=connection_limit,
deserialize=deserialize, io_loop=self.io_loop)
Expand Down
20 changes: 20 additions & 0 deletions distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,26 @@ def f():
loop.run_sync(f)


def test_server_raises_on_blocked_handlers(loop):
@gen.coroutine
def f():
server = Server({'ping': pingpong}, blocked_handlers=['ping'])
server.listen(8881)

comm = yield connect(server.address)
yield comm.write({'op': 'ping'})
msg = yield comm.read()

assert 'exception' in msg
assert isinstance(msg['exception'], ValueError)
assert "'ping' handler has been explicitly disallowed" in repr(msg['exception'])

comm.close()
server.stop()

res = loop.run_sync(f)


class MyServer(Server):
default_port = 8756

Expand Down
25 changes: 25 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,31 @@ def test_add_worker(s, a, b):
yield w._close()


@gen_cluster(scheduler_kwargs={'blocked_handlers': ['feed']})
def test_blocked_handlers_are_respected(s, a, b):
def func(scheduler):
return dumps(dict(scheduler.worker_info))

comm = yield connect(s.address)
yield comm.write({'op': 'feed',
'function': dumps(func),
'interval': 0.01})

response = yield comm.read()

assert 'exception' in response
assert isinstance(response['exception'], ValueError)
assert "'feed' handler has been explicitly disallowed" in repr(response['exception'])

yield comm.close()


def test_scheduler_init_pulls_blocked_handlers_from_config():
with dask.config.set({'distributed.scheduler.blocked-handlers': ['test-handler']}):
s = Scheduler()
assert s.blocked_handlers == ['test-handler']


@gen_cluster()
def test_feed(s, a, b):
def func(scheduler):
Expand Down

0 comments on commit 8e843cd

Please sign in to comment.