forked from Megvii-BaseDetection/YOLOX
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathyolox_voc_s.py
60 lines (49 loc) · 1.96 KB
/
yolox_voc_s.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# encoding: utf-8
import os
from yolox.data import get_yolox_datadir
from yolox.exp import Exp as MyExp
class Exp(MyExp):
def __init__(self):
super(Exp, self).__init__()
self.num_classes = 8 # TODO: KITTI class is 6
self.depth = 0.33
self.width = 0.50
self.warmup_epochs = 1
# ---------- transform config ------------ #
self.mosaic_prob = 1.0
self.mixup_prob = 1.0
self.flip_prob = 0.5
self.hsv_prob = 1.0
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
def get_dataset(self, cache: bool, cache_type: str = "disk"):
from yolox.data import VOCDetection, TrainTransform
return VOCDetection(
data_dir=os.path.join(get_yolox_datadir(), "CUSTOMER"), # TODO: CUSTOMER to KITTI
image_sets=[('train')],
img_size=self.input_size,
preproc=TrainTransform(
max_labels=50,
flip_prob=self.flip_prob,
hsv_prob=self.hsv_prob),
# cache=True,
# cache_type="disk",
)
def get_eval_dataset(self, **kwargs):
from yolox.data import VOCDetection, ValTransform
legacy = kwargs.get("legacy", False)
return VOCDetection(
data_dir=os.path.join(get_yolox_datadir(), "CUSTOMER"), # TODO: CUSTOMER to KITTI
image_sets=[('test')],
img_size=self.test_size,
preproc=ValTransform(legacy=legacy),
)
def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
from yolox.evaluators import VOCEvaluator
return VOCEvaluator(
dataloader=self.get_eval_loader(batch_size, is_distributed,
testdev=testdev, legacy=legacy),
img_size=self.test_size,
confthre=self.test_conf,
nmsthre=self.nmsthre,
num_classes=self.num_classes,
)