-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_amend6.py
70 lines (61 loc) · 2.46 KB
/
main_amend6.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import torch
import datetime
import shutil
from pathlib import Path
import argparse
from types import SimpleNamespace
from train_amend import main as train_main
from utils.logger import Logger, log_info
from conf.config import load_config
from utils.utils import set_seed
def setup_environment(config, timestamp, exp_dir):
os.makedirs(exp_dir / 'results', exist_ok=True)
os.makedirs(exp_dir / 'models', exist_ok=True)
os.makedirs(exp_dir / 'logs', exist_ok=True)
os.makedirs(exp_dir / 'Files', exist_ok=True)
files_save = exp_dir / 'Files' / (timestamp + '/')
os.makedirs(files_save, exist_ok=True)
shutil.copy(__file__, files_save)
shutil.copy('utils/utils.py', files_save)
shutil.copy('train.py', files_save)
shutil.copy('main.py', files_save)
# shutil.copy('data.py', files_save)
return files_save
if __name__ == "__main__":
torch.cuda.set_device(2)
parser = argparse.ArgumentParser(description='simple distributed training job')
parser.add_argument('--seed', type=int, default=42, help='random seed')
args_global = parser.parse_args()
world_size = torch.cuda.device_count()
# Set the random seed
set_seed(args_global.seed)
config = load_config()
for i in range(6, 7):
config['data']['traj_length'] = i
config['training']['n_epochs'] = 200
config['training']['batch_size'] = 2560
config["diffusion"]["num_diffusion_timesteps"] = 500
config['model']['in_channels'] = 3
config['model']['out_ch'] = 3
# if i < 4:
# config['training']['batch_size'] = 5120
# else:
# config['training']['batch_size'] = 10240
temp = {k: SimpleNamespace(**v) for k, v in config.items()}
config = SimpleNamespace(**temp)
root_dir = Path(__file__).resolve().parent
result_name = '{}_steps={}_len={}_{}_bs={}'.format(
config.data.dataset, config.diffusion.num_diffusion_timesteps,
config.data.traj_length, config.diffusion.beta_end,
config.training.batch_size)
exp_dir = root_dir / "Backups" / result_name
timestamp = datetime.datetime.now().strftime("%m-%d-%H-%M-%S")
files_save = setup_environment(config, timestamp, exp_dir)
logger = Logger(
__name__,
log_path=exp_dir / "logs" / (timestamp + '.log'),
colorize=True,
)
log_info(config, logger)
train_main(config, logger, exp_dir)