Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify the get_dataset; Add check.py; Add classes_initializer function #62

Merged
merged 5 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions test/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,28 +60,21 @@ def test_noising(self):
Test noising
:return: None
"""
# Parameter settings
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=2)
# Input image size
parser.add_argument("--image_size", type=int, default=640)
parser.add_argument("--dataset_path", type=str, default="./noising_test")

args = parser.parse_args()
logger.info(msg=f"Input params: {args}")

# Start test
logger.info(msg="Start noising noising_test.")
dataset_path = args.dataset_path
image_size = 64
batch_size = 1
num_workers = 2
dataset_path = "./noising_test"
save_path = os.path.join(dataset_path, "noise")
# You need to clear all files under the 'noise' folder first
delete_files(path=save_path)
dataloader = get_dataset(args=args)
dataloader = get_dataset(image_size=image_size, dataset_path=dataset_path, batch_size=batch_size,
num_workers=num_workers)
# Recreate the folder
os.makedirs(name=save_path, exist_ok=True)
# Diffusion model initialization
diffusion = sample_initializer(sample="ddpm", image_size=args.image_size, device="cpu")
diffusion = sample_initializer(sample="ddpm", image_size=image_size, device="cpu")
# Get image and noise tensor
image = next(iter(dataloader))[0]
time = torch.Tensor([0, 50, 125, 225, 350, 500, 675, 999]).long()
Expand Down
10 changes: 0 additions & 10 deletions tools/__init__.py

This file was deleted.

16 changes: 10 additions & 6 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
image_format_choices, noise_schedule_choices
from model.modules.ema import EMA
from utils.initializer import device_initializer, seed_initializer, network_initializer, optimizer_initializer, \
sample_initializer, lr_initializer, amp_initializer
sample_initializer, lr_initializer, amp_initializer, classes_initializer
from utils.utils import plot_images, save_images, get_dataset, setup_logging, save_train_logging
from utils.checkpoint import load_ckpt, save_ckpt

Expand Down Expand Up @@ -60,8 +60,14 @@ def train(rank=None, args=None):
init_lr = args.lr
# Learning rate function
lr_func = args.lr_func
# Batch size
batch_size = args.batch_size
# Number of workers
num_workers = args.num_workers
# Dataset path
dataset_path = args.dataset_path
# Number of classes
num_classes = args.num_classes
num_classes = classes_initializer(dataset_path=dataset_path)
# classifier-free guidance interpolation weight, users can better generate model effect
cfg_scale = args.cfg_scale
# Whether to enable conditional training
Expand Down Expand Up @@ -118,7 +124,8 @@ def train(rank=None, args=None):
results_vis_dir = results_logging[2]
results_tb_dir = results_logging[3]
# Dataloader
dataloader = get_dataset(args=args, distributed=distributed)
dataloader = get_dataset(image_size=image_size, dataset_path=dataset_path, batch_size=batch_size,
num_workers=num_workers, distributed=distributed)
# Resume training
resume = args.resume
# Pretrain
Expand Down Expand Up @@ -396,9 +403,6 @@ def main(args):
parser.add_argument("--world_size", type=int, default=2)

# =====================Enable the conditional training (if '--conditional' is set to 'True')=====================
# Number of classes (required)
# [Note] The classes settings are consistent with the loaded datasets settings.
parser.add_argument("--num_classes", type=int, default=10)
# classifier-free guidance interpolation weight, users can better generate model effect (recommend)
parser.add_argument("--cfg_scale", type=int, default=3)

Expand Down
35 changes: 35 additions & 0 deletions utils/check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@Date : 2024/4/14 17:31
@Author : chairc
@Site : https://github.com/chairc
"""
import os
import logging
import coloredlogs

import torch

logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")


def check_and_create_dir(path):
"""
Check and create not exist folder
:param path: Create path
:return: None
"""
logger.info(msg=f"Check and create folder '{path}'.")
os.makedirs(name=path, exist_ok=True)


def check_path_is_exist(path):
"""
Check the path is existed
:param path: Path
:return: None
"""
if not os.path.exists(path=path):
raise FileNotFoundError(f"The path '{path}' does not exist.")
21 changes: 21 additions & 0 deletions utils/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@Author : chairc
@Site : https://github.com/chairc
"""
import os
import random
import numpy as np
import torch
Expand All @@ -21,6 +22,7 @@
from model.samples.ddim import DDIMDiffusion
from model.samples.ddpm import DDPMDiffusion
from model.samples.plms import PLMSDiffusion
from utils.check import check_path_is_exist
from utils.lr_scheduler import set_cosine_lr

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -258,3 +260,22 @@ def check_param_in_dict(param, dict_params, args_param):
act = check_param_in_dict(param="act", dict_params=ckpt_state, args_param=args.act)
logger.info(msg=f"[{device}]: Successfully checked parameters.")
return conditional, network, image_size, num_classes, act


def classes_initializer(dataset_path):
"""
Initialize number of classes
:param dataset_path: Dataset path
:return: num_classes
"""
check_path_is_exist(path=dataset_path)
num_classes = 0
# Loop dataset path
for classes_dir in os.listdir(path=dataset_path):
# Check current dir
if os.path.isdir(s=os.path.join(dataset_path, classes_dir)):
num_classes += 1
logger.info(msg=f"Current number of classes is {num_classes}.")
if num_classes == 0:
raise Exception(f"No dataset folders found in '{dataset_path}'.")
return num_classes
23 changes: 14 additions & 9 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from config.choices import RANDOM_RESIZED_CROP_SCALE, MEAN, STD
from sr.dataset import SRDataset
from utils.check import check_path_is_exist

logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")
Expand Down Expand Up @@ -93,7 +94,7 @@ def save_one_image_in_images(images, path, generate_name, image_size=None, image
count += 1


def get_dataset(args, distributed=False):
def get_dataset(image_size=64, dataset_path=None, batch_size=2, num_workers=0, distributed=False):
"""
Get dataset

Expand Down Expand Up @@ -134,16 +135,21 @@ def get_dataset(args, distributed=False):
| | | |
+------------------------+ +-----------+

:param args: Parameters
:param image_size: Image size
:param dataset_path: Dataset path
:param batch_size: Batch size
:param num_workers: Number of workers
:param distributed: Whether to distribute training
:return: dataloader
"""
check_path_is_exist(path=dataset_path)
# Data augmentation
transforms = torchvision.transforms.Compose([
# Resize input size
# Resize input size, input type is (height, width)
# torchvision.transforms.Resize(80), args.image_size + 1/4 * args.image_size
torchvision.transforms.Resize(size=int(args.image_size + args.image_size / 4)),
torchvision.transforms.Resize(size=int(image_size + image_size / 4)),
# Random adjustment cropping
torchvision.transforms.RandomResizedCrop(size=args.image_size, scale=RANDOM_RESIZED_CROP_SCALE),
torchvision.transforms.RandomResizedCrop(size=image_size, scale=RANDOM_RESIZED_CROP_SCALE),
# To Tensor Format
torchvision.transforms.ToTensor(),
# For standardization, the mean and standard deviation
Expand All @@ -152,14 +158,13 @@ def get_dataset(args, distributed=False):
])
# Load the folder data under the current path,
# and automatically divide the labels according to the dataset under each file name
dataset = torchvision.datasets.ImageFolder(root=args.dataset_path, transform=transforms)
dataset = torchvision.datasets.ImageFolder(root=dataset_path, transform=transforms)
if distributed:
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers,
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers,
pin_memory=True, sampler=sampler)
else:
dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
pin_memory=True)
return dataloader

Expand Down