Skip to content

Commit

Permalink
Merge pull request #62 from chairc/dev
Browse files Browse the repository at this point in the history
Modify the get_dataset; Add check.py; Add classes_initializer function
  • Loading branch information
chairc authored Apr 15, 2024
2 parents 9b55551 + a40f525 commit 87a2227
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 39 deletions.
21 changes: 7 additions & 14 deletions test/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,28 +60,21 @@ def test_noising(self):
Test noising
:return: None
"""
# Parameter settings
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=2)
# Input image size
parser.add_argument("--image_size", type=int, default=640)
parser.add_argument("--dataset_path", type=str, default="./noising_test")

args = parser.parse_args()
logger.info(msg=f"Input params: {args}")

# Start test
logger.info(msg="Start noising noising_test.")
dataset_path = args.dataset_path
image_size = 64
batch_size = 1
num_workers = 2
dataset_path = "./noising_test"
save_path = os.path.join(dataset_path, "noise")
# You need to clear all files under the 'noise' folder first
delete_files(path=save_path)
dataloader = get_dataset(args=args)
dataloader = get_dataset(image_size=image_size, dataset_path=dataset_path, batch_size=batch_size,
num_workers=num_workers)
# Recreate the folder
os.makedirs(name=save_path, exist_ok=True)
# Diffusion model initialization
diffusion = sample_initializer(sample="ddpm", image_size=args.image_size, device="cpu")
diffusion = sample_initializer(sample="ddpm", image_size=image_size, device="cpu")
# Get image and noise tensor
image = next(iter(dataloader))[0]
time = torch.Tensor([0, 50, 125, 225, 350, 500, 675, 999]).long()
Expand Down
10 changes: 0 additions & 10 deletions tools/__init__.py

This file was deleted.

16 changes: 10 additions & 6 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
image_format_choices, noise_schedule_choices
from model.modules.ema import EMA
from utils.initializer import device_initializer, seed_initializer, network_initializer, optimizer_initializer, \
sample_initializer, lr_initializer, amp_initializer
sample_initializer, lr_initializer, amp_initializer, classes_initializer
from utils.utils import plot_images, save_images, get_dataset, setup_logging, save_train_logging
from utils.checkpoint import load_ckpt, save_ckpt

Expand Down Expand Up @@ -60,8 +60,14 @@ def train(rank=None, args=None):
init_lr = args.lr
# Learning rate function
lr_func = args.lr_func
# Batch size
batch_size = args.batch_size
# Number of workers
num_workers = args.num_workers
# Dataset path
dataset_path = args.dataset_path
# Number of classes
num_classes = args.num_classes
num_classes = classes_initializer(dataset_path=dataset_path)
# classifier-free guidance interpolation weight, users can better generate model effect
cfg_scale = args.cfg_scale
# Whether to enable conditional training
Expand Down Expand Up @@ -118,7 +124,8 @@ def train(rank=None, args=None):
results_vis_dir = results_logging[2]
results_tb_dir = results_logging[3]
# Dataloader
dataloader = get_dataset(args=args, distributed=distributed)
dataloader = get_dataset(image_size=image_size, dataset_path=dataset_path, batch_size=batch_size,
num_workers=num_workers, distributed=distributed)
# Resume training
resume = args.resume
# Pretrain
Expand Down Expand Up @@ -396,9 +403,6 @@ def main(args):
parser.add_argument("--world_size", type=int, default=2)

# =====================Enable the conditional training (if '--conditional' is set to 'True')=====================
# Number of classes (required)
# [Note] The classes settings are consistent with the loaded datasets settings.
parser.add_argument("--num_classes", type=int, default=10)
# classifier-free guidance interpolation weight, users can better generate model effect (recommend)
parser.add_argument("--cfg_scale", type=int, default=3)

Expand Down
35 changes: 35 additions & 0 deletions utils/check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@Date : 2024/4/14 17:31
@Author : chairc
@Site : https://github.com/chairc
"""
import os
import logging
import coloredlogs

import torch

logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")


