diff --git a/tools/train.py b/tools/train.py index c7c6c32..90458cc 100644 --- a/tools/train.py +++ b/tools/train.py @@ -26,7 +26,7 @@ image_format_choices, noise_schedule_choices from model.modules.ema import EMA from utils.initializer import device_initializer, seed_initializer, network_initializer, optimizer_initializer, \ - sample_initializer, lr_initializer, amp_initializer + sample_initializer, lr_initializer, amp_initializer, classes_initializer from utils.utils import plot_images, save_images, get_dataset, setup_logging, save_train_logging from utils.checkpoint import load_ckpt, save_ckpt @@ -60,8 +60,14 @@ def train(rank=None, args=None): init_lr = args.lr # Learning rate function lr_func = args.lr_func + # Batch size + batch_size = args.batch_size + # Number of workers + num_workers = args.num_workers + # Dataset path + dataset_path = args.dataset_path # Number of classes - num_classes = args.num_classes + num_classes = classes_initializer(dataset_path=dataset_path) # classifier-free guidance interpolation weight, users can better generate model effect cfg_scale = args.cfg_scale # Whether to enable conditional training @@ -118,7 +124,8 @@ def train(rank=None, args=None): results_vis_dir = results_logging[2] results_tb_dir = results_logging[3] # Dataloader - dataloader = get_dataset(args=args, distributed=distributed) + dataloader = get_dataset(image_size=image_size, dataset_path=dataset_path, batch_size=batch_size, + num_workers=num_workers, distributed=distributed) # Resume training resume = args.resume # Pretrain @@ -396,9 +403,6 @@ def main(args): parser.add_argument("--world_size", type=int, default=2) # =====================Enable the conditional training (if '--conditional' is set to 'True')===================== - # Number of classes (required) - # [Note] The classes settings are consistent with the loaded datasets settings. - parser.add_argument("--num_classes", type=int, default=10) # classifier-free guidance interpolation weight, users can better generate model effect (recommend) parser.add_argument("--cfg_scale", type=int, default=3)