From c3f85056be2e8afc3e5ce6b8bceaf0838cde7d2c Mon Sep 17 00:00:00 2001 From: ChairC <974833488@qq.com> Date: Sun, 5 May 2024 22:10:29 +0800 Subject: [PATCH] Update: Refactor train.py format. --- tools/train.py | 50 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/tools/train.py b/tools/train.py index 90458cc..cea8e69 100644 --- a/tools/train.py +++ b/tools/train.py @@ -41,15 +41,31 @@ def train(rank=None, args=None): :param args: Input parameters :return: None """ + # =================================Before training================================= + # Output params to console logger.info(msg=f"[{rank}]: Input params: {args}") + # Step1: Set path and create log + # Saving path + result_path = args.result_path + # Run name + run_name = args.run_name + # Create data logging path + results_logging = setup_logging(save_path=result_path, run_name=run_name) + results_dir = results_logging[1] + results_vis_dir = results_logging[2] + results_tb_dir = results_logging[3] + # Tensorboard + tb_logger = SummaryWriter(log_dir=results_tb_dir) + # Train log + save_train_logging(arg=args, save_path=results_dir) + + # Step2: Get the parameters of the initializer and args # Initialize the seed seed_initializer(seed_id=args.seed) # Sample type sample = args.sample # Network network = args.network - # Run name - run_name = args.run_name # Input image size image_size = args.image_size # Select optimizer @@ -116,20 +132,13 @@ def train(rank=None, args=None): image_format = args.image_format # Noise schedule noise_schedule = args.noise_schedule - # Saving path - result_path = args.result_path - # Create data logging path - results_logging = setup_logging(save_path=result_path, run_name=run_name) - results_dir = results_logging[1] - results_vis_dir = results_logging[2] - results_tb_dir = results_logging[3] - # Dataloader - dataloader = get_dataset(image_size=image_size, dataset_path=dataset_path, batch_size=batch_size, - num_workers=num_workers, distributed=distributed) # Resume training resume = args.resume # Pretrain pretrain = args.pretrain + + # =================================About model initializer================================= + # Step3: Init model # Network Network = network_initializer(network=network, device=device) # Model @@ -170,17 +179,21 @@ def train(rank=None, args=None): mse = nn.MSELoss() # Initialize the diffusion model diffusion = sample_initializer(sample=sample, image_size=image_size, device=device, schedule_name=noise_schedule) - # Tensorboard - tb_logger = SummaryWriter(log_dir=results_tb_dir) - # Train log - save_train_logging(args, results_dir) - # Number of dataset batches in the dataloader - len_dataloader = len(dataloader) # Exponential Moving Average (EMA) may not be as dominant for single class as for multi class ema = EMA(beta=0.995) # EMA model ema_model = copy.deepcopy(model).eval().requires_grad_(False) + # =================================About data================================= + # Step4: Set data + # Dataloader + dataloader = get_dataset(image_size=image_size, dataset_path=dataset_path, batch_size=batch_size, + num_workers=num_workers, distributed=distributed) + # Number of dataset batches in the dataloader + len_dataloader = len(dataloader) + + # =================================Training================================= + # Step5: Training logger.info(msg=f"[{device}]: Start training.") # Start iterating for epoch in range(start_epoch, args.epochs): @@ -281,6 +294,7 @@ def train(rank=None, args=None): dist.barrier() logger.info(msg=f"[{device}]: Finish training.") + logger.info(msg="[Note]: If you want to evaluate model quality, use 'FID_calculator.py' to evaluate.") # Clean up the distributed environment if distributed: