diff --git a/eval_image_retrieval.py b/eval_image_retrieval.py index e60c71d2b..999f8c900 100644 --- a/eval_image_retrieval.py +++ b/eval_image_retrieval.py @@ -132,7 +132,7 @@ def config_qimname(cfg, i): model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") elif "xcit" in args.arch: - model = torch.hub.load('facebookresearch/xcit', args.arch, num_classes=0) + model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0) elif args.arch in torchvision_models.__dict__.keys(): model = torchvision_models.__dict__[args.arch](num_classes=0) else: diff --git a/eval_knn.py b/eval_knn.py index 2b9f0054f..fe99a2604 100644 --- a/eval_knn.py +++ b/eval_knn.py @@ -60,7 +60,7 @@ def extract_feature_pipeline(args): model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") elif "xcit" in args.arch: - model = torch.hub.load('facebookresearch/xcit', args.arch, num_classes=0) + model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0) elif args.arch in torchvision_models.__dict__.keys(): model = torchvision_models.__dict__[args.arch](num_classes=0) model.fc = nn.Identity() diff --git a/eval_linear.py b/eval_linear.py index 81eb94fd0..cdef16b47 100644 --- a/eval_linear.py +++ b/eval_linear.py @@ -41,7 +41,7 @@ def eval_linear(args): embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)) # if the network is a XCiT elif "xcit" in args.arch: - model = torch.hub.load('facebookresearch/xcit', args.arch, num_classes=0) + model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0) embed_dim = model.embed_dim # otherwise, we check if the architecture is in torchvision models elif args.arch in torchvision_models.__dict__.keys(): diff --git a/hubconf.py b/hubconf.py index ef1cdeaed..3709271ed 100644 --- a/hubconf.py +++ b/hubconf.py @@ -99,7 +99,7 @@ def dino_xcit_small_12_p16(pretrained=True, **kwargs): """ XCiT-Small-12/16 pre-trained with DINO. """ - model = torch.hub.load('facebookresearch/xcit', "xcit_small_12_p16", num_classes=0, **kwargs) + model = torch.hub.load('facebookresearch/xcit:main', "xcit_small_12_p16", num_classes=0, **kwargs) if pretrained: state_dict = torch.hub.load_state_dict_from_url( url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth", @@ -113,7 +113,7 @@ def dino_xcit_small_12_p8(pretrained=True, **kwargs): """ XCiT-Small-12/8 pre-trained with DINO. """ - model = torch.hub.load('facebookresearch/xcit', "xcit_small_12_p8", num_classes=0, **kwargs) + model = torch.hub.load('facebookresearch/xcit:main', "xcit_small_12_p8", num_classes=0, **kwargs) if pretrained: state_dict = torch.hub.load_state_dict_from_url( url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth", @@ -127,7 +127,7 @@ def dino_xcit_medium_24_p16(pretrained=True, **kwargs): """ XCiT-Medium-24/16 pre-trained with DINO. """ - model = torch.hub.load('facebookresearch/xcit', "xcit_medium_24_p16", num_classes=0, **kwargs) + model = torch.hub.load('facebookresearch/xcit:main', "xcit_medium_24_p16", num_classes=0, **kwargs) if pretrained: state_dict = torch.hub.load_state_dict_from_url( url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth", @@ -141,7 +141,7 @@ def dino_xcit_medium_24_p8(pretrained=True, **kwargs): """ XCiT-Medium-24/8 pre-trained with DINO. """ - model = torch.hub.load('facebookresearch/xcit', "xcit_medium_24_p8", num_classes=0, **kwargs) + model = torch.hub.load('facebookresearch/xcit:main', "xcit_medium_24_p8", num_classes=0, **kwargs) if pretrained: state_dict = torch.hub.load_state_dict_from_url( url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth", diff --git a/main_dino.py b/main_dino.py index 8e1756edf..cade9873d 100644 --- a/main_dino.py +++ b/main_dino.py @@ -44,7 +44,7 @@ def get_args_parser(): # Model parameters parser.add_argument('--arch', default='vit_small', type=str, choices=['vit_tiny', 'vit_small', 'vit_base', 'xcit', 'deit_tiny', 'deit_small'] \ - + torchvision_archs + torch.hub.list("facebookresearch/xcit"), + + torchvision_archs + torch.hub.list("facebookresearch/xcit:main"), help="""Name of architecture to train. For quick experiments with ViTs, we recommend using vit_tiny or vit_small.""") parser.add_argument('--patch_size', default=16, type=int, help="""Size in pixels @@ -166,10 +166,10 @@ def train_dino(args): teacher = vits.__dict__[args.arch](patch_size=args.patch_size) embed_dim = student.embed_dim # if the network is a XCiT - elif args.arch in torch.hub.list("facebookresearch/xcit"): - student = torch.hub.load('facebookresearch/xcit', args.arch, + elif args.arch in torch.hub.list("facebookresearch/xcit:main"): + student = torch.hub.load('facebookresearch/xcit:main', args.arch, pretrained=False, drop_path_rate=args.drop_path_rate) - teacher = torch.hub.load('facebookresearch/xcit', args.arch, pretrained=False) + teacher = torch.hub.load('facebookresearch/xcit:main', args.arch, pretrained=False) embed_dim = student.embed_dim # otherwise, we check if the architecture is in torchvision models elif args.arch in torchvision_models.__dict__.keys():