From dc441d2c0faf55ed9a5ac41b141d41edcf31dcd2 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 30 May 2022 11:00:47 -0700 Subject: [PATCH] fix version problem if researcher did not install from pip --- dalle_pytorch/__init__.py | 2 +- generate.py | 1 + setup.py | 4 +++- train_dalle.py | 7 ++----- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dalle_pytorch/__init__.py b/dalle_pytorch/__init__.py index 09ba97bf..8ca869b1 100644 --- a/dalle_pytorch/__init__.py +++ b/dalle_pytorch/__init__.py @@ -2,4 +2,4 @@ from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE from pkg_resources import get_distribution -__version__ = get_distribution('dalle_pytorch').version +from dalle_pytorch.version import __version__ diff --git a/generate.py b/generate.py index b58ef6eb..8c0b7638 100644 --- a/generate.py +++ b/generate.py @@ -15,6 +15,7 @@ # dalle related classes and utils +from dalle_pytorch import __version__ from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer diff --git a/setup.py b/setup.py index 68484cd2..777912a3 100644 --- a/setup.py +++ b/setup.py @@ -1,14 +1,16 @@ from setuptools import setup, find_packages +exec(open('dalle_pytorch/version.py').read()) setup( name = 'dalle-pytorch', packages = find_packages(), include_package_data = True, - version = '1.6.2', + version = __version__, license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang', author_email = 'lucidrains@gmail.com', + long_description_content_type = 'text/markdown', url = 'https://github.com/lucidrains/dalle-pytorch', keywords = [ 'artificial intelligence', diff --git a/train_dalle.py b/train_dalle.py index 2da2e540..6e3a4e09 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -12,6 +12,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.utils.data import DataLoader +from dalle_pytorch import __version__ from dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE, DiscreteVAE, DALLE from dalle_pytorch import distributed_utils from dalle_pytorch.loader import TextImageDataset @@ -147,10 +148,6 @@ def exists(val): def get_trainable_params(model): return [params for params in model.parameters() if params.requires_grad] -def get_pkg_version(): - from pkg_resources import get_distribution - return get_distribution('dalle_pytorch').version - def cp_path_to_dir(cp_path, tag): """Convert a checkpoint path to a directory with `tag` inserted. If `cp_path` is already a directory, return it unchanged. @@ -540,7 +537,7 @@ def save_model(path, epoch=0): 'hparams': dalle_params, 'vae_params': vae_params, 'epoch': epoch, - 'version': get_pkg_version(), + 'version': __version__, 'vae_class_name': vae.__class__.__name__ }