Skip to content

Commit

Permalink
feat: Add wrap-up decorator for managing celery tasks (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
xmnlab authored Aug 18, 2024
1 parent 7ce37d1 commit 41c6fb3
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 38 deletions.
82 changes: 80 additions & 2 deletions src/retsu/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,25 @@

from __future__ import annotations

from typing import Any, Optional
import logging
import time
import uuid

from functools import wraps
from typing import Any, Callable, Optional

import celery
import redis

from celery import chain, chord, group
from public import public

from retsu.core import MultiProcess, SingleProcess
from retsu.core import (
MultiProcess,
RandomSemaphoreManager,
SequenceSemaphoreManager,
SingleProcess,
)


class CeleryProcess:
Expand Down Expand Up @@ -121,3 +132,70 @@ class SingleCeleryProcess(CeleryProcess, SingleProcess):
"""Single Process for Celery."""

...


def limit_random_concurrent_tasks(
max_concurrent_tasks: int,
redis_client: redis.Redis,
) -> Callable[[Any], Any]:
"""Limit the number of concurrent Celery tasks."""

def decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
semaphore_manager = RandomSemaphoreManager(
key=f"celery_task_semaphore_random_{func.__name__}",
max_concurrent_tasks=max_concurrent_tasks,
redis_client=redis_client,
)

@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
# Acquire semaphore slot
acquired = semaphore_manager.acquire()
if not acquired:
logging.info(f"Task {func.__name__} is waiting for a slot...")
while not acquired:
time.sleep(0.01) # Polling interval
acquired = semaphore_manager.acquire()

try:
result = func(*args, **kwargs)
return result
finally:
# Release semaphore slot
semaphore_manager.release()

return wrapper

return decorator


def limit_sequence_concurrent_tasks(
max_concurrent_tasks: int,
redis_client: redis.Redis,
) -> Callable[[Any], Any]:
"""Limit the number of concurrent Celery tasks and maintain FIFO order."""

def decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
semaphore_manager = SequenceSemaphoreManager(
key=f"celery_task_semaphore_sequence_{func.__name__}",
max_concurrent_tasks=max_concurrent_tasks,
redis_client=redis_client,
)

@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
task_id = str(uuid.uuid4()) # Unique identifier for each task

# Acquire semaphore slot with FIFO order
acquired = semaphore_manager.acquire(task_id)
if acquired:
try:
result = func(*args, **kwargs)
return result
finally:
# Release semaphore slot
semaphore_manager.release()

return wrapper

return decorator
86 changes: 85 additions & 1 deletion src/retsu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

import logging
import multiprocessing as mp
import time
import warnings

from abc import abstractmethod
from datetime import datetime
from typing import Any, Optional
from typing import Any, Optional, cast
from uuid import uuid4

import redis
Expand Down Expand Up @@ -177,3 +178,86 @@ def stop(self) -> None:

for task_name, process in self.tasks.items():
process.stop()


class RandomSemaphoreManager:
"""Manages a semaphore using Redis to limit concurrent tasks."""

def __init__(
self, key: str, max_concurrent_tasks: int, redis_client: redis.Redis
):
self.key: str = key
self.max_concurrent_tasks: int = max_concurrent_tasks
self.redis_client: redis.Redis = redis_client

def acquire(self) -> bool:
"""Try to acquire a semaphore slot."""
current_count_tmp = self.redis_client.get(self.key)
current_count = 0

if current_count_tmp is None:
self.redis_client.set(self.key, 0)
else:
# note: Argument 1 to "int" has incompatible type
# "Union[Awaitable[Any], Any]"; expected
# "Union[str, Buffer, SupportsInt, SupportsIndex, SupportsTrunc]"
current_count = int(current_count_tmp) # type: ignore

if current_count < self.max_concurrent_tasks:
self.redis_client.incr(self.key)
return True
return False

def release(self) -> None:
"""Release a semaphore slot."""
self.redis_client.decr(self.key)


class SequenceSemaphoreManager:
"""Manages a semaphore using Redis to limit concurrent tasks."""

def __init__(
self, key: str, max_concurrent_tasks: int, redis_client: redis.Redis
):
self.key: str = key
self.max_concurrent_tasks: int = max_concurrent_tasks
self.redis_client: redis.Redis = redis_client

def acquire(self, task_id: str) -> bool:
"""Try to acquire a semaphore slot and ensure FIFO order."""
task_bid = task_id.encode("utf8")
queue_name = f"{self.key}_queue"

# Add task to the queue
self.redis_client.rpush(queue_name, task_id)

while True:
# Get the list of current tasks in the queue
queue_tasks = self.redis_client.lrange(queue_name, 0, -1)
count_tmp = cast(bytes, self.redis_client.get(self.key) or b"0")
current_count = int(count_tmp)

# Check if the task is in the first `max_concurrent_tasks`
# in the queue
# mypy: Item "Awaitable[List[Any]]" of
# "Union[Awaitable[List[Any]], List[Any]]"
# has no attribute "index"
task_position = queue_tasks.index(task_bid) # type: ignore

if (
task_position < self.max_concurrent_tasks
and current_count < self.max_concurrent_tasks
):
# If a slot is available and the task is within the
# allowed concurrent limit
self.redis_client.incr(self.key)
return True

# If no slot is available or task is not in the allowed
# concurrent tasks, keep waiting
time.sleep(0.1)

def release(self) -> None:
"""Release a semaphore slot and remove the task from the queue."""
self.redis_client.decr(self.key)
self.redis_client.lpop(f"{self.key}_queue")
43 changes: 40 additions & 3 deletions tests/celery_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@

import os
import sys
import time

from datetime import datetime
from time import sleep

import redis

from celery import Celery
from retsu.celery import (
limit_random_concurrent_tasks,
limit_sequence_concurrent_tasks,
)

redis_host: str = os.getenv("RETSU_REDIS_HOST", "localhost")
redis_port: int = int(os.getenv("RETSU_REDIS_PORT", 6379))
Expand All @@ -33,10 +38,10 @@
worker_task_log_format=(
f"{LOG_FORMAT_PREFIX} %(task_name)s[%(task_id)s]: %(message)s"
),
task_annotations={"*": {"rate_limit": "10/s"}},
# task_annotations={"*": {"rate_limit": "10/s"}},
task_track_started=True,
task_time_limit=30 * 60,
task_soft_time_limit=30 * 60,
# task_time_limit=30 * 60,
# task_soft_time_limit=30 * 60,
worker_redirect_stdouts_level="DEBUG",
)

