Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update: Refactor train.py format. #66

Merged
merged 1 commit into from
May 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,31 @@ def train(rank=None, args=None):
:param args: Input parameters
:return: None
"""
# =================================Before training=================================
# Output params to console
logger.info(msg=f"[{rank}]: Input params: {args}")
# Step1: Set path and create log
# Saving path
result_path = args.result_path
# Run name
run_name = args.run_name
# Create data logging path
results_logging = setup_logging(save_path=result_path, run_name=run_name)
results_dir = results_logging[1]
results_vis_dir = results_logging[2]
results_tb_dir = results_logging[3]
# Tensorboard
tb_logger = SummaryWriter(log_dir=results_tb_dir)
# Train log
save_train_logging(arg=args, save_path=results_dir)

# Step2: Get the parameters of the initializer and args
# Initialize the seed
seed_initializer(seed_id=args.seed)
# Sample type
sample = args.sample
# Network
network = args.network
# Run name
run_name = args.run_name
# Input image size
image_size = args.image_size
# Select optimizer
Expand Down Expand Up @@ -116,20 +132,13 @@ def train(rank=None, args=None):
image_format = args.image_format
# Noise schedule
noise_schedule = args.noise_schedule
# Saving path
result_path = args.result_path
# Create data logging path
results_logging = setup_logging(save_path=result_path, run_name=run_name)
results_dir = results_logging[1]
results_vis_dir = results_logging[2]
results_tb_dir = results_logging[3]
# Dataloader
dataloader = get_dataset(image_size=image_size, dataset_path=dataset_path, batch_size=batch_size,
num_workers=num_workers, distributed=distributed)
# Resume training
resume = args.resume
# Pretrain
pretrain = args.pretrain

# =================================About model initializer=================================
# Step3: Init model
# Network
Network = network_initializer(network=network, device=device)
# Model
Expand Down Expand Up @@ -170,17 +179,21 @@ def train(rank=None, args=None):
mse = nn.MSELoss()
# Initialize the diffusion model
diffusion = sample_initializer(sample=sample, image_size=image_size, device=device, schedule_name=noise_schedule)
# Tensorboard
tb_logger = SummaryWriter(log_dir=results_tb_dir)
# Train log
save_train_logging(args, results_dir)
# Number of dataset batches in the dataloader
len_dataloader = len(dataloader)
# Exponential Moving Average (EMA) may not be as dominant for single class as for multi class
ema = EMA(beta=0.995)
# EMA model
ema_model = copy.deepcopy(model).eval().requires_grad_(False)

# =================================About data=================================
# Step4: Set data
# Dataloader
dataloader = get_dataset(image_size=image_size, dataset_path=dataset_path, batch_size=batch_size,
num_workers=num_workers, distributed=distributed)
# Number of dataset batches in the dataloader
len_dataloader = len(dataloader)

# =================================Training=================================
# Step5: Training
logger.info(msg=f"[{device}]: Start training.")
# Start iterating
for epoch in range(start_epoch, args.epochs):
Expand Down Expand Up @@ -281,6 +294,7 @@ def train(rank=None, args=None):
dist.barrier()

logger.info(msg=f"[{device}]: Finish training.")
logger.info(msg="[Note]: If you want to evaluate model quality, use 'FID_calculator.py' to evaluate.")

# Clean up the distributed environment
if distributed:
Expand Down