From 1fe2552e36bf296a999d55c59a5186d75201008d Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 14 Aug 2024 01:10:45 +0530 Subject: [PATCH] add max_frames setting on sadtalker and wav2lip --- common/llms.py | 2 +- retro/sadtalker.py | 5 ++++- retro/wav2lip.py | 3 +++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/common/llms.py b/common/llms.py index 49ec697..62f341e 100644 --- a/common/llms.py +++ b/common/llms.py @@ -107,12 +107,12 @@ def postprocess_new( records.append( { - "generated_text": text[prompt_length:], "usage": { "completion_tokens": len(sequence) - prompt_tokens, "prompt_tokens": prompt_tokens, "total_tokens": len(sequence), }, + "generated_text": text[prompt_length:], } ) diff --git a/retro/sadtalker.py b/retro/sadtalker.py index e725bdd..bba2cba 100644 --- a/retro/sadtalker.py +++ b/retro/sadtalker.py @@ -83,6 +83,7 @@ class SadtalkerInput(BaseModel): still: bool = ( False # can crop back to the original videos for the full body aniamtion ) + max_frames: typing.Optional[int] = None try: @@ -291,6 +292,7 @@ def sadtalker( background_enhancer=inputs.background_enhancer, preprocess=pipeline.preprocess, img_size=pipeline.size, + max_frames=inputs.max_frames, ) with open(result_path, "rb") as f: @@ -315,6 +317,7 @@ def animate_from_coeff_generate( background_enhancer, preprocess, img_size, + max_frames, ): source_image = x["source_image"].type(torch.FloatTensor) @@ -392,7 +395,7 @@ def animate_from_coeff_generate( roll_c_seq, ): for out_image in batch: - if out_meta.num_frames >= frame_num: + if max_frames and out_meta.num_frames > max_frames: break out_image = img_as_ubyte( out_image.data.cpu().numpy().transpose([1, 2, 0]).astype(np.float32) diff --git a/retro/wav2lip.py b/retro/wav2lip.py index 60afa49..eda45b5 100644 --- a/retro/wav2lip.py +++ b/retro/wav2lip.py @@ -82,6 +82,7 @@ class Wav2LipInputs(BaseModel): default=480, ) batch_size: int = 256 + max_frames: typing.Optional[int] = None @app.task(name="wav2lip") @@ -173,6 +174,8 @@ def main(model, detector, outfile: str, inputs: Wav2LipInputs): mel_chunks = get_mel_chunks(inputs.audio, fps) for idx in tqdm(range(0, len(mel_chunks), inputs.batch_size)): + if inputs.max_frames and idx >= inputs.max_frames: + break if is_static: frame_batch = [frame.copy()] * inputs.batch_size else: