-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
34 lines (30 loc) · 1.04 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from pathlib import Path
def get_config():
return {
"batch_size": 64,
"stimuli_list":[i for i in range(16)],
"neuron_ranges":[(0,140)],
"num_neurons":140,
"lr":1e-2,
"num_epochs":50,
"seed":1233,
"pre_train":False,
"device":"cuda:1",
"model_folder":"weights",
"model_basename":'train_16_',
"loss_folder":'loss'
}
def get_weights_file_path(config, epoch: str):
model_filename = f"{config['model_basename']}{epoch}.pt"
return str(Path('.') / config["model_folder"] / model_filename)
def latest_weights_file_path(config, from_epoch:str):
if from_epoch == 'latest':
model_filename = f"{config['model_basename']}*"
weights_files = list(Path(config['model_folder']).glob(model_filename))
if len(weights_files) == 0:
return None
weights_files.sort()
return str(weights_files[-1])
else:
model_filename = f"{config['model_basename']}{from_epoch}.pt"
return str(Path('.') / config["model_folder"] / model_filename)