Skip to content

Commit

Permalink
Merge pull request #44 from dabapps/stopping-state
Browse files Browse the repository at this point in the history
Add STOPPING state
  • Loading branch information
j4mie authored Jan 6, 2022
2 parents d90e9fc + e870812 commit e2e5677
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 88 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,14 @@ Jobs have a `state` field which can have one of the following values:
* `NEW` (has been created, waiting for a worker process to run the next task)
* `READY` (has run a task before, awaiting a worker process to run the next task)
* `PROCESSING` (a task is currently being processed by a worker)
* `STOPPING` (the worker process has received a signal from the OS requesting it to exit)
* `COMPLETED` (all job tasks have completed successfully)
* `FAILED` (a job task failed)

#### State diagram

![state diagram](states.png)

### API

#### Model methods
Expand Down
123 changes: 62 additions & 61 deletions django_dbq/management/commands/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,72 +14,13 @@
DEFAULT_QUEUE_NAME = "default"


def process_job(queue_name):
"""This function grabs the next available job for a given queue, and runs its next task."""

with transaction.atomic():
job = Job.objects.get_ready_or_none(queue_name)
if not job:
return

logger.info(
'Processing job: name="%s" queue="%s" id=%s state=%s next_task=%s',
job.name,
queue_name,
job.pk,
job.state,
job.next_task,
)
job.state = Job.STATES.PROCESSING
job.save()

try:
task_function = import_string(job.next_task)
task_function(job)
job.update_next_task()
if not job.next_task:
job.state = Job.STATES.COMPLETE
else:
job.state = Job.STATES.READY
except Exception as exception:
logger.exception("Job id=%s failed", job.pk)
job.state = Job.STATES.FAILED

failure_hook_name = job.get_failure_hook_name()
if failure_hook_name:
logger.info(
"Running failure hook %s for job id=%s", failure_hook_name, job.pk
)
failure_hook_function = import_string(failure_hook_name)
failure_hook_function(job, exception)
else:
logger.info("No failure hook for job id=%s", job.pk)

logger.info(
'Updating job: name="%s" id=%s state=%s next_task=%s',
job.name,
job.pk,
job.state,
job.next_task or "none",
)

try:
job.save()
except:
logger.error(
"Failed to save job: id=%s org=%s",
job.pk,
job.workspace.get("organisation_id"),
)
raise


class Worker:
def __init__(self, name, rate_limit_in_seconds):
self.queue_name = name
self.rate_limit_in_seconds = rate_limit_in_seconds
self.alive = True
self.last_job_finished = None
self.current_job = None
self.init_signals()

def init_signals(self):
Expand All @@ -93,6 +34,9 @@ def init_signals(self):

def shutdown(self, signum, frame):
self.alive = False
if self.current_job:
self.current_job.state = Job.STATES.STOPPING
self.current_job.save(update_fields=["state"])

def run(self):
while self.alive:
Expand All @@ -107,9 +51,66 @@ def process_job(self):
):
return

process_job(self.queue_name)
self._process_job()

self.last_job_finished = timezone.now()

def _process_job(self):
with transaction.atomic():
job = Job.objects.get_ready_or_none(self.queue_name)
if not job:
return

logger.info(
'Processing job: name="%s" queue="%s" id=%s state=%s next_task=%s',
job.name,
self.queue_name,
job.pk,
job.state,
job.next_task,
)
job.state = Job.STATES.PROCESSING
job.save()
self.current_job = job

try:
task_function = import_string(job.next_task)
task_function(job)
job.update_next_task()
if not job.next_task:
job.state = Job.STATES.COMPLETE
else:
job.state = Job.STATES.READY
except Exception as exception:
logger.exception("Job id=%s failed", job.pk)
job.state = Job.STATES.FAILED

failure_hook_name = job.get_failure_hook_name()
if failure_hook_name:
logger.info(
"Running failure hook %s for job id=%s", failure_hook_name, job.pk
)
failure_hook_function = import_string(failure_hook_name)
failure_hook_function(job, exception)
else:
logger.info("No failure hook for job id=%s", job.pk)

logger.info(
'Updating job: name="%s" id=%s state=%s next_task=%s',
job.name,
job.pk,
job.state,
job.next_task or "none",
)

try:
job.save()
except:
logger.exception("Failed to save job: id=%s", job.pk)
raise

self.current_job = None


class Command(BaseCommand):

Expand Down
30 changes: 30 additions & 0 deletions django_dbq/migrations/0006_alter_job_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Generated by Django 3.2rc1 on 2021-11-29 04:48

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("django_dbq", "0005_job_run_after"),
]

operations = [
migrations.AlterField(
model_name="job",
name="state",
field=models.CharField(
choices=[
("NEW", "New"),
("READY", "Ready"),
("PROCESSING", "Processing"),
("STOPPING", "Stopping"),
("FAILED", "Failed"),
("COMPLETE", "Complete"),
],
db_index=True,
default="NEW",
max_length=20,
),
),
]
7 changes: 6 additions & 1 deletion django_dbq/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def delete_old(self):
"""
Delete all jobs older than DELETE_JOBS_AFTER_HOURS
"""
delete_jobs_in_states = [Job.STATES.FAILED, Job.STATES.COMPLETE]
delete_jobs_in_states = [
Job.STATES.FAILED,
Job.STATES.COMPLETE,
Job.STATES.STOPPING,
]
delete_jobs_created_before = timezone.now() - datetime.timedelta(
hours=DELETE_JOBS_AFTER_HOURS
)
Expand Down Expand Up @@ -82,6 +86,7 @@ class STATES(TextChoices):
NEW = "NEW"
READY = "READY"
PROCESSING = "PROCESSING"
STOPPING = "STOPPING"
FAILED = "FAILED"
COMPLETE = "COMPLETE"

