-
Notifications
You must be signed in to change notification settings - Fork 41
/
main.py
109 lines (78 loc) · 3.24 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import numpy as np
import cv2
import argparse
import yaml
import logging
from utils.tools import plot_keypoints
from DataLoader import create_dataloader
from Detectors import create_detector
from Matchers import create_matcher
from VO.VisualOdometry import VisualOdometry, AbosluteScaleComputer
def keypoints_plot(img, vo):
if img.shape[2] == 1:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
return plot_keypoints(img, vo.kptdescs["cur"]["keypoints"], vo.kptdescs["cur"]["scores"])
class TrajPlotter(object):
def __init__(self):
self.errors = []
self.traj = np.zeros((600, 600, 3), dtype=np.uint8)
pass
def update(self, est_xyz, gt_xyz):
x, z = est_xyz[0], est_xyz[2]
gt_x, gt_z = gt_xyz[0], gt_xyz[2]
est = np.array([x, z]).reshape(2)
gt = np.array([gt_x, gt_z]).reshape(2)
error = np.linalg.norm(est - gt)
self.errors.append(error)
avg_error = np.mean(np.array(self.errors))
# === drawer ==================================
# each point
draw_x, draw_y = int(x) + 290, int(z) + 90
true_x, true_y = int(gt_x) + 290, int(gt_z) + 90
# draw trajectory
cv2.circle(self.traj, (draw_x, draw_y), 1, (0, 255, 0), 1)
cv2.circle(self.traj, (true_x, true_y), 1, (0, 0, 255), 2)
cv2.rectangle(self.traj, (10, 20), (600, 80), (0, 0, 0), -1)
# draw text
text = "[AvgError] %2.4fm" % (avg_error)
cv2.putText(self.traj, text, (20, 40),
cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1, 8)
return self.traj
def run(args):
with open(args.config, 'r') as f:
config = yaml.load(f)
# create dataloader
loader = create_dataloader(config["dataset"])
# create detector
detector = create_detector(config["detector"])
# create matcher
matcher = create_matcher(config["matcher"])
absscale = AbosluteScaleComputer()
traj_plotter = TrajPlotter()
# log
fname = args.config.split('/')[-1].split('.')[0]
log_fopen = open("results/" + fname + ".txt", mode='a')
vo = VisualOdometry(detector, matcher, loader.cam)
for i, img in enumerate(loader):
gt_pose = loader.get_cur_pose()
R, t = vo.update(img, absscale.update(gt_pose))
# === log writer ==============================
print(i, t[0, 0], t[1, 0], t[2, 0], gt_pose[0, 3], gt_pose[1, 3], gt_pose[2, 3], file=log_fopen)
# === drawer ==================================
img1 = keypoints_plot(img, vo)
img2 = traj_plotter.update(t, gt_pose[:, 3])
cv2.imshow("keypoints", img1)
cv2.imshow("trajectory", img2)
if cv2.waitKey(10) == 27:
break
cv2.imwrite("results/" + fname + '.png', img2)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='python_vo')
parser.add_argument('--config', type=str, default='params/kitti_superpoint_supergluematch.yaml',
help='config file')
parser.add_argument('--logging', type=str, default='INFO',
help='logging level: NOTSET, DEBUG, INFO, WARNING, ERROR, CRITICAL')
args = parser.parse_args()
logging.basicConfig(level=logging._nameToLevel[args.logging])
run(args)