From 5a175e15ba641a1843380e9a56ebc82e0cd1a3f2 Mon Sep 17 00:00:00 2001 From: Justin John <34035011+justinjohn0306@users.noreply.github.com> Date: Thu, 7 Sep 2023 20:30:19 +0530 Subject: [PATCH] Add files via upload --- README.md | 18 +- face_detection/detection/sfd/sfd_detector.py | 13 +- inference.py | 441 ++++++++++--------- 3 files changed, 258 insertions(+), 214 deletions(-) diff --git a/README.md b/README.md index 76e66198e..35f69daff 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ This code is part of the paper: _A Lip Sync Expert Is All You Need for Speech to -------- **Disclaimer** -------- -All results from this open-source code or our [demo website](https://bhaasha.iiit.ac.in/lipsync) should only be used for research/academic/personal purposes only. As the models are trained on the LRS2 dataset, any form of commercial use is strictly prohibited. For commercial requests please contact us directly! +All results from this open-source code or our [demo website](https://bhaasha.iiit.ac.in/lipsync) should only be used for research/academic/personal purposes only. As the models are trained on the LRS2 dataset, any form of commercial use is strictly prohibhited. For commercial requests please contact us directly! Prerequisites ------------- @@ -55,8 +55,8 @@ The result is saved (by default) in `results/result_voice.mp4`. You can specify ##### Tips for better results: - Experiment with the `--pads` argument to adjust the detected face bounding box. Often leads to improved results. You might need to increase the bottom padding to include the chin region. E.g. `--pads 0 20 0 0`. -- If you see the mouth position dislocated or some weird artifacts such as two mouths, then it can be because of over-smoothing the face detections. Use the `--nosmooth` argument and give it another try. -- Experiment with the `--resize_factor` argument, to get a lower-resolution video. Why? The models are trained on faces that were at a lower resolution. You might get better, visually pleasing results for 720p videos than for 1080p videos (in many cases, the latter works well too). +- If you see the mouth position dislocated or some weird artifacts such as two mouths, then it can be because of over-smoothing the face detections. Use the `--nosmooth` argument and give another try. +- Experiment with the `--resize_factor` argument, to get a lower resolution video. Why? The models are trained on faces which were at a lower resolution. You might get better, visually pleasing results for 720p videos than for 1080p videos (in many cases, the latter works well too). - The Wav2Lip model without GAN usually needs more experimenting with the above two to get the most ideal results, and sometimes, can give you a better result as well. Preparing LRS2 for training @@ -78,7 +78,7 @@ Place the LRS2 filelists (train, val, test) `.txt` files in the `filelists/` fol ```bash python preprocess.py --data_root data_root/main --preprocessed_root lrs2_preprocessed/ ``` -Additional options like `batch_size` and the number of GPUs to use in parallel to use can also be set. +Additional options like `batch_size` and number of GPUs to use in parallel to use can also be set. ##### Preprocessed LRS2 folder structure ``` @@ -99,12 +99,12 @@ You can download [the pre-trained weights](#getting-the-weights) if you want to python color_syncnet_train.py --data_root lrs2_preprocessed/ --checkpoint_dir ``` ##### Training the Wav2Lip models -You can either train the model without the additional visual quality discriminator (< 1 day of training) or use the discriminator (~2 days). For the former, run: +You can either train the model without the additional visual quality disriminator (< 1 day of training) or use the discriminator (~2 days). For the former, run: ```bash python wav2lip_train.py --data_root lrs2_preprocessed/ --checkpoint_dir --syncnet_checkpoint_path ``` -To train with the visual quality discriminator, you should run `hq_wav2lip_train.py` instead. The arguments for both files are similar. In both cases, you can resume training as well. Look at `python wav2lip_train.py --help` for more details. You can also set additional less commonly-used hyper-parameters at the bottom of the `hparams.py` file. +To train with the visual quality discriminator, you should run `hq_wav2lip_train.py` instead. The arguments for both the files are similar. In both the cases, you can resume training as well. Look at `python wav2lip_train.py --help` for more details. You can also set additional less commonly-used hyper-parameters at the bottom of the `hparams.py` file. Training on datasets other than LRS2 ------------------------------------ @@ -126,7 +126,7 @@ Please check the `evaluation/` folder for the instructions. License and Citation ---------- -This repository can only be used for personal/research/non-commercial purposes. However, for commercial requests, please contact us directly at radrabha.m@research.iiit.ac.in or prajwal.k@research.iiit.ac.in. We have an HD model trained on a dataset allowing commercial usage. The size of the generated face will be 192 x 288 in our new model. Please cite the following paper if you use this repository: +Theis repository can only be used for personal/research/non-commercial purposes. However, for commercial requests, please contact us directly at radrabha.m@research.iiit.ac.in or prajwal.k@research.iiit.ac.in. We have an HD model trained on a dataset allowing commercial usage. The size of the generated face will be 192 x 288 in our new model. Please cite the following paper if you use this repository: ``` @inproceedings{10.1145/3394171.3413532, author = {Prajwal, K R and Mukhopadhyay, Rudrabha and Namboodiri, Vinay P. and Jawahar, C.V.}, @@ -147,6 +147,6 @@ series = {MM '20} ``` -Acknowledgments +Acknowledgements ---------- -Parts of the code structure are inspired by this [TTS repository](https://github.com/r9y9/deepvoice3_pytorch). We thank the author for this wonderful code. The code for Face Detection has been taken from the [face_alignment](https://github.com/1adrianb/face-alignment) repository. We thank the authors for releasing their code and models. We thank [zabique](https://github.com/zabique) for the tutorial collab notebook. +Parts of the code structure is inspired by this [TTS repository](https://github.com/r9y9/deepvoice3_pytorch). We thank the author for this wonderful code. The code for Face Detection has been taken from the [face_alignment](https://github.com/1adrianb/face-alignment) repository. We thank the authors for releasing their code and models. We thank [zabique](https://github.com/zabique) for the tutorial collab notebook. diff --git a/face_detection/detection/sfd/sfd_detector.py b/face_detection/detection/sfd/sfd_detector.py index 8fbce1525..d1776e4bf 100644 --- a/face_detection/detection/sfd/sfd_detector.py +++ b/face_detection/detection/sfd/sfd_detector.py @@ -14,8 +14,9 @@ class SFDDetector(FaceDetector): - def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False): - super(SFDDetector, self).__init__(device, verbose) + @classmethod + def load_model(cls, device): + path_to_detector = os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth') # Initialise the face detector if not os.path.isfile(path_to_detector): @@ -23,10 +24,10 @@ def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path else: model_weights = torch.load(path_to_detector) - self.face_detector = s3fd() - self.face_detector.load_state_dict(model_weights) - self.face_detector.to(device) - self.face_detector.eval() + cls.face_detector = s3fd() + cls.face_detector.load_state_dict(model_weights) + cls.face_detector.to(device) + cls.face_detector.eval() def detect_from_image(self, tensor_or_path): image = self.tensor_or_path_to_ndarray(tensor_or_path) diff --git a/inference.py b/inference.py index 90692521e..bb1a61cd4 100644 --- a/inference.py +++ b/inference.py @@ -1,280 +1,323 @@ -from os import listdir, path +import argparse +import math +import os +import platform +import subprocess + +import cv2 import numpy as np -import scipy, cv2, os, sys, argparse, audio -import json, subprocess, random, string +import torch from tqdm import tqdm -from glob import glob -import torch, face_detection + +import audio +# from face_detect import face_rect from models import Wav2Lip -import platform + +from batch_face import RetinaFace +from time import time parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models') parser.add_argument('--checkpoint_path', type=str, - help='Name of saved checkpoint to load weights from', required=True) + help='Name of saved checkpoint to load weights from', required=True) parser.add_argument('--face', type=str, - help='Filepath of video/image that contains faces to use', required=True) + help='Filepath of video/image that contains faces to use', required=True) parser.add_argument('--audio', type=str, - help='Filepath of video/audio file to use as raw audio source', required=True) + help='Filepath of video/audio file to use as raw audio source', required=True) parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.', - default='results/result_voice.mp4') + default='results/result_voice.mp4') parser.add_argument('--static', type=bool, - help='If True, then use only first video frame for inference', default=False) + help='If True, then use only first video frame for inference', default=False) parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)', - default=25., required=False) + default=25., required=False) parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0], - help='Padding (top, bottom, left, right). Please adjust to include chin at least') + help='Padding (top, bottom, left, right). Please adjust to include chin at least') -parser.add_argument('--face_det_batch_size', type=int, - help='Batch size for face detection', default=16) parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128) -parser.add_argument('--resize_factor', default=1, type=int, - help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p') +# parser.add_argument('--resize_factor', default=1, type=int, +# help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p') + +parser.add_argument('--out_height', default=480, type=int, + help='Output video height. Best results are obtained at 480 or 720') -parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1], - help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. ' - 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width') +parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1], + help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. ' + 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width') parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1], - help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.' - 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).') + help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.' + 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).') parser.add_argument('--rotate', default=False, action='store_true', - help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.' - 'Use if you get a flipped result, despite feeding a normal looking video') + help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.' + 'Use if you get a flipped result, despite feeding a normal looking video') parser.add_argument('--nosmooth', default=False, action='store_true', - help='Prevent smoothing face detections over a short temporal window') + help='Prevent smoothing face detections over a short temporal window') -args = parser.parse_args() -args.img_size = 96 - -if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']: - args.static = True def get_smoothened_boxes(boxes, T): - for i in range(len(boxes)): - if i + T > len(boxes): - window = boxes[len(boxes) - T:] - else: - window = boxes[i : i + T] - boxes[i] = np.mean(window, axis=0) - return boxes + for i in range(len(boxes)): + if i + T > len(boxes): + window = boxes[len(boxes) - T:] + else: + window = boxes[i : i + T] + boxes[i] = np.mean(window, axis=0) + return boxes def face_detect(images): - detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, - flip_input=False, device=device) - - batch_size = args.face_det_batch_size - - while 1: - predictions = [] - try: - for i in tqdm(range(0, len(images), batch_size)): - predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) - except RuntimeError: - if batch_size == 1: - raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument') - batch_size //= 2 - print('Recovering from OOM error; New batch size: {}'.format(batch_size)) - continue - break - - results = [] - pady1, pady2, padx1, padx2 = args.pads - for rect, image in zip(predictions, images): - if rect is None: - cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected. - raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') - - y1 = max(0, rect[1] - pady1) - y2 = min(image.shape[0], rect[3] + pady2) - x1 = max(0, rect[0] - padx1) - x2 = min(image.shape[1], rect[2] + padx2) - - results.append([x1, y1, x2, y2]) - - boxes = np.array(results) - if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5) - results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] - - del detector - return results + results = [] + pady1, pady2, padx1, padx2 = args.pads + + s = time() + + for image, rect in zip(images, face_rect(images)): + if rect is None: + cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected. + raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') + + y1 = max(0, rect[1] - pady1) + y2 = min(image.shape[0], rect[3] + pady2) + x1 = max(0, rect[0] - padx1) + x2 = min(image.shape[1], rect[2] + padx2) + + results.append([x1, y1, x2, y2]) + + print('face detect time:', time() - s) + + boxes = np.array(results) + if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5) + results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] + + return results + def datagen(frames, mels): - img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] - if args.box[0] == -1: - if not args.static: - face_det_results = face_detect(frames) # BGR2RGB for CNN face detection - else: - face_det_results = face_detect([frames[0]]) - else: - print('Using the specified bounding box instead of face detection...') - y1, y2, x1, x2 = args.box - face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames] + if args.box[0] == -1: + if not args.static: + face_det_results = face_detect(frames) # BGR2RGB for CNN face detection + else: + face_det_results = face_detect([frames[0]]) + else: + print('Using the specified bounding box instead of face detection...') + y1, y2, x1, x2 = args.box + face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames] - for i, m in enumerate(mels): - idx = 0 if args.static else i%len(frames) - frame_to_save = frames[idx].copy() - face, coords = face_det_results[idx].copy() + for i, m in enumerate(mels): + idx = 0 if args.static else i%len(frames) + frame_to_save = frames[idx].copy() + face, coords = face_det_results[idx].copy() - face = cv2.resize(face, (args.img_size, args.img_size)) - - img_batch.append(face) - mel_batch.append(m) - frame_batch.append(frame_to_save) - coords_batch.append(coords) + face = cv2.resize(face, (args.img_size, args.img_size)) - if len(img_batch) >= args.wav2lip_batch_size: - img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + img_batch.append(face) + mel_batch.append(m) + frame_batch.append(frame_to_save) + coords_batch.append(coords) - img_masked = img_batch.copy() - img_masked[:, args.img_size//2:] = 0 + if len(img_batch) >= args.wav2lip_batch_size: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) - img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. - mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + img_masked = img_batch.copy() + img_masked[:, args.img_size//2:] = 0 - yield img_batch, mel_batch, frame_batch, coords_batch - img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) - if len(img_batch) > 0: - img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + yield img_batch, mel_batch, frame_batch, coords_batch + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] - img_masked = img_batch.copy() - img_masked[:, args.img_size//2:] = 0 + if len(img_batch) > 0: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) - img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. - mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + img_masked = img_batch.copy() + img_masked[:, args.img_size//2:] = 0 - yield img_batch, mel_batch, frame_batch, coords_batch + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch mel_step_size = 16 device = 'cuda' if torch.cuda.is_available() else 'cpu' print('Using {} for inference.'.format(device)) def _load(checkpoint_path): - if device == 'cuda': - checkpoint = torch.load(checkpoint_path) - else: - checkpoint = torch.load(checkpoint_path, - map_location=lambda storage, loc: storage) - return checkpoint + if device == 'cuda': + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, + map_location=lambda storage, loc: storage) + return checkpoint def load_model(path): - model = Wav2Lip() - print("Load checkpoint from: {}".format(path)) - checkpoint = _load(path) - s = checkpoint["state_dict"] - new_s = {} - for k, v in s.items(): - new_s[k.replace('module.', '')] = v - model.load_state_dict(new_s) - - model = model.to(device) - return model.eval() + model = Wav2Lip() + print("Load checkpoint from: {}".format(path)) + checkpoint = _load(path) + s = checkpoint["state_dict"] + new_s = {} + for k, v in s.items(): + new_s[k.replace('module.', '')] = v + model.load_state_dict(new_s) + + model = model.to(device) + return model.eval() def main(): - if not os.path.isfile(args.face): - raise ValueError('--face argument must be a valid path to video/image file') + args.img_size = 96 + + if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']: + args.static = True + + if not os.path.isfile(args.face): + raise ValueError('--face argument must be a valid path to video/image file') + + elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']: + full_frames = [cv2.imread(args.face)] + fps = args.fps + + else: + video_stream = cv2.VideoCapture(args.face) + fps = video_stream.get(cv2.CAP_PROP_FPS) + + print('Reading video frames...') + + full_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + + aspect_ratio = frame.shape[1] / frame.shape[0] + frame = cv2.resize(frame, (int(args.out_height * aspect_ratio), args.out_height)) + # if args.resize_factor > 1: + # frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor)) + + if args.rotate: + frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE) + + y1, y2, x1, x2 = args.crop + if x2 == -1: x2 = frame.shape[1] + if y2 == -1: y2 = frame.shape[0] + + frame = frame[y1:y2, x1:x2] + + full_frames.append(frame) + + print ("Number of frames available for inference: "+str(len(full_frames))) + + if not args.audio.endswith('.wav'): + print('Extracting raw audio...') + # command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav') + # subprocess.call(command, shell=True) + subprocess.check_call([ + "ffmpeg", "-y", + "-i", args.audio, + "temp/temp.wav", + ]) + args.audio = 'temp/temp.wav' + + wav = audio.load_wav(args.audio, 16000) + mel = audio.melspectrogram(wav) + print(mel.shape) - elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']: - full_frames = [cv2.imread(args.face)] - fps = args.fps + if np.isnan(mel.reshape(-1)).sum() > 0: + raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') - else: - video_stream = cv2.VideoCapture(args.face) - fps = video_stream.get(cv2.CAP_PROP_FPS) + mel_chunks = [] + mel_idx_multiplier = 80./fps + i = 0 + while 1: + start_idx = int(i * mel_idx_multiplier) + if start_idx + mel_step_size > len(mel[0]): + mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) + break + mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size]) + i += 1 - print('Reading video frames...') + print("Length of mel chunks: {}".format(len(mel_chunks))) - full_frames = [] - while 1: - still_reading, frame = video_stream.read() - if not still_reading: - video_stream.release() - break - if args.resize_factor > 1: - frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor)) + full_frames = full_frames[:len(mel_chunks)] - if args.rotate: - frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE) + batch_size = args.wav2lip_batch_size + gen = datagen(full_frames.copy(), mel_chunks) - y1, y2, x1, x2 = args.crop - if x2 == -1: x2 = frame.shape[1] - if y2 == -1: y2 = frame.shape[0] + s = time() - frame = frame[y1:y2, x1:x2] + for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen, + total=int(np.ceil(float(len(mel_chunks))/batch_size)))): + if i == 0: + frame_h, frame_w = full_frames[0].shape[:-1] + out = cv2.VideoWriter('temp/result.avi', + cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h)) - full_frames.append(frame) + img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) + mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) - print ("Number of frames available for inference: "+str(len(full_frames))) + with torch.no_grad(): + pred = model(mel_batch, img_batch) - if not args.audio.endswith('.wav'): - print('Extracting raw audio...') - command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav') + pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. - subprocess.call(command, shell=True) - args.audio = 'temp/temp.wav' + for p, f, c in zip(pred, frames, coords): + y1, y2, x1, x2 = c + p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) - wav = audio.load_wav(args.audio, 16000) - mel = audio.melspectrogram(wav) - print(mel.shape) + f[y1:y2, x1:x2] = p + out.write(f) - if np.isnan(mel.reshape(-1)).sum() > 0: - raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') + out.release() - mel_chunks = [] - mel_idx_multiplier = 80./fps - i = 0 - while 1: - start_idx = int(i * mel_idx_multiplier) - if start_idx + mel_step_size > len(mel[0]): - mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) - break - mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size]) - i += 1 + print("wav2lip prediction time:", time() - s) - print("Length of mel chunks: {}".format(len(mel_chunks))) + subprocess.check_call([ + "ffmpeg", "-y", + # "-vsync", "0", "-hwaccel", "cuda", "-hwaccel_output_format", "cuda", + "-i", "temp/result.avi", + "-i", args.audio, + # "-c:v", "h264_nvenc", + args.outfile, + ]) - full_frames = full_frames[:len(mel_chunks)] +model = detector = detector_model = None - batch_size = args.wav2lip_batch_size - gen = datagen(full_frames.copy(), mel_chunks) +def do_load(checkpoint_path): + global model, detector, detector_model - for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen, - total=int(np.ceil(float(len(mel_chunks))/batch_size)))): - if i == 0: - model = load_model(args.checkpoint_path) - print ("Model loaded") + model = load_model(checkpoint_path) - frame_h, frame_w = full_frames[0].shape[:-1] - out = cv2.VideoWriter('temp/result.avi', - cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h)) + # SFDDetector.load_model(device) + detector = RetinaFace(gpu_id=0, model_path="checkpoints/mobilenet.pth", network="mobilenet") + # detector = RetinaFace(gpu_id=0, model_path="checkpoints/resnet50.pth", network="resnet50") - img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) - mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) + detector_model = detector.model - with torch.no_grad(): - pred = model(mel_batch, img_batch) + print("Models loaded") - pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. - - for p, f, c in zip(pred, frames, coords): - y1, y2, x1, x2 = c - p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) - f[y1:y2, x1:x2] = p - out.write(f) +face_batch_size = 64 * 8 - out.release() +def face_rect(images): + num_batches = math.ceil(len(images) / face_batch_size) + prev_ret = None + for i in range(num_batches): + batch = images[i * face_batch_size: (i + 1) * face_batch_size] + all_faces = detector(batch) # return faces list of all images + for faces in all_faces: + if faces: + box, landmarks, score = faces[0] + prev_ret = tuple(map(int, box)) + yield prev_ret - command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile) - subprocess.call(command, shell=platform.system() != 'Windows') if __name__ == '__main__': - main() + args = parser.parse_args() + do_load(args.checkpoint_path) + main()