Skip to content

Commit

Permalink
Don't pile up context_meter callbacks (#7961)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Jul 5, 2023
1 parent fca4b35 commit 67e073f
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 24 deletions.
57 changes: 44 additions & 13 deletions distributed/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ class ContextMeter:
A->B comms: network-write 0.567 seconds
"""

_callbacks: ContextVar[list[Callable[[Hashable, float, str], None]]]
_callbacks: ContextVar[dict[Hashable, Callable[[Hashable, float, str], None]]]

def __init__(self):
self._callbacks = ContextVar(f"MetricHook<{id(self)}>._callbacks", default=[])
self._callbacks = ContextVar(f"MetricHook<{id(self)}>._callbacks", default={})

def __reduce__(self):
assert self is context_meter, "Found copy of singleton"
Expand All @@ -204,13 +204,28 @@ def _unpickle_singleton():

@contextmanager
def add_callback(
self, callback: Callable[[Hashable, float, str], None]
self,
callback: Callable[[Hashable, float, str], None],
*,
key: Hashable | None = None,
) -> Iterator[None]:
"""Add a callback when entering the context and remove it when exiting it.
The callback must accept the same parameters as :meth:`digest_metric`.
Parameters
----------
callback: Callable
``f(label, value, unit)`` to be executed
key: Hashable, optional
Unique key for the callback. If two nested calls to ``add_callback`` use the
same key, suppress the outermost callback.
"""
if key is None:
key = object()
cbs = self._callbacks.get()
tok = self._callbacks.set(cbs + [callback])
cbs = cbs.copy()
cbs[key] = callback
tok = self._callbacks.set(cbs)
try:
yield
finally:
Expand All @@ -221,7 +236,7 @@ def digest_metric(self, label: Hashable, value: float, unit: str) -> None:
metric.
"""
cbs = self._callbacks.get()
for cb in cbs:
for cb in cbs.values():
cb(label, value, unit)

@contextmanager
Expand All @@ -234,9 +249,10 @@ def meter(
) -> Iterator[MeterOutput]:
"""Convenience context manager or decorator which calls func() before and after
the wrapped code, calculates the delta, and finally calls :meth:`digest_metric`.
It also subtracts any other calls to :meth:`meter` or :meth:`digest_metric` with
the same unit performed within the context, so that the total is strictly
additive.
If unit=='seconds', it also subtracts any other calls to :meth:`meter` or
:meth:`digest_metric` with the same unit performed within the context, so that
the total is strictly additive.
Parameters
----------
Expand All @@ -256,10 +272,19 @@ def meter(
nested calls to :meth:`meter`, then delta (for seconds only) is reduced by the
inner metrics, to a minimum of ``floor``.
"""
if unit != "seconds":
try:
with meter(func, floor=floor) as m:
yield m
finally:
self.digest_metric(label, m.delta, unit)
return

# If unit=="seconds", subtract time metered from the sub-contexts
offsets = []

def callback(label2: Hashable, value2: float, unit2: str) -> None:
if unit2 == unit == "seconds":
if unit2 == unit:
# This must be threadsafe to support callbacks invoked from
# distributed.utils.offload; '+=' on a float would not be threadsafe!
offsets.append(value2)
Expand Down Expand Up @@ -316,14 +341,20 @@ def __init__(self, func: Callable[[], float] = timemod.perf_counter):
self.start = func()
self.metrics = []

def _callback(self, label: Hashable, value: float, unit: str) -> None:
self.metrics.append((label, value, unit))

@contextmanager
def record(self) -> Iterator[None]:
def record(self, *, key: Hashable | None = None) -> Iterator[None]:
"""Ingest metrics logged with :meth:`ContextMeter.digest_metric` or
:meth:`ContextMeter.meter` and temporarily store them in :ivar:`metrics`.
Parameters
----------
key: Hashable, optional
See :meth:`ContextMeter.add_callback`
"""
with context_meter.add_callback(
lambda label, value, unit: self.metrics.append((label, value, unit))
):
with context_meter.add_callback(self._callback, key=key):
yield

