-
Notifications
You must be signed in to change notification settings - Fork 8
/
generate_synthetic_data.py
80 lines (63 loc) · 3.24 KB
/
generate_synthetic_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python
from synthetic_data import SyntheticCalciumDataGenerator
from utils import write_data
import argparse
import yaml
import os
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--system', default='lorenz', type=str)
parser.add_argument('-o', '--output', default='./', type=str)
parser.add_argument('-s', '--seed', default=100, type=int)
parser.add_argument('-p', '--parameters', type=str)
parser.add_argument('--trials', default=10, type=int)
parser.add_argument('--inits', default=100, type=int)
parser.add_argument('--cells', default=50, type=int)
parser.add_argument('--steps', default=100, type=int)
parser.add_argument('--rate_scale', default=5.0, type=float)
parser.add_argument('--trainp', default=0.8, type=float)
parser.add_argument('--dt_spike', default=0.01, type=float)
parser.add_argument('--dt_sys', default=0.01, type=float)
parser.add_argument('--burn_steps', default=0, type=int)
def main():
args = parser.parse_args()
if os.path.exists('%s/%s_%03d'%(args.output, args.system, args.seed)):
pass
else:
if args.parameters:
params_dict = yaml.load(open(args.parameters), Loader=yaml.FullLoader)
for key, val in params_dict.items():
args.__setattr__(key, val)
print('%s : %s'%(key, str(args.__getattribute__(key))), flush=True)
if args.system == 'lorenz':
from synthetic_data import LorenzSystem, EmbeddedLowDNetwork
lorenz = LorenzSystem(num_inits= args.inits,
dt= args.dt_sys)
net = EmbeddedLowDNetwork(low_d_system = lorenz,
net_size = args.cells,
base_rate = args.rate_scale,
dt = args.dt_sys)
elif args.system == 'chaotic-rnn':
from synthetic_data import ChaoticNetwork, RandomPerturbation
inputs = RandomPerturbation(t_span=[0.25, 0.75], scale=10)
net = ChaoticNetwork(num_inits= args.inits,
base_rate= args.rate_scale,
net_size = args.cells,
weight_scale = 2.5,
dt=args.dt_sys,
inputs= inputs)
# generate data
generator = SyntheticCalciumDataGenerator(system = net,
seed = args.seed,
trainp = args.trainp,
burn_steps = args.burn_steps,
num_steps = args.steps,
num_trials = args.trials,
tau_cal = 0.3,
dt_cal = args.dt_spike,
sigma = 0.2)
data_dict = generator.generate_dataset()
# save
print('Saving to %s/%s_%03d'%(args.output, args.system, args.seed), flush=True)
write_data('%s/%s_%03d'%(args.output, args.system, args.seed), data_dict)
if __name__ == '__main__':
main()