diff --git a/daras_ai_v2/img_model_settings_widgets.py b/daras_ai_v2/img_model_settings_widgets.py index 17f5ea1c1..7e0e35d22 100644 --- a/daras_ai_v2/img_model_settings_widgets.py +++ b/daras_ai_v2/img_model_settings_widgets.py @@ -181,7 +181,7 @@ def controlnet_weight_setting( ) -def num_outputs_setting(selected_model: str = None): +def num_outputs_setting(selected_models: str | list[str] = None): col1, col2 = st.columns(2, gap="medium") with col1: st.slider( @@ -200,12 +200,35 @@ def num_outputs_setting(selected_model: str = None): """ ) with col2: - quality_setting(selected_model) + quality_setting(selected_models) -def quality_setting(selected_model=None): - if selected_model in [InpaintingModels.dall_e.name]: +def quality_setting(selected_models=None): + if not isinstance(selected_models, list): + selected_models = [selected_models] + if any( + [ + selected_model in [InpaintingModels.dall_e.name] + for selected_model in selected_models + ] + ): return + if any( + [ + selected_model in [Text2ImgModels.dall_e_3.name] + for selected_model in selected_models + ] + ): + st.selectbox( + """##### Quality""", + options=[ + "standard, natural", + "hd, natural", + "standard, vivid", + "hd, vivid", + ], + key="dalle_3_quality", + ) st.slider( label=""" ##### Quality diff --git a/daras_ai_v2/stable_diffusion.py b/daras_ai_v2/stable_diffusion.py index ee3ce21de..7fd640c2b 100644 --- a/daras_ai_v2/stable_diffusion.py +++ b/daras_ai_v2/stable_diffusion.py @@ -261,6 +261,7 @@ def text2img( prompt: str, num_outputs: int, num_inference_steps: int, + dalle_3_quality: str, width: int, height: int, seed: int = 42, @@ -277,11 +278,14 @@ def text2img( client = OpenAI() width, height = _get_dalle_3_img_size(width, height) + quality, style = dalle_3_quality.split(", ") response = client.images.generate( model=text2img_model_ids[Text2ImgModels[selected_model]], - n=num_outputs, + n=1, # num_outputs, not supported yet prompt=prompt, response_format="b64_json", + quality=quality, + style=style, size=f"{width}x{height}", ) out_imgs = [b64_img_decode(part.b64_json) for part in response.data] diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index 05fd37057..9c3ec4124 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -40,6 +40,7 @@ class CompareText2ImgPage(BasePage): "seed": 42, "sd_2_upscaling": False, "image_guidance_scale": 1.2, + "dalle_3_quality": "standard, vivid", } class RequestModel(BaseModel): @@ -51,6 +52,7 @@ class RequestModel(BaseModel): num_outputs: int | None quality: int | None + dalle_3_quality: str | None guidance_scale: float | None seed: int | None @@ -152,7 +154,7 @@ def render_settings(self): negative_prompt_setting() output_resolution_setting() - num_outputs_setting() + num_outputs_setting(st.session_state.get("selected_models", [])) sd_2_upscaling_setting() col1, col2 = st.columns(2) with col1: @@ -178,6 +180,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: prompt=request.text_prompt, num_outputs=request.num_outputs, num_inference_steps=request.quality, + dalle_3_quality=request.dalle_3_quality, width=request.output_width, height=request.output_height, guidance_scale=request.guidance_scale,