diff --git a/models/vgg_.py b/models/vgg_.py index 130083d..fbcef24 100644 --- a/models/vgg_.py +++ b/models/vgg_.py @@ -2,8 +2,12 @@ """ Mostly copy-paste from torchvision references. """ +from posixpath import split import torch import torch.nn as nn +import os +from tqdm import tqdm +import urllib.request __all__ = [ @@ -23,14 +27,13 @@ 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', } +checkpoints_dir = os.path.join(os.path.dirname(__file__), "../checkpoints/") model_paths = { - 'vgg16_bn': '/apdcephfs/private_changanwang/checkpoints/vgg16_bn-6c64b313.pth', - 'vgg16': '/apdcephfs/private_changanwang/checkpoints/vgg16-397923af.pth', - + 'vgg16_bn': os.path.join(checkpoints_dir,'vgg16_bn-6c64b313.pth'), + 'vgg16': os.path.join(checkpoints_dir,'vgg16-397923af.pth'), } - class VGG(nn.Module): def __init__(self, features, num_classes=1000, init_weights=True): @@ -98,11 +101,30 @@ def make_layers(cfg, batch_norm=False, sync=False): } +class DownloadProgressBar(tqdm): + def update_to(self, b=1, bsize=1, tsize=None): + if tsize is not None: + self.total = tsize + self.update(b * bsize - self.n) + +def download_arch_if_not_exists(arch): + if not os.path.exists(model_paths[arch]): + print(f"{arch} model not found.") + chunk_size = 1024 + os.makedirs(checkpoints_dir, exist_ok=True) + + url = model_urls[arch] + filename = os.path.basename(url) + output_path = os.path.join(checkpoints_dir, filename) + with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t: + urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to) + def _vgg(arch, cfg, batch_norm, pretrained, progress, sync=False, **kwargs): if pretrained: kwargs['init_weights'] = False model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm, sync=sync), **kwargs) if pretrained: + download_arch_if_not_exists(arch) state_dict = torch.load(model_paths[arch]) model.load_state_dict(state_dict) return model