diff --git a/taming/modules/losses/lpips.py b/taming/modules/losses/lpips.py index a7280447..00a46d27 100644 --- a/taming/modules/losses/lpips.py +++ b/taming/modules/losses/lpips.py @@ -1,4 +1,5 @@ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" +import os import torch import torch.nn as nn @@ -10,8 +11,9 @@ class LPIPS(nn.Module): # Learned perceptual metric - def __init__(self, use_dropout=True): + def __init__(self, use_dropout=True, download_directory: str = "/tmp/"): super().__init__() + self.download_directory = download_directory self.scaling_layer = ScalingLayer() self.chns = [64, 128, 256, 512, 512] # vg16 features self.net = vgg16(pretrained=True, requires_grad=False) @@ -25,7 +27,9 @@ def __init__(self, use_dropout=True): param.requires_grad = False def load_from_pretrained(self, name="vgg_lpips"): - ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") + root = os.path.join(self.download_directory, "taming/modules/autoencoder/lpips") + os.makedirs(root, exist_ok=True) + ckpt = get_ckpt_path(name, root) self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) print("loaded pretrained LPIPS loss from {}".format(ckpt))