forked from mhamilton723/FeatUp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hubconf.py
67 lines (49 loc) · 2.19 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# hubconf.py
import torch
from featup.featurizers.util import get_featurizer
from featup.layers import ChannelNorm
from featup.upsamplers import get_upsampler
from torch.nn import Module
dependencies = ['torch', 'torchvision', 'PIL', 'featup'] # List any dependencies here
class UpsampledBackbone(Module):
def __init__(self, model_name, use_norm):
super().__init__()
model, patch_size, self.dim = get_featurizer(model_name, "token", num_classes=1000)
if use_norm:
self.model = torch.nn.Sequential(model, ChannelNorm(self.dim))
else:
self.model = model
self.upsampler = get_upsampler("jbu_stack", self.dim)
def forward(self, image):
return self.upsampler(self.model(image), image)
def _load_backbone(pretrained, use_norm, model_name):
"""
The function that will be called by Torch Hub users to instantiate your model.
Args:
pretrained (bool): If True, returns a model pre-loaded with weights.
Returns:
An instance of your model.
"""
model = UpsampledBackbone(model_name, use_norm)
if pretrained:
# Define how you load your pretrained weights here
# For example:
if use_norm:
exp_dir = ""
else:
exp_dir = "no_norm/"
checkpoint_url = f"https://marhamilresearch4.blob.core.windows.net/feature-upsampling-public/pretrained/{exp_dir}{model_name}_jbu_stack_cocostuff.ckpt"
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)["state_dict"]
state_dict = {k: v for k, v in state_dict.items() if "scale_net" not in k and "downsampler" not in k}
model.load_state_dict(state_dict, strict=False)
return model
def vit(pretrained=True, use_norm=True):
return _load_backbone(pretrained, use_norm, "vit")
def dino16(pretrained=True, use_norm=True):
return _load_backbone(pretrained, use_norm, "dino16")
def clip(pretrained=True, use_norm=True):
return _load_backbone(pretrained, use_norm, "clip")
def dinov2(pretrained=True, use_norm=True):
return _load_backbone(pretrained, use_norm, "dinov2")
def resnet50(pretrained=True, use_norm=True):
return _load_backbone(pretrained, use_norm, "resnet50")