def check_and_create_dir(path):
"""
Check and create not exist folder
:param path: Create path
:return: None
"""
logger.info(msg=f"Check and create folder '{path}'.")
os.makedirs(name=path, exist_ok=True)


def check_path_is_exist(path):
"""
Check the path is existed
:param path: Path
:return: None
"""
if not os.path.exists(path=path):
raise FileNotFoundError(f"The path '{path}' does not exist.")
21 changes: 21 additions & 0 deletions utils/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@Author : chairc
@Site : https://github.com/chairc
"""
import os
import random
import numpy as np
import torch
Expand All @@ -21,6 +22,7 @@
from model.samples.ddim import DDIMDiffusion
from model.samples.ddpm import DDPMDiffusion
from model.samples.plms import PLMSDiffusion
from utils.check import check_path_is_exist
from utils.lr_scheduler import set_cosine_lr

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -258,3 +260,22 @@ def check_param_in_dict(param, dict_params, args_param):
act = check_param_in_dict(param="act", dict_params=ckpt_state, args_param=args.act)
logger.info(msg=f"[{device}]: Successfully checked parameters.")
return conditional, network, image_size, num_classes, act


def classes_initializer(dataset_path):
"""
Initialize number of classes
:param dataset_path: Dataset path
:return: num_classes
"""
check_path_is_exist(path=dataset_path)
num_classes = 0
# Loop dataset path
for classes_dir in os.listdir(path=dataset_path):
# Check current dir
if os.path.isdir(s=os.path.join(dataset_path, classes_dir)):
num_classes += 1
logger.info(msg=f"Current number of classes is {num_classes}.")
if num_classes == 0:
raise Exception(f"No dataset folders found in '{dataset_path}'.")
return num_classes
23 changes: 14 additions & 9 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from config.choices import RANDOM_RESIZED_CROP_SCALE, MEAN, STD
from sr.dataset import SRDataset
from utils.check import check_path_is_exist

logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")
Expand Down Expand Up @@ -93,7 +94,7 @@ def save_one_image_in_images(images, path, generate_name, image_size=None, image
count += 1


def get_dataset(args, distributed=False):
def get_dataset(image_size=64, dataset_path=None, batch_size=2, num_workers=0, distributed=False):
"""
Get dataset
Expand Down Expand Up @@ -134,16 +135,21 @@ def get_dataset(args, distributed=False):
| | | |
+------------------------+ +-----------+
:param args: Parameters
:param image_size: Image size
:param dataset_path: Dataset path
:param batch_size: Batch size
:param num_workers: Number of workers
:param distributed: Whether to distribute training
:return: dataloader
"""
check_path_is_exist(path=dataset_path)
# Data augmentation
transforms = torchvision.transforms.Compose([
# Resize input size
# Resize input size, input type is (height, width)
# torchvision.transforms.Resize(80), args.image_size + 1/4 * args.image_size
torchvision.transforms.Resize(size=int(args.image_size + args.image_size / 4)),
torchvision.transforms.Resize(size=int(image_size + image_size / 4)),
# Random adjustment cropping
torchvision.transforms.RandomResizedCrop(size=args.image_size, scale=RANDOM_RESIZED_CROP_SCALE),
torchvision.transforms.RandomResizedCrop(size=image_size, scale=RANDOM_RESIZED_CROP_SCALE),
# To Tensor Format
torchvision.transforms.ToTensor(),
# For standardization, the mean and standard deviation
Expand All @@ -152,14 +158,13 @@ def get_dataset(args, distributed=False):
])
# Load the folder data under the current path,
# and automatically divide the labels according to the dataset under each file name
dataset = torchvision.datasets.ImageFolder(root=args.dataset_path, transform=transforms)
dataset = torchvision.datasets.ImageFolder(root=dataset_path, transform=transforms)
if distributed:
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers,
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers,
pin_memory=True, sampler=sampler)
else:
dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
pin_memory=True)
return dataloader

Expand Down

0 comments on commit 87a2227

Please sign in to comment.