-
Notifications
You must be signed in to change notification settings - Fork 17
/
CURBD_example.py
63 lines (53 loc) · 2.13 KB
/
CURBD_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#!/usr/bin/env python
"""
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% example script to generate simulated interacting brain regions and
% perform Current-Based Decomposition (CURBD). Ref:
%
% Perich MG et al. Inferring brain-wide interactions using data-constrained
% recurrent neural network models. bioRxiv. DOI: https://doi.org/10.1101/2020.12.18.423348
%
% Written by Matthew G. Perich and Eugene Carter. Updated December 2020.
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
"""
import numpy as np
import pylab
import curbd
sim = curbd.threeRegionSim(number_units=100)
activity = np.concatenate((sim['Ra'], sim['Rb'], sim['Rc']), 0)
Na = sim['params']['Na']
Nb = sim['params']['Nb']
Nc = sim['params']['Nc']
regions = []
regions.append(['Region A', np.arange(0, Na)])
regions.append(['Region B', np.arange(Na, Na + Nb)])
regions.append(['Region C', np.arange(Na + Nb, Na + Nb + Nc)])
regions = np.array(regions, dtype=object)
model = curbd.trainMultiRegionRNN(activity,
dtData=sim['params']['dtData'],
dtFactor=5,
regions=regions,
tauRNN=2*sim['params']['tau']/2,
nRunTrain=500,
verbose=True,
nRunFree=5)
[curbd_arr, curbd_labels] = curbd.computeCURBD(model)
n_regions = curbd_arr.shape[0]
n_region_units = curbd_arr[0, 0].shape[0]
fig = pylab.figure(figsize=[8, 8])
count = 1
for iTarget in range(n_regions):
for iSource in range(n_regions):
axn = fig.add_subplot(n_regions, n_regions, count)
count += 1
axn.pcolormesh(model['tRNN'], range(n_region_units),
curbd_arr[iTarget, iSource])
axn.set_xlabel('Time (s)')
axn.set_ylabel('Neurons in {}'.format(regions[iTarget, 0]))
axn.set_title(curbd_labels[iTarget, iSource])
axn.title.set_fontsize(8)
axn.xaxis.label.set_fontsize(8)
axn.yaxis.label.set_fontsize(8)
fig.subplots_adjust(hspace=0.4, wspace=0.3)
fig.show()