forked from hkchengrex/STCN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_davis.py
108 lines (87 loc) · 3.67 KB
/
eval_davis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
from os import path
import time
from argparse import ArgumentParser
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image
from model.eval_network import STCN
from dataset.davis_test_dataset import DAVISTestDataset
from util.tensor_util import unpad
from inference_core import InferenceCore
from progressbar import progressbar
"""
Arguments loading
"""
parser = ArgumentParser()
parser.add_argument('--model', default='saves/stcn.pth')
parser.add_argument('--davis_path', default='../DAVIS/2017')
parser.add_argument('--output')
parser.add_argument('--split', help='val/testdev', default='val')
parser.add_argument('--top', type=int, default=20)
parser.add_argument('--amp', action='store_true')
parser.add_argument('--mem_every', default=5, type=int)
parser.add_argument('--include_last', help='include last frame as temporary memory?', action='store_true')
args = parser.parse_args()
davis_path = args.davis_path
out_path = args.output
# Simple setup
os.makedirs(out_path, exist_ok=True)
torch.autograd.set_grad_enabled(False)
# Setup Dataset
if args.split == 'val':
palette = Image.open(path.expanduser(davis_path + '/trainval/Annotations/480p/blackswan/00000.png')).getpalette()
test_dataset = DAVISTestDataset(davis_path+'/trainval', imset='2017/val.txt')
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)
elif args.split == 'testdev':
palette = Image.open(path.expanduser(davis_path + '/test-dev/Annotations/480p/salsa/00000.png')).getpalette()
test_dataset = DAVISTestDataset(davis_path+'/test-dev', imset='2017/test-dev.txt')
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)
else:
raise NotImplementedError
# Load our checkpoint
prop_saved = torch.load(args.model)
top_k = args.top
prop_model = STCN().cuda().eval()
prop_model.load_state_dict(prop_saved)
total_process_time = 0
total_frames = 0
# Start eval
for data in progressbar(test_loader, max_value=len(test_loader), redirect_stdout=True):
with torch.cuda.amp.autocast(enabled=args.amp):
rgb = data['rgb'].cuda()
msk = data['gt'][0].cuda()
info = data['info']
name = info['name'][0]
k = len(info['labels'][0])
size = info['size_480p']
torch.cuda.synchronize()
process_begin = time.time()
processor = InferenceCore(prop_model, rgb, k, top_k=top_k,
mem_every=args.mem_every, include_last=args.include_last)
processor.interact(msk[:,0], 0, rgb.shape[1])
# Do unpad -> upsample to original size
out_masks = torch.zeros((processor.t, 1, *size), dtype=torch.uint8, device='cuda')
for ti in range(processor.t):
prob = unpad(processor.prob[:,ti], processor.pad)
prob = F.interpolate(prob, size, mode='bilinear', align_corners=False)
out_masks[ti] = torch.argmax(prob, dim=0)
out_masks = (out_masks.detach().cpu().numpy()[:,0]).astype(np.uint8)
torch.cuda.synchronize()
total_process_time += time.time() - process_begin
total_frames += out_masks.shape[0]
# Save the results
this_out_path = path.join(out_path, name)
os.makedirs(this_out_path, exist_ok=True)
for f in range(out_masks.shape[0]):
img_E = Image.fromarray(out_masks[f])
img_E.putpalette(palette)
img_E.save(os.path.join(this_out_path, '{:05d}.png'.format(f)))
del rgb
del msk
del processor
print('Total processing time: ', total_process_time)
print('Total processed frames: ', total_frames)
print('FPS: ', total_frames / total_process_time)