-
Notifications
You must be signed in to change notification settings - Fork 9
/
inference.py
229 lines (205 loc) · 11.2 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
from glob import glob
import shutil
import torch
from time import strftime
import os, sys, time
from argparse import ArgumentParser
import platform
from src.utils.preprocess import CropAndExtract
from src.test_audio2coeff import Audio2Coeff
from src.facerender.animate import AnimateFromCoeff
from src.facerender.pirender_animate import AnimateFromCoeff_PIRender
from src.generate_batch import get_data
from src.generate_facerender_batch import get_facerender_data
from src.utils.init_path import init_path
from datetime import datetime
import json
def main(args):
import random
random.seed(319)
#torch.backends.cudnn.enabled = False
pic_path = args.source_image
audio_path = args.driven_audio
save_dir = os.path.join(args.result_dir, strftime("%Y_%m_%d_%H.%M.%S"))
os.makedirs(save_dir, exist_ok=True)
pose_style = args.pose_style
device = args.device
batch_size = args.batch_size
input_yaw_list = args.input_yaw
input_pitch_list = args.input_pitch
input_roll_list = args.input_roll
ref_eyeblink = args.ref_eyeblink
ref_pose = args.ref_pose
current_root_path = os.path.split(sys.argv[0])[0]
sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess)
#init model
timestamp = datetime.timestamp(datetime.now())
print("start to generate video...", timestamp)
import time
preprocess_model = CropAndExtract(sadtalker_paths, device)
start_time = time.time()
audio_to_coeff = Audio2Coeff(sadtalker_paths, device)
end_time = time.time()
print("0000: Audio2Coeff")
print(end_time - start_time)
start_time = end_time
# animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device, args.bf16)
if args.facerender == 'facevid2vid':
animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device, args.bf16)
elif args.facerender == 'pirender':
animate_from_coeff = AnimateFromCoeff_PIRender(sadtalker_paths, device, args.bf16) # TODO add bf16 here
else:
raise(RuntimeError('Unknown model: {}'.format(args.facerender)))
end_time = time.time()
print("0001: AnimateFromCoeff")
print(end_time - start_time)
start_time = end_time
#crop image and extract 3dmm from image
first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
os.makedirs(first_frame_dir, exist_ok=True)
print('3DMM Extraction for source image')
first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\
source_image_flag=True, pic_size=args.size)
end_time = time.time()
print("0002: preprocess_model generate")
print(end_time - start_time)
start_time = end_time
if first_coeff_path is None:
print("Can't get the coeffs of the input")
return
print("eyeblick? pose?")
print(ref_eyeblink)
print(ref_pose)
if ref_eyeblink is not None:
ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0]
ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname)
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
print('3DMM Extraction for the reference video providing eye blinking')
ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False)
else:
ref_eyeblink_coeff_path=None
if ref_pose is not None:
if ref_pose == ref_eyeblink:
ref_pose_coeff_path = ref_eyeblink_coeff_path
else:
ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)
os.makedirs(ref_pose_frame_dir, exist_ok=True)
print('3DMM Extraction for the reference video providing pose')
ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False)
else:
ref_pose_coeff_path=None
#audio2ceoff
batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)
coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
end_time = time.time()
print("0003: audio_to_coeff generate...")
print(end_time - start_time)
start_time = end_time
# 3dface render
if args.face3dvis:
from src.face3d.visualize import gen_composed_video
gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4'))
#coeff2video
if args.rank == 0:
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
batch_size, input_yaw_list, input_pitch_list, input_roll_list,
expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size,
facemodel=args.facerender)
shutil.rmtree("workspace", ignore_errors=True)
os.mkdir("workspace")
#dict_keys(['source_image', 'source_semantics', 'frame_num', 'target_semantics_list', 'video_name', 'audio_path'])
torch.save(data['source_image'], 'workspace/source_image.pt')
torch.save(data['source_semantics'], 'workspace/source_semantics.pt')
torch.save(data['target_semantics_list'], 'workspace/target_semantics_list.pt')
meta = {}
meta['frame_num'] = data['frame_num']
meta['video_name'] = data['video_name']
meta['audio_path'] = data['audio_path']
with open("workspace/meta.json", "w") as outfile:
json.dump(meta, outfile)
else:
data = {}
for pt_path in ['workspace/source_image.pt','workspace/source_semantics.pt', 'workspace/target_semantics_list.pt']:
while os.path.exists(pt_path) == False:
time.sleep(0.2)
pkey = pt_path.split("/")[1].split(".")[0]
try:
data[pkey] = torch.load(pt_path)
except:
print("reload...")
time.sleep(1)
data[pkey] = torch.load(pt_path)
while os.path.exists("workspace/meta.json") == False:
time.sleep(0.2)
with open("workspace/meta.json", "r") as read_content:
meta = json.load(read_content)
data['frame_num'] = meta['frame_num']
data['video_name'] = meta['video_name']
data['audio_path'] = meta['audio_path']
result = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size,rank=args.rank, p_num=args.p_num, bf16=args.bf16)
#os.remove('target_semantics.pt')
shutil.rmtree("logs", ignore_errors=True)
shutil.rmtree("enhancer_logs", ignore_errors=True)
shutil.rmtree("workspace", ignore_errors=True)
timestamp = datetime.timestamp(datetime.now())
end_time = time.time()
print("0004: render+enhance...")
print(end_time - start_time)
start_time = end_time
print("generate video done...", timestamp)
shutil.move(result, save_dir+'.mp4')
print('The generated video is named:', save_dir+'.mp4')
if not args.verbose:
shutil.rmtree(save_dir)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--driven_audio", default='./examples/driven_audio/bus_chinese.wav', help="path to driven audio")
parser.add_argument("--source_image", default='./examples/source_image/full_body_1.png', help="path to source image")
parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking")
parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose")
parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output")
parser.add_argument("--result_dir", default='./results', help="path to output")
parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)")
parser.add_argument("--batch_size", type=int, default=2, help="the batch size of facerender")
parser.add_argument("--size", type=int, default=256, help="the image size of the facerender")
parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender")
parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ")
parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user")
parser.add_argument('--input_roll', nargs='+', type=int, default=None, help="the input roll degree of the user")
parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [gfpgan, RestoreFormer]")
parser.add_argument('--background_enhancer', type=str, default=None, help="background enhancer, [realesrgan]")
parser.add_argument("--cpu", dest="cpu", action="store_true")
parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks")
parser.add_argument("--still", action="store_true", help="can crop back to the original videos for the full body aniamtion")
parser.add_argument("--preprocess", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help="how to preprocess the images" )
parser.add_argument("--verbose",action="store_true", help="saving the intermedia output or not" )
parser.add_argument("--old_version",action="store_true", help="use the pth other than safetensor version" )
# net structure and parameters
parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless')
parser.add_argument('--init_path', type=str, default=None, help='Useless')
parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc')
parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')
parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
# default renderer parameters
parser.add_argument('--focal', type=float, default=1015.)
parser.add_argument('--center', type=float, default=112.)
parser.add_argument('--camera_d', type=float, default=10.)
parser.add_argument('--z_near', type=float, default=5.)
parser.add_argument('--z_far', type=float, default=15.)
# distributed infer
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--p_num', type=int, default=1)
# bf16
parser.add_argument('--bf16', dest="bf16", action="store_true", help="whether to use bf16")
# facerender model: refer to https://github.com/OpenTalker/SadTalker/discussions/457
parser.add_argument("--facerender", default='facevid2vid', choices=['pirender', 'facevid2vid'])
args = parser.parse_args()
if torch.cuda.is_available() and not args.cpu:
args.device = "cuda"
elif platform.system() == 'Darwin' and args.facerender == 'pirender': # macos
args.device = "mps"
else:
args.device = "cpu"
main(args)