Skip to content

Commit

Permalink
CCN comparison in RNN example
Browse files Browse the repository at this point in the history
  • Loading branch information
agosztolai committed Oct 24, 2023
1 parent 357e6ca commit dfe483f
Show file tree
Hide file tree
Showing 10 changed files with 390 additions and 290 deletions.
2 changes: 1 addition & 1 deletion MARBLE/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _compute_geometric_objects(
n_geodesic_nb=2.0,
var_explained=0.9,
return_spectrum=True,
local_gauges=True,
local_gauges=False,
compute_laplacian=False,
compute_connection_laplacian=False,
dim_man=None,
Expand Down
545 changes: 322 additions & 223 deletions examples/RNN/RNN.ipynb

Large diffs are not rendered by default.

20 changes: 11 additions & 9 deletions examples/RNN/RNN_scripts/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,13 @@ def load_network(f):
return z, net


def sample_network(net, f):
def sample_network(net, f, seed=0):

if os.path.exists(f):
print('Network found with same name. Loading...')
return torch.load(open(f, "rb"))

n_pops = 2
seed = 0
z, _ = clustering.gmm_fit(net, n_pops, algo="bayes", random_state=seed)
net_sampled = clustering.to_support_net(net, z)

Expand All @@ -68,7 +67,7 @@ def sample_network(net, f):


def generate_trajectories(
net, input=None, epochs=None, n_traj=None, fname="./outputs/RNN_trajectories.pkl"
net, input=None, epochs=None, n_traj=None, fname="./data/RNN_trajectories.pkl"
):

if fname is not None:
Expand Down Expand Up @@ -455,7 +454,7 @@ def plot_experiment(net, input, traj, epochs, rect=(-8, 8, -6, 6), traj_to_show=
fig.subplots_adjust(hspace=0.1, wspace=0.1)


def aggregate_data(traj, epochs, transient=10, only_stim=False):
def aggregate_data(traj, epochs, transient=10, only_stim=False, pca=True):

n_conds = len(traj)
n_epochs = len(epochs) - 1
Expand All @@ -468,9 +467,10 @@ def aggregate_data(traj, epochs, transient=10, only_stim=False):
for j in range(n_traj): # trajectories
pos.append(traj[i][j][k][transient:])

pca = PCA(n_components=3)
pca.fit(np.vstack(pos))
print("Explained variance: ", pca.explained_variance_ratio_)
if pca:
pca = PCA(n_components=3)
pca.fit(np.vstack(pos))
print("Explained variance: ", pca.explained_variance_ratio_)

# aggregate data under baseline condition (no input)
pos, vel = [], []
Expand All @@ -480,7 +480,8 @@ def aggregate_data(traj, epochs, transient=10, only_stim=False):
for k in [0, 2, 4]:
for j in range(n_traj): # trajectories
pos_proj = traj[i][j][k][transient:]
pos_proj = pca.transform(pos_proj)
if pca:
pos_proj = pca.transform(pos_proj)
pos_.append(pos_proj[:-1]) # stack trajectories
vel_.append(np.diff(pos_proj, axis=0)) # compute differences

Expand All @@ -494,7 +495,8 @@ def aggregate_data(traj, epochs, transient=10, only_stim=False):
for k in [1, 3]:
for j in range(n_traj): # trajectories
pos_proj = traj[i][j][k][transient:]
pos_proj = pca.transform(pos_proj)
if pca:
pos_proj = pca.transform(pos_proj)
pos_.append(pos_proj[:-1])
vel_.append(np.diff(pos_proj, axis=0))

Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
113 changes: 56 additions & 57 deletions examples/vanderpol.ipynb

Large diffs are not rendered by default.

0 comments on commit dfe483f

Please sign in to comment.