Skip to content

Commit

Permalink
Set unique task ID for GPU tasks -- img2img, qr-code, and deforum
Browse files Browse the repository at this point in the history
  • Loading branch information
nikochiko committed Feb 5, 2024
1 parent 4783f7d commit 37f37f7
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 15 deletions.
4 changes: 4 additions & 0 deletions daras_ai_v2/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,7 @@ def get_random_api_key() -> str:

def get_random_string(length: int, allowed_chars: str) -> str:
return "".join(secrets.choice(allowed_chars) for _ in range(length))


def hash_together(*args, **kwargs) -> str:
return hashlib.md5(f"{args}{kwargs}".encode()).hexdigest()
74 changes: 61 additions & 13 deletions daras_ai_v2/gpu_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import datetime
import os
import typing
import uuid

import requests
from furl import furl

from daras_ai.image_input import storage_blob_for
from daras_ai_v2 import settings
from daras_ai_v2.exceptions import raise_for_status
from daras_ai_v2.redis_cache import redis_cache_decorator


class GpuEndpoints:
Expand Down Expand Up @@ -41,6 +43,7 @@ def call_sd_multi(
endpoint: str,
pipeline: dict,
inputs: dict,
task_id: str | None = None,
) -> typing.List[str]:
prompt = inputs["prompt"]
num_images_per_prompt = inputs["num_images_per_prompt"]
Expand All @@ -54,6 +57,7 @@ def call_sd_multi(
content_type="image/png",
filename=f"gooey.ai - {prompt}.png",
num_outputs=num_outputs,
task_id=task_id,
)

