Skip to content

Commit

Permalink
Merge pull request ansible#12582 from AlanCoding/clean_and_forget
Browse files Browse the repository at this point in the history
Move reaper logic into worker, avoiding bottlenecks
  • Loading branch information
AlanCoding authored Aug 17, 2022
2 parents 56df3f0 + e0c59d1 commit a72da3b
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 28 deletions.
28 changes: 14 additions & 14 deletions awx/main/dispatch/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from django.conf import settings
from django.db import connection as django_connection, connections
from django.core.cache import cache as django_cache
from django.utils.timezone import now as tz_now
from django_guid import set_guid
from jinja2 import Template
import psutil
Expand Down Expand Up @@ -377,8 +378,6 @@ def cleanup(self):
1. Discover worker processes that exited, and recover messages they
were handling.
2. Clean up unnecessary, idle workers.
3. Check to see if the database says this node is running any tasks
that aren't actually running. If so, reap them.
IMPORTANT: this function is one of the few places in the dispatcher
(aside from setting lookups) where we talk to the database. As such,
Expand Down Expand Up @@ -437,18 +436,17 @@ def cleanup(self):
idx = random.choice(range(len(self.workers)))
self.write(idx, m)

# if the database says a job is running or queued on this node, but it's *not*,
# then reap it
running_uuids = []
for worker in self.workers:
worker.calculate_managed_tasks()
running_uuids.extend(list(worker.managed_tasks.keys()))

# if we are not in the dangerous situation of queue backup then clear old waiting jobs
if self.workers and max(len(w.managed_tasks) for w in self.workers) <= 1:
reaper.reap_waiting(excluded_uuids=running_uuids)

reaper.reap(excluded_uuids=running_uuids)
def add_bind_kwargs(self, body):
bind_kwargs = body.pop('bind_kwargs', [])
body.setdefault('kwargs', {})
if 'dispatch_time' in bind_kwargs:
body['kwargs']['dispatch_time'] = tz_now().isoformat()
if 'worker_tasks' in bind_kwargs:
worker_tasks = {}
for worker in self.workers:
worker.calculate_managed_tasks()
worker_tasks[worker.pid] = list(worker.managed_tasks.keys())
body['kwargs']['worker_tasks'] = worker_tasks

def up(self):
if self.full:
Expand All @@ -463,6 +461,8 @@ def write(self, preferred_queue, body):
if 'guid' in body:
set_guid(body['guid'])
try:
if isinstance(body, dict) and body.get('bind_kwargs'):
self.add_bind_kwargs(body)
# when the cluster heartbeat occurs, clean up internally
if isinstance(body, dict) and 'cluster_node_heartbeat' in body['task']:
self.cleanup()
Expand Down
12 changes: 11 additions & 1 deletion awx/main/dispatch/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,21 @@ def snooze():
@task(queue='tower_broadcast')
def announce():
print("Run this everywhere!")
# The special parameter bind_kwargs tells the main dispatcher process to add certain kwargs
@task(bind_kwargs=['dispatch_time'])
def print_time(dispatch_time=None):
print(f"Time I was dispatched: {dispatch_time}")
"""

def __init__(self, queue=None):
def __init__(self, queue=None, bind_kwargs=None):
self.queue = queue
self.bind_kwargs = bind_kwargs

def __call__(self, fn=None):
queue = self.queue
bind_kwargs = self.bind_kwargs

class PublisherMixin(object):

Expand All @@ -80,6 +88,8 @@ def apply_async(cls, args=None, kwargs=None, queue=None, uuid=None, **kw):
guid = get_guid()
if guid:
obj['guid'] = guid
if bind_kwargs:
obj['bind_kwargs'] = bind_kwargs
obj.update(**kw)
if callable(queue):
queue = queue()
Expand Down
7 changes: 4 additions & 3 deletions awx/main/dispatch/reaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def reap_job(j, status, job_explanation=None):
logger.error(f'{j.log_format} is no longer {status_before}; reaping')


def reap_waiting(instance=None, status='failed', job_explanation=None, grace_period=None, excluded_uuids=None):
def reap_waiting(instance=None, status='failed', job_explanation=None, grace_period=None, excluded_uuids=None, ref_time=None):
"""
Reap all jobs in waiting for this instance.
"""
Expand All @@ -69,8 +69,9 @@ def reap_waiting(instance=None, status='failed', job_explanation=None, grace_per
except RuntimeError as e:
logger.warning(f'Local instance is not registered, not running reaper: {e}')
return
now = tz_now()
jobs = UnifiedJob.objects.filter(status='waiting', modified__lte=now - timedelta(seconds=grace_period), controller_node=me.hostname)
if ref_time is None:
ref_time = tz_now()
jobs = UnifiedJob.objects.filter(status='waiting', modified__lte=ref_time - timedelta(seconds=grace_period), controller_node=me.hostname)
if excluded_uuids:
jobs = jobs.exclude(celery_task_id__in=excluded_uuids)
for j in jobs:
Expand Down
14 changes: 12 additions & 2 deletions awx/main/tasks/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import shutil
import time
from distutils.version import LooseVersion as Version
from datetime import datetime

# Django
from django.conf import settings
Expand Down Expand Up @@ -482,8 +483,8 @@ def inspect_execution_nodes(instance_list):
execution_node_health_check.apply_async([hostname])


@task(queue=get_local_queuename)
def cluster_node_heartbeat():
@task(queue=get_local_queuename, bind_kwargs=['dispatch_time', 'worker_tasks'])
def cluster_node_heartbeat(dispatch_time=None, worker_tasks=None):
logger.debug("Cluster node heartbeat task.")
nowtime = now()
instance_list = list(Instance.objects.all())
Expand Down Expand Up @@ -562,6 +563,15 @@ def cluster_node_heartbeat():
else:
logger.exception('Error marking {} as lost'.format(other_inst.hostname))

# Run local reaper
if worker_tasks is not None:
active_task_ids = []
for task_list in worker_tasks.values():
active_task_ids.extend(task_list)
reaper.reap(instance=this_inst, excluded_uuids=active_task_ids)
if max(len(task_list) for task_list in worker_tasks.values()) <= 1:
reaper.reap_waiting(instance=this_inst, excluded_uuids=active_task_ids, ref_time=datetime.fromisoformat(dispatch_time))


@task(queue=get_local_queuename)
def awx_receptor_workunit_reaper():
Expand Down
10 changes: 2 additions & 8 deletions awx/main/tests/functional/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,7 @@ def test_scale_down(self):
assert len(self.pool) == 10

# cleanup should scale down to 8 workers
with mock.patch('awx.main.dispatch.reaper.reap') as reap:
with mock.patch('awx.main.dispatch.reaper.reap_waiting') as reap:
self.pool.cleanup()
reap.assert_called()
self.pool.cleanup()
assert len(self.pool) == 2

def test_max_scale_up(self):
Expand Down Expand Up @@ -250,10 +247,7 @@ def test_lost_worker_autoscale(self):
time.sleep(1) # wait a moment for sigterm

# clean up and the dead worker
with mock.patch('awx.main.dispatch.reaper.reap') as reap:
with mock.patch('awx.main.dispatch.reaper.reap_waiting') as reap:
self.pool.cleanup()
reap.assert_called()
self.pool.cleanup()
assert len(self.pool) == 1
assert self.pool.workers[0].pid == alive_pid

Expand Down

0 comments on commit a72da3b

Please sign in to comment.