forked from mbzuai-metaverse/XMem2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinteractive_demo.py
100 lines (77 loc) · 4.21 KB
/
interactive_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
A simple user interface for XMem
"""
import os
# fix for Windows
if 'QT_QPA_PLATFORM_PLUGIN_PATH' not in os.environ:
os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = ''
import sys
from argparse import ArgumentParser
import torch
from model.network import XMem
from inference.interact.s2m_controller import S2MController
from inference.interact.fbrs_controller import FBRSController
from inference.interact.s2m.s2m_network import deeplabv3plus_resnet50 as S2M
from PyQt5.QtWidgets import QApplication
from inference.interact.gui import App
from inference.interact.resource_manager import ResourceManager
torch.set_grad_enabled(False)
if __name__ == '__main__':
# Arguments parsing
parser = ArgumentParser()
parser.add_argument('--model', default='./saves/XMem.pth')
parser.add_argument('--s2m_model', default='saves/s2m.pth')
parser.add_argument('--fbrs_model', default='saves/fbrs.pth')
"""
Priority 1: If a "images" folder exists in the workspace, we will read from that directory
Priority 2: If --images is specified, we will copy/resize those images to the workspace
Priority 3: If --video is specified, we will extract the frames to the workspace (in an "images" folder) and read from there
In any case, if a "masks" folder exists in the workspace, we will use that to initialize the mask
That way, you can continue annotation from an interrupted run as long as the same workspace is used.
"""
parser.add_argument('--images', help='Folders containing input images.', default=None)
parser.add_argument('--video', help='Video file readable by OpenCV.', default=None)
parser.add_argument('--workspace', help='directory for storing buffered images (if needed) and output masks', default=None)
parser.add_argument('--buffer_size', help='Correlate with CPU memory consumption', type=int, default=100)
parser.add_argument('--num_objects', type=int, default=None)
# Long-memory options
# Defaults. Some can be changed in the GUI.
parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10)
parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5)
parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time',
type=int, default=10000)
parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128)
parser.add_argument('--top_k', type=int, default=30)
parser.add_argument('--mem_every', type=int, default=10)
parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1)
parser.add_argument('--no_amp', help='Turn off AMP', action='store_true')
parser.add_argument('--size', default=480, type=int,
help='Resize the shorter side to this size. -1 to use original resolution. ')
args = parser.parse_args()
config = vars(args)
config['enable_long_term'] = True
config['enable_long_term_count_usage'] = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'
with torch.cuda.amp.autocast(enabled=not args.no_amp):
# Load our checkpoint
network = XMem(config, args.model, pretrained_key_encoder=False, pretrained_value_encoder=False).to(device).eval()
# Loads the S2M model
if args.s2m_model is not None:
s2m_saved = torch.load(args.s2m_model)
s2m_model = S2M().to(device).eval()
s2m_model.load_state_dict(s2m_saved)
else:
s2m_model = None
# Manages most IO
config['num_objects_default_value'] = 1
resource_manager = ResourceManager(config)
num_objects = resource_manager.num_objects
config['num_objects'] = num_objects
s2m_controller = S2MController(s2m_model, num_objects, ignore_class=255)
if args.fbrs_model is not None:
fbrs_controller = FBRSController(args.fbrs_model)
else:
fbrs_controller = None
app = QApplication(sys.argv)
ex = App(network, resource_manager, s2m_controller, fbrs_controller, config)
sys.exit(app.exec_())