-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathhubconf.py
81 lines (69 loc) · 2.51 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
dependencies = ['torch', 'torchvision']
import torch
from src.backbones import ResNet, DinoV2
from src.boq import BoQ
class VPRModel(torch.nn.Module):
def __init__(self,
backbone,
aggregator):
super().__init__()
self.backbone = backbone
self.aggregator = aggregator
def forward(self, x):
x = self.backbone(x)
x, attns = self.aggregator(x)
return x, attns
AVAILABLE_BACKBONES = {
# this list will be extended
# "resnet18": [8192 , 4096],
"resnet50": [16384],
"dinov2": [12288],
}
MODEL_URLS = {
"resnet50_16384": "https://github.com/amaralibey/Bag-of-Queries/releases/download/v1.0/resnet50_16384.pth",
"dinov2_12288": "https://github.com/amaralibey/Bag-of-Queries/releases/download/v1.0/dinov2_12288.pth",
# "resnet50_4096": "",
}
def get_trained_boq(backbone_name="resnet50", output_dim=16384):
if backbone_name not in AVAILABLE_BACKBONES:
raise ValueError(f"backbone_name should be one of {list(AVAILABLE_BACKBONES.keys())}")
try:
output_dim = int(output_dim)
except:
raise ValueError(f"output_dim should be an integer, not a {type(output_dim)}")
if output_dim not in AVAILABLE_BACKBONES[backbone_name]:
raise ValueError(f"output_dim should be one of {AVAILABLE_BACKBONES[backbone_name]}")
if "dinov2" in backbone_name:
# load the backbone
backbone = DinoV2()
# load the aggregator
aggregator = BoQ(
in_channels=backbone.out_channels, # make sure the backbone has out_channels attribute
proj_channels=384,
num_queries=64,
num_layers=2,
row_dim=output_dim//384, # 32 for dinov2
)
elif "resnet" in backbone_name:
backbone = ResNet(
backbone_name=backbone_name,
crop_last_block=True,
)
aggregator = BoQ(
in_channels=backbone.out_channels, # make sure the backbone has out_channels attribute
proj_channels=512,
num_queries=64,
num_layers=2,
row_dim=output_dim//512, # 32 for resnet
)
vpr_model = VPRModel(
backbone=backbone,
aggregator=aggregator
)
vpr_model.load_state_dict(
torch.hub.load_state_dict_from_url(
MODEL_URLS[f"{backbone_name}_{output_dim}"],
map_location=torch.device('cpu')
)
)
return vpr_model