From b2d03499200174c6e48f9cd44010173fe75d1e4b Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Sun, 26 Nov 2023 22:53:26 -0800 Subject: [PATCH 1/3] added reference image support via controlnets --- daras_ai_v2/stable_diffusion.py | 6 ++-- recipes/Img2Img.py | 2 +- recipes/QRCodeGenerator.py | 57 ++++++++++++++++++++++++++++++++- scripts/run_all_diffusion.py | 2 +- 4 files changed, 62 insertions(+), 5 deletions(-) diff --git a/daras_ai_v2/stable_diffusion.py b/daras_ai_v2/stable_diffusion.py index f325de57e..d799b6353 100644 --- a/daras_ai_v2/stable_diffusion.py +++ b/daras_ai_v2/stable_diffusion.py @@ -402,7 +402,7 @@ def controlnet( scheduler: str = None, prompt: str, num_outputs: int = 1, - init_image: str, + init_images: list[str] | str, num_inference_steps: int = 50, negative_prompt: str = None, guidance_scale: float = 7.5, @@ -411,6 +411,8 @@ def controlnet( ): if isinstance(selected_controlnet_model, str): selected_controlnet_model = [selected_controlnet_model] + if isinstance(init_images, str): + init_images = [init_images] * len(selected_controlnet_model) prompt = add_prompt_prefix(prompt, selected_model) return call_sd_multi( "diffusion.controlnet", @@ -432,7 +434,7 @@ def controlnet( "num_images_per_prompt": num_outputs, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, - "image": [init_image] * len(selected_controlnet_model), + "image": init_images, "controlnet_conditioning_scale": controlnet_conditioning_scale, # "strength": prompt_strength, }, diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index 03d4dd668..1f8a7d3ea 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -165,7 +165,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: selected_controlnet_model=request.selected_controlnet_model, prompt=request.text_prompt, num_outputs=request.num_outputs, - init_image=init_image, + init_images=init_image, num_inference_steps=request.quality, negative_prompt=request.negative_prompt, guidance_scale=request.guidance_scale, diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index e14f6fe20..567e19ec0 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -37,6 +37,7 @@ from recipes.EmailFaceInpainting import get_photo_for_email from recipes.SocialLookupEmail import get_profile_for_email from url_shortener.models import ShortenedURL +from daras_ai_v2.enum_selector_widget import enum_multiselect ATTEMPTS = 1 DEFAULT_QR_CODE_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/f09c8cfa-5393-11ee-a837-02420a000190/ai%20art%20qr%20codes1%201.png.png" @@ -59,6 +60,12 @@ class QRCodeGeneratorPage(BasePage): obj_scale=0.65, obj_pos_x=0.5, obj_pos_y=0.5, + image_prompt_controlnet_models=[ + ControlNetModels.sd_controlnet_canny.name, + ControlNetModels.sd_controlnet_depth.name, + ControlNetModels.sd_controlnet_tile.name, + ], + inspiration_strength=0.3, ) def __init__(self, *args, **kwargs): @@ -75,6 +82,11 @@ class RequestModel(BaseModel): text_prompt: str negative_prompt: str | None + image_prompt: str | None + image_prompt_controlnet_models: list[ + typing.Literal[tuple(e.name for e in ControlNetModels)], ... + ] | None + inspiration_strength: float | None selected_model: typing.Literal[tuple(e.name for e in Text2ImgModels)] | None selected_controlnet_model: list[ @@ -129,6 +141,14 @@ def render_form_v2(self): key="text_prompt", placeholder="Bright sunshine coming through the cracks of a wet, cave wall of big rocks", ) + st.file_uploader( + """ + ### 🏞️ Reference Image [optional] + This image will be used as inspiration to blend with the QR Code. + """, + key="image_prompt", + accept=["image/*"], + ) qr_code_source_key = "__qr_code_source" if qr_code_source_key not in st.session_state: @@ -320,6 +340,30 @@ def render_settings(self): color=255, ) + if st.session_state.get("image_prompt"): + st.write("---") + st.write( + """ + ##### 🎨 Inspiration + Use this to control how the image prompt should influence the output. + """, + className="gui-input", + ) + st.slider( + "Inspiration Strength", + min_value=0.0, + max_value=1.0, + step=0.05, + key="inspiration_strength", + ) + enum_multiselect( + ControlNetModels, + label="Control Net Models", + key="image_prompt_controlnet_models", + checkboxes=False, + allow_none=False, + ) + def render_output(self): state = st.session_state self._render_outputs(state) @@ -376,11 +420,22 @@ def run(self, state: dict) -> typing.Iterator[str | None]: state["raw_images"] = raw_images = [] yield f"Running {Text2ImgModels[request.selected_model].value}..." + if isinstance(request.selected_controlnet_model, str): + request.selected_controlnet_model = [request.selected_controlnet_model] + init_images = [image] * len(request.selected_controlnet_model) + if request.image_prompt: + init_images += [request.image_prompt] * len( + request.image_prompt_controlnet_models + ) + request.selected_controlnet_model += request.image_prompt_controlnet_models + request.controlnet_conditioning_scale += [ + request.inspiration_strength + ] * len(request.image_prompt_controlnet_models) state["output_images"] = controlnet( selected_model=request.selected_model, selected_controlnet_model=request.selected_controlnet_model, prompt=request.text_prompt, - init_image=image, + init_images=init_images, num_outputs=request.num_outputs, num_inference_steps=request.quality, negative_prompt=request.negative_prompt, diff --git a/scripts/run_all_diffusion.py b/scripts/run_all_diffusion.py index eabb9b1f5..d631883df 100644 --- a/scripts/run_all_diffusion.py +++ b/scripts/run_all_diffusion.py @@ -108,7 +108,7 @@ selected_controlnet_model=controlnet_model.name, prompt=get_random_string(100, string.ascii_letters), num_outputs=4, - init_image=random_img, + init_images=random_img, num_inference_steps=1, guidance_scale=7, ), From 9ae130193b73049bab11571d16afe5936e8f9390 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Wed, 20 Dec 2023 13:33:18 -0800 Subject: [PATCH 2/3] renamed to image_prompt_strength --- recipes/QRCodeGenerator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 567e19ec0..89d5fa3de 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -65,7 +65,7 @@ class QRCodeGeneratorPage(BasePage): ControlNetModels.sd_controlnet_depth.name, ControlNetModels.sd_controlnet_tile.name, ], - inspiration_strength=0.3, + image_prompt_strength=0.3, ) def __init__(self, *args, **kwargs): @@ -86,7 +86,7 @@ class RequestModel(BaseModel): image_prompt_controlnet_models: list[ typing.Literal[tuple(e.name for e in ControlNetModels)], ... ] | None - inspiration_strength: float | None + image_prompt_strength: float | None selected_model: typing.Literal[tuple(e.name for e in Text2ImgModels)] | None selected_controlnet_model: list[ @@ -354,7 +354,7 @@ def render_settings(self): min_value=0.0, max_value=1.0, step=0.05, - key="inspiration_strength", + key="image_prompt_strength", ) enum_multiselect( ControlNetModels, @@ -429,7 +429,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ) request.selected_controlnet_model += request.image_prompt_controlnet_models request.controlnet_conditioning_scale += [ - request.inspiration_strength + request.image_prompt_strength ] * len(request.image_prompt_controlnet_models) state["output_images"] = controlnet( selected_model=request.selected_model, From 59dc9eac3aaf2402812a239f1db0e4485e8d8785 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Wed, 20 Dec 2023 13:57:09 -0800 Subject: [PATCH 3/3] scale and reposition setting --- recipes/QRCodeGenerator.py | 77 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 89d5fa3de..834ded310 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -66,6 +66,9 @@ class QRCodeGeneratorPage(BasePage): ControlNetModels.sd_controlnet_tile.name, ], image_prompt_strength=0.3, + image_prompt_scale=1.0, + image_prompt_pos_x=0.5, + image_prompt_pos_y=0.5, ) def __init__(self, *args, **kwargs): @@ -87,6 +90,9 @@ class RequestModel(BaseModel): typing.Literal[tuple(e.name for e in ControlNetModels)], ... ] | None image_prompt_strength: float | None + image_prompt_scale: float | None + image_prompt_pos_x: float | None + image_prompt_pos_y: float | None selected_model: typing.Literal[tuple(e.name for e in Text2ImgModels)] | None selected_controlnet_model: list[ @@ -363,6 +369,55 @@ def render_settings(self): checkboxes=False, allow_none=False, ) + st.write( + """ + ##### ⌖ Positioning + Use this to control where the reference image is placed, and how big it should be. + """, + className="gui-input", + ) + col1, _ = st.columns(2) + with col1: + image_prompt_scale = st.slider( + "Scale", + min_value=0.1, + max_value=1.0, + step=0.05, + key="image_prompt_scale", + ) + col1, col2 = st.columns(2, responsive=False) + with col1: + image_prompt_pos_x = st.slider( + "Position X", + min_value=0.0, + max_value=1.0, + step=0.05, + key="image_prompt_pos_x", + ) + with col2: + image_prompt_pos_y = st.slider( + "Position Y", + min_value=0.0, + max_value=1.0, + step=0.05, + key="image_prompt_pos_y", + ) + + img_cv2 = mask_cv2 = bytes_to_cv2_img( + requests.get(st.session_state["image_prompt"]).content, + ) + repositioning_preview_widget( + img_cv2=img_cv2, + mask_cv2=mask_cv2, + obj_scale=image_prompt_scale, + pos_x=image_prompt_pos_x, + pos_y=image_prompt_pos_y, + out_size=( + st.session_state["output_width"], + st.session_state["output_height"], + ), + color=255, + ) def render_output(self): state = st.session_state @@ -424,6 +479,28 @@ def run(self, state: dict) -> typing.Iterator[str | None]: request.selected_controlnet_model = [request.selected_controlnet_model] init_images = [image] * len(request.selected_controlnet_model) if request.image_prompt: + if ( + request.image_prompt_scale != 1.0 + or request.image_prompt_pos_x != 0.5 + or request.image_prompt_pos_y != 0.5 + ): + # we only need to reposition if the user moved/scaled the image + image_prompt = bytes_to_cv2_img( + requests.get(request.image_prompt).content + ) + repositioned_image_prompt, _ = reposition_object( + orig_img=image_prompt, + orig_mask=image_prompt, + out_size=(request.output_width, request.output_height), + out_obj_scale=request.image_prompt_scale, + out_pos_x=request.image_prompt_pos_x, + out_pos_y=request.image_prompt_pos_y, + color=255, + ) + request.image_prompt = upload_file_from_bytes( + "repositioned_image_prompt.png", + cv2_img_to_bytes(repositioned_image_prompt), + ) init_images += [request.image_prompt] * len( request.image_prompt_controlnet_models )