Expand Down Expand Up @@ -68,3 +73,35 @@ def task_sleep(seconds: int, task_id: str) -> int:
"""Sum two numbers, x and y, and sleep the same amount of the sum."""
sleep(seconds)
return int(datetime.now().timestamp())


@app.task # type: ignore
@limit_random_concurrent_tasks(
max_concurrent_tasks=2, redis_client=redis_client
)
def task_random_get_time(
request_id: int, start_time: float
) -> tuple[int, float]:
"""Limit simple task max concurrent."""
print(
f"[Random] Started task {request_id} after:",
time.time() - start_time,
)
sleep(1)
return request_id, time.time()


@app.task # type: ignore
@limit_sequence_concurrent_tasks(
max_concurrent_tasks=2, redis_client=redis_client
)
def task_sequence_get_time(
request_id: int, start_time: float
) -> tuple[int, float]:
"""Limit simple task max concurrent."""
print(
f"[Sequence] Started task {request_id} after:",
time.time() - start_time,
)
sleep(1)
return request_id, time.time()
58 changes: 26 additions & 32 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,50 @@

from __future__ import annotations

import subprocess
import time
import logging

from typing import Generator
from typing import Any, Generator

import pytest
import redis

from celery.contrib.testing.worker import start_worker
from retsu.queues import get_redis_queue_config

from tests.celery_tasks import app as celery_app


def redis_flush() -> None:
"""Wipe-out redis database."""
logging.info("Wiping-out redis database.")
r = redis.Redis(**get_redis_queue_config()) # type: ignore
r.flushdb()


@pytest.fixture(scope="session")
def celery_worker_parameters() -> dict[str, Any]:
"""Parameters for the Celery worker."""
return {
"loglevel": "debug", # Set log level
"concurrency": 4, # Number of concurrent workers
"perform_ping_check": False,
"pool": "prefork",
}


@pytest.fixture(autouse=True, scope="session")
def setup() -> Generator[None, None, None]:
def setup(
celery_worker_parameters: dict[str, Any],
) -> Generator[None, None, None]:
"""Set up the services needed by the tests."""
try:
# # Run the `sugar build` command
# subprocess.run(["sugar", "build"], check=True)
# # Run the `sugar ext restart --options -d` command
# subprocess.run(
# ["sugar", "ext", "restart", "--options", "-d"], check=True
# )
# # Sleep for 5 seconds
# time.sleep(5)

# Clean Redis queues
logging.info("Clean Redis queues")
redis_flush()

# Start the Celery worker
celery_process = subprocess.Popen(
[
"celery",
"-A",
"tests.celery_tasks",
"worker",
"--loglevel=debug",
],
)

time.sleep(5)

yield
logging.info("Start the Celery worker")
with start_worker(celery_app, **celery_worker_parameters) as worker:
# Ensure worker is up and running
yield worker # Now you can use this worker in your tests

finally:
# Teardown: Terminate the Celery worker
celery_process.terminate()
celery_process.wait()
# subprocess.run(["sugar", "ext", "stop"], check=True)
pass
Loading

0 comments on commit 41c6fb3

Please sign in to comment.