Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qr v2 #136

Closed
wants to merge 16 commits into from
16 changes: 11 additions & 5 deletions daras_ai_v2/img_model_settings_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def controlnet_settings(
explanations = controlnet_model_explanations | extra_explanations

state_values = st.session_state.get("controlnet_conditioning_scale", [])
if len(models) != len(state_values):
state_values = state_values + [1] * len(models)
state_values = state_values[: len(models)]
new_values = []
st.write(
"""
Expand All @@ -149,17 +152,17 @@ def controlnet_settings(
`{high_explanation.format(high=int(CONTROLNET_CONDITIONING_SCALE_RANGE[1]))}`.
"""
)
for i, model in enumerate(sorted(models)):
for model, value in sorted(zip(models, state_values), key=lambda x: x[0]):
key = f"controlnet_conditioning_scale_{model}"
try:
st.session_state.setdefault(key, state_values[i])
except IndexError:
pass
if st.session_state.get("controlnet_overwrite"):
st.session_state[key] = value
st.session_state.setdefault(key, value)
new_values.append(
controlnet_weight_setting(
selected_controlnet_model=model, explanations=explanations, key=key
),
)
st.session_state["selected_controlnet_model"] = sorted(models)
st.session_state["controlnet_conditioning_scale"] = new_values


Expand All @@ -178,6 +181,9 @@ def controlnet_weight_setting(
min_value=CONTROLNET_CONDITIONING_SCALE_RANGE[0],
max_value=CONTROLNET_CONDITIONING_SCALE_RANGE[1],
step=0.05,
default_value_attr="value"
if st.session_state.get("controlnet_overwrite")
else "defaultValue",
)


Expand Down
13 changes: 8 additions & 5 deletions daras_ai_v2/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class ControlNetModels(Enum):
sd_controlnet_seg = "Image Segmentation"
sd_controlnet_tile = "Tiling"
sd_controlnet_brightness = "Brightness"
sd_controlnet_qrmonster = "QR Monster"


controlnet_model_explanations = {
Expand All @@ -119,6 +120,7 @@ class ControlNetModels(Enum):
ControlNetModels.sd_controlnet_seg: "Image segmentation",
ControlNetModels.sd_controlnet_tile: "Tiling: to preserve small details",
ControlNetModels.sd_controlnet_brightness: "Brightness: to increase contrast naturally",
ControlNetModels.sd_controlnet_qrmonster: "QR Monster: make beautiful QR codes that still scan with a controlnet specifically trained for this purpose",
}

controlnet_model_ids = {
Expand All @@ -132,6 +134,7 @@ class ControlNetModels(Enum):
ControlNetModels.sd_controlnet_seg: "lllyasviel/sd-controlnet-seg",
ControlNetModels.sd_controlnet_tile: "lllyasviel/control_v11f1e_sd15_tile",
ControlNetModels.sd_controlnet_brightness: "ioclab/control_v1p_sd15_brightness",
ControlNetModels.sd_controlnet_qrmonster: "monster-labs/control_v1p_sd15_qrcode_monster/v2",
}


Expand Down Expand Up @@ -416,7 +419,7 @@ def controlnet(
return call_sd_multi(
"diffusion.controlnet",
pipeline={
"model_id": text2img_model_ids[Text2ImgModels[selected_model]],
"model_id": img2img_model_ids[Img2ImgModels[selected_model]],
"seed": seed,
"scheduler": Schedulers[scheduler].label
if scheduler
Expand All @@ -442,13 +445,13 @@ def controlnet(

def add_prompt_prefix(prompt: str, selected_model: str) -> str:
match selected_model:
case Text2ImgModels.openjourney.name:
case Text2ImgModels.openjourney.name | Img2ImgModels.openjourney.name:
prompt = "mdjrny-v4 style " + prompt
case Text2ImgModels.analog_diffusion.name:
case Text2ImgModels.analog_diffusion.name | Img2ImgModels.analog_diffusion.name:
prompt = "analog style " + prompt
case Text2ImgModels.protogen_5_3.name:
case Text2ImgModels.protogen_5_3.name | Img2ImgModels.protogen_5_3.name:
prompt = "modelshoot style " + prompt
case Text2ImgModels.dreamlike_2.name:
case Text2ImgModels.dreamlike_2.name | Img2ImgModels.dreamlike_2.name:
prompt = "photo, " + prompt
return prompt

Expand Down
2 changes: 2 additions & 0 deletions gooey_ui/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def slider(
help: str = None,
*,
disabled: bool = False,
default_value_attr: str = "defaultValue",
) -> float:
value = _input_widget(
input_type="range",
Expand All @@ -570,6 +571,7 @@ def slider(
min=min_value,
max=max_value,
step=_step_value(min_value, max_value, step),
default_value_attr=default_value_attr,
)
return value or 0

Expand Down
166 changes: 150 additions & 16 deletions recipes/QRCodeGenerator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing

import numpy as np
from PIL import Image, ImageOps, ImageEnhance
import qrcode
import requests
from django.core.exceptions import ValidationError
Expand All @@ -25,7 +26,6 @@
)
from daras_ai_v2.repositioning import reposition_object, repositioning_preview_widget
from daras_ai_v2.stable_diffusion import (
Text2ImgModels,
controlnet,
ControlNetModels,
Img2ImgModels,
Expand All @@ -35,6 +35,93 @@

ATTEMPTS = 1

PRESETS = {
"Reliable": {
"description": "If you just want something tried and tested, this is our original defaults.",
"state_update": {
"negative_prompt": "ugly, disfigured, low quality, blurry, nsfw, text, words",
"controlnet_conditioning_scale": [0.45, 0.35],
"guidance_scale": 9,
"num_outputs": 2,
"obj_pos_x": 0.5,
"obj_pos_y": 0.5,
"obj_scale": 0.65,
"output_height": 512,
"output_width": 512,
"quality": 70,
"scheduler": Schedulers.euler_ancestral.name,
"selected_controlnet_model": [
ControlNetModels.sd_controlnet_brightness.name,
ControlNetModels.sd_controlnet_tile.name,
],
"selected_model": Img2ImgModels.dream_shaper.name,
},
},
"Creative": {
"description": "Stunning QR Codes with a creative flair that may not always be readable and could end up weird.",
"state_update": {
"negative_prompt": "ugly, disfigured, low quality, blurry, nsfw, text, words, multiple heads",
"controlnet_conditioning_scale": [1.4],
"guidance_scale": 4,
"num_outputs": 4,
"obj_pos_x": 0.5,
"obj_pos_y": 0.5,
"obj_scale": 0.65,
"output_height": 768,
"output_width": 768,
"quality": 40,
"scheduler": Schedulers.euler_ancestral.name,
"selected_controlnet_model": [
ControlNetModels.sd_controlnet_qrmonster.name,
],
"selected_model": Img2ImgModels.dream_shaper.name,
},
},
"Beautiful": {
"description": "The best mix of reliability and creativity. Produces some of the best results for most purposes.",
"state_update": {
"negative_prompt": "ugly, disfigured, low quality, blurry, nsfw, text, words, multiple heads, many",
"controlnet_conditioning_scale": [0.25, 1.4],
"guidance_scale": 9,
"num_outputs": 1,
"obj_pos_x": 0.5,
"obj_pos_y": 0.5,
"obj_scale": 0.65,
"output_height": 768,
"output_width": 768,
"quality": 70,
"scheduler": Schedulers.euler_ancestral.name,
"selected_controlnet_model": [
ControlNetModels.sd_controlnet_brightness.name,
ControlNetModels.sd_controlnet_qrmonster.name,
],
"selected_model": Img2ImgModels.dream_shaper.name,
},
},
"3D": {
"description": "Uses depth information to make the QR Code appear more 3D. This is experimental and may not always work.",
"state_update": {
"negative_prompt": "ugly, disfigured, low quality, blurry, nsfw, text, words, multiple heads, many",
"controlnet_conditioning_scale": [0.35, 0.3, 0.3],
"guidance_scale": 8,
"num_outputs": 1,
"obj_pos_x": 0.5,
"obj_pos_y": 0.5,
"obj_scale": 0.8,
"output_height": 512,
"output_width": 512,
"quality": 100,
"scheduler": Schedulers.euler_ancestral.name,
"selected_controlnet_model": [
ControlNetModels.sd_controlnet_brightness.name,
ControlNetModels.sd_controlnet_depth.name,
ControlNetModels.sd_controlnet_tile.name,
],
"selected_model": Img2ImgModels.dream_shaper.name,
},
},
}


class QRCodeGeneratorPage(BasePage):
title = "AI Art QR Code"
Expand All @@ -45,6 +132,7 @@ class QRCodeGeneratorPage(BasePage):
obj_scale=0.65,
obj_pos_x=0.5,
obj_pos_y=0.5,
color=255,
)

def __init__(self, *args, **kwargs):
Expand All @@ -60,7 +148,7 @@ class RequestModel(BaseModel):
text_prompt: str
negative_prompt: str | None

selected_model: typing.Literal[tuple(e.name for e in Text2ImgModels)] | None
selected_model: typing.Literal[tuple(e.name for e in Img2ImgModels)] | None
selected_controlnet_model: list[
typing.Literal[tuple(e.name for e in ControlNetModels)], ...
] | None
Expand All @@ -80,6 +168,7 @@ class RequestModel(BaseModel):
obj_scale: float | None
obj_pos_x: float | None
obj_pos_y: float | None
color: int | None

class ResponseModel(BaseModel):
output_images: list[str]
Expand Down Expand Up @@ -212,6 +301,24 @@ def render_settings(self):
"""
)

preset_match = None
st.session_state["controlnet_overwrite"] = False
for preset, preset_value in PRESETS.items():
key = "preset_" + preset
if st.button(preset, key=key, disabled=st.session_state.get(key, False)):
preset_match = preset_value
st.session_state.update(preset_value["state_update"])
st.session_state["controlnet_overwrite"] = True
if not preset_match:
st.button("Custom", key="preset_Custom", disabled=True)
st.caption(
"For the tech savvy and the curious. Here you can play around with settings to create something truly unique."
)
else:
st.caption(
preset_match["description"],
)

img_model_settings(
Img2ImgModels,
show_scheduler=True,
Expand Down Expand Up @@ -265,6 +372,9 @@ def render_settings(self):
img_cv2 = mask_cv2 = np.array(
qrcode.QRCode(border=0).make_image().convert("RGB")
)
color = st.slider(
"`Grayscale background`", min_value=0, max_value=255, key="color"
)
repositioning_preview_widget(
img_cv2=img_cv2,
mask_cv2=mask_cv2,
Expand All @@ -275,7 +385,7 @@ def render_settings(self):
st.session_state["output_width"],
st.session_state["output_height"],
),
color=255,
color=color,
)

def render_output(self):
Expand Down Expand Up @@ -328,7 +438,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]:

state["raw_images"] = raw_images = []

yield f"Running {Text2ImgModels[request.selected_model].value}..."
yield f"Running {Img2ImgModels[request.selected_model].value}..."
state["output_images"] = controlnet(
selected_model=request.selected_model,
selected_controlnet_model=request.selected_controlnet_model,
Expand Down Expand Up @@ -364,12 +474,10 @@ def preview_description(self, state: dict) -> str:
"""

def get_raw_price(self, state: dict) -> int:
selected_model = state.get("selected_model", Text2ImgModels.dream_shaper.name)
total = 5
selected_model = state.get("selected_model", Img2ImgModels.dream_shaper.name)
total = 30
match selected_model:
case Text2ImgModels.deepfloyd_if.name:
total += 3
case Text2ImgModels.dall_e.name:
case Img2ImgModels.dall_e.name:
total += 10
return total * state.get("num_outputs", 1)

Expand Down Expand Up @@ -412,7 +520,7 @@ def generate_and_upload_qr_code(
out_obj_scale=request.obj_scale,
out_pos_x=request.obj_pos_x,
out_pos_y=request.obj_pos_y,
color=255,
color=request.color,
)

img_url = upload_file_from_bytes("cleaned_qr.png", cv2_img_to_bytes(img_cv2))
Expand All @@ -432,14 +540,40 @@ def download_qr_code_data(url: str) -> str:
return extract_qr_code_data(img)


def extract_qr_code_data(img: np.ndarray) -> str:
decoded = pyzbar.decode(img)
if not (decoded and decoded[0]):
def extract_qr_code_data(image: np.ndarray) -> str:
# cycle through different sizes and sharpnesses, etc. for a more robust qr code extraction
image = Image.fromarray(image)
found_qr_code = False
found_qr_code_data = False
x, y = image.size
for scalar in [0.1, 0.2, 0.5, 1]:
image_scaled = image.resize((int(round(x * scalar)), int(round(y * scalar))))
for sharpness in [0.1, 0.5, 1, 2]:
image_scaled_sharp = ImageEnhance.Sharpness(image_scaled).enhance(sharpness)
image_autocontrast = ImageOps.autocontrast(image_scaled_sharp)
image_inverted = ImageOps.invert(image_autocontrast)
image_grayscale = ImageOps.grayscale(image_autocontrast)
image_grayscale_inverted = ImageOps.grayscale(image_inverted)
for img in [
image_autocontrast,
image_inverted,
image_grayscale,
image_grayscale_inverted,
]:
decoded = pyzbar.decode(img)
if decoded and decoded[0]:
found_qr_code = True
info = decoded[0].data.decode()
if info:
found_qr_code_data = True
print(f"QR code data: {info}")
return info

if not found_qr_code:
raise InvalidQRCode("No QR code found in image")
info = decoded[0].data.decode()
if not info:

if not found_qr_code_data:
raise InvalidQRCode("No data found in QR code")
return info


class InvalidQRCode(AssertionError):
Expand Down