# deepfloyd
Expand Down Expand Up @@ -86,7 +90,7 @@ def call_gooey_gpu(
pipeline["upload_urls"] = [
blob.generate_signed_url(
version="v4",
# This URL is valid for 15 minutes
# This URL is valid for 12 hours
expiration=datetime.timedelta(hours=12),
# Allow PUT requests using this URL.
method="PUT",
Expand All @@ -102,17 +106,13 @@ def call_gooey_gpu(
return [blob.public_url for blob in blobs]


def call_celery_task_outfile(
task_name: str,
*,
pipeline: dict,
inputs: dict,
content_type: str,
def create_storage_blob_urls(
filename: str,
num_outputs: int = 1,
):
content_type: str,
num_outputs: int,
) -> tuple[list[str], list[str]]:
blobs = [storage_blob_for(filename) for i in range(num_outputs)]
pipeline["upload_urls"] = [
signed_urls = [
blob.generate_signed_url(
version="v4",
# This URL is valid for 15 minutes
Expand All @@ -123,8 +123,50 @@ def call_celery_task_outfile(
)
for blob in blobs
]
call_celery_task(task_name, pipeline=pipeline, inputs=inputs)
return [blob.public_url for blob in blobs]
public_urls = [blob.public_url for blob in blobs]
return signed_urls, public_urls


@redis_cache_decorator
def get_or_create_storage_blob_urls(
task_id: str,
filename: str,
content_type: str,
num_outputs: int,
) -> tuple[list[str], list[str]]:
# task_id has no purpose other than to serve as the caching key
assert task_id

# cache decorator makes it fetch from cache if it exists
return create_storage_blob_urls(filename, content_type, num_outputs)


def call_celery_task_outfile(
task_name: str,
*,
pipeline: dict,
inputs: dict,
content_type: str,
filename: str,
num_outputs: int = 1,
task_id: str | None = None,
):
if task_id:
signed_urls, public_urls = get_or_create_storage_blob_urls(
task_id=task_id,
filename=filename,
content_type=content_type,
num_outputs=num_outputs,
)
else:
signed_urls, public_urls = create_storage_blob_urls(
filename=filename,
content_type=content_type,
num_outputs=num_outputs,
)
pipeline["upload_urls"] = signed_urls
call_celery_task(task_name, pipeline=pipeline, inputs=inputs, task_id=task_id)
return public_urls


_app = None
Expand All @@ -148,9 +190,15 @@ def call_celery_task(
pipeline: dict,
inputs: dict,
queue_prefix: str = "gooey-gpu",
task_id: str | None = None,
):
queue = os.path.join(queue_prefix, pipeline["model_id"].strip()).strip("/")
task_id = task_id or str(uuid.uuid4())
result = get_celery().send_task(
task_name, kwargs=dict(pipeline=pipeline, inputs=inputs), queue=queue
task_name,
kwargs=dict(pipeline=pipeline, inputs=inputs),
queue=queue,
task_id=task_id,
)
print(f"{task_id=} {queue=} {result=}")
return result.get(disable_sync_subtasks=False)
2 changes: 2 additions & 0 deletions daras_ai_v2/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def controlnet(
guidance_scale: float = 7.5,
seed: int = 42,
controlnet_conditioning_scale: typing.List[float] | float = 1.0,
task_id: str | None = None,
):
if isinstance(selected_controlnet_model, str):
selected_controlnet_model = [selected_controlnet_model]
Expand Down Expand Up @@ -459,6 +460,7 @@ def controlnet(
"controlnet_conditioning_scale": controlnet_conditioning_scale,
# "strength": prompt_strength,
},
task_id=task_id,
)


Expand Down
1 change: 1 addition & 0 deletions recipes/CompareText2Img.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import typing

from pydantic import BaseModel
Expand Down
7 changes: 5 additions & 2 deletions recipes/DeforumSD.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import typing
import uuid
from datetime import datetime, timedelta

from django.db.models import TextChoices
from pydantic import BaseModel
Expand All @@ -9,11 +8,13 @@
import gooey_ui as st
from bots.models import Workflow
from daras_ai_v2.base import BasePage
from daras_ai_v2.crypto import hash_together
from daras_ai_v2.enum_selector_widget import enum_selector
from daras_ai_v2.gpu_server import call_celery_task_outfile
from daras_ai_v2.loom_video_widget import youtube_video
from daras_ai_v2.query_params import gooey_get_query_params
from daras_ai_v2.query_params_util import extract_query_params
from daras_ai_v2.safety_checker import safety_checker
from daras_ai_v2.tabs_widget import MenuTabs

DEFAULT_DEFORUMSD_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/7dc25196-93fe-11ee-9e3a-02420a0001ce/AI%20Animation%20generator.jpg.png"

Expand Down Expand Up @@ -455,6 +456,7 @@ def run(self, state: dict):
if not self.request.user.disable_safety_checker:
safety_checker(text=self.preview_input(state))

_, run_id, uid = extract_query_params(gooey_get_query_params())
state["output_video"] = call_celery_task_outfile(
"deforum",
pipeline=dict(
Expand All @@ -478,4 +480,5 @@ def run(self, state: dict):
),
content_type="video/mp4",
filename=f"gooey.ai animation {request.animation_prompts}.mp4",
task_id=hash_together(run_id, uid) if run_id and uid else None,
)[0]
2 changes: 2 additions & 0 deletions recipes/Img2Img.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import gooey_ui as st
from bots.models import Workflow
from daras_ai_v2.base import BasePage
from daras_ai_v2.crypto import hash_together
from daras_ai_v2.img_model_settings_widgets import img_model_settings
from daras_ai_v2.loom_video_widget import youtube_video
from daras_ai_v2.stable_diffusion import (
Expand Down Expand Up @@ -190,6 +191,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
negative_prompt=request.negative_prompt,
guidance_scale=request.guidance_scale,
seed=request.seed,
task_id=hash_together(run_id, uid) if run_id and uid else None,
)

def preview_description(self, state: dict) -> str:
Expand Down
6 changes: 6 additions & 0 deletions recipes/QRCodeGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@
cv2_img_to_bytes,
)
from daras_ai_v2.base import BasePage
from daras_ai_v2.crypto import hash_together
from daras_ai_v2.descriptions import prompting101
from daras_ai_v2.exceptions import raise_for_status
from daras_ai_v2.img_model_settings_widgets import (
output_resolution_setting,
img_model_settings,
)
from daras_ai_v2.loom_video_widget import youtube_video
from daras_ai_v2.query_params import gooey_get_query_params
from daras_ai_v2.query_params_util import extract_query_params
from daras_ai_v2.repositioning import reposition_object, repositioning_preview_widget
from daras_ai_v2.stable_diffusion import (
Text2ImgModels,
Expand Down Expand Up @@ -498,6 +501,8 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
request.controlnet_conditioning_scale += [
request.image_prompt_strength
] * len(request.image_prompt_controlnet_models)

_, run_id, uid = extract_query_params(gooey_get_query_params())
state["output_images"] = controlnet(
selected_model=request.selected_model,
selected_controlnet_model=request.selected_controlnet_model,
Expand All @@ -510,6 +515,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
seed=request.seed,
controlnet_conditioning_scale=request.controlnet_conditioning_scale,
scheduler=request.scheduler,
task_id=hash_together(run_id, uid) if run_id and uid else None,
)

# TODO: properly detect bad qr code
Expand Down

0 comments on commit 37f37f7

Please sign in to comment.