From 5a175e15ba641a1843380e9a56ebc82e0cd1a3f2 Mon Sep 17 00:00:00 2001
From: Justin John <34035011+justinjohn0306@users.noreply.github.com>
Date: Thu, 7 Sep 2023 20:30:19 +0530
Subject: [PATCH] Add files via upload
---
README.md | 18 +-
face_detection/detection/sfd/sfd_detector.py | 13 +-
inference.py | 441 ++++++++++---------
3 files changed, 258 insertions(+), 214 deletions(-)
diff --git a/README.md b/README.md
index 76e66198e..35f69daff 100644
--- a/README.md
+++ b/README.md
@@ -27,7 +27,7 @@ This code is part of the paper: _A Lip Sync Expert Is All You Need for Speech to
--------
**Disclaimer**
--------
-All results from this open-source code or our [demo website](https://bhaasha.iiit.ac.in/lipsync) should only be used for research/academic/personal purposes only. As the models are trained on the LRS2 dataset, any form of commercial use is strictly prohibited. For commercial requests please contact us directly!
+All results from this open-source code or our [demo website](https://bhaasha.iiit.ac.in/lipsync) should only be used for research/academic/personal purposes only. As the models are trained on the LRS2 dataset, any form of commercial use is strictly prohibhited. For commercial requests please contact us directly!
Prerequisites
-------------
@@ -55,8 +55,8 @@ The result is saved (by default) in `results/result_voice.mp4`. You can specify
##### Tips for better results:
- Experiment with the `--pads` argument to adjust the detected face bounding box. Often leads to improved results. You might need to increase the bottom padding to include the chin region. E.g. `--pads 0 20 0 0`.
-- If you see the mouth position dislocated or some weird artifacts such as two mouths, then it can be because of over-smoothing the face detections. Use the `--nosmooth` argument and give it another try.
-- Experiment with the `--resize_factor` argument, to get a lower-resolution video. Why? The models are trained on faces that were at a lower resolution. You might get better, visually pleasing results for 720p videos than for 1080p videos (in many cases, the latter works well too).
+- If you see the mouth position dislocated or some weird artifacts such as two mouths, then it can be because of over-smoothing the face detections. Use the `--nosmooth` argument and give another try.
+- Experiment with the `--resize_factor` argument, to get a lower resolution video. Why? The models are trained on faces which were at a lower resolution. You might get better, visually pleasing results for 720p videos than for 1080p videos (in many cases, the latter works well too).
- The Wav2Lip model without GAN usually needs more experimenting with the above two to get the most ideal results, and sometimes, can give you a better result as well.
Preparing LRS2 for training
@@ -78,7 +78,7 @@ Place the LRS2 filelists (train, val, test) `.txt` files in the `filelists/` fol
```bash
python preprocess.py --data_root data_root/main --preprocessed_root lrs2_preprocessed/
```
-Additional options like `batch_size` and the number of GPUs to use in parallel to use can also be set.
+Additional options like `batch_size` and number of GPUs to use in parallel to use can also be set.
##### Preprocessed LRS2 folder structure
```
@@ -99,12 +99,12 @@ You can download [the pre-trained weights](#getting-the-weights) if you want to
python color_syncnet_train.py --data_root lrs2_preprocessed/ --checkpoint_dir
```
##### Training the Wav2Lip models
-You can either train the model without the additional visual quality discriminator (< 1 day of training) or use the discriminator (~2 days). For the former, run:
+You can either train the model without the additional visual quality disriminator (< 1 day of training) or use the discriminator (~2 days). For the former, run:
```bash
python wav2lip_train.py --data_root lrs2_preprocessed/ --checkpoint_dir --syncnet_checkpoint_path
```
-To train with the visual quality discriminator, you should run `hq_wav2lip_train.py` instead. The arguments for both files are similar. In both cases, you can resume training as well. Look at `python wav2lip_train.py --help` for more details. You can also set additional less commonly-used hyper-parameters at the bottom of the `hparams.py` file.
+To train with the visual quality discriminator, you should run `hq_wav2lip_train.py` instead. The arguments for both the files are similar. In both the cases, you can resume training as well. Look at `python wav2lip_train.py --help` for more details. You can also set additional less commonly-used hyper-parameters at the bottom of the `hparams.py` file.
Training on datasets other than LRS2
------------------------------------
@@ -126,7 +126,7 @@ Please check the `evaluation/` folder for the instructions.
License and Citation
----------
-This repository can only be used for personal/research/non-commercial purposes. However, for commercial requests, please contact us directly at radrabha.m@research.iiit.ac.in or prajwal.k@research.iiit.ac.in. We have an HD model trained on a dataset allowing commercial usage. The size of the generated face will be 192 x 288 in our new model. Please cite the following paper if you use this repository:
+Theis repository can only be used for personal/research/non-commercial purposes. However, for commercial requests, please contact us directly at radrabha.m@research.iiit.ac.in or prajwal.k@research.iiit.ac.in. We have an HD model trained on a dataset allowing commercial usage. The size of the generated face will be 192 x 288 in our new model. Please cite the following paper if you use this repository:
```
@inproceedings{10.1145/3394171.3413532,
author = {Prajwal, K R and Mukhopadhyay, Rudrabha and Namboodiri, Vinay P. and Jawahar, C.V.},
@@ -147,6 +147,6 @@ series = {MM '20}
```
-Acknowledgments
+Acknowledgements
----------
-Parts of the code structure are inspired by this [TTS repository](https://github.com/r9y9/deepvoice3_pytorch). We thank the author for this wonderful code. The code for Face Detection has been taken from the [face_alignment](https://github.com/1adrianb/face-alignment) repository. We thank the authors for releasing their code and models. We thank [zabique](https://github.com/zabique) for the tutorial collab notebook.
+Parts of the code structure is inspired by this [TTS repository](https://github.com/r9y9/deepvoice3_pytorch). We thank the author for this wonderful code. The code for Face Detection has been taken from the [face_alignment](https://github.com/1adrianb/face-alignment) repository. We thank the authors for releasing their code and models. We thank [zabique](https://github.com/zabique) for the tutorial collab notebook.
diff --git a/face_detection/detection/sfd/sfd_detector.py b/face_detection/detection/sfd/sfd_detector.py
index 8fbce1525..d1776e4bf 100644
--- a/face_detection/detection/sfd/sfd_detector.py
+++ b/face_detection/detection/sfd/sfd_detector.py
@@ -14,8 +14,9 @@
class SFDDetector(FaceDetector):
- def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
- super(SFDDetector, self).__init__(device, verbose)
+ @classmethod
+ def load_model(cls, device):
+ path_to_detector = os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth')
# Initialise the face detector
if not os.path.isfile(path_to_detector):
@@ -23,10 +24,10 @@ def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path
else:
model_weights = torch.load(path_to_detector)
- self.face_detector = s3fd()
- self.face_detector.load_state_dict(model_weights)
- self.face_detector.to(device)
- self.face_detector.eval()
+ cls.face_detector = s3fd()
+ cls.face_detector.load_state_dict(model_weights)
+ cls.face_detector.to(device)
+ cls.face_detector.eval()
def detect_from_image(self, tensor_or_path):
image = self.tensor_or_path_to_ndarray(tensor_or_path)
diff --git a/inference.py b/inference.py
index 90692521e..bb1a61cd4 100644
--- a/inference.py
+++ b/inference.py
@@ -1,280 +1,323 @@
-from os import listdir, path
+import argparse
+import math
+import os
+import platform
+import subprocess
+
+import cv2
import numpy as np
-import scipy, cv2, os, sys, argparse, audio
-import json, subprocess, random, string
+import torch
from tqdm import tqdm
-from glob import glob
-import torch, face_detection
+
+import audio
+# from face_detect import face_rect
from models import Wav2Lip
-import platform
+
+from batch_face import RetinaFace
+from time import time
parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
parser.add_argument('--checkpoint_path', type=str,
- help='Name of saved checkpoint to load weights from', required=True)
+ help='Name of saved checkpoint to load weights from', required=True)
parser.add_argument('--face', type=str,
- help='Filepath of video/image that contains faces to use', required=True)
+ help='Filepath of video/image that contains faces to use', required=True)
parser.add_argument('--audio', type=str,
- help='Filepath of video/audio file to use as raw audio source', required=True)
+ help='Filepath of video/audio file to use as raw audio source', required=True)
parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
- default='results/result_voice.mp4')
+ default='results/result_voice.mp4')
parser.add_argument('--static', type=bool,
- help='If True, then use only first video frame for inference', default=False)
+ help='If True, then use only first video frame for inference', default=False)
parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
- default=25., required=False)
+ default=25., required=False)
parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
- help='Padding (top, bottom, left, right). Please adjust to include chin at least')
+ help='Padding (top, bottom, left, right). Please adjust to include chin at least')
-parser.add_argument('--face_det_batch_size', type=int,
- help='Batch size for face detection', default=16)
parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128)
-parser.add_argument('--resize_factor', default=1, type=int,
- help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
+# parser.add_argument('--resize_factor', default=1, type=int,
+# help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
+
+parser.add_argument('--out_height', default=480, type=int,
+ help='Output video height. Best results are obtained at 480 or 720')
-parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
- help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
- 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
+parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
+ help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
+ 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
- help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
- 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
+ help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
+ 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
parser.add_argument('--rotate', default=False, action='store_true',
- help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
- 'Use if you get a flipped result, despite feeding a normal looking video')
+ help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
+ 'Use if you get a flipped result, despite feeding a normal looking video')
parser.add_argument('--nosmooth', default=False, action='store_true',
- help='Prevent smoothing face detections over a short temporal window')
+ help='Prevent smoothing face detections over a short temporal window')
-args = parser.parse_args()
-args.img_size = 96
-
-if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
- args.static = True
def get_smoothened_boxes(boxes, T):
- for i in range(len(boxes)):
- if i + T > len(boxes):
- window = boxes[len(boxes) - T:]
- else:
- window = boxes[i : i + T]
- boxes[i] = np.mean(window, axis=0)
- return boxes
+ for i in range(len(boxes)):
+ if i + T > len(boxes):
+ window = boxes[len(boxes) - T:]
+ else:
+ window = boxes[i : i + T]
+ boxes[i] = np.mean(window, axis=0)
+ return boxes
def face_detect(images):
- detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
- flip_input=False, device=device)
-
- batch_size = args.face_det_batch_size
-
- while 1:
- predictions = []
- try:
- for i in tqdm(range(0, len(images), batch_size)):
- predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
- except RuntimeError:
- if batch_size == 1:
- raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
- batch_size //= 2
- print('Recovering from OOM error; New batch size: {}'.format(batch_size))
- continue
- break
-
- results = []
- pady1, pady2, padx1, padx2 = args.pads
- for rect, image in zip(predictions, images):
- if rect is None:
- cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
- raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
-
- y1 = max(0, rect[1] - pady1)
- y2 = min(image.shape[0], rect[3] + pady2)
- x1 = max(0, rect[0] - padx1)
- x2 = min(image.shape[1], rect[2] + padx2)
-
- results.append([x1, y1, x2, y2])
-
- boxes = np.array(results)
- if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
- results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
-
- del detector
- return results
+ results = []
+ pady1, pady2, padx1, padx2 = args.pads
+
+ s = time()
+
+ for image, rect in zip(images, face_rect(images)):
+ if rect is None:
+ cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
+ raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
+
+ y1 = max(0, rect[1] - pady1)
+ y2 = min(image.shape[0], rect[3] + pady2)
+ x1 = max(0, rect[0] - padx1)
+ x2 = min(image.shape[1], rect[2] + padx2)
+
+ results.append([x1, y1, x2, y2])
+
+ print('face detect time:', time() - s)
+
+ boxes = np.array(results)
+ if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
+
+ return results
+
def datagen(frames, mels):
- img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
- if args.box[0] == -1:
- if not args.static:
- face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
- else:
- face_det_results = face_detect([frames[0]])
- else:
- print('Using the specified bounding box instead of face detection...')
- y1, y2, x1, x2 = args.box
- face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
+ if args.box[0] == -1:
+ if not args.static:
+ face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
+ else:
+ face_det_results = face_detect([frames[0]])
+ else:
+ print('Using the specified bounding box instead of face detection...')
+ y1, y2, x1, x2 = args.box
+ face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
- for i, m in enumerate(mels):
- idx = 0 if args.static else i%len(frames)
- frame_to_save = frames[idx].copy()
- face, coords = face_det_results[idx].copy()
+ for i, m in enumerate(mels):
+ idx = 0 if args.static else i%len(frames)
+ frame_to_save = frames[idx].copy()
+ face, coords = face_det_results[idx].copy()
- face = cv2.resize(face, (args.img_size, args.img_size))
-
- img_batch.append(face)
- mel_batch.append(m)
- frame_batch.append(frame_to_save)
- coords_batch.append(coords)
+ face = cv2.resize(face, (args.img_size, args.img_size))
- if len(img_batch) >= args.wav2lip_batch_size:
- img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
+ img_batch.append(face)
+ mel_batch.append(m)
+ frame_batch.append(frame_to_save)
+ coords_batch.append(coords)
- img_masked = img_batch.copy()
- img_masked[:, args.img_size//2:] = 0
+ if len(img_batch) >= args.wav2lip_batch_size:
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
- img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
+ img_masked = img_batch.copy()
+ img_masked[:, args.img_size//2:] = 0
- yield img_batch, mel_batch, frame_batch, coords_batch
- img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
- if len(img_batch) > 0:
- img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
+ yield img_batch, mel_batch, frame_batch, coords_batch
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
- img_masked = img_batch.copy()
- img_masked[:, args.img_size//2:] = 0
+ if len(img_batch) > 0:
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
- img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
+ img_masked = img_batch.copy()
+ img_masked[:, args.img_size//2:] = 0
- yield img_batch, mel_batch, frame_batch, coords_batch
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
+
+ yield img_batch, mel_batch, frame_batch, coords_batch
mel_step_size = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))
def _load(checkpoint_path):
- if device == 'cuda':
- checkpoint = torch.load(checkpoint_path)
- else:
- checkpoint = torch.load(checkpoint_path,
- map_location=lambda storage, loc: storage)
- return checkpoint
+ if device == 'cuda':
+ checkpoint = torch.load(checkpoint_path)
+ else:
+ checkpoint = torch.load(checkpoint_path,
+ map_location=lambda storage, loc: storage)
+ return checkpoint
def load_model(path):
- model = Wav2Lip()
- print("Load checkpoint from: {}".format(path))
- checkpoint = _load(path)
- s = checkpoint["state_dict"]
- new_s = {}
- for k, v in s.items():
- new_s[k.replace('module.', '')] = v
- model.load_state_dict(new_s)
-
- model = model.to(device)
- return model.eval()
+ model = Wav2Lip()
+ print("Load checkpoint from: {}".format(path))
+ checkpoint = _load(path)
+ s = checkpoint["state_dict"]
+ new_s = {}
+ for k, v in s.items():
+ new_s[k.replace('module.', '')] = v
+ model.load_state_dict(new_s)
+
+ model = model.to(device)
+ return model.eval()
def main():
- if not os.path.isfile(args.face):
- raise ValueError('--face argument must be a valid path to video/image file')
+ args.img_size = 96
+
+ if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
+ args.static = True
+
+ if not os.path.isfile(args.face):
+ raise ValueError('--face argument must be a valid path to video/image file')
+
+ elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
+ full_frames = [cv2.imread(args.face)]
+ fps = args.fps
+
+ else:
+ video_stream = cv2.VideoCapture(args.face)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+
+ print('Reading video frames...')
+
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+
+ aspect_ratio = frame.shape[1] / frame.shape[0]
+ frame = cv2.resize(frame, (int(args.out_height * aspect_ratio), args.out_height))
+ # if args.resize_factor > 1:
+ # frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
+
+ if args.rotate:
+ frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
+
+ y1, y2, x1, x2 = args.crop
+ if x2 == -1: x2 = frame.shape[1]
+ if y2 == -1: y2 = frame.shape[0]
+
+ frame = frame[y1:y2, x1:x2]
+
+ full_frames.append(frame)
+
+ print ("Number of frames available for inference: "+str(len(full_frames)))
+
+ if not args.audio.endswith('.wav'):
+ print('Extracting raw audio...')
+ # command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')
+ # subprocess.call(command, shell=True)
+ subprocess.check_call([
+ "ffmpeg", "-y",
+ "-i", args.audio,
+ "temp/temp.wav",
+ ])
+ args.audio = 'temp/temp.wav'
+
+ wav = audio.load_wav(args.audio, 16000)
+ mel = audio.melspectrogram(wav)
+ print(mel.shape)
- elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
- full_frames = [cv2.imread(args.face)]
- fps = args.fps
+ if np.isnan(mel.reshape(-1)).sum() > 0:
+ raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
- else:
- video_stream = cv2.VideoCapture(args.face)
- fps = video_stream.get(cv2.CAP_PROP_FPS)
+ mel_chunks = []
+ mel_idx_multiplier = 80./fps
+ i = 0
+ while 1:
+ start_idx = int(i * mel_idx_multiplier)
+ if start_idx + mel_step_size > len(mel[0]):
+ mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
+ break
+ mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
+ i += 1
- print('Reading video frames...')
+ print("Length of mel chunks: {}".format(len(mel_chunks)))
- full_frames = []
- while 1:
- still_reading, frame = video_stream.read()
- if not still_reading:
- video_stream.release()
- break
- if args.resize_factor > 1:
- frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
+ full_frames = full_frames[:len(mel_chunks)]
- if args.rotate:
- frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
+ batch_size = args.wav2lip_batch_size
+ gen = datagen(full_frames.copy(), mel_chunks)
- y1, y2, x1, x2 = args.crop
- if x2 == -1: x2 = frame.shape[1]
- if y2 == -1: y2 = frame.shape[0]
+ s = time()
- frame = frame[y1:y2, x1:x2]
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
+ total=int(np.ceil(float(len(mel_chunks))/batch_size)))):
+ if i == 0:
+ frame_h, frame_w = full_frames[0].shape[:-1]
+ out = cv2.VideoWriter('temp/result.avi',
+ cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
- full_frames.append(frame)
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
- print ("Number of frames available for inference: "+str(len(full_frames)))
+ with torch.no_grad():
+ pred = model(mel_batch, img_batch)
- if not args.audio.endswith('.wav'):
- print('Extracting raw audio...')
- command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
- subprocess.call(command, shell=True)
- args.audio = 'temp/temp.wav'
+ for p, f, c in zip(pred, frames, coords):
+ y1, y2, x1, x2 = c
+ p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
- wav = audio.load_wav(args.audio, 16000)
- mel = audio.melspectrogram(wav)
- print(mel.shape)
+ f[y1:y2, x1:x2] = p
+ out.write(f)
- if np.isnan(mel.reshape(-1)).sum() > 0:
- raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
+ out.release()
- mel_chunks = []
- mel_idx_multiplier = 80./fps
- i = 0
- while 1:
- start_idx = int(i * mel_idx_multiplier)
- if start_idx + mel_step_size > len(mel[0]):
- mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
- break
- mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
- i += 1
+ print("wav2lip prediction time:", time() - s)
- print("Length of mel chunks: {}".format(len(mel_chunks)))
+ subprocess.check_call([
+ "ffmpeg", "-y",
+ # "-vsync", "0", "-hwaccel", "cuda", "-hwaccel_output_format", "cuda",
+ "-i", "temp/result.avi",
+ "-i", args.audio,
+ # "-c:v", "h264_nvenc",
+ args.outfile,
+ ])
- full_frames = full_frames[:len(mel_chunks)]
+model = detector = detector_model = None
- batch_size = args.wav2lip_batch_size
- gen = datagen(full_frames.copy(), mel_chunks)
+def do_load(checkpoint_path):
+ global model, detector, detector_model
- for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
- total=int(np.ceil(float(len(mel_chunks))/batch_size)))):
- if i == 0:
- model = load_model(args.checkpoint_path)
- print ("Model loaded")
+ model = load_model(checkpoint_path)
- frame_h, frame_w = full_frames[0].shape[:-1]
- out = cv2.VideoWriter('temp/result.avi',
- cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
+ # SFDDetector.load_model(device)
+ detector = RetinaFace(gpu_id=0, model_path="checkpoints/mobilenet.pth", network="mobilenet")
+ # detector = RetinaFace(gpu_id=0, model_path="checkpoints/resnet50.pth", network="resnet50")
- img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
- mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
+ detector_model = detector.model
- with torch.no_grad():
- pred = model(mel_batch, img_batch)
+ print("Models loaded")
- pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
-
- for p, f, c in zip(pred, frames, coords):
- y1, y2, x1, x2 = c
- p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
- f[y1:y2, x1:x2] = p
- out.write(f)
+face_batch_size = 64 * 8
- out.release()
+def face_rect(images):
+ num_batches = math.ceil(len(images) / face_batch_size)
+ prev_ret = None
+ for i in range(num_batches):
+ batch = images[i * face_batch_size: (i + 1) * face_batch_size]
+ all_faces = detector(batch) # return faces list of all images
+ for faces in all_faces:
+ if faces:
+ box, landmarks, score = faces[0]
+ prev_ret = tuple(map(int, box))
+ yield prev_ret
- command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile)
- subprocess.call(command, shell=platform.system() != 'Windows')
if __name__ == '__main__':
- main()
+ args = parser.parse_args()
+ do_load(args.checkpoint_path)
+ main()