diff --git a/bots/tasks.py b/bots/tasks.py index 5dda17e90..a6c3cebe8 100644 --- a/bots/tasks.py +++ b/bots/tasks.py @@ -19,6 +19,7 @@ SlackBot, ) from daras_ai_v2.vector_search import references_as_prompt +from gooeysite.bg_db_conn import get_celery_result_db_safe from recipes.VideoBots import ReplyButton @@ -57,7 +58,7 @@ def msg_analysis(msg_id: int): Message.objects.filter(id=msg_id).update(analysis_run=sr) # wait for the result - result.get(disable_sync_subtasks=False) + get_celery_result_db_safe(result) sr.refresh_from_db() # if failed, raise error if sr.error_msg: diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 5e1690b31..8d13a6f7d 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -16,6 +16,7 @@ from daras_ai_v2.settings import templates from gooey_ui.pubsub import realtime_push from gooey_ui.state import set_query_params +from gooeysite.bg_db_conn import db_middleware, next_db_safe @app.task @@ -40,6 +41,7 @@ def gui_runner( error_msg = None set_query_params(query_params or {}) + @db_middleware def save(done=False): if done: # clear run status @@ -81,7 +83,7 @@ def save(done=False): start_time = time() try: # advance the generator (to further progress of run()) - yield_val = next(gen) + yield_val = next_db_safe(gen) # increment total time taken after every iteration run_time += time() - start_time continue diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index 30acf07a3..ee977b294 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -26,7 +26,7 @@ from daras_ai_v2.language_model import CHATML_ROLE_USER, CHATML_ROLE_ASSISTANT from daras_ai_v2.vector_search import doc_url_to_file_metadata from gooey_ui.pubsub import realtime_subscribe -from gooeysite.bg_db_conn import db_middleware +from gooeysite.bg_db_conn import db_middleware, get_celery_result_db_safe from recipes.VideoBots import VideoBotsPage, ReplyButton from routers.api import submit_api_call @@ -392,7 +392,7 @@ def _process_and_send_msg( break # we're done streaming, abort # wait for the celery task to finish - result.get(disable_sync_subtasks=False) + get_celery_result_db_safe(result) # get the final state from db state = page.run_doc_sr(run_id, uid).to_dict() # check for errors diff --git a/daras_ai_v2/gpu_server.py b/daras_ai_v2/gpu_server.py index 752879f8d..329a33fbf 100644 --- a/daras_ai_v2/gpu_server.py +++ b/daras_ai_v2/gpu_server.py @@ -10,6 +10,7 @@ from daras_ai.image_input import storage_blob_for from daras_ai_v2 import settings from daras_ai_v2.exceptions import raise_for_status +from gooeysite.bg_db_conn import get_celery_result_db_safe class GpuEndpoints: @@ -159,7 +160,7 @@ def call_celery_task( task_name, kwargs=dict(pipeline=pipeline, inputs=inputs), queue=queue ) s = time() - ret = result.get(disable_sync_subtasks=False) + ret = get_celery_result_db_safe(result) record_cost_auto( model=queue, sku=ModelSku.gpu_ms, quantity=int((time() - s) * 1000) ) diff --git a/daras_ai_v2/safety_checker.py b/daras_ai_v2/safety_checker.py index 541e41b97..3f6859164 100644 --- a/daras_ai_v2/safety_checker.py +++ b/daras_ai_v2/safety_checker.py @@ -2,6 +2,7 @@ from daras_ai_v2.azure_image_moderation import is_image_nsfw from daras_ai_v2.functional import flatten from daras_ai_v2 import settings +from gooeysite.bg_db_conn import get_celery_result_db_safe from recipes.CompareLLM import CompareLLMPage @@ -31,7 +32,7 @@ def safety_checker_text(text_input: str): ) # wait for checker - result.get(disable_sync_subtasks=False) + get_celery_result_db_safe(result) sr.refresh_from_db() # if checker failed, raise error if sr.error_msg: diff --git a/gooeysite/bg_db_conn.py b/gooeysite/bg_db_conn.py index 9e6a6d3db..9c7680df9 100644 --- a/gooeysite/bg_db_conn.py +++ b/gooeysite/bg_db_conn.py @@ -3,11 +3,17 @@ from django.db import reset_queries, close_old_connections +if typing.TYPE_CHECKING: + import celery.result + F = typing.TypeVar("F", bound=typing.Callable[..., typing.Any]) def db_middleware(fn: F) -> F: - """Decorator to ensure the `fn` runs safely in a background task with a new database connection.""" + """ + Decorator to ensure the `fn` runs safely in a background task with a new database connection. + Workaround for https://code.djangoproject.com/ticket/24810 + """ @wraps(fn) def wrapper(*args, **kwargs): @@ -19,3 +25,11 @@ def wrapper(*args, **kwargs): close_old_connections() return wrapper + + +next_db_safe = db_middleware(next) + + +@db_middleware +def get_celery_result_db_safe(result: "celery.result.AsyncResult") -> typing.Any: + return result.get(disable_sync_subtasks=False) diff --git a/recipes/BulkRunner.py b/recipes/BulkRunner.py index 15710c410..c4ef5a3c0 100644 --- a/recipes/BulkRunner.py +++ b/recipes/BulkRunner.py @@ -19,6 +19,7 @@ doc_url_to_metadata, download_content_bytes, ) +from gooeysite.bg_db_conn import get_celery_result_db_safe from recipes.DocSearch import render_documents DEFAULT_BULK_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d80fd4d8-93fa-11ee-bc13-02420a0001cc/Bulk%20Runner.jpg.png" @@ -301,7 +302,7 @@ def run_v2( result, sr = sr.submit_api_call( current_user=self.request.user, request_body=request_body ) - result.get(disable_sync_subtasks=False) + get_celery_result_db_safe(result) sr.refresh_from_db() run_time = datetime.timedelta( @@ -366,7 +367,7 @@ def run_v2( result, sr = sr.submit_api_call( current_user=self.request.user, request_body=request_body ) - result.get(disable_sync_subtasks=False) + get_celery_result_db_safe(result) sr.refresh_from_db() response.eval_runs.append(sr.get_app_url()) diff --git a/routers/api.py b/routers/api.py index f7db3e8e0..bc9a0303e 100644 --- a/routers/api.py +++ b/routers/api.py @@ -30,6 +30,7 @@ BasePage, StateKeys, ) +from gooeysite.bg_db_conn import get_celery_result_db_safe app = APIRouter() @@ -383,7 +384,7 @@ def build_api_response( } else: # wait for the result - result.get(disable_sync_subtasks=False) + get_celery_result_db_safe(result) state = self.run_doc_sr(run_id, uid).to_dict() # check for errors err_msg = state.get(StateKeys.error_msg)