Skip to content

Commit

Permalink
fix: Add expires to report export child tasks (#19833)
Browse files Browse the repository at this point in the history
  • Loading branch information
webjunkie authored Jan 18, 2024
1 parent 7f174fa commit c869841
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
9 changes: 6 additions & 3 deletions ee/tasks/subscriptions/subscription_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import timedelta
import datetime
from typing import List, Tuple, Union
from django.conf import settings
import structlog
Expand Down Expand Up @@ -56,11 +56,14 @@ def generate_assets(
# Wait for all assets to be exported
tasks = [exporter.export_asset.si(asset.id) for asset in assets]
# run them one after the other, so we don't exhaust celery workers
parallel_job = chain(*tasks).apply_async()
exports_expire = datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(
minutes=settings.PARALLEL_ASSET_GENERATION_MAX_TIMEOUT_MINUTES
)
parallel_job = chain(*tasks).apply_async(expires=exports_expire, retry=False)

wait_for_parallel_celery_group(
parallel_job,
max_timeout=timedelta(minutes=settings.PARALLEL_ASSET_GENERATION_MAX_TIMEOUT_MINUTES),
expires=exports_expire,
)

return insights, assets
14 changes: 7 additions & 7 deletions posthog/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,20 +1261,20 @@ def sleep_time_generator() -> Generator[float, None, None]:


@async_to_sync
async def wait_for_parallel_celery_group(task: Any, max_timeout: Optional[datetime.timedelta] = None) -> Any:
async def wait_for_parallel_celery_group(task: Any, expires: Optional[datetime.datetime] = None) -> Any:
"""
Wait for a group of celery tasks to finish, but don't wait longer than max_timeout.
For parallel tasks, this is the only way to await the entire group.
"""
if not max_timeout:
max_timeout = datetime.timedelta(minutes=5)
default_expires = datetime.timedelta(minutes=5)

start_time = timezone.now()
if not expires:
expires = datetime.datetime.now(tz=datetime.timezone.utc) + default_expires

sleep_generator = sleep_time_generator()

while not task.ready():
if timezone.now() - start_time > max_timeout:
if datetime.datetime.now(tz=datetime.timezone.utc) > expires:
child_states = []
child: AsyncResult
children = task.children or []
Expand All @@ -1292,13 +1292,13 @@ async def wait_for_parallel_celery_group(task: Any, max_timeout: Optional[dateti

logger.error(
"Timed out waiting for celery task to finish",
task_id=task.id,
ready=task.ready(),
successful=task.successful(),
failed=task.failed(),
task_state=task.state,
child_states=child_states,
timeout=max_timeout,
start_time=start_time,
timeout=expires,
)
raise TimeoutError("Timed out waiting for celery task to finish")

Expand Down

0 comments on commit c869841

Please sign in to comment.