Skip to content

Commit

Permalink
Remove dispatching in TaskCollection
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Oct 23, 2024
1 parent 48509b3 commit c723c21
Showing 1 changed file with 27 additions and 38 deletions.
65 changes: 27 additions & 38 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,17 +960,6 @@ def __init__(self, name: str):
self.states = dict.fromkeys(ALL_TASK_STATES, 0)
self._types = defaultdict(int)

def add(self, other: TaskState) -> None:
self.states[other.state] += 1

def add_duration(self, action: str, start: float, stop: float) -> None:
duration_us = self._calculate_duration_us(start, stop)
self._duration_us += duration_us
self._all_durations_us[action] += duration_us

def add_type(self, typename: str) -> None:
self._types[typename] += 1

@property
def all_durations(self) -> defaultdict[str, float]:
"""Cumulative duration of all completed actions of tasks belonging to this collection, by action"""
Expand All @@ -987,18 +976,11 @@ def duration(self) -> float:
"""The total amount of time spent on all tasks belonging to this collection"""
return self._duration_us / 1e6

def transition(self, old: TaskStateState, new: TaskStateState) -> None:
self.states[old] -= 1
self.states[new] += 1

@property
def types(self) -> Set[str]:
"""The result types of this collection"""
return self._types.keys()

def update_nbytes(self, diff: int) -> None:
self.nbytes_total += diff

@staticmethod
def _calculate_duration_us(start: float, stop: float) -> int:
return max(round((stop - start) * 1e6), 0)
Expand Down Expand Up @@ -1033,7 +1015,7 @@ class TaskPrefix(TaskCollection):
__slots__ = tuple(__annotations__)

def __init__(self, name: str):
super().__init__(name)
TaskCollection.__init__(self, name)
self.state_counts = defaultdict(int)
task_durations = dask.config.get("distributed.scheduler.default-task-durations")
if self.name in task_durations:
Expand All @@ -1050,19 +1032,18 @@ def add_exec_time(self, duration: float) -> None:
self.duration_average = -1

def add_duration(self, action: str, start: float, stop: float) -> None:
super().add_duration(action, start, stop)
duration_s = self._calculate_duration_us(start, stop) / 1e6
duration_us = max(round((stop - start) * 1e6), 0)
self._duration_us += duration_us
self._all_durations_us[action] += duration_us

duration_s = max(round((stop - start) * 1e6), 0) / 1e6
if action == "compute":
old = self.duration_average
if old < 0:
self.duration_average = duration_s
else:
self.duration_average = 0.5 * duration_s + 0.5 * old

def transition(self, old: TaskStateState, new: TaskStateState) -> None:
super().transition(old, new)
self.state_counts[new] += 1

def add_group(self, tg: TaskGroup) -> None:
self._groups[tg] = None

Expand Down Expand Up @@ -1149,7 +1130,7 @@ class TaskGroup(TaskCollection):
__slots__ = tuple(__annotations__)

def __init__(self, name: str, prefix: TaskPrefix):
super().__init__(name)
TaskCollection.__init__(self, name)
self.dependencies = set()
self.start = 0.0
self.stop = 0.0
Expand All @@ -1160,7 +1141,10 @@ def __init__(self, name: str, prefix: TaskPrefix):
prefix.add_group(self)

def add_duration(self, action: str, start: float, stop: float) -> None:
super().add_duration(action, start, stop)
duration_us = max(round((stop - start) * 1e6), 0)
self._duration_us += duration_us
self._all_durations_us[action] += duration_us

if action == "compute":
if self.stop < stop:
self.stop = stop
Expand All @@ -1169,21 +1153,17 @@ def add_duration(self, action: str, start: float, stop: float) -> None:
self.prefix.add_duration(action, start, stop)

def add(self, other: TaskState) -> None:
super().add(other)
self.prefix.add(other)
self.states[other.state] += 1
self.prefix.states[other.state] += 1
other.group = self

def add_type(self, typename: str) -> None:
super().add_type(typename)
self.prefix.add_type(typename)

def transition(self, old: TaskStateState, new: TaskStateState) -> None:
super().transition(old, new)
self.prefix.transition(old, new)
self._types[typename] += 1
self.prefix._types[typename] += 1

def update_nbytes(self, diff: int) -> None:
super().update_nbytes(diff)
self.prefix.update_nbytes(diff)
self.nbytes_total += diff
self.prefix.nbytes_total += diff

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -1506,7 +1486,16 @@ def state(self) -> TaskStateState:

@state.setter
def state(self, value: TaskStateState) -> None:
self.group.transition(self._state, value)
# Note: It would be cleaner to move this to the subclasses but the
# function dispatch is adding notable overhead and this setter is called
# *very* often
gr_st = self.group.states
gr_st[self._state] -= 1
gr_st[value] += 1
pf = self.prefix
pf.states[self._state] -= 1
pf.states[value] += 1
pf.state_counts[value] += 1
self._state = value

def add_dependency(self, other: TaskState) -> None:
Expand Down

0 comments on commit c723c21

Please sign in to comment.