-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
37 lines (28 loc) · 1.39 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"""
测试大图 并显示出来
"""
import torch
import argparse
from utils import read_yml
from road_extractor import build_road_extractor
from utils.predict_large_img import predict
def parse_args():
parser = argparse.ArgumentParser(description='Analysis road extractor model')
parser.add_argument('--config', default='configs/LRDNet_RNBD.yml', help='train config file path')
parser.add_argument('--checkpoints', type=str, default='./work_dir/2022-10-12-19_50_10/model99_resume.pth', help='the dir of checkpoints')
parser.add_argument('--path_img', type=str, default='predict\RNBD\8.png', help='the path of image')
parser.add_argument('--patch_size', type=int, default=768, help='the size of split patch')
parser.add_argument('--stride', type=int, default=640, help='the stride of split img')
parser.add_argument('--path_save', type=str, default='predict/result', help='the path of save result')
parser.add_argument("--device", type=int, default=0)
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = read_yml(args.config)
torch.cuda.set_device(args.device)
model = build_road_extractor(cfg_model=cfg['model']).cuda()
model.load_state_dict(torch.load(args.checkpoints, map_location='cuda:0')['state_dict'])
predict(args.path_img, model, args.path_save, args.patch_size, args.stride)
if __name__ == '__main__':
main()