Skip to content

Commit

Permalink
Update: Remove the num classes parameter; replace manual num classes …
Browse files Browse the repository at this point in the history
…with the auto acquisition method; modify the get_dataset method.
  • Loading branch information
chairc committed Apr 15, 2024
1 parent cec82ec commit be5ffac
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
image_format_choices, noise_schedule_choices
from model.modules.ema import EMA
from utils.initializer import device_initializer, seed_initializer, network_initializer, optimizer_initializer, \
sample_initializer, lr_initializer, amp_initializer
sample_initializer, lr_initializer, amp_initializer, classes_initializer
from utils.utils import plot_images, save_images, get_dataset, setup_logging, save_train_logging
from utils.checkpoint import load_ckpt, save_ckpt

Expand Down Expand Up @@ -60,8 +60,14 @@ def train(rank=None, args=None):
init_lr = args.lr
# Learning rate function
lr_func = args.lr_func
# Batch size
batch_size = args.batch_size
# Number of workers
num_workers = args.num_workers
# Dataset path
dataset_path = args.dataset_path
# Number of classes
num_classes = args.num_classes
num_classes = classes_initializer(dataset_path=dataset_path)
# classifier-free guidance interpolation weight, users can better generate model effect
cfg_scale = args.cfg_scale
# Whether to enable conditional training
Expand Down Expand Up @@ -118,7 +124,8 @@ def train(rank=None, args=None):
results_vis_dir = results_logging[2]
results_tb_dir = results_logging[3]
# Dataloader
dataloader = get_dataset(args=args, distributed=distributed)
dataloader = get_dataset(image_size=image_size, dataset_path=dataset_path, batch_size=batch_size,
num_workers=num_workers, distributed=distributed)
# Resume training
resume = args.resume
# Pretrain
Expand Down Expand Up @@ -396,9 +403,6 @@ def main(args):
parser.add_argument("--world_size", type=int, default=2)

# =====================Enable the conditional training (if '--conditional' is set to 'True')=====================
# Number of classes (required)
# [Note] The classes settings are consistent with the loaded datasets settings.
parser.add_argument("--num_classes", type=int, default=10)
# classifier-free guidance interpolation weight, users can better generate model effect (recommend)
parser.add_argument("--cfg_scale", type=int, default=3)

Expand Down

0 comments on commit be5ffac

Please sign in to comment.