diff --git a/config/__init__.py b/config/__init__.py index 3ae6e8f..a1afc4f 100644 --- a/config/__init__.py +++ b/config/__init__.py @@ -8,4 +8,5 @@ from .choices import bool_choices, sample_choices, network_choices, optim_choices, act_choices, lr_func_choices, \ image_format_choices, noise_schedule_choices from .setting import MASTER_ADDR, MASTER_PORT, EMA_BETA, RANDOM_RESIZED_CROP_SCALE, MEAN, STD -from .version import __version__, get_versions, get_latest_version, get_old_versions, check_version_is_latest +from .version import __version__, get_versions, get_latest_version, get_old_versions, check_version_is_latest, \ + get_version_banner diff --git a/config/banner.txt b/config/banner.txt new file mode 100644 index 0000000..a664132 --- /dev/null +++ b/config/banner.txt @@ -0,0 +1,8 @@ + _____ _ +| __ \ (_) +| |__) | _ _ __ _ __ _ _ __ __ _ +| _ / | | | '_ \| '_ \| | '_ \ / _` | +| | \ \ |_| | | | | | | | | | | | (_| | _ _ _ +|_| \_\__,_|_| |_|_| |_|_|_| |_|\__, | (_) (_) (_) + __/ | + |___/ \ No newline at end of file diff --git a/config/version.py b/config/version.py index ecb3926..8182c87 100644 --- a/config/version.py +++ b/config/version.py @@ -57,7 +57,21 @@ def check_version_is_latest(current_version): return False +def get_version_banner(): + """ + Get version banner. + """ + with open(file="../config/banner.txt", mode="r", encoding="utf-8") as banner_file: + contents = banner_file.read() + print(contents) + print(f"===============IDDM version: {get_latest_version()}===============\n" + "Project Author : chairc\n" + "Project GitHub : https://github.com/chairc/Integrated-Design-Diffusion-Model") + banner_file.close() + + if __name__ == "__main__": get_versions() get_latest_version() get_old_versions() + get_version_banner() diff --git a/sr/demo.py b/sr/demo.py index 16f8f22..8411f0e 100644 --- a/sr/demo.py +++ b/sr/demo.py @@ -16,6 +16,7 @@ from PIL import Image sys.path.append(os.path.dirname(sys.path[0])) +from config.version import get_version_banner from sr.interface import inference, load_sr_model from utils.initializer import device_initializer from utils.utils import plot_images, save_images, check_and_create_dir @@ -81,4 +82,6 @@ def lr2hr(args): parser.add_argument("--result_path", type=str, default="/your/path/Diffusion-Model/result") args = parser.parse_args() + # Get version banner + get_version_banner() lr2hr(args) diff --git a/sr/train.py b/sr/train.py index 92fe37f..8869907 100644 --- a/sr/train.py +++ b/sr/train.py @@ -25,6 +25,7 @@ sys.path.append(os.path.dirname(sys.path[0])) from config.choices import loss_func_choices, sr_network_choices, optim_choices from config.setting import MASTER_ADDR, MASTER_PORT, EMA_BETA +from config.version import get_version_banner from model.modules.ema import EMA from utils.initializer import device_initializer, seed_initializer, sr_network_initializer, optimizer_initializer, \ lr_initializer, amp_initializer, loss_initializer @@ -379,5 +380,6 @@ def main(args): parser.add_argument("--world_size", type=int, default=2) args = parser.parse_args() - + # Get version banner + get_version_banner() main(args) diff --git a/tools/FID_calculator_plus.py b/tools/FID_calculator_plus.py index c01765d..581aade 100644 --- a/tools/FID_calculator_plus.py +++ b/tools/FID_calculator_plus.py @@ -16,6 +16,7 @@ from pytorch_fid.inception import InceptionV3 sys.path.append(os.path.dirname(sys.path[0])) +from config.version import get_version_banner from utils.initializer import device_initializer logger = logging.getLogger(__name__) @@ -68,4 +69,6 @@ def main(args): # Set the use GPU in normal training (required) parser.add_argument("--use_gpu", type=int, default=0) args = parser.parse_args() + # Get version banner + get_version_banner() main(args) diff --git a/tools/generate.py b/tools/generate.py index 58dac53..e68fc20 100644 --- a/tools/generate.py +++ b/tools/generate.py @@ -16,6 +16,7 @@ sys.path.append(os.path.dirname(sys.path[0])) from config.choices import sample_choices, network_choices, act_choices, image_format_choices, parse_image_size_type +from config.version import get_version_banner from utils.check import check_image_size from utils.initializer import device_initializer, network_initializer, sample_initializer, generate_initializer from utils.utils import plot_images, save_images, save_one_image_in_images, check_and_create_dir @@ -165,4 +166,6 @@ def generate(args): parser.add_argument("--num_classes", type=int, default=10) args = parser.parse_args() + # Get version banner + get_version_banner() generate(args) diff --git a/tools/train.py b/tools/train.py index 5ce1fc0..46400cd 100644 --- a/tools/train.py +++ b/tools/train.py @@ -25,6 +25,7 @@ from config.choices import sample_choices, network_choices, optim_choices, act_choices, lr_func_choices, \ image_format_choices, noise_schedule_choices, parse_image_size_type, loss_func_choices from config.setting import MASTER_ADDR, MASTER_PORT, EMA_BETA +from config.version import get_version_banner from model.modules.ema import EMA from utils.check import check_image_size from utils.dataset import get_dataset @@ -431,5 +432,6 @@ def main(args): parser.add_argument("--cfg_scale", type=int, default=3) args = parser.parse_args() - + # Get version banner + get_version_banner() main(args)