-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
73 lines (55 loc) · 2.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from torch.utils.data import DataLoader
import wandb
from util import load_data, run
from config import RunConfig, OptimizeConfig
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--run_config", type=str, required=True, help="Path to the YAML config file"
)
parser.add_argument(
"--optimize_config", type=str, help="Path to the optimization YAML config file"
)
args = parser.parse_args()
wandb.require("core") # pyright: ignore
# Run Config
base_config = RunConfig.from_yaml(args.run_config)
# Load data
ds_train, ds_val = load_data() # pyright: ignore
dl_train = DataLoader(ds_train, batch_size=base_config.batch_size, shuffle=True)
dl_val = DataLoader(ds_val, batch_size=base_config.batch_size)
# Run
if args.optimize_config:
def objective(trial, base_config, optimize_config, dl_train, dl_val):
params = optimize_config.suggest_params(trial)
config = base_config.gen_config()
config["project"] = f"{base_config.project}_Opt"
for category, category_params in params.items():
config[category].update(category_params)
run_config = RunConfig(**config)
group_name = run_config.gen_group_name()
group_name += f"[{trial.number}]"
trial.set_user_attr("group_name", group_name)
return run(run_config, dl_train, dl_val, group_name)
optimize_config = OptimizeConfig.from_yaml(args.optimize_config)
study = optimize_config.create_study(project=f"{base_config.project}_Opt")
study.optimize(
lambda trial: objective(
trial, base_config, optimize_config, dl_train, dl_val
),
n_trials=optimize_config.trials,
)
print("Best trial:")
trial = study.best_trial
print(f" Value: {trial.value}")
print(" Params: ")
for key, value in trial.params.items():
print(f" {key}: {value}")
print(
f" Path: runs/{base_config.project}_Opt/{trial.user_attrs['group_name']}"
)
else:
run(base_config, dl_train, dl_val)
if __name__ == "__main__":
main()