Skip to content

Commit

Permalink
Merge pull request #380 from GooeyAI/sadtalker_settings_tweaks
Browse files Browse the repository at this point in the history
Sadtalker Settings Tweaks
  • Loading branch information
devxpy authored Jul 3, 2024
2 parents 15c5fc8 + 1e97905 commit a8b270f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 49 deletions.
42 changes: 33 additions & 9 deletions daras_ai_v2/lipsync_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,44 @@ class LipsyncModel(Enum):

class SadTalkerSettings(BaseModel):
still: bool = Field(
False, title="Still (fewer head motion, works with preprocess 'full')"
True, title="Still (fewer head motion, works with preprocess 'full')"
)
preprocess: typing.Literal["crop", "extcrop", "resize", "full", "extfull"] = Field(
"crop", title="Preprocess"
"resize",
title="Preprocess",
description="SadTalker only generates 512x512 output. 'crop' handles this by cropping the input to 512x512. 'resize' scales down the input to fit 512x512 and scales it back up after lipsyncing (does not work well for full person images, better for portraits). 'full' processes the cropped region and pastes it back into the original input. 'extcrop' and 'extfull' are similar to 'crop' and 'full' but with extended cropping.",
)
pose_style: int = Field(
0,
title="Pose Style",
description="Random seed 0-45 inclusive that affects how the pose is animated.",
)
expression_scale: float = Field(
1.0,
title="Expression Scale",
description="Scale the amount of expression motion. 1.0 is normal, 0.5 is very reduced, and 2.0 is quite a lot.",
)
ref_eyeblink: FieldHttpUrl = Field(
None,
title="Reference Eyeblink",
description="Optional reference video for eyeblinks to make the eyebrow movement more natural.",
)
ref_pose: FieldHttpUrl = Field(
None,
title="Reference Pose",
description="Optional reference video to pose the head.",
)
pose_style: int = Field(0, title="Pose Style")
expression_scale: float = Field(1.0, title="Expression Scale")
ref_eyeblink: FieldHttpUrl = Field(None, title="Reference Eyeblink")
ref_pose: FieldHttpUrl = Field(None, title="Reference Pose")
input_yaw: list[int] = Field(None, title="Input Yaw (comma separated)")
input_pitch: list[int] = Field(None, title="Input Pitch (comma separated)")
input_roll: list[int] = Field(None, title="Input Roll (comma separated)")
# enhancer: typing.Literal["gfpgan", "RestoreFormer"] =None
# background_enhancer: typing.Literal["realesrgan"] =None
input_yaw: list[int] = Field(
None, title="Input Yaw (comma separated)", deprecated=True
)
input_pitch: list[int] = Field(
None, title="Input Pitch (comma separated)", deprecated=True
)
input_roll: list[int] = Field(
None, title="Input Roll (comma separated)", deprecated=True
)


class LipsyncSettings(BaseModel):
Expand Down
44 changes: 4 additions & 40 deletions daras_ai_v2/lipsync_settings_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ def sadtalker_settings(settings: SadTalkerSettings):
**field_label_val(settings, "expression_scale"),
)

# st.selectbox("Face Enhancer", [None, "gfpgan", "RestoreFormer"], value=settings.enhancer)
# st.selectbox("Background Enhancer", [None, "realesrgan"], value=settings.background_enhancer)

settings.ref_eyeblink = (
st.file_uploader(
**field_label_val(settings, "ref_eyeblink"),
Expand All @@ -51,45 +48,12 @@ def sadtalker_settings(settings: SadTalkerSettings):
)

settings.ref_pose = (
st.file_uploader("Reference Pose", value=settings.ref_pose, accept=[".mp4"])
or None
)

input_yaw = st.text_input(
"Input Yaw (comma separated)",
value=", ".join(map(str, settings.input_yaw or [])),
)
try:
settings.input_yaw = (
list(map(int, filter(None, input_yaw.strip().split(",")))) or None
)
except ValueError:
settings.input_yaw = None
st.error("Please enter comma separated integers for Input Yaw")

input_pitch = st.text_input(
"Input Pitch (comma separated)",
value=", ".join(map(str, settings.input_pitch or [])),
)
try:
settings.input_pitch = (
list(map(int, filter(None, input_pitch.strip().split(",")))) or None
st.file_uploader(
**field_label_val(settings, "ref_pose"),
accept=[".mp4"],
)
except ValueError:
settings.input_pitch = None
st.error("Please enter comma separated integers for Input Pitch")

input_roll = st.text_input(
"Input Roll (comma separated)",
value=", ".join(map(str, settings.input_roll or [])),
or None
)
try:
settings.input_roll = (
list(map(int, filter(None, input_roll.strip().split(",")))) or None
)
except ValueError:
settings.input_roll = None
st.error("Please enter comma separated integers for Input Roll")


def wav2lip_settings():
Expand Down

0 comments on commit a8b270f

Please sign in to comment.