Skip to content

Commit

Permalink
Merge branch 'master' into Remix-Share
Browse files Browse the repository at this point in the history
  • Loading branch information
clr-li committed Feb 15, 2024
2 parents b4ad428 + 4b3bbc9 commit f56c91e
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 10 deletions.
3 changes: 2 additions & 1 deletion bots/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
SlackBot,
)
from daras_ai_v2.vector_search import references_as_prompt
from gooeysite.bg_db_conn import get_celery_result_db_safe
from recipes.VideoBots import ReplyButton


Expand Down Expand Up @@ -57,7 +58,7 @@ def msg_analysis(msg_id: int):
Message.objects.filter(id=msg_id).update(analysis_run=sr)

# wait for the result
result.get(disable_sync_subtasks=False)
get_celery_result_db_safe(result)
sr.refresh_from_db()
# if failed, raise error
if sr.error_msg:
Expand Down
4 changes: 3 additions & 1 deletion celeryapp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from daras_ai_v2.settings import templates
from gooey_ui.pubsub import realtime_push
from gooey_ui.state import set_query_params
from gooeysite.bg_db_conn import db_middleware, next_db_safe


@app.task
Expand All @@ -40,6 +41,7 @@ def gui_runner(
error_msg = None
set_query_params(query_params or {})

@db_middleware
def save(done=False):
if done:
# clear run status
Expand Down Expand Up @@ -81,7 +83,7 @@ def save(done=False):
start_time = time()
try:
# advance the generator (to further progress of run())
yield_val = next(gen)
yield_val = next_db_safe(gen)
# increment total time taken after every iteration
run_time += time() - start_time
continue
Expand Down
4 changes: 2 additions & 2 deletions daras_ai_v2/bots.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from daras_ai_v2.language_model import CHATML_ROLE_USER, CHATML_ROLE_ASSISTANT
from daras_ai_v2.vector_search import doc_url_to_file_metadata
from gooey_ui.pubsub import realtime_subscribe
from gooeysite.bg_db_conn import db_middleware
from gooeysite.bg_db_conn import db_middleware, get_celery_result_db_safe
from recipes.VideoBots import VideoBotsPage, ReplyButton
from routers.api import submit_api_call

Expand Down Expand Up @@ -392,7 +392,7 @@ def _process_and_send_msg(
break # we're done streaming, abort

# wait for the celery task to finish
result.get(disable_sync_subtasks=False)
get_celery_result_db_safe(result)
# get the final state from db
state = page.run_doc_sr(run_id, uid).to_dict()
# check for errors
Expand Down
3 changes: 2 additions & 1 deletion daras_ai_v2/gpu_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from daras_ai.image_input import storage_blob_for
from daras_ai_v2 import settings
from daras_ai_v2.exceptions import raise_for_status
from gooeysite.bg_db_conn import get_celery_result_db_safe


class GpuEndpoints:
Expand Down Expand Up @@ -159,7 +160,7 @@ def call_celery_task(
task_name, kwargs=dict(pipeline=pipeline, inputs=inputs), queue=queue
)
s = time()
ret = result.get(disable_sync_subtasks=False)
ret = get_celery_result_db_safe(result)
record_cost_auto(
model=queue, sku=ModelSku.gpu_ms, quantity=int((time() - s) * 1000)
)
Expand Down
3 changes: 2 additions & 1 deletion daras_ai_v2/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from daras_ai_v2.azure_image_moderation import is_image_nsfw
from daras_ai_v2.functional import flatten
from daras_ai_v2 import settings
from gooeysite.bg_db_conn import get_celery_result_db_safe
from recipes.CompareLLM import CompareLLMPage


Expand Down Expand Up @@ -31,7 +32,7 @@ def safety_checker_text(text_input: str):
)

# wait for checker
result.get(disable_sync_subtasks=False)
get_celery_result_db_safe(result)
sr.refresh_from_db()
# if checker failed, raise error
if sr.error_msg:
Expand Down
16 changes: 15 additions & 1 deletion gooeysite/bg_db_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@

from django.db import reset_queries, close_old_connections

if typing.TYPE_CHECKING:
import celery.result

F = typing.TypeVar("F", bound=typing.Callable[..., typing.Any])


def db_middleware(fn: F) -> F:
"""Decorator to ensure the `fn` runs safely in a background task with a new database connection."""
"""
Decorator to ensure the `fn` runs safely in a background task with a new database connection.
Workaround for https://code.djangoproject.com/ticket/24810
"""

@wraps(fn)
def wrapper(*args, **kwargs):
Expand All @@ -19,3 +25,11 @@ def wrapper(*args, **kwargs):
close_old_connections()

return wrapper


next_db_safe = db_middleware(next)


@db_middleware
def get_celery_result_db_safe(result: "celery.result.AsyncResult") -> typing.Any:
return result.get(disable_sync_subtasks=False)
5 changes: 3 additions & 2 deletions recipes/BulkRunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
doc_url_to_metadata,
download_content_bytes,
)
from gooeysite.bg_db_conn import get_celery_result_db_safe
from recipes.DocSearch import render_documents

DEFAULT_BULK_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d80fd4d8-93fa-11ee-bc13-02420a0001cc/Bulk%20Runner.jpg.png"
Expand Down Expand Up @@ -301,7 +302,7 @@ def run_v2(
result, sr = sr.submit_api_call(
current_user=self.request.user, request_body=request_body
)
result.get(disable_sync_subtasks=False)
get_celery_result_db_safe(result)
sr.refresh_from_db()

run_time = datetime.timedelta(
Expand Down Expand Up @@ -366,7 +367,7 @@ def run_v2(
result, sr = sr.submit_api_call(
current_user=self.request.user, request_body=request_body
)
result.get(disable_sync_subtasks=False)
get_celery_result_db_safe(result)
sr.refresh_from_db()
response.eval_runs.append(sr.get_app_url())

Expand Down
3 changes: 2 additions & 1 deletion routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
BasePage,
StateKeys,
)
from gooeysite.bg_db_conn import get_celery_result_db_safe

app = APIRouter()

Expand Down Expand Up @@ -383,7 +384,7 @@ def build_api_response(
}
else:
# wait for the result
result.get(disable_sync_subtasks=False)
get_celery_result_db_safe(result)
state = self.run_doc_sr(run_id, uid).to_dict()
# check for errors
err_msg = state.get(StateKeys.error_msg)
Expand Down

0 comments on commit f56c91e

Please sign in to comment.