Skip to content

Commit

Permalink
added non square resolutions for dalle3
Browse files Browse the repository at this point in the history
  • Loading branch information
SanderGi committed Nov 14, 2023
1 parent b849a6f commit 15c61f1
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
20 changes: 18 additions & 2 deletions daras_ai_v2/img_model_settings_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,10 @@ def quality_setting(selected_model=None):
)


RESOLUTIONS = {
RESOLUTIONS: dict[int, dict[str, str]] = {
256: {
"256, 256": "square",
},
512: {
"512, 512": "square",
"576, 448": "A4",
Expand All @@ -247,6 +250,7 @@ def quality_setting(selected_model=None):
"1536, 512": "smartphone",
"1792, 512": "cinema",
"2048, 512": "panorama",
"1792, 1024": "wide",
},
}
LANDSCAPE = "Landscape"
Expand Down Expand Up @@ -281,12 +285,19 @@ def output_resolution_setting():
st.session_state.get("selected_model", st.session_state.get("selected_models"))
or ""
)
allowed_shapes = RESOLUTIONS[st.session_state["__pixels"]].values()
if not isinstance(selected_models, list):
selected_models = [selected_models]
if "jack_qiao" in selected_models or "sd_1_4" in selected_models:
pixel_options = [512]
elif selected_models == ["deepfloyd_if"]:
pixel_options = [1024]
elif selected_models == ["dall_e"]:
pixel_options = [256, 512, 1024]
allowed_shapes = ["square"]
elif selected_models == ["dall_e_3"]:
pixel_options = [1024]
allowed_shapes = ["square", "wide"]
else:
pixel_options = [512, 768]

Expand All @@ -298,11 +309,16 @@ def output_resolution_setting():
options=pixel_options,
)
with col2:
res_options = [
key
for key, val in RESOLUTIONS[pixels or pixel_options[0]].items()
if val in allowed_shapes
]
res = st.selectbox(
"##### Resolution",
key="__res",
format_func=lambda r: f"{r.split(', ')[0]} x {r.split(', ')[1]} ({RESOLUTIONS[pixels][r]})",
options=list(RESOLUTIONS[pixels].keys()),
options=res_options,
)
res = tuple(map(int, res.split(", ")))

Expand Down
14 changes: 13 additions & 1 deletion daras_ai_v2/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,18 +268,21 @@ def text2img(
negative_prompt: str = None,
scheduler: str = None,
):
_resolution_check(width, height, max_size=(1024, 1024))
if selected_model != Text2ImgModels.dall_e_3.name:
_resolution_check(width, height, max_size=(1024, 1024))

match selected_model:
case Text2ImgModels.dall_e_3.name:
from openai import OpenAI

client = OpenAI()
width, height = _get_dalle_3_img_size(width, height)
response = client.images.generate(
model=text2img_model_ids[Text2ImgModels[selected_model]],
n=num_outputs,
prompt=prompt,
response_format="b64_json",
size=f"{width}x{height}",
)
out_imgs = [b64_img_decode(part.b64_json) for part in response.data]
case Text2ImgModels.dall_e.name:
Expand Down Expand Up @@ -332,6 +335,15 @@ def _get_dalle_img_size(width: int, height: int) -> int:
return edge


def _get_dalle_3_img_size(width: int, height: int) -> tuple[int, int]:
if height == width:
return 1024, 1024
elif width < height:
return 1024, 1792
else:
return 1792, 1024


def img2img(
*,
selected_model: str,
Expand Down

0 comments on commit 15c61f1

Please sign in to comment.