forked from jagorn/NMA2020_group_project
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathanalysis_test.py
63 lines (47 loc) · 1.63 KB
/
analysis_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from smz_load import *
from smz_plot import *
from get_clean_trials import *
from matplotlib import pyplot as plt
from hmm_map_states import *
import ssm
import matplotlib.cm as cm
# np.random.seed(0)
# Experimental Parameters
recording_name = 'Cori_2016-12-14'
brain_region = 'MOs'
neuron_min_score = 2
# Model Parameters
bin_dt = 0.01 # seconds
pre_stim_dt = 0.5 # seconds
post_resp_dt = 0.5 # seconds
N_states = 3
# Choose trials
trials = extract_clean_trials(recording_name)
conditioned_trials = np.where(trials['choice'] == 1)[0]
# Run the fit
n_trials = 1
for i in range(n_trials, n_trials+1):
# Load time pointers for the given trial
trial = conditioned_trials[i]
visual_time = trials['visStim_times'][trial]
cue_time = trials['cue_times'][trial]
feedback_time = trials['feedback_times'][trial]
# generate the spike count histograms
t0 = visual_time - pre_stim_dt
tf = feedback_time + post_resp_dt
[dataset, time_bins] = generate_spike_counts(recording_name, brain_region, neuron_min_score, bin_dt, t0, tf)
(n_neurons, n_bins) = dataset.shape
# Create a hmm model
train_data = dataset.astype(int).T
model = ssm.HMM(N_states, n_neurons, observations="poisson")
hmm_lls = model.fit(train_data, method="em", num_iters=1000)
posterior = model.filter(train_data)
# states = model.most_likely_states(train_data)
plt.figure(n_trials, figsize=[9,5])
for s in range(N_states):
plt.plot(posterior[:, s], label="State %d" % s)
plt.suptitle('Posterior probability of latent states')
plt.xlabel(f'time bin ({int(bin_dt*1000)} ms)')
plt.ylabel('probability')
plt.legend()
plt.show()