Skip to content

Commit

Permalink
Ensure client_desires_keys does not corrupt Scheduler state (#8827)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Aug 20, 2024
1 parent 3075b08 commit fe79a36
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 198 deletions.
2 changes: 1 addition & 1 deletion distributed/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _try_bind_worker_client(self):
if not self._client:
try:
self._client = get_client()
self._future = Future(self._key, inform=False)
self._future = Future(self._key, self._client)
# ^ When running on a worker, only hold a weak reference to the key, otherwise the key could become unreleasable.
except ValueError:
self._client = None
Expand Down
54 changes: 25 additions & 29 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def __str__(self) -> str:
result = "\n".join([result, self.msg])
return result

def __reduce__(self):
return self.__class__, (self.key, self.reason, self.msg)


class FuturesCancelledError(CancelledError):
error_groups: list[CancelledFuturesGroup]
Expand Down Expand Up @@ -297,13 +300,12 @@ class Future(WrappedKey):
# Make sure this stays unique even across multiple processes or hosts
_uid = uuid.uuid4().hex

def __init__(self, key, client=None, inform=True, state=None, _id=None):
def __init__(self, key, client=None, state=None, _id=None):
self.key = key
self._cleared = False
self._client = client
self._id = _id or (Future._uid, next(Future._counter))
self._input_state = state
self._inform = inform
self._state = None
self._bind_late()

Expand All @@ -312,13 +314,11 @@ def client(self):
self._bind_late()
return self._client

def bind_client(self, client):
self._client = client
self._bind_late()

def _bind_late(self):
if not self._client:
try:
client = get_client()
except ValueError:
client = None
self._client = client
if self._client and not self._state:
self._client._inc_ref(self.key)
self._generation = self._client.generation
Expand All @@ -328,15 +328,6 @@ def _bind_late(self):
else:
self._state = self._client.futures[self.key] = FutureState(self.key)

if self._inform:
self._client._send_to_scheduler(
{
"op": "client-desires-keys",
"keys": [self.key],
"client": self._client.id,
}
)

if self._input_state is not None:
try:
handler = self._client._state_handlers[self._input_state]
Expand Down Expand Up @@ -588,13 +579,8 @@ def release(self):
except TypeError: # pragma: no cover
pass # Shutting down, add_callback may be None

@staticmethod
def make_future(key, id):
# Can't use kwargs in pickle __reduce__ methods
return Future(key=key, _id=id)

def __reduce__(self) -> str | tuple[Any, ...]:
return Future.make_future, (self.key, self._id)
return Future, (self.key,)

def __dask_tokenize__(self):
return (type(self).__name__, self.key, self._id)
Expand Down Expand Up @@ -2161,7 +2147,7 @@ def submit(

with self._refcount_lock:
if key in self.futures:
return Future(key, self, inform=False)
return Future(key, self)

if allow_other_workers and workers is None:
raise ValueError("Only use allow_other_workers= if using workers=")
Expand Down Expand Up @@ -2661,7 +2647,7 @@ async def _scatter(
timeout=timeout,
)

out = {k: Future(k, self, inform=False) for k in data}
out = {k: Future(k, self) for k in data}
for key, typ in types.items():
self.futures[key].finish(type=typ)

Expand Down Expand Up @@ -2969,12 +2955,14 @@ def list_datasets(self, **kwargs):
async def _get_dataset(self, name, default=no_default):
with self.as_current():
out = await self.scheduler.publish_get(name=name, client=self.id)

if out is None:
if default is no_default:
raise KeyError(f"Dataset '{name}' not found")
else:
return default
for fut in futures_of(out["data"]):
fut.bind_client(self)
self._inform_scheduler_of_futures()
return out["data"]

def get_dataset(self, name, default=no_default, **kwargs):
Expand Down Expand Up @@ -3300,6 +3288,14 @@ def _get_computation_code(

return tuple(reversed(code))

def _inform_scheduler_of_futures(self):
self._send_to_scheduler(
{
"op": "client-desires-keys",
"keys": list(self.refcount),
}
)

def _graph_to_futures(
self,
dsk,
Expand Down Expand Up @@ -3348,7 +3344,7 @@ def _graph_to_futures(
validate_key(key)

# Create futures before sending graph (helps avoid contention)
futures = {key: Future(key, self, inform=False) for key in keyset}
futures = {key: Future(key, self) for key in keyset}
# Circular import
from distributed.protocol import serialize
from distributed.protocol.serialize import ToPickle
Expand Down Expand Up @@ -3507,7 +3503,7 @@ def _optimize_insert_futures(self, dsk, keys):
if not changed:
changed = True
dsk = ensure_dict(dsk)
dsk[key] = Future(key, self, inform=False)
dsk[key] = Future(key, self)

if changed:
dsk, _ = dask.optimization.cull(dsk, keys)
Expand Down Expand Up @@ -6092,7 +6088,7 @@ def futures_of(o, client=None):
stack.extend(x.values())
elif type(x) is SubgraphCallable:
stack.extend(x.dsk.values())
elif isinstance(x, Future):
elif isinstance(x, WrappedKey):
if x not in seen:
seen.add(x)
futures.append(x)
Expand Down
13 changes: 10 additions & 3 deletions distributed/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dask.utils import parse_timedelta

from distributed.client import Future
from distributed.utils import wait_for
from distributed.utils import Deadline, wait_for
from distributed.worker import get_client

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,15 +67,22 @@ def release(self, name=None, client=None):
self.scheduler.client_releases_keys(keys=keys, client="queue-%s" % name)

async def put(self, name=None, key=None, data=None, client=None, timeout=None):
deadline = Deadline.after(timeout)
if key is not None:
while key not in self.scheduler.tasks:
await asyncio.sleep(0.01)
if deadline.expired:
raise TimeoutError(f"Task {key} unknown to scheduler.")

record = {"type": "Future", "value": key}
self.future_refcount[name, key] += 1
self.scheduler.client_desires_keys(keys=[key], client="queue-%s" % name)
else:
record = {"type": "msgpack", "value": data}
await wait_for(self.queues[name].put(record), timeout=timeout)
await wait_for(self.queues[name].put(record), timeout=deadline.remaining)

def future_release(self, name=None, key=None, client=None):
self.scheduler.client_desires_keys(keys=[key], client=client)
self.future_refcount[name, key] -= 1
if self.future_refcount[name, key] == 0:
self.scheduler.client_releases_keys(keys=[key], client="queue-%s" % name)
Expand Down Expand Up @@ -265,7 +272,7 @@ async def _get(self, timeout=None, batch=False):

def process(d):
if d["type"] == "Future":
value = Future(d["value"], self.client, inform=True, state=d["state"])
value = Future(d["value"], self.client, state=d["state"])
if d["state"] == "erred":
value._state.set_error(d["exception"], d["traceback"])
self.client._send_to_scheduler(
Expand Down
12 changes: 5 additions & 7 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,9 +670,7 @@ def clean(self) -> WorkerState:
)
ws._occupancy_cache = self.occupancy

ws.executing = {
ts.key: duration for ts, duration in self.executing.items() # type: ignore
}
ws.executing = {ts.key: duration for ts, duration in self.executing.items()} # type: ignore
return ws

def __repr__(self) -> str:
Expand Down Expand Up @@ -4634,7 +4632,7 @@ def _match_graph_with_tasks(
): # bad key
lost_keys.add(k)
logger.info("User asked for computation on lost data, %s", k)
del dsk[k]
dsk.pop(k, None)
del dependencies[k]
if k in keys:
keys.remove(k)
Expand Down Expand Up @@ -5595,8 +5593,8 @@ def client_desires_keys(self, keys: Collection[Key], client: str) -> None:
for k in keys:
ts = self.tasks.get(k)
if ts is None:
# For publish, queues etc.
ts = self.new_task(k, None, "released")
warnings.warn(f"Client desires key {k!r} but key is unknown.")
continue
if ts.who_wants is None:
ts.who_wants = set()
ts.who_wants.add(cs)
Expand Down Expand Up @@ -9345,7 +9343,7 @@ def transition(
def _materialize_graph(
graph: HighLevelGraph, global_annotations: dict[str, Any], validate: bool
) -> tuple[dict[Key, T_runspec], dict[Key, set[Key]], dict[str, dict[Key, Any]]]:
dsk = ensure_dict(graph)
dsk: dict = ensure_dict(graph)
if validate:
for k in dsk:
validate_key(k)
Expand Down
Loading

0 comments on commit fe79a36

Please sign in to comment.