-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
101 lines (91 loc) · 4.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse
import sys
import os, re
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(project_root)
import logging
import torch
import sys
import wandb
from torchvision.utils import make_grid
import torch.optim as optim
import gc,logging,os
from matplotlib import pyplot as plt
from utils.myparser import getYamlConfig
from utils.dataset import getDataset
from models.diffusion.forward import ForwardSampler
from models.unet import MacroprosDenoiser
from models.diffusion.ddpm import DDPM
from models.training import train_one_epoch
from torchsummary import summary
from functools import partial
def train(cfg, filenames, show_losses_plot=False):
wandb.init(
project="macroprops-predict",
config={
"architecture": "DDPM",
"dataset": cfg.DATASET.NAME,
"learning_rate": cfg.TRAIN.SOLVER.LR,
"epochs": cfg.TRAIN.EPOCHS,
"batch_size": cfg.DATASET.BATCH_SIZE,
"observation_len": cfg.DATASET.OBS_LEN,
"prediction_len": cfg.DATASET.PRED_LEN,
"weight_decay": cfg.TRAIN.SOLVER.WEIGHT_DECAY,
"solver_betas": cfg.TRAIN.SOLVER.BETAS,
}
)
torch.manual_seed(42)
# Setting the device to work with
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Get batched datasets ready to iterate
batched_train_data, _, _ = getDataset(cfg, filenames, train_data_only=True)
# Instanciate the UNet for the reverse diffusion
denoiser = MacroprosDenoiser(num_res_blocks = cfg.MODEL.NUM_RES_BLOCKS,
base_channels = cfg.MODEL.BASE_CH,
base_channels_multiples = cfg.MODEL.BASE_CH_MULT,
apply_attention = cfg.MODEL.APPLY_ATTENTION,
dropout_rate = cfg.MODEL.DROPOUT_RATE,
time_multiple = cfg.MODEL.TIME_EMB_MULT)
denoiser.to(device)
#specific_timesteps = [250]
#t = torch.as_tensor(specific_timesteps, dtype=torch.long)
#t = torch.randint(low=0, high=1000, size=(64,), device=device)
#summary(denoiser, [(64, 4, 12, 36, 5), t] )
# The optimizer (Adam with weight decay)
optimizer = optim.Adam(denoiser.parameters(),lr=cfg.TRAIN.SOLVER.LR, betas=cfg.TRAIN.SOLVER.BETAS,weight_decay=cfg.TRAIN.SOLVER.WEIGHT_DECAY)
# Instantiate the diffusion model
diffusionmodel = DDPM(timesteps=cfg.DIFFUSION.TIMESTEPS, scale=cfg.DIFFUSION.SCALE)
diffusionmodel.to(device)
# Training loop
best_loss = 1e6
for epoch in range(1,cfg.TRAIN.EPOCHS + 1):
torch.cuda.empty_cache()
gc.collect()
# One epoch of training
epoch_loss = train_one_epoch(denoiser,diffusionmodel,batched_train_data,optimizer,device,epoch=epoch,total_epochs=cfg.TRAIN.EPOCHS)
wandb.log({"loss_2D": epoch_loss})
if epoch_loss < best_loss:
best_loss = epoch_loss
# Save best checkpoints -> AR, shouldn't we save diffusionmodel too?? I think it also has weigths, isn't?
checkpoint_dict = {
"opt": optimizer.state_dict(),
"model": denoiser.state_dict()
}
if not os.path.exists(cfg.MODEL.SAVE_DIR):
# Create a new directory if it does not exist
os.makedirs(cfg.MODEL.SAVE_DIR)
lr_str = "{:.0e}".format(cfg.TRAIN.SOLVER.LR)
scale_str = "{:.0e}".format(cfg.DIFFUSION.SCALE)
save_path = cfg.MODEL.SAVE_DIR+(cfg.MODEL.MODEL_NAME.format(cfg.TRAIN.EPOCHS, lr_str, scale_str))
torch.save(checkpoint_dict, save_path)
del checkpoint_dict
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="A script to train a diffusion model for crowd macroproperties.")
parser.add_argument('--config-yml-file', type=str, default='config/ATC_ddpm_4test.yml', help='Configuration YML file for specific dataset.')
parser.add_argument('--configList-yml-file', type=str, default='config/ATC_ddpm_DSlist4test.yml',help='Configuration YML macroprops list for specific dataset.')
args = parser.parse_args()
cfg = getYamlConfig(args.config_yml_file, args.configList_yml_file)
filenames = cfg.SUNDAY_DATA_LIST
filenames = [filename.replace(".csv", ".pkl") for filename in filenames]
filenames = [ os.path.join(cfg.PICKLE.PICKLE_DIR, filename) for filename in filenames if filename.endswith('.pkl')]
train(cfg, filenames)