Skip to content

Commit

Permalink
add sd and hd settings
Browse files Browse the repository at this point in the history
  • Loading branch information
SanderGi committed Nov 14, 2023
1 parent 15c61f1 commit 33ef6c0
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 6 deletions.
31 changes: 27 additions & 4 deletions daras_ai_v2/img_model_settings_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def controlnet_weight_setting(
)


def num_outputs_setting(selected_model: str = None):
def num_outputs_setting(selected_models: str | list[str] = None):
col1, col2 = st.columns(2, gap="medium")
with col1:
st.slider(
Expand All @@ -200,12 +200,35 @@ def num_outputs_setting(selected_model: str = None):
"""
)
with col2:
quality_setting(selected_model)
quality_setting(selected_models)


def quality_setting(selected_model=None):
if selected_model in [InpaintingModels.dall_e.name]:
def quality_setting(selected_models=None):
if not isinstance(selected_models, list):
selected_models = [selected_models]
if any(
[
selected_model in [InpaintingModels.dall_e.name]
for selected_model in selected_models
]
):
return
if any(
[
selected_model in [Text2ImgModels.dall_e_3.name]
for selected_model in selected_models
]
):
st.selectbox(
"""##### Quality""",
options=[
"standard, natural",
"hd, natural",
"standard, vivid",
"hd, vivid",
],
key="dalle_3_quality",
)
st.slider(
label="""
##### Quality
Expand Down
6 changes: 5 additions & 1 deletion daras_ai_v2/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def text2img(
prompt: str,
num_outputs: int,
num_inference_steps: int,
dalle_3_quality: str,
width: int,
height: int,
seed: int = 42,
Expand All @@ -277,11 +278,14 @@ def text2img(

client = OpenAI()
width, height = _get_dalle_3_img_size(width, height)
quality, style = dalle_3_quality.split(", ")
response = client.images.generate(
model=text2img_model_ids[Text2ImgModels[selected_model]],
n=num_outputs,
n=1, # num_outputs, not supported yet
prompt=prompt,
response_format="b64_json",
quality=quality,
style=style,
size=f"{width}x{height}",
)
out_imgs = [b64_img_decode(part.b64_json) for part in response.data]
Expand Down
5 changes: 4 additions & 1 deletion recipes/CompareText2Img.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class CompareText2ImgPage(BasePage):
"seed": 42,
"sd_2_upscaling": False,
"image_guidance_scale": 1.2,
"dalle_3_quality": "standard, vivid",
}

class RequestModel(BaseModel):
Expand All @@ -51,6 +52,7 @@ class RequestModel(BaseModel):

num_outputs: int | None
quality: int | None
dalle_3_quality: str | None

guidance_scale: float | None
seed: int | None
Expand Down Expand Up @@ -152,7 +154,7 @@ def render_settings(self):

negative_prompt_setting()
output_resolution_setting()
num_outputs_setting()
num_outputs_setting(st.session_state.get("selected_models", []))
sd_2_upscaling_setting()
col1, col2 = st.columns(2)
with col1:
Expand All @@ -178,6 +180,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
prompt=request.text_prompt,
num_outputs=request.num_outputs,
num_inference_steps=request.quality,
dalle_3_quality=request.dalle_3_quality,
width=request.output_width,
height=request.output_height,
guidance_scale=request.guidance_scale,
Expand Down

0 comments on commit 33ef6c0

Please sign in to comment.