diff --git a/examples/train.py b/examples/train.py index 8abd68e..df38aec 100644 --- a/examples/train.py +++ b/examples/train.py @@ -100,7 +100,7 @@ def parse_args(): default="samples/", ) parser.add_argument( - "--save-samples-every-n-epoch", + "--save-samples-every-n-epochs", type=int, default=10, ) @@ -115,7 +115,7 @@ def parse_args(): type=int, default=5, ) - parser.add_argument("--sample-ema", action="store_true") + parser.add_argument("--use-ema", action="store_true") parser.add_argument( "--sample-seed", type=int, @@ -127,6 +127,12 @@ def parse_args(): parser.add_argument("--resume-ckpt-path", type=str) parser.add_argument("--resume-wandb-id", type=str) + + parser.add_argument("--push-to-hub", type="store_true") + parser.add_argument("--model-id", type=int) + parser.add_argument("--token", type=str) + parser.add_argument("--push-every-n-steps", type=int) + args = parser.parse_args() return args @@ -175,12 +181,8 @@ def __getitem__(self, index: int) -> torch.Tensor: ) if args.unet_config_path: - import json - - with open(args.unet_config_path) as f: - config = json.load(f) - - UNet2DModel.from_config(config) + unet_config = UNet2DModel.load_config(args.unet_config_path) + unet = UNet2DModel.from_config(unet_config) else: # Simplified NCSN++ Architecture # See https://huggingface.co/google/ncsnpp-ffhq-1024/blob/main/config.json @@ -208,28 +210,42 @@ def __getitem__(self, index: int) -> torch.Tensor: "SkipUpBlock2D", ), ) + # Use both VGG and SqueezeNet as loss loss_fn = PerceptualLoss(net_type=("vgg", "squeeze")) + configs = { + "model": unet, + "loss_fn": loss_fn, + "learning_rate": args.learning_rate, + "data_std": args.data_std, + "time_min": args.time_min, + "time_max": args.time_max, + "bins_min": args.bins_min, + "bins_max": args.bins_max, + "bins_rho": args.bins_rho, + "initial_ema_decay": args.initial_ema_decay, + "samples_path": args.sample_path, + "save_samples_every_n_epochs": args.save_samples_every_n_epochs, + "num_samples": args.num_samples, + "sample_steps": args.sample_steps, + "use_ema": args.use_ema, + "sample_seed": args.sample_seed, + } + + if args.push_to_hub: + configs.update( + { + "model_id": args.model_id + or f"cm-{args.dataset_name}-{unet.sample_size}", + "token": args.token, + "push_every_n_steps": args.push_every_n_steps, + } + ) + if args.resume_ckpt_path: consistency = Consistency.load_from_checkpoint( - checkpoint_path=args.resume_ckpt_path, - model=unet, - loss_fn=loss_fn, - learning_rate=args.learning_rate, - data_std=args.data_std, - time_min=args.time_min, - time_max=args.time_max, - bins_min=args.bins_min, - bins_max=args.bins_max, - bins_rho=args.bins_rho, - initial_ema_decay=args.initial_ema_decay, - samples_path=args.sample_path, - save_samples_every_n_epoch=args.save_samples_every_n_epoch, - num_samples=args.num_samples, - sample_steps=args.sample_steps, - sample_ema=args.sample_ema, - sample_seed=args.sample_seed, + checkpoint_path=args.resume_ckpt_path, **configs ) trainer = Trainer( @@ -261,24 +277,7 @@ def __getitem__(self, index: int) -> torch.Tensor: trainer.fit(consistency, dataloader, ckpt_path=args.ckpt_path) else: - consistency = Consistency( - model=unet, - loss_fn=loss_fn, - learning_rate=args.learning_rate, - data_std=args.data_std, - time_min=args.time_min, - time_max=args.time_max, - bins_min=args.bins_min, - bins_max=args.bins_max, - bins_rho=args.bins_rho, - initial_ema_decay=args.initial_ema_decay, - samples_path=args.sample_path, - save_samples_every_n_epoch=args.save_samples_every_n_epoch, - num_samples=args.num_samples, - sample_steps=args.sample_steps, - sample_ema=args.sample_ema, - sample_seed=args.sample_seed, - ) + consistency = Consistency(**configs) trainer = Trainer( accelerator="auto",