Skip to content

Commit

Permalink
Fix task_id caching for img2img
Browse files Browse the repository at this point in the history
  • Loading branch information
nikochiko committed Feb 5, 2024
1 parent 37f37f7 commit 6d2c86e
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
4 changes: 4 additions & 0 deletions daras_ai_v2/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def instruct_pix2pix(
guidance_scale: float,
image_guidance_scale: float,
seed: int = 42,
task_id: str | None = None,
):
return call_sd_multi(
"diffusion.instruct_pix2pix",
Expand All @@ -256,6 +257,7 @@ def instruct_pix2pix(
"image": images,
"image_guidance_scale": image_guidance_scale,
},
task_id=task_id,
)


Expand Down Expand Up @@ -364,6 +366,7 @@ def img2img(
negative_prompt: str = None,
guidance_scale: float,
seed: int = 42,
task_id: str | None = None,
):
prompt_strength = prompt_strength or 0.7
assert 0 <= prompt_strength <= 0.9, "Prompt Strength must be in range [0, 0.9]"
Expand Down Expand Up @@ -409,6 +412,7 @@ def img2img(
"image": [init_image],
"strength": prompt_strength,
},
task_id=task_id,
)
return [
upload_file_from_bytes(f"gooey.ai - {prompt}.png", sd_img_bytes)
Expand Down
1 change: 0 additions & 1 deletion recipes/CompareText2Img.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import hashlib
import typing

from pydantic import BaseModel
Expand Down
8 changes: 7 additions & 1 deletion recipes/Img2Img.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from daras_ai_v2.crypto import hash_together
from daras_ai_v2.img_model_settings_widgets import img_model_settings
from daras_ai_v2.loom_video_widget import youtube_video
from daras_ai_v2.query_params import gooey_get_query_params
from daras_ai_v2.query_params_util import extract_query_params
from daras_ai_v2.stable_diffusion import (
InpaintingModels,
Img2ImgModels,
Expand Down Expand Up @@ -155,6 +157,8 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
if not self.request.user.disable_safety_checker:
safety_checker(text=request.text_prompt, image=request.input_image)

_, run_id, uid = extract_query_params(gooey_get_query_params())
task_id = hash_together(run_id, uid) if run_id and uid else None
if request.selected_model == Img2ImgModels.instruct_pix2pix.name:
state["output_images"] = instruct_pix2pix(
prompt=request.text_prompt,
Expand All @@ -165,6 +169,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
seed=request.seed,
images=[init_image],
image_guidance_scale=request.image_guidance_scale,
task_id=task_id,
)
elif request.selected_controlnet_model:
state["output_images"] = controlnet(
Expand All @@ -178,6 +183,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
guidance_scale=request.guidance_scale,
seed=request.seed,
controlnet_conditioning_scale=request.controlnet_conditioning_scale,
task_id=task_id,
)
else:
state["output_images"] = img2img(
Expand All @@ -191,7 +197,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
negative_prompt=request.negative_prompt,
guidance_scale=request.guidance_scale,
seed=request.seed,
task_id=hash_together(run_id, uid) if run_id and uid else None,
task_id=task_id,
)

def preview_description(self, state: dict) -> str:
Expand Down

0 comments on commit 6d2c86e

Please sign in to comment.