Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pause receiving while submitting tasks #534

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion ipyparallel/client/asyncresult.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(

self._return_exceptions = return_exceptions

if isinstance(children[0], string_types):
if children and isinstance(children[0], string_types):
self.msg_ids = children
self._children = []
else:
Expand All @@ -96,6 +96,15 @@ def __init__(
self._targets = targets
self.owner = owner

if not children:
# empty result!
self._ready = True
self._success = True
f = Future()
f.set_result([])
self._resolve_result(f)
return

self._ready = False
self._ready_event = Event()
self._output_ready = False
Expand Down
83 changes: 73 additions & 10 deletions ipyparallel/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import warnings
from collections.abc import Iterable
from concurrent.futures import Future
from contextlib import contextmanager
from getpass import getpass
from pprint import pprint
from threading import current_thread
Expand Down Expand Up @@ -990,21 +991,59 @@ def _stop_io_thread(self):
self._io_thread.join()

def _setup_streams(self):
self._query_stream = ZMQStream(self._query_socket, self._io_loop)
self._query_stream.on_recv(self._dispatch_single_reply, copy=False)
self._control_stream = ZMQStream(self._control_socket, self._io_loop)
self._streams = [] # all streams
self._engine_streams = [] # streams that talk to engines
self._query_stream = s = ZMQStream(self._query_socket, self._io_loop)
self._streams.append(s)
self._notification_stream = s = ZMQStream(
self._notification_socket, self._io_loop
)
self._streams.append(s)

self._control_stream = s = ZMQStream(self._control_socket, self._io_loop)
self._streams.append(s)
self._engine_streams.append(s)
self._mux_stream = s = ZMQStream(self._mux_socket, self._io_loop)
self._streams.append(s)
self._engine_streams.append(s)
self._task_stream = s = ZMQStream(self._task_socket, self._io_loop)
self._streams.append(s)
self._engine_streams.append(s)
self._broadcast_stream = s = ZMQStream(self._broadcast_socket, self._io_loop)
self._streams.append(s)
self._engine_streams.append(s)
self._iopub_stream = s = ZMQStream(self._iopub_socket, self._io_loop)
self._streams.append(s)
self._engine_streams.append(s)
self._start_receiving(all=True)

def _start_receiving(self, all=False):
"""Start receiving on streams

default: only engine streams

if all: include hub streams
"""
if all:
self._query_stream.on_recv(self._dispatch_single_reply, copy=False)
self._notification_stream.on_recv(self._dispatch_notification, copy=False)
self._control_stream.on_recv(self._dispatch_single_reply, copy=False)
self._mux_stream = ZMQStream(self._mux_socket, self._io_loop)
self._mux_stream.on_recv(self._dispatch_reply, copy=False)
self._task_stream = ZMQStream(self._task_socket, self._io_loop)
self._task_stream.on_recv(self._dispatch_reply, copy=False)
self._iopub_stream = ZMQStream(self._iopub_socket, self._io_loop)
self._broadcast_stream.on_recv(self._dispatch_reply, copy=False)
self._iopub_stream.on_recv(self._dispatch_iopub, copy=False)
self._notification_stream = ZMQStream(self._notification_socket, self._io_loop)
self._notification_stream.on_recv(self._dispatch_notification, copy=False)

self._broadcast_stream = ZMQStream(self._broadcast_socket, self._io_loop)
self._broadcast_stream.on_recv(self._dispatch_reply, copy=False)
def _stop_receiving(self, all=False):
"""Stop receiving on engine streams

If all: include hub streams
"""
if all:
streams = self._streams
else:
streams = self._engine_streams
for s in streams:
s.stop_on_recv()

def _start_io_thread(self):
"""Start IOLoop in a background thread."""
Expand Down Expand Up @@ -1034,6 +1073,30 @@ def _io_main(self, start_evt=None):
self._io_loop.start()
self._io_loop.close()

@contextmanager
def _pause_results(self):
"""Context manager to pause receiving results

When submitting lots of tasks,
the arrival of results can disrupt the processing
of new submissions.

Threadsafe.
"""
f = Future()

def _stop():
self._stop_receiving()
f.set_result(None)

# use add_callback to make it threadsafe
self._io_loop.add_callback(_stop)
f.result()
try:
yield
finally:
self._io_loop.add_callback(self._start_receiving)

@unpack_message
def _dispatch_single_reply(self, msg):
"""Dispatch single (non-execution) replies"""
Expand Down
4 changes: 4 additions & 0 deletions ipyparallel/client/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def joinPartitions(self, listOfPartitions):
return self.concatenate(listOfPartitions)

def concatenate(self, listOfPartitions):
if len(listOfPartitions) == 0:
return listOfPartitions
testObject = listOfPartitions[0]
# First see if we have a known array type
if is_array(testObject):
Expand All @@ -88,6 +90,8 @@ def getPartition(self, seq, p, q, n=None):
return seq[p:n:q]

def joinPartitions(self, listOfPartitions):
if len(listOfPartitions) == 0:
return listOfPartitions
testObject = listOfPartitions[0]
# First see if we have a known array type
if is_array(testObject):
Expand Down
12 changes: 11 additions & 1 deletion ipyparallel/client/remotefunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,17 @@ def __call__(self, *sequences, **kwargs):

if maxlen == 0:
# nothing to iterate over
return []
if self.block:
return []
else:
return AsyncMapResult(
self.view.client,
[],
self.mapObject,
fname=getname(self.func),
ordered=self.ordered,
return_exceptions=self.return_exceptions,
)

# check that the length of sequences match
if not _mapping and minlen != maxlen:
Expand Down
44 changes: 30 additions & 14 deletions ipyparallel/client/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,11 +578,12 @@ def _really_apply(
pargs = [PrePickled(arg) for arg in args]
pkwargs = {k: PrePickled(v) for k, v in kwargs.items()}

for ident in _idents:
future = self.client.send_apply_request(
self._socket, pf, pargs, pkwargs, track=track, ident=ident
)
futures.append(future)
with self.client._pause_results():
for ident in _idents:
future = self.client.send_apply_request(
self._socket, pf, pargs, pkwargs, track=track, ident=ident
)
futures.append(future)
if track:
trackers = [_.tracker for _ in futures]
else:
Expand Down Expand Up @@ -641,9 +642,16 @@ def map(self, f, *sequences, block=None, track=False, return_exceptions=False):

assert len(sequences) > 0, "must have some sequences to map onto!"
pf = ParallelFunction(
self, f, block=block, track=track, return_exceptions=return_exceptions
self, f, block=False, track=track, return_exceptions=return_exceptions
)
return pf.map(*sequences)
with self.client._pause_results():
ar = pf.map(*sequences)
if block:
try:
return ar.get()
except KeyboardInterrupt:
return ar
return ar

@sync_results
@save_ids
Expand All @@ -665,11 +673,12 @@ def execute(self, code, silent=True, targets=None, block=None):

_idents, _targets = self.client._build_targets(targets)
futures = []
for ident in _idents:
future = self.client.send_execute_request(
self._socket, code, silent=silent, ident=ident
)
futures.append(future)
with self.client._pause_results():
for ident in _idents:
future = self.client.send_execute_request(
self._socket, code, silent=silent, ident=ident
)
futures.append(future)
if isinstance(targets, int):
futures = futures[0]
ar = AsyncResult(
Expand Down Expand Up @@ -1292,12 +1301,19 @@ def map(
pf = ParallelFunction(
self,
f,
block=block,
block=False,
chunksize=chunksize,
ordered=ordered,
return_exceptions=return_exceptions,
)
return pf.map(*sequences)
with self.client._pause_results():
ar = pf.map(*sequences)
if block:
try:
return ar.get()
except KeyboardInterrupt:
return ar
return ar

def imap(
self,
Expand Down