Skip to content

Commit

Permalink
Make 'steal' command atomic
Browse files Browse the repository at this point in the history
Either unschedule all requested tests, or none if it's not possible -
if some of the requested tests have already been processed by the time
the request arrives. It may happen if the worker runs tests faster than
the controller receives and processes status updates. But in this case
maybe it's just better to let the worker keep running.

This is a prerequisite for group/scope support in worksteal scheduler -
so they won't be broken up incorrectly.

This change could break schedulers that use "steal" command. However:

1) worksteal scheduler doesn't need any adjustments.

2) I'm not aware of any external schedulers relying on this command yet.

So I think it's better to keep the protocol simple, not complicate it for
imaginary compatibility with some unknown and likely non-existent
schedulers.
  • Loading branch information
amezin committed Oct 23, 2024
1 parent 9c24f0f commit aa1324c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 31 deletions.
1 change: 1 addition & 0 deletions changelog/1144.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make "steal" command atomic - make it unschedule either all requested tests or none.
84 changes: 53 additions & 31 deletions src/xdist/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@

from __future__ import annotations

import collections
import contextlib
import enum
import os
import sys
import time
from typing import Any
from typing import Generator
from typing import Iterable
from typing import Literal
from typing import Sequence
from typing import TypedDict
Expand Down Expand Up @@ -66,7 +68,44 @@ def worker_title(title: str) -> None:

class Marker(enum.Enum):
SHUTDOWN = 0
QUEUE_REPLACED = 1


class TestQueue:
"""A simple queue that can be inspected and modified while the lock is held."""

Item = int | Literal[Marker.SHUTDOWN]

def __init__(self, execmodel: execnet.gateway_base.ExecModel):
self._items: collections.deque[TestQueue.Item] = collections.deque()
self._lock = execmodel.RLock() # type: ignore[no-untyped-call]
self._has_items_event = execmodel.Event()

def get(self) -> Item:
while True:
with self.lock() as locked_items:
if locked_items:
return locked_items.popleft()

self._has_items_event.wait()

def put(self, item: Item) -> None:
with self.lock() as locked_items:
locked_items.append(item)

def replace(self, iterable: Iterable[Item]) -> None:
with self.lock():
self._items = collections.deque(iterable)

@contextlib.contextmanager
def lock(self) -> Generator[collections.deque[Item], None, None]:
with self._lock:
try:
yield self._items
finally:
if self._items:
self._has_items_event.set()
else:
self._has_items_event.clear()


class WorkerInteractor:
Expand All @@ -77,22 +116,10 @@ def __init__(self, config: pytest.Config, channel: execnet.Channel) -> None:
self.testrunuid = workerinput["testrunuid"]
self.log = Producer(f"worker-{self.workerid}", enabled=config.option.debug)
self.channel = channel
self.torun = self._make_queue()
self.torun = TestQueue(self.channel.gateway.execmodel)
self.nextitem_index: int | None | Literal[Marker.SHUTDOWN] = None
config.pluginmanager.register(self)

def _make_queue(self) -> Any:
return self.channel.gateway.execmodel.queue.Queue()

def _get_next_item_index(self) -> int | Literal[Marker.SHUTDOWN]:
"""Gets the next item from test queue. Handles the case when the queue
is replaced concurrently in another thread.
"""
result = self.torun.get()
while result is Marker.QUEUE_REPLACED:
result = self.torun.get()
return result # type: ignore[no-any-return]

def sendevent(self, name: str, **kwargs: object) -> None:
self.log("sending", name, kwargs)
self.channel.send((name, kwargs))
Expand Down Expand Up @@ -146,30 +173,25 @@ def handle_command(
self.steal(kwargs["indices"])

def steal(self, indices: Sequence[int]) -> None:
indices_set = set(indices)
stolen = []

old_queue, self.torun = self.torun, self._make_queue()

def old_queue_get_nowait_noraise() -> int | None:
with contextlib.suppress(self.channel.gateway.execmodel.queue.Empty):
return old_queue.get_nowait() # type: ignore[no-any-return]
return None

for i in iter(old_queue_get_nowait_noraise, None):
if i in indices_set:
stolen.append(i)
with self.torun.lock() as locked_queue:
requested_set = set(indices)
stolen = list(item for item in locked_queue if item in requested_set)

# Stealing only if all requested tests are still pending
if len(stolen) == len(requested_set):
self.torun.replace(
item for item in locked_queue if item not in requested_set
)
else:
self.torun.put(i)
stolen = []

self.sendevent("unscheduled", indices=stolen)
old_queue.put(Marker.QUEUE_REPLACED)

@pytest.hookimpl
def pytest_runtestloop(self, session: pytest.Session) -> bool:
self.log("entering main loop")
self.channel.setcallback(self.handle_command, endmarker=Marker.SHUTDOWN)
self.nextitem_index = self._get_next_item_index()
self.nextitem_index = self.torun.get()
while self.nextitem_index is not Marker.SHUTDOWN:
self.run_one_test()
if session.shouldfail or session.shouldstop:
Expand All @@ -179,7 +201,7 @@ def pytest_runtestloop(self, session: pytest.Session) -> bool:
def run_one_test(self) -> None:
assert isinstance(self.nextitem_index, int)
self.item_index = self.nextitem_index
self.nextitem_index = self._get_next_item_index()
self.nextitem_index = self.torun.get()

items = self.session.items
item = items[self.item_index]
Expand Down
4 changes: 4 additions & 0 deletions testing/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ def test_func4(): pass

worker.sendcommand("steal", indices=[1, 2])
ev = worker.popevent("unscheduled")
assert ev.kwargs["indices"] == []

worker.sendcommand("steal", indices=[2])
ev = worker.popevent("unscheduled")
assert ev.kwargs["indices"] == [2]

reports = [
Expand Down

0 comments on commit aa1324c

Please sign in to comment.