def finalize(
Expand Down
65 changes: 57 additions & 8 deletions distributed/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,25 +77,34 @@ def test_meter_floor(kwargs, delta):


def test_context_meter():
it = iter([123, 124])
it = iter([123, 124, 125, 126])
cbs = []

with metrics.context_meter.add_callback(lambda l, v, u: cbs.append((l, v, u))):
with metrics.context_meter.meter("m1", func=lambda: next(it)) as m:
assert m.start == 123
assert math.isnan(m.stop)
assert math.isnan(m.delta)
with metrics.context_meter.meter("m1", func=lambda: next(it)) as m1:
assert m1.start == 123
assert math.isnan(m1.stop)
assert math.isnan(m1.delta)
with metrics.context_meter.meter("m2", func=lambda: next(it), unit="foo") as m2:
assert m2.start == 125
assert math.isnan(m2.stop)
assert math.isnan(m2.delta)

metrics.context_meter.digest_metric("m1", 2, "seconds")
metrics.context_meter.digest_metric("m1", 1, "foo")

# Not recorded - out of context
metrics.context_meter.digest_metric("m1", 123, "foo")

assert m.start == 123
assert m.stop == 124
assert m.delta == 1
assert m1.start == 123
assert m1.stop == 124
assert m1.delta == 1
assert m2.start == 125
assert m2.stop == 126
assert m2.delta == 1
assert cbs == [
("m1", 1, "seconds"),
("m2", 1, "foo"),
("m1", 2, "seconds"),
("m1", 1, "foo"),
]
Expand Down Expand Up @@ -199,3 +208,43 @@ def test_delayed_metrics_ledger():
("foo", 10, "bytes"),
("other", 20, "seconds"),
]


def test_context_meter_keyed():
cbs = []

def cb(tag, key):
return metrics.context_meter.add_callback(
lambda l, v, u: cbs.append((tag, l)), key=key
)

with cb("x", key="x"), cb("y", key="y"):
metrics.context_meter.digest_metric("l1", 1, "u")
with cb("z", key="x"):
metrics.context_meter.digest_metric("l2", 2, "u")
metrics.context_meter.digest_metric("l3", 3, "u")

assert cbs == [
("x", "l1"),
("y", "l1"),
("z", "l2"),
("y", "l2"),
("x", "l3"),
("y", "l3"),
]


def test_delayed_metrics_ledger_keyed():
l1 = metrics.DelayedMetricsLedger()
l2 = metrics.DelayedMetricsLedger()
l3 = metrics.DelayedMetricsLedger()

with l1.record(key="x"), l2.record(key="y"):
metrics.context_meter.digest_metric("l1", 1, "u")
with l3.record(key="x"):
metrics.context_meter.digest_metric("l2", 2, "u")
metrics.context_meter.digest_metric("l3", 3, "u")

assert l1.metrics == [("l1", 1, "u"), ("l3", 3, "u")]
assert l2.metrics == [("l1", 1, "u"), ("l2", 2, "u"), ("l3", 3, "u")]
assert l3.metrics == [("l2", 2, "u")]
19 changes: 19 additions & 0 deletions distributed/tests/test_worker_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,3 +618,22 @@ async def test_new_metrics_during_heartbeat(c, s, a):
assert a.digests_total["execute", span.id, "x", "test", "test"] == n
assert s.cumulative_worker_metrics["execute", "x", "test", "test"] == n
assert span.cumulative_worker_metrics["execute", "x", "test", "test"] == n


@gen_cluster(
client=True,
nthreads=[("", 1)],
config={"distributed.scheduler.worker-saturation": float("inf")},
)
async def test_delayed_ledger_is_not_reentrant(c, s, a):
"""https://github.com/dask/distributed/issues/7949
Test that, when there's a long chain of task done -> task start events,
the callbacks added by the delayed ledger don't pile up on top of each other.
"""

def f(_):
return len(context_meter._callbacks.get())

out = await c.gather(c.map(f, range(1000)))
assert max(out) < 10
9 changes: 6 additions & 3 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3635,7 +3635,7 @@ def _start_async_instruction( # type: ignore[valid-type]

@wraps(func)
async def wrapper() -> StateMachineEvent:
with ledger.record():
with ledger.record(key="async-instruction"):
return await func(*args, **kwargs)

task = asyncio.create_task(wrapper(), name=task_name)
Expand Down Expand Up @@ -3664,8 +3664,11 @@ def _finish_async_instruction(
logger.exception("async instruction handlers should never raise!")
raise

with ledger.record():
# Capture metric events in _transition_to_memory()
# Capture metric events in _transition_to_memory()
# As this may trigger calls to _start_async_instruction for more tasks,
# make sure we don't endlessly pile up context_meter callbacks by specifying
# the same key as in _start_async_instruction.
with ledger.record(key="async-instruction"):
self.handle_stimulus(stim)

self._finalize_metrics(stim, ledger, span_id)
Expand Down

0 comments on commit 67e073f

Please sign in to comment.