forked from shinianzhihou/ChangeDetection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_net.py
72 lines (57 loc) · 1.92 KB
/
eval_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import argparse
import pandas as pd
from configs import cfg
from utils.checkpoints import lcwo
from utils.eval import eval_model
from build import (
build_dataloader,
build_model,
)
def run_eval(cfg,save_max_imgs=False):
test_loader = build_dataloader(cfg, test=True)
model = build_model(cfg).to(cfg.MODEL.DEVICE)
cp_paths = get_cp_paths(cfg)
test_res = pd.DataFrame(columns=["checkpoint"]+cfg.EVAL.METRIC)
for idx,cp_path in enumerate(cp_paths):
if model._get_name() not in cp_path:
continue
model = lcwo(cp_path,model)
metric = eval_model(model,test_loader,cfg)
save_value = [cp_path] + [v.item() for k,v in metric.items()]
test_res.loc[test_res.shape[0]] = save_value
test_res.to_csv(os.path.join(cfg.EVAL.SAVE_PATH,cfg.EVAL.SAVE_NAME),index=False)
if cfg.EVAL.SAVE_IMAGES:
max_checkpoint = test_res.loc[test_res[cfg.EVAL.SAVE_BY_METRIC].idxmax()].to_dict()
model = lcwo(max_checkpoint["checkpoint"],model)
eval_model(model,test_loader,cfg,save_imgs=True)
return test_res
def get_cp_paths(cfg):
cp_root = cfg.EVAL.CHECKPOINTS_PATH
cp_names = os.listdir(cfg.EVAL.CHECKPOINTS_PATH)
cp_paths = [os.path.join(cp_root,cp_name) for cp_name in cp_names]
return cp_paths
def main():
parser = argparse.ArgumentParser(
description="eval models from checkpoints.")
parser.add_argument(
"-cfg",
"--config_file",
default="configs/homo/default.yaml",
metavar="FILE",
help="Path to config file",
type=str,
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
run_eval(cfg)
if __name__ == "__main__":
main()