Skip to content

Commit

Permalink
revise PSTH plot
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Jun 7, 2022
1 parent 5f1ef69 commit 08cc600
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 27 deletions.
47 changes: 32 additions & 15 deletions notebooks/07-downstream-analysis.ipynb

Large diffs are not rendered by default.

25 changes: 18 additions & 7 deletions notebooks/py_scripts/07-downstream-analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
# format_version: '1.5'
# jupytext_version: 1.13.7
# kernelspec:
# display_name: venv-nwb
# display_name: Python 3.8.11 ('ele')
# language: python
# name: venv-nwb
# name: python3
# ---

# + [markdown] tags=[]
Expand Down Expand Up @@ -110,9 +110,10 @@
alignment_key = (event.AlignmentEvent & 'alignment_name = "center_button"'
).fetch1('KEY')
alignment_condition = {**clustering_key, **alignment_key,
'trial_condition': 'ctrl_center_button'}
'trial_condition': 'ctrl_center_button',
'bin_size':.2}
analysis.SpikesAlignmentCondition.insert1(alignment_condition, skip_duplicates=True)

alignment_condition.pop('bin_size')
analysis.SpikesAlignmentCondition.Trial.insert(
(analysis.SpikesAlignmentCondition * ctrl_trials & alignment_condition).proj(),
skip_duplicates=True)
Expand All @@ -125,8 +126,11 @@
# Now, let's create another set for the stimulus condition.
# + a set of trials of interest to perform the analysis on - `stim` trials
stim_trials = trial.Trial & clustering_key & 'trial_type = "stim"'
alignment_condition = {**clustering_key, **alignment_key, 'trial_condition': 'stim_center_button'}
alignment_condition = {**clustering_key, **alignment_key,
'trial_condition': 'stim_center_button',
'bin_size':.2}
analysis.SpikesAlignmentCondition.insert1(alignment_condition, skip_duplicates=True)
alignment_condition.pop('bin_size')
analysis.SpikesAlignmentCondition.Trial.insert(
(analysis.SpikesAlignmentCondition * stim_trials & alignment_condition).proj(),
skip_duplicates=True)
Expand All @@ -151,7 +155,15 @@
# + a set of trials of interest to perform the analysis on - `stim` trials [markdown]
# ## Visualize
#
# We can visualize the results with the `plot` function.
# We can visualize the results with the `plot` function with our keys.
# -

clustering_key = (ephys.CuratedClustering
& {'subject': 'subject6', 'session_datetime': '2021-01-15 11:16:38',
'insertion_number': 0}
).fetch1('KEY')
alignment_key = (event.AlignmentEvent & 'alignment_name = "center_button"'
).fetch1('KEY')

# + a set of trials of interest to perform the analysis on - `stim` trials
alignment_condition = {**clustering_key, **alignment_key, 'trial_condition': 'ctrl_center_button'}
Expand All @@ -162,4 +174,3 @@
analysis.SpikesAlignment().plot(alignment_condition, unit=2);
# -


4 changes: 3 additions & 1 deletion workflow_array_ephys/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from .pipeline import db_prefix, ephys, trial

__all__ = ["db_prefix", "ephys", "trial", "event"]

schema = dj.schema(db_prefix + 'analysis')

Expand Down Expand Up @@ -101,6 +102,7 @@ def plot(self, key, unit, axs=None):
if axs is None:
fig, axs = plt.subplots(2, 1, figsize=(12, 8))

bin_size = (SpikesAlignmentCondition & key).fetch1("bin_size")
trial_ids, aligned_spikes = (self.AlignedTrialSpikes
& key & {'unit': unit}
).fetch('trial_id', 'aligned_spike_times')
Expand All @@ -111,7 +113,7 @@ def plot(self, key, unit, axs=None):

plot_psth._plot_spike_raster(aligned_spikes, trial_ids=trial_ids, ax=axs[0],
title=f'{dict(**key, unit=unit)}', xlim=xlim)
plot_psth._plot_psth(psth, psth_edges, ax=axs[1],
plot_psth._plot_psth(psth, psth_edges, bin_size, ax=axs[1],
title='', xlim=xlim)

return fig
8 changes: 4 additions & 4 deletions workflow_array_ephys/plotting/plot_psth.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@ def _plot_spike_raster(aligned_spikes, trial_ids=None, vlines=[0], ax=None, titl

assert len(raster) == len(trial_ids)

ax.plot(raster, trial_ids, 'r.', markersize=1)
ax.plot(raster, trial_ids, 'ro', markersize=4)

for x in vlines:
ax.axvline(x=x, linestyle='--', color='k')

ax.set_ylabel('Trial (#)')
if xlim:
ax.set_xlim(xlim)
ax.set_axis_off()
# ax.set_axis_off()
ax.set_title(title)


def _plot_psth(psth, psth_edges, vlines=[0], ax=None, title='', xlim=None):
def _plot_psth(psth, psth_edges, bin_size, vlines=[0], ax=None, title='', xlim=None):
if not ax:
fig, ax = plt.subplots(1, 1)

ax.plot(psth_edges, psth, 'r')
ax.bar(psth_edges, psth, width=bin_size, edgecolor="black", align="edge")

for x in vlines:
ax.axvline(x=x, linestyle='--', color='k')
Expand Down

0 comments on commit 08cc600

Please sign in to comment.