forked from yzhou359/MakeItTalk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_train_image_translation.py
90 lines (70 loc) · 3.23 KB
/
main_train_image_translation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
# Copyright 2020 Adobe
# All Rights Reserved.
# NOTICE: Adobe permits you to use, modify, and distribute this file in
# accordance with the terms of the Adobe license agreement accompanying
# it.
"""
import sys
sys.path.append('thirdparty/AdaptiveWingLoss')
import os, glob
import numpy as np
import cv2
import argparse
from src.dataset.image_translation import landmark_extraction, landmark_image_to_data
from approaches.train_image_translation import Image_translation_block
import platform
import torch
if platform.release() == '4.4.0-83-generic':
src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation/raw_fl3d'
mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4'
jpg_dir = r'img_output'
ckpt_dir = r'img_output'
log_dir = r'img_output'
else: # 3.10.0-957.21.2.el7.x86_64
# root = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_imagetranslation'
root = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_imagetranslation'
src_dir = os.path.join(root, 'raw_fl3d')
# mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4'
mp4_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_mp4'
jpg_dir = os.path.join(root, 'tmp_v')
ckpt_dir = os.path.join(root, 'ckpt')
log_dir = os.path.join(root, 'log')
''' Step 1. Data preparation '''
# landmark extraction
# landmark_extraction(int(sys.argv[1]), int(sys.argv[2]))
# save image data ahead -> saved file too large, will create data online
# landmark_image_to_data(0, 0, show=False)
''' Step 2. Train the network '''
parser = argparse.ArgumentParser()
parser.add_argument('--nepoch', type=int, default=150, help='number of epochs to train for')
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
parser.add_argument('--num_frames', type=int, default=1, help='')
parser.add_argument('--num_workers', type=int, default=4, help='number of frames extracted from each video')
parser.add_argument('--lr', type=float, default=0.0001, help='')
parser.add_argument('--write', default=False, action='store_true')
parser.add_argument('--train', default=False, action='store_true')
parser.add_argument('--name', type=str, default='tmp')
parser.add_argument('--test_speed', default=False, action='store_true')
parser.add_argument('--jpg_dir', type=str, default=jpg_dir)
parser.add_argument('--ckpt_dir', type=str, default=ckpt_dir)
parser.add_argument('--log_dir', type=str, default=log_dir)
parser.add_argument('--jpg_freq', type=int, default=50, help='')
parser.add_argument('--ckpt_last_freq', type=int, default=1000, help='')
parser.add_argument('--ckpt_epoch_freq', type=int, default=1, help='')
parser.add_argument('--load_G_name', type=str, default='')
parser.add_argument('--use_vox_dataset', type=str, default='raw')
parser.add_argument('--add_audio_in', default=False, action='store_true')
parser.add_argument('--comb_fan_awing', default=False, action='store_true')
parser.add_argument('--fan_2or3D', type=str, default='3D')
parser.add_argument('--single_test', type=str, default='')
opt_parser = parser.parse_args()
model = Image_translation_block(opt_parser)
if(opt_parser.single_test != ''):
with torch.no_grad():
model.single_test()
if(opt_parser.train):
model.train()
else:
with torch.no_grad():
model.test()