Skip to content

Commit

Permalink
add max_frames setting on sadtalker and wav2lip
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy committed Aug 13, 2024
1 parent 0773b08 commit 1fe2552
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
2 changes: 1 addition & 1 deletion common/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ def postprocess_new(

records.append(
{
"generated_text": text[prompt_length:],
"usage": {
"completion_tokens": len(sequence) - prompt_tokens,
"prompt_tokens": prompt_tokens,
"total_tokens": len(sequence),
},
"generated_text": text[prompt_length:],
}
)

Expand Down
5 changes: 4 additions & 1 deletion retro/sadtalker.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class SadtalkerInput(BaseModel):
still: bool = (
False # can crop back to the original videos for the full body aniamtion
)
max_frames: typing.Optional[int] = None


try:
Expand Down Expand Up @@ -291,6 +292,7 @@ def sadtalker(
background_enhancer=inputs.background_enhancer,
preprocess=pipeline.preprocess,
img_size=pipeline.size,
max_frames=inputs.max_frames,
)

with open(result_path, "rb") as f:
Expand All @@ -315,6 +317,7 @@ def animate_from_coeff_generate(
background_enhancer,
preprocess,
img_size,
max_frames,
):

source_image = x["source_image"].type(torch.FloatTensor)
Expand Down Expand Up @@ -392,7 +395,7 @@ def animate_from_coeff_generate(
roll_c_seq,
):
for out_image in batch:
if out_meta.num_frames >= frame_num:
if max_frames and out_meta.num_frames > max_frames:
break
out_image = img_as_ubyte(
out_image.data.cpu().numpy().transpose([1, 2, 0]).astype(np.float32)
Expand Down
3 changes: 3 additions & 0 deletions retro/wav2lip.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class Wav2LipInputs(BaseModel):
default=480,
)
batch_size: int = 256
max_frames: typing.Optional[int] = None


@app.task(name="wav2lip")
Expand Down Expand Up @@ -173,6 +174,8 @@ def main(model, detector, outfile: str, inputs: Wav2LipInputs):

mel_chunks = get_mel_chunks(inputs.audio, fps)
for idx in tqdm(range(0, len(mel_chunks), inputs.batch_size)):
if inputs.max_frames and idx >= inputs.max_frames:
break
if is_static:
frame_batch = [frame.copy()] * inputs.batch_size
else:
Expand Down

0 comments on commit 1fe2552

Please sign in to comment.