-
Notifications
You must be signed in to change notification settings - Fork 1
/
demo_2.py
130 lines (102 loc) · 3.26 KB
/
demo_2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import argparse
import copy
import json
import os
import sys
try:
import apex
except:
print("No APEX!")
import numpy as np
import torch
import yaml
from det3d import __version__, torchie
from det3d.datasets import build_dataloader, build_dataset
from det3d.models import build_detector
from det3d.torchie import Config
from det3d.torchie.apis import (
batch_processor,
build_optimizer,
get_root_logger,
init_dist,
set_random_seed,
train_detector,
)
from det3d.torchie.trainer import load_checkpoint
import pickle
import time
from matplotlib import pyplot as plt
from det3d.torchie.parallel import collate, collate_kitti
from torch.utils.data import DataLoader
import matplotlib.cm as cm
import subprocess
import cv2
from tools.demo_utils import visual
from collections import defaultdict
def convert_box(info):
boxes = info["gt_boxes"].astype(np.float32)
names = info["gt_names"]
assert len(boxes) == len(names)
detection = {}
detection['box3d_lidar'] = boxes
# dummy value
detection['label_preds'] = np.zeros(len(boxes))
detection['scores'] = np.ones(len(boxes))
return detection
def main():
cfg = Config.fromfile('configs/centerpoint/nusc_centerpoint_pp_02voxel_circle_nms_demo_2.py')
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
dataset = build_dataset(cfg.data.val)
data_loader = DataLoader(
dataset,
batch_size=1,
sampler=None,
shuffle=False,
num_workers=8,
collate_fn=collate_kitti,
pin_memory=False,
)
checkpoint = load_checkpoint(model, 'work_dirs/centerpoint_pillar_512_demo/last.pth', map_location="cpu")
model.eval()
model = model.cuda()
cpu_device = torch.device("cpu")
points_list = []
gt_annos = []
detections = []
for i, data_batch in enumerate(data_loader):
info = dataset._nusc_infos[i]
gt_annos.append(convert_box(info))
points = data_batch['points'][:, 1:4].cpu().numpy()
with torch.no_grad():
outputs = batch_processor(
model, data_batch, train_mode=False, local_rank=0,
)
for output in outputs:
for k, v in output.items():
if k not in [
"metadata",
]:
output[k] = v.to(cpu_device)
detections.append(output)
points_list.append(points.T)
print('Done model inference. Please wait a minute, the matplotlib is a little slow...')
for i in range(len(points_list)):
visual(points_list[i], gt_annos[i], detections[i], i)
print("Rendered Image {}".format(i))
image_folder = 'demo'
video_name = 'video.avi'
images = [img for img in os.listdir(image_folder) if img.endswith(".png")]
images.sort()
frame = cv2.imread(os.path.join(image_folder, images[0]))
height, width, layers = frame.shape
video = cv2.VideoWriter(video_name, 0, 1, (width,height))
cv2_images = []
for image in images:
cv2_images.append(cv2.imread(os.path.join(image_folder, image)))
for img in cv2_images:
video.write(img)
cv2.destroyAllWindows()
video.release()
print("Successfully save video in the main folder")
if __name__ == "__main__":
main()