-
Notifications
You must be signed in to change notification settings - Fork 6
/
equivariance_detection.py
118 lines (97 loc) · 5.12 KB
/
equivariance_detection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os, sys
import random
import yaml
from argparse import ArgumentParser
from time import gmtime, strftime
from shutil import copy
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from frames_dataset import FramesDataset
from frames_dataset import DatasetRepeater
from logger import Visualizer
from modules.keypoint_detector import KPDetector
from modules.model import Transform
if __name__ == "__main__":
if sys.version_info[0] < 3:
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
parser = ArgumentParser()
parser.add_argument("--config", required=True, help="path to dam config")
parser.add_argument("--config_hdam", default=None, help="path to hdam config")
parser.add_argument("--equi_threshold", default=0.25, type=float, help="path to log into")
parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
opt = parser.parse_args()
with open(opt.config) as f:
config = yaml.load(f)
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
kp_detector.eval()
if torch.cuda.is_available():
# kp_detector.to(opt.device_ids[0])
kp_detector.cuda()
checkpoint = torch.load(opt.checkpoint, map_location='cuda:{}'.format(0))
state_dict = checkpoint['kp_detector']
kp_detector.load_state_dict(state_dict)
dataset = FramesDataset(is_train=True, **config['dataset_params'])
dataset = DatasetRepeater(dataset, 150)
dataloader = DataLoader(dataset, batch_size=config['train_params']['batch_size'], shuffle=False, num_workers=6, drop_last=True)
# print(dataloader.__len__())
equivariance_list = []
equivariance_jacobian_list = []
# print(kp_detector.scale_factor)
i = 0
for x in tqdm(dataloader):
if i >=100:
break
else:
i = i+1
x['source'] = x['source'].cuda()
kp_source = kp_detector(x['source'])
transform = Transform(x['source'].shape[0],**config['train_params']['transform_params'])
transformed_frame = transform.transform_frame(x['source'][:,0:3,:,:])
transformed_kp = kp_detector(transformed_frame)
if config['train_params']['loss_weights']['root_motion_kp_distance'] != 0 or config['train_params']['loss_weights']['root_motion_sub_root_distance'] != 0:
num_kp = config['model_params']['common_params']['num_kp']
num_root_kp = config['model_params']['generator_params']['num_root_kp']
kp_source['value'] = kp_source['value'][:,0:num_kp-num_root_kp,:]
kp_source['jacobian'] = kp_source['jacobian'][:,0:num_kp-num_root_kp,:,:]
transformed_kp['value'] = transformed_kp['value'][:,0:num_kp-num_root_kp,:]
transformed_kp['jacobian'] = transformed_kp['jacobian'][:,0:num_kp-num_root_kp,:,:]
## Value loss part
value = torch.abs(kp_source['value'] - transform.warp_coordinates(transformed_kp['value']))
value = value.mean(0).mean(-1)
value = value.unsqueeze(0)
equivariance_value = config['train_params']['loss_weights']['equivariance_value'] * value
## jacobian loss part
jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']),
transformed_kp['jacobian'])
normed_source = torch.inverse(kp_source['jacobian'])
normed_transformed = jacobian_transformed
value = torch.matmul(normed_source, normed_transformed)
eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())
value = torch.abs(eye - value).mean(0).mean(-1).mean(-1)
value = value.unsqueeze(0)
equivariance_jacobian = config['train_params']['loss_weights']['equivariance_jacobian'] * value
equivariance_list.append(equivariance_value.data.cpu().numpy())
equivariance_jacobian_list.append(equivariance_jacobian.data.cpu().numpy())
equivariance_list = np.concatenate(equivariance_list, axis=0)
equivariance_jacobian_list = np.concatenate(equivariance_jacobian_list, axis=0)
equivariance = np.mean(equivariance_list, axis=0)
equivariance_jacobian = np.mean(equivariance_jacobian_list, axis=0)
print("equivariance loss: %s" % equivariance)
print("equivariance_jacobian loss: %s" % equivariance_jacobian)
# print(equivariance_list.shape)
# print(equivariance_jacobian_list.shape)
with open(opt.config_hdam) as f:
config_hdam = yaml.load(f)
ignore_kp_list = []
for i in range(len(equivariance)):
if equivariance[i] > opt.equi_threshold:
ignore_kp_list.append(i)
print( config_hdam['model_params']['kp_detector_params']['ignore_kp_list'], config_hdam['visualizer_params']['ignore_kp_list'])
config_hdam['model_params']['kp_detector_params']['ignore_kp_list'] = ignore_kp_list
config_hdam['visualizer_params']['ignore_kp_list'] = ignore_kp_list.copy()
print("writing the ingnore_kp_list to the hdam config: %s" % ignore_kp_list)
with open(opt.config_hdam, "w") as f:
yaml.dump(config_hdam, f)