Expand Down
68 changes: 42 additions & 26 deletions django_dbq/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from django.test.utils import override_settings
from django.utils import timezone

from django_dbq.management.commands.worker import process_job, Worker
from django_dbq.management.commands.worker import Worker
from django_dbq.models import Job

from io import StringIO
Expand Down Expand Up @@ -123,41 +123,53 @@ def test_queue_depth_for_queue_with_zero_jobs(self):

@freezegun.freeze_time()
@mock.patch("django_dbq.management.commands.worker.sleep")
@mock.patch("django_dbq.management.commands.worker.process_job")
class WorkerProcessProcessJobTestCase(TestCase):
def setUp(self):
super().setUp()
self.MockWorker = mock.MagicMock()
self.MockWorker.queue_name = "default"
self.MockWorker.rate_limit_in_seconds = 5
self.MockWorker.last_job_finished = None
self.mock_worker = mock.MagicMock()
self.mock_worker.queue_name = "default"
self.mock_worker.rate_limit_in_seconds = 5
self.mock_worker.last_job_finished = None

def test_process_job_no_previous_job_run(self, mock_process_job, mock_sleep):
Worker.process_job(self.MockWorker)
def test_process_job_no_previous_job_run(self, mock_sleep):
Worker.process_job(self.mock_worker)
self.assertEqual(mock_sleep.call_count, 1)
self.assertEqual(mock_process_job.call_count, 1)
self.assertEqual(self.MockWorker.last_job_finished, timezone.now())
self.assertEqual(self.mock_worker._process_job.call_count, 1)
self.assertEqual(self.mock_worker.last_job_finished, timezone.now())

def test_process_job_previous_job_too_soon(self, mock_process_job, mock_sleep):
self.MockWorker.last_job_finished = timezone.now() - timezone.timedelta(
def test_process_job_previous_job_too_soon(self, mock_sleep):
self.mock_worker.last_job_finished = timezone.now() - timezone.timedelta(
seconds=2
)
Worker.process_job(self.MockWorker)
Worker.process_job(self.mock_worker)
self.assertEqual(mock_sleep.call_count, 1)
self.assertEqual(mock_process_job.call_count, 0)
self.assertEqual(self.mock_worker._process_job.call_count, 0)
self.assertEqual(
self.MockWorker.last_job_finished,
self.mock_worker.last_job_finished,
timezone.now() - timezone.timedelta(seconds=2),
)

def test_process_job_previous_job_long_time_ago(self, mock_process_job, mock_sleep):
self.MockWorker.last_job_finished = timezone.now() - timezone.timedelta(
def test_process_job_previous_job_long_time_ago(self, mock_sleep):
self.mock_worker.last_job_finished = timezone.now() - timezone.timedelta(
seconds=7
)
Worker.process_job(self.MockWorker)
Worker.process_job(self.mock_worker)
self.assertEqual(mock_sleep.call_count, 1)
self.assertEqual(mock_process_job.call_count, 1)
self.assertEqual(self.MockWorker.last_job_finished, timezone.now())
self.assertEqual(self.mock_worker._process_job.call_count, 1)
self.assertEqual(self.mock_worker.last_job_finished, timezone.now())


@override_settings(JOBS={"testjob": {"tasks": ["a"]}})
class ShutdownTestCase(TestCase):
def test_shutdown_sets_state_to_stopping(self):
job = Job.objects.create(name="testjob")
worker = Worker("default", 1)
worker.current_job = job

worker.shutdown(None, None)

job.refresh_from_db()
self.assertEqual(job.state, Job.STATES.STOPPING)


@override_settings(JOBS={"testjob": {"tasks": ["a"]}})
Expand Down Expand Up @@ -267,7 +279,7 @@ def test_task_sequence(self):
class ProcessJobTestCase(TestCase):
def test_process_job(self):
job = Job.objects.create(name="testjob")
process_job("default")
Worker("default", 1)._process_job()
job = Job.objects.get()
self.assertEqual(job.state, Job.STATES.COMPLETE)

Expand All @@ -276,7 +288,7 @@ def test_process_job_wrong_queue(self):
Processing a different queue shouldn't touch our other job
"""
job = Job.objects.create(name="testjob", queue_name="lol")
process_job("default")
Worker("default", 1)._process_job()
job = Job.objects.get()
self.assertEqual(job.state, Job.STATES.NEW)

Expand Down Expand Up @@ -315,7 +327,7 @@ def test_creation_hook_only_runs_on_create(self):
class JobFailureHookTestCase(TestCase):
def test_failure_hook(self):
job = Job.objects.create(name="testjob")
process_job("default")
Worker("default", 1)._process_job()
job = Job.objects.get()
self.assertEqual(job.state, Job.STATES.FAILED)
self.assertEqual(job.workspace["output"], "failure hook ran")
Expand All @@ -334,14 +346,18 @@ def test_delete_old_jobs(self):
j2.created = two_days_ago
j2.save()

j3 = Job.objects.create(name="testjob", state=Job.STATES.NEW)
j3 = Job.objects.create(name="testjob", state=Job.STATES.STOPPING)
j3.created = two_days_ago
j3.save()

j4 = Job.objects.create(name="testjob", state=Job.STATES.COMPLETE)
j4 = Job.objects.create(name="testjob", state=Job.STATES.NEW)
j4.created = two_days_ago
j4.save()

j5 = Job.objects.create(name="testjob", state=Job.STATES.COMPLETE)

Job.objects.delete_old()

self.assertEqual(Job.objects.count(), 2)
self.assertTrue(j3 in Job.objects.all())
self.assertTrue(j4 in Job.objects.all())
self.assertTrue(j5 in Job.objects.all())
Binary file added states.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit e2e5677

Please sign in to comment.