Skip to content

Commit

Permalink
docs: update training script
Browse files Browse the repository at this point in the history
  • Loading branch information
junhsss committed Mar 28, 2023
1 parent 7346087 commit f3461d2
Showing 1 changed file with 42 additions and 43 deletions.
85 changes: 42 additions & 43 deletions examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def parse_args():
default="samples/",
)
parser.add_argument(
"--save-samples-every-n-epoch",
"--save-samples-every-n-epochs",
type=int,
default=10,
)
Expand All @@ -115,7 +115,7 @@ def parse_args():
type=int,
default=5,
)
parser.add_argument("--sample-ema", action="store_true")
parser.add_argument("--use-ema", action="store_true")
parser.add_argument(
"--sample-seed",
type=int,
Expand All @@ -127,6 +127,12 @@ def parse_args():

parser.add_argument("--resume-ckpt-path", type=str)
parser.add_argument("--resume-wandb-id", type=str)

parser.add_argument("--push-to-hub", type="store_true")
parser.add_argument("--model-id", type=int)
parser.add_argument("--token", type=str)
parser.add_argument("--push-every-n-steps", type=int)

args = parser.parse_args()
return args

Expand Down Expand Up @@ -175,12 +181,8 @@ def __getitem__(self, index: int) -> torch.Tensor:
)

if args.unet_config_path:
import json

with open(args.unet_config_path) as f:
config = json.load(f)

UNet2DModel.from_config(config)
unet_config = UNet2DModel.load_config(args.unet_config_path)
unet = UNet2DModel.from_config(unet_config)
else:
# Simplified NCSN++ Architecture
# See https://huggingface.co/google/ncsnpp-ffhq-1024/blob/main/config.json
Expand Down Expand Up @@ -208,28 +210,42 @@ def __getitem__(self, index: int) -> torch.Tensor:
"SkipUpBlock2D",
),
)

# Use both VGG and SqueezeNet as loss
loss_fn = PerceptualLoss(net_type=("vgg", "squeeze"))

configs = {
"model": unet,
"loss_fn": loss_fn,
"learning_rate": args.learning_rate,
"data_std": args.data_std,
"time_min": args.time_min,
"time_max": args.time_max,
"bins_min": args.bins_min,
"bins_max": args.bins_max,
"bins_rho": args.bins_rho,
"initial_ema_decay": args.initial_ema_decay,
"samples_path": args.sample_path,
"save_samples_every_n_epochs": args.save_samples_every_n_epochs,
"num_samples": args.num_samples,
"sample_steps": args.sample_steps,
"use_ema": args.use_ema,
"sample_seed": args.sample_seed,
}

if args.push_to_hub:
configs.update(
{
"model_id": args.model_id
or f"cm-{args.dataset_name}-{unet.sample_size}",
"token": args.token,
"push_every_n_steps": args.push_every_n_steps,
}
)

if args.resume_ckpt_path:
consistency = Consistency.load_from_checkpoint(
checkpoint_path=args.resume_ckpt_path,
model=unet,
loss_fn=loss_fn,
learning_rate=args.learning_rate,
data_std=args.data_std,
time_min=args.time_min,
time_max=args.time_max,
bins_min=args.bins_min,
bins_max=args.bins_max,
bins_rho=args.bins_rho,
initial_ema_decay=args.initial_ema_decay,
samples_path=args.sample_path,
save_samples_every_n_epoch=args.save_samples_every_n_epoch,
num_samples=args.num_samples,
sample_steps=args.sample_steps,
sample_ema=args.sample_ema,
sample_seed=args.sample_seed,
checkpoint_path=args.resume_ckpt_path, **configs
)

trainer = Trainer(
Expand Down Expand Up @@ -261,24 +277,7 @@ def __getitem__(self, index: int) -> torch.Tensor:
trainer.fit(consistency, dataloader, ckpt_path=args.ckpt_path)

else:
consistency = Consistency(
model=unet,
loss_fn=loss_fn,
learning_rate=args.learning_rate,
data_std=args.data_std,
time_min=args.time_min,
time_max=args.time_max,
bins_min=args.bins_min,
bins_max=args.bins_max,
bins_rho=args.bins_rho,
initial_ema_decay=args.initial_ema_decay,
samples_path=args.sample_path,
save_samples_every_n_epoch=args.save_samples_every_n_epoch,
num_samples=args.num_samples,
sample_steps=args.sample_steps,
sample_ema=args.sample_ema,
sample_seed=args.sample_seed,
)
consistency = Consistency(**configs)

trainer = Trainer(
accelerator="auto",
Expand Down

0 comments on commit f3461d2

Please sign in to comment.