forked from cvg/Hierarchical-Localization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
r2d2.py
56 lines (48 loc) · 1.74 KB
/
r2d2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import sys
from pathlib import Path
import torchvision.transforms as tvf
from ..utils.base_model import BaseModel
r2d2_path = Path(__file__).parent / "../../third_party/r2d2"
sys.path.append(str(r2d2_path))
from extract import load_network, NonMaxSuppression, extract_multiscale
class R2D2(BaseModel):
default_conf = {
'model_name': 'r2d2_WASF_N16.pt',
'max_keypoints': 5000,
'scale_factor': 2**0.25,
'min_size': 256,
'max_size': 1024,
'min_scale': 0,
'max_scale': 1,
'reliability_threshold': 0.7,
'repetability_threshold': 0.7,
}
required_inputs = ['image']
def _init(self, conf):
model_fn = r2d2_path / "models" / conf['model_name']
self.norm_rgb = tvf.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
self.net = load_network(model_fn)
self.detector = NonMaxSuppression(
rel_thr=conf['reliability_threshold'],
rep_thr=conf['repetability_threshold']
)
def _forward(self, data):
img = data['image']
img = self.norm_rgb(img)
xys, desc, scores = extract_multiscale(
self.net, img, self.detector,
scale_f=self.conf['scale_factor'],
min_size=self.conf['min_size'],
max_size=self.conf['max_size'],
min_scale=self.conf['min_scale'],
max_scale=self.conf['max_scale'],
)
idxs = scores.argsort()[-self.conf['max_keypoints'] or None:]
xy = xys[idxs, :2]
desc = desc[idxs].t()
scores = scores[idxs]
pred = {'keypoints': xy[None],
'descriptors': desc[None],
'scores': scores[None]}
return pred