forked from JDAI-CV/FaceX-Zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
extract_feature.py
51 lines (49 loc) · 2.41 KB
/
extract_feature.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
@author: Jun Wang
@date: 20201012
@contact: [email protected]
"""
import sys
import yaml
import argparse
import torch
from torch.utils.data import Dataset, DataLoader
from utils.model_loader import ModelLoader
from utils.extractor.feature_extractor import CommonExtractor
sys.path.append('..')
from data_processor.test_dataset import CommonTestDataset
from backbone.backbone_def import BackboneFactory
if __name__ == '__main__':
conf = argparse.ArgumentParser(description='extract features for megaface.')
conf.add_argument("--data_conf_file", type = str,
help = "The path of data_conf.yaml.")
conf.add_argument("--backbone_type", type = str,
help = "Resnet, Mobilefacenets.")
conf.add_argument("--backbone_conf_file", type = str,
help = "The path of backbone_conf.yaml.")
conf.add_argument('--batch_size', type = int, default = 1024)
conf.add_argument('--model_path', type = str, default = 'mv_epoch_8.pt',
help = 'The path of model')
conf.add_argument('--feats_root', type = str, default = 'mv_epoch_8.pt',
help = 'The path for feature save.')
args = conf.parse_args()
with open(args.data_conf_file) as f:
data_conf = yaml.load(f)['MegaFace']
croped_face_folder = data_conf['croped_face_folder']
image_list_file = data_conf['image_list_file']
megaface_mask = data_conf['megaface-mask']
masked_croped_face_folder = data_conf['masked_croped_face_folder']
masked_image_list_file = data_conf['masked_image_list_file']
data_loader = DataLoader(CommonTestDataset(croped_face_folder, image_list_file, False),
batch_size=args.batch_size, num_workers=4, shuffle=False)
# define model.
backbone_factory = BackboneFactory(args.backbone_type, args.backbone_conf_file)
model_loader = ModelLoader(backbone_factory)
model = model_loader.load_model(args.model_path)
# extract feature.
feature_extractor = CommonExtractor('cuda:0')
feature_extractor.extract_offline(args.feats_root, model, data_loader)
if megaface_mask == 1:
data_loader = DataLoader(CommonTestDataset(masked_croped_face_folder, masked_image_list_file, False),
batch_size=args.batch_size, num_workers=4, shuffle=False)
feature_extractor.extract_offline(args.feats_root, model, data_loader)