-
Notifications
You must be signed in to change notification settings - Fork 1
/
superpixel_segmenter.py
99 lines (92 loc) · 3.45 KB
/
superpixel_segmenter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
from multiprocessing import Pool
from skimage.segmentation import slic
from nuscenes.nuscenes import NuScenes
def compute_slic(cam_token):
cam = nusc.get("sample_data", cam_token)
im = Image.open(os.path.join(nusc.dataroot, cam["filename"]))
segments_slic = slic(
im, n_segments=150, compactness=6, sigma=3.0, start_label=0
).astype(np.uint8)
im = Image.fromarray(segments_slic)
im.save(
"./superpixels/nuscenes/superpixels_slic/" + cam["token"] + ".png"
)
def compute_slic_30(cam_token):
cam = nusc.get("sample_data", cam_token)
im = Image.open(os.path.join(nusc.dataroot, cam["filename"]))
segments_slic = slic(
im, n_segments=30, compactness=6, sigma=3.0, start_label=0
).astype(np.uint8)
im = Image.fromarray(segments_slic)
im.save(
"./superpixels/nuscenes/superpixels_slic_30/" + cam["token"] + ".png"
)
def pool_func(scene):
current_sample_token = scene["first_sample_token"]
while current_sample_token != "":
# print("### current_sample_token = ", current_sample_token)
current_sample = nusc.get("sample", current_sample_token)
if args.model == "minkunet":
func = compute_slic
elif args.model == "voxelnet":
func = compute_slic_30
# p.map(
# func,
# [
# current_sample["data"][camera_name]
# for camera_name in camera_list
# ],
# )
for camera_name in camera_list:
func(current_sample["data"][camera_name])
current_sample_token = current_sample["next"]
if __name__ == "__main__":
nuscenes_path = "datasets/nuscenes"
parser = argparse.ArgumentParser(description="arg parser")
parser.add_argument(
"--model", type=str, default="minkunet", help="specify the model targeted, either minkunet or voxelnet"
)
assert os.path.exists(nuscenes_path), f"nuScenes not found in {nuscenes_path}"
args = parser.parse_args()
assert args.model in ["minkunet", "voxelnet"]
nusc = NuScenes(
version="v1.0-trainval", dataroot=nuscenes_path, verbose=False
)
os.makedirs("superpixels/nuscenes/superpixels_slic/", exist_ok=True)
camera_list = [
"CAM_FRONT",
"CAM_FRONT_RIGHT",
"CAM_BACK_RIGHT",
"CAM_BACK",
"CAM_BACK_LEFT",
"CAM_FRONT_LEFT",
]
# with Pool(32) as p:
# for scene_idx in tqdm(range(len(nusc.scene))):
# scene = nusc.scene[scene_idx]
# current_sample_token = scene["first_sample_token"]
# while current_sample_token != "":
# current_sample = nusc.get("sample", current_sample_token)
# if args.model == "minkunet":
# func = compute_slic
# elif args.model == "voxelnet":
# func = compute_slic_30
# p.map(
# func,
# [
# current_sample["data"][camera_name]
# for camera_name in camera_list
# ],
# )
# current_sample_token = current_sample["next"]
pool = Pool(128)
for scene_idx in tqdm(range(len(nusc.scene))):
pool.apply_async(func=pool_func, args=(nusc.scene[scene_idx],)) # fun_02的入参为fun_01的返回值
pool.close()
pool.join()
print('done')