From cae695e5ada44eefe6a27bf6abfd548e5306a48e Mon Sep 17 00:00:00 2001 From: ChairC <974833488@qq.com> Date: Tue, 16 Apr 2024 00:19:22 +0800 Subject: [PATCH 1/5] Update: Modify the get_dataset function, remove args parameter. --- utils/utils.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index d3488de..f18a3fd 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -20,6 +20,7 @@ from config.choices import RANDOM_RESIZED_CROP_SCALE, MEAN, STD from sr.dataset import SRDataset +from utils.check import check_path_is_exist logger = logging.getLogger(__name__) coloredlogs.install(level="INFO") @@ -93,7 +94,7 @@ def save_one_image_in_images(images, path, generate_name, image_size=None, image count += 1 -def get_dataset(args, distributed=False): +def get_dataset(image_size=64, dataset_path=None, batch_size=2, num_workers=0, distributed=False): """ Get dataset @@ -134,16 +135,21 @@ def get_dataset(args, distributed=False): | | | | +------------------------+ +-----------+ - :param args: Parameters + :param image_size: Image size + :param dataset_path: Dataset path + :param batch_size: Batch size + :param num_workers: Number of workers :param distributed: Whether to distribute training :return: dataloader """ + check_path_is_exist(path=dataset_path) + # Data augmentation transforms = torchvision.transforms.Compose([ - # Resize input size + # Resize input size, input type is (height, width) # torchvision.transforms.Resize(80), args.image_size + 1/4 * args.image_size - torchvision.transforms.Resize(size=int(args.image_size + args.image_size / 4)), + torchvision.transforms.Resize(size=int(image_size + image_size / 4)), # Random adjustment cropping - torchvision.transforms.RandomResizedCrop(size=args.image_size, scale=RANDOM_RESIZED_CROP_SCALE), + torchvision.transforms.RandomResizedCrop(size=image_size, scale=RANDOM_RESIZED_CROP_SCALE), # To Tensor Format torchvision.transforms.ToTensor(), # For standardization, the mean and standard deviation @@ -152,14 +158,13 @@ def get_dataset(args, distributed=False): ]) # Load the folder data under the current path, # and automatically divide the labels according to the dataset under each file name - dataset = torchvision.datasets.ImageFolder(root=args.dataset_path, transform=transforms) + dataset = torchvision.datasets.ImageFolder(root=dataset_path, transform=transforms) if distributed: sampler = DistributedSampler(dataset) - dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False, - num_workers=args.num_workers, + dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, sampler=sampler) else: - dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, + dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) return dataloader From 4a353336f9ad5d5d554d32908522bc7e348afe81 Mon Sep 17 00:00:00 2001 From: ChairC <974833488@qq.com> Date: Tue, 16 Apr 2024 00:21:13 +0800 Subject: [PATCH 2/5] Add: Add check.py. --- tools/__init__.py | 10 ---------- utils/check.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 10 deletions(-) delete mode 100644 tools/__init__.py create mode 100644 utils/check.py diff --git a/tools/__init__.py b/tools/__init__.py deleted file mode 100644 index 42a52c5..0000000 --- a/tools/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python -# -*- coding:utf-8 -*- -""" - @Date : 2023/6/20 17:43 - @Author : chairc - @Site : https://github.com/chairc -""" -import os -import sys -sys.path.append(os.path.dirname(sys.path[0])) \ No newline at end of file diff --git a/utils/check.py b/utils/check.py new file mode 100644 index 0000000..c402747 --- /dev/null +++ b/utils/check.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" + @Date : 2024/4/14 17:31 + @Author : chairc + @Site : https://github.com/chairc +""" +import os +import logging +import coloredlogs + +import torch + +logger = logging.getLogger(__name__) +coloredlogs.install(level="INFO") + + +def check_and_create_dir(path): + """ + Check and create not exist folder + :param path: Create path + :return: None + """ + logger.info(msg=f"Check and create folder '{path}'.") + os.makedirs(name=path, exist_ok=True) + + +def check_path_is_exist(path): + """ + Check the path is existed + :param path: Path + :return: None + """ + if not os.path.exists(path=path): + raise FileNotFoundError(f"The path '{path}' does not exist.") From cec82ec5d5e592fdb3ba5ce8db5e2d1b4f151097 Mon Sep 17 00:00:00 2001 From: ChairC <974833488@qq.com> Date: Tue, 16 Apr 2024 00:22:34 +0800 Subject: [PATCH 3/5] Add: Add classes_initializer function, auto check the number of classes. --- utils/initializer.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/utils/initializer.py b/utils/initializer.py index 2bb727c..7d771ed 100644 --- a/utils/initializer.py +++ b/utils/initializer.py @@ -5,6 +5,7 @@ @Author : chairc @Site : https://github.com/chairc """ +import os import random import numpy as np import torch @@ -21,6 +22,7 @@ from model.samples.ddim import DDIMDiffusion from model.samples.ddpm import DDPMDiffusion from model.samples.plms import PLMSDiffusion +from utils.check import check_path_is_exist from utils.lr_scheduler import set_cosine_lr logger = logging.getLogger(__name__) @@ -258,3 +260,22 @@ def check_param_in_dict(param, dict_params, args_param): act = check_param_in_dict(param="act", dict_params=ckpt_state, args_param=args.act) logger.info(msg=f"[{device}]: Successfully checked parameters.") return conditional, network, image_size, num_classes, act + + +def classes_initializer(dataset_path): + """ + Initialize number of classes + :param dataset_path: Dataset path + :return: num_classes + """ + check_path_is_exist(path=dataset_path) + num_classes = 0 + # Loop dataset path + for classes_dir in os.listdir(path=dataset_path): + # Check current dir + if os.path.isdir(s=os.path.join(dataset_path, classes_dir)): + num_classes += 1 + logger.info(msg=f"Current number of classes is {num_classes}.") + if num_classes == 0: + raise Exception(f"No dataset folders found in '{dataset_path}'.") + return num_classes From be5ffacb41308abd537607c5d755448107f8b058 Mon Sep 17 00:00:00 2001 From: ChairC <974833488@qq.com> Date: Tue, 16 Apr 2024 00:25:07 +0800 Subject: [PATCH 4/5] Update: Remove the num classes parameter; replace manual num classes with the auto acquisition method; modify the get_dataset method. --- tools/train.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tools/train.py b/tools/train.py index c7c6c32..90458cc 100644 --- a/tools/train.py +++ b/tools/train.py @@ -26,7 +26,7 @@ image_format_choices, noise_schedule_choices from model.modules.ema import EMA from utils.initializer import device_initializer, seed_initializer, network_initializer, optimizer_initializer, \ - sample_initializer, lr_initializer, amp_initializer + sample_initializer, lr_initializer, amp_initializer, classes_initializer from utils.utils import plot_images, save_images, get_dataset, setup_logging, save_train_logging from utils.checkpoint import load_ckpt, save_ckpt @@ -60,8 +60,14 @@ def train(rank=None, args=None): init_lr = args.lr # Learning rate function lr_func = args.lr_func + # Batch size + batch_size = args.batch_size + # Number of workers + num_workers = args.num_workers + # Dataset path + dataset_path = args.dataset_path # Number of classes - num_classes = args.num_classes + num_classes = classes_initializer(dataset_path=dataset_path) # classifier-free guidance interpolation weight, users can better generate model effect cfg_scale = args.cfg_scale # Whether to enable conditional training @@ -118,7 +124,8 @@ def train(rank=None, args=None): results_vis_dir = results_logging[2] results_tb_dir = results_logging[3] # Dataloader - dataloader = get_dataset(args=args, distributed=distributed) + dataloader = get_dataset(image_size=image_size, dataset_path=dataset_path, batch_size=batch_size, + num_workers=num_workers, distributed=distributed) # Resume training resume = args.resume # Pretrain @@ -396,9 +403,6 @@ def main(args): parser.add_argument("--world_size", type=int, default=2) # =====================Enable the conditional training (if '--conditional' is set to 'True')===================== - # Number of classes (required) - # [Note] The classes settings are consistent with the loaded datasets settings. - parser.add_argument("--num_classes", type=int, default=10) # classifier-free guidance interpolation weight, users can better generate model effect (recommend) parser.add_argument("--cfg_scale", type=int, default=3) From a40f525227b824d3f47674df3f139e8eac78b773 Mon Sep 17 00:00:00 2001 From: ChairC <974833488@qq.com> Date: Tue, 16 Apr 2024 00:28:16 +0800 Subject: [PATCH 5/5] Update: Modify the get_dataset in test file --- test/test_module.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/test/test_module.py b/test/test_module.py index 29c3add..766fe13 100644 --- a/test/test_module.py +++ b/test/test_module.py @@ -60,28 +60,21 @@ def test_noising(self): Test noising :return: None """ - # Parameter settings - parser = argparse.ArgumentParser() - parser.add_argument("--batch_size", type=int, default=1) - parser.add_argument("--num_workers", type=int, default=2) - # Input image size - parser.add_argument("--image_size", type=int, default=640) - parser.add_argument("--dataset_path", type=str, default="./noising_test") - - args = parser.parse_args() - logger.info(msg=f"Input params: {args}") - # Start test logger.info(msg="Start noising noising_test.") - dataset_path = args.dataset_path + image_size = 64 + batch_size = 1 + num_workers = 2 + dataset_path = "./noising_test" save_path = os.path.join(dataset_path, "noise") # You need to clear all files under the 'noise' folder first delete_files(path=save_path) - dataloader = get_dataset(args=args) + dataloader = get_dataset(image_size=image_size, dataset_path=dataset_path, batch_size=batch_size, + num_workers=num_workers) # Recreate the folder os.makedirs(name=save_path, exist_ok=True) # Diffusion model initialization - diffusion = sample_initializer(sample="ddpm", image_size=args.image_size, device="cpu") + diffusion = sample_initializer(sample="ddpm", image_size=image_size, device="cpu") # Get image and noise tensor image = next(iter(dataloader))[0] time = torch.Tensor([0, 50, 125, 225, 350, 500, 675, 999]).long()