Skip to content

Commit

Permalink
Merge pull request #440 from int-brain-lab/develop
Browse files Browse the repository at this point in the history
Release 2.9.1
  • Loading branch information
oliche authored Jan 26, 2022
2 parents ce861a2 + 41f1df3 commit 11cd0ed
Show file tree
Hide file tree
Showing 15 changed files with 520 additions and 120 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
include ibllib/atlas/allen_structure_tree.csv
include ibllib/atlas/beryl.npy
include ibllib/atlas/cosmos.npy
include ibllib/atlas/mappings.pqt
include ibllib/io/extractors/extractor_types.json
include brainbox/tests/wheel_test.p
recursive-include brainbox/tests/fixtures *
Expand Down
26 changes: 13 additions & 13 deletions brainbox/behavior/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def compute_performance(trials, signed_contrast=None, block=None):
block_idx = trials.probabilityLeft == block

if not np.any(block_idx):
return np.nan * np.zeros(2)
return np.nan * np.zeros(3)

contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
rightward = trials.choice == -1
Expand Down Expand Up @@ -584,15 +584,15 @@ def plot_psychometric(trials, ax=None, title=None, **kwargs):
signed_contrast = get_signed_contrast(trials)
contrasts_fit = np.arange(-100, 100)

prob_right_50, contrasts, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5)
prob_right_50, contrasts_50, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5)
pars_50 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.5)
prob_right_fit_50 = psy.erf_psycho_2gammas(pars_50, contrasts_fit)

prob_right_20, contrasts, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2)
prob_right_20, contrasts_20, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2)
pars_20 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.2)
prob_right_fit_20 = psy.erf_psycho_2gammas(pars_20, contrasts_fit)

prob_right_80, contrasts, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8)
prob_right_80, contrasts_80, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8)
pars_80 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.8)
prob_right_fit_80 = psy.erf_psycho_2gammas(pars_80, contrasts_fit)

Expand All @@ -606,11 +606,11 @@ def plot_psychometric(trials, ax=None, title=None, **kwargs):
# TODO error bars

fit_50 = ax.plot(contrasts_fit, prob_right_fit_50, color=cmap[1])
data_50 = ax.scatter(contrasts, prob_right_50, color=cmap[1])
data_50 = ax.scatter(contrasts_50, prob_right_50, color=cmap[1])
fit_20 = ax.plot(contrasts_fit, prob_right_fit_20, color=cmap[0])
data_20 = ax.scatter(contrasts, prob_right_20, color=cmap[0])
data_20 = ax.scatter(contrasts_20, prob_right_20, color=cmap[0])
fit_80 = ax.plot(contrasts_fit, prob_right_fit_80, color=cmap[2])
data_80 = ax.scatter(contrasts, prob_right_80, color=cmap[2])
data_80 = ax.scatter(contrasts_80, prob_right_80, color=cmap[2])
ax.legend([fit_50[0], data_50, fit_20[0], data_20, fit_80[0], data_80],
['p_left=0.5 fit', 'p_left=0.5 data', 'p_left=0.2 fit', 'p_left=0.2 data', 'p_left=0.8 fit', 'p_left=0.8 data'],
loc='upper left')
Expand All @@ -631,9 +631,9 @@ def plot_reaction_time(trials, ax=None, title=None, **kwargs):
"""

signed_contrast = get_signed_contrast(trials)
reaction_50, contrasts, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.5)
reaction_20, contrasts, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.2)
reaction_80, contrasts, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.8)
reaction_50, contrasts_50, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.5)
reaction_20, contrasts_20, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.2)
reaction_80, contrasts_80, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.8)

cmap = sns.diverging_palette(20, 220, n=3, center="dark")

Expand All @@ -642,9 +642,9 @@ def plot_reaction_time(trials, ax=None, title=None, **kwargs):
else:
fig = plt.gcf()

data_50 = ax.plot(contrasts, reaction_50, '-o', color=cmap[1])
data_20 = ax.plot(contrasts, reaction_20, '-o', color=cmap[0])
data_80 = ax.plot(contrasts, reaction_80, '-o', color=cmap[2])
data_50 = ax.plot(contrasts_50, reaction_50, '-o', color=cmap[1])
data_20 = ax.plot(contrasts_20, reaction_20, '-o', color=cmap[0])
data_80 = ax.plot(contrasts_80, reaction_80, '-o', color=cmap[2])

# TODO error bars

Expand Down
62 changes: 28 additions & 34 deletions brainbox/examples/docs_load_spike_sorting.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,37 @@
"""
Get spikes, clusters and channels data
========================================
Downloads and loads in spikes, clusters and channels data for a given session. Data is returned
Downloads and loads in spikes, clusters and channels data for a given probe insertion.
There could be several spike sorting collections, by default the loader will get the pykilosort collection
The channel locations can come from several sources, it will load the most advanced version of the histology available,
regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging):
- alf: the final version of channel locations, same as resolved with the difference that data has been written out to files
- resolved: channel locations alignments have been agreed upon
- aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate
- traced: the histology track has been recovered from microscopy, however the depths may not match, inacurate data
"""
import brainbox.io.one as bbone

from one.api import ONE
from ibllib.atlas import AllenAtlas
from brainbox.io.one import SpikeSortingLoader


one = ONE(base_url='https://openalyx.internationalbrainlab.org')
ba = AllenAtlas()

insertions = one.alyx.rest('insertions', 'list')
pid = insertions[0]['id']
sl = SpikeSortingLoader(pid, one=one, atlas=ba)
spikes, clusters, channels = sl.load_spike_sorting()
clusters_labeled = SpikeSortingLoader.merge_clusters(spikes, clusters, channels)

# the histology property holds the provenance of the current channel locations
print(sl.histology)

one = ONE(base_url='https://openalyx.internationalbrainlab.org', silent=True)

# Find eid of interest
eid = one.search(subject='CSH_ZAD_029', date='2020-09-19')[0]

##################################################################################################
# Example 1:
# Download spikes, clusters and channels data for all available probes for this session.
# The data for each probe is returned as a dict
spikes, clusters, channels = bbone.load_spike_sorting_with_channel(eid, one=one)
print(spikes.keys())
print(spikes['probe01'].keys())

##################################################################################################
# Example 2:
# Download spikes, clusters and channels data for a single probe
spikes, clusters, channels = bbone.load_spike_sorting_with_channel(eid, one=one, probe='probe01')
print(spikes.keys())

##################################################################################################
# Example 3:
# The default spikes and clusters datasets that are downloaded are '
# ['clusters.channels',
# 'clusters.depths',
# 'clusters.metrics',
# 'spikes.clusters',
# 'spikes.times']
# If we also want to load for example, 'clusters.peakToTrough we can add a dataset_types argument

spikes, clusters, channels = bbone.load_spike_sorting_with_channel(eid, one=one, probe='probe01',
dataset_types=['clusters.peakToTrough'])
print(clusters['probe01'].keys())
# available spike sorting collections for this probe insertion
print(sl.collections)

# the collection that has been loaded
print(sl.collection)
63 changes: 54 additions & 9 deletions brainbox/io/one.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ def channel_locations_interpolation(channels_aligned, channels=None, brain_regio

def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False,
brain_atlas=None, return_source=False):
if not hasattr(one, 'alyx'):
return {}, None
_logger.debug(f"trying to load from traj {probe}")
channels = Bunch()
brain_atlas = brain_atlas or AllenAtlas
Expand Down Expand Up @@ -416,6 +418,8 @@ def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike
:param return_collection: (False) if True, will return the collection used to load
:return: spikes, clusters, channels (dict of bunch, 1 bunch per probe)
"""
_logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting_fast will be removed in future versions.'
'Use brainbox.io.one.SpikeSortingLoader instead')
if collection is None:
collection = _collection_filter_from_args(probe, spike_sorter)
_logger.debug(f"load spike sorting with collection filter {collection}")
Expand Down Expand Up @@ -455,6 +459,8 @@ def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sort
:param return_collection:(bool - False) if True, returns the collection for loading the data
:return: spikes, clusters (dict of bunch, 1 bunch per probe)
"""
_logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.'
'Use brainbox.io.one.SpikeSortingLoader instead')
collection = _collection_filter_from_args(probe, spike_sorter)
_logger.debug(f"load spike sorting with collection filter {collection}")
spikes, clusters = _load_spike_sorting(eid=eid, one=one, collection=collection, revision=revision,
Expand Down Expand Up @@ -506,6 +512,8 @@ def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, da
'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized.
"""
# --- Get spikes and clusters data
_logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.'
'Use brainbox.io.one.SpikeSortingLoader instead')
one = one or ONE()
brain_atlas = brain_atlas or AllenAtlas()
spikes, clusters, collection = load_spike_sorting(
Expand Down Expand Up @@ -862,12 +870,17 @@ def load_channels_from_insertion(ins, depths=None, one=None, ba=None):

@dataclass
class SpikeSortingLoader:
"""Class for loading spike sorting"""
pid: str
"""
Object that will load spike sorting data for a given probe insertion.
"""
one: ONE
atlas: None
# the following properties are the outcome of the post init funciton
atlas: None = None
pid: str = None
eid: str = ''
pname: str = ''
# the following properties are the outcome of the post init funciton
session_path: Path = ''
collections: list = None
datasets: list = None # list of all datasets belonging to the sesion
Expand All @@ -878,7 +891,10 @@ class SpikeSortingLoader:
spike_sorting_path: Path = None

def __post_init__(self):
self.eid, self.pname = self.one.pid2eid(self.pid)
if self.pid is not None:
self.eid, self.pname = self.one.pid2eid(self.pid)
if self.atlas is None:
self.atlas = AllenAtlas()
self.session_path = self.one.eid2path(self.eid)
self.collections = self.one.list_collections(
self.eid, filename='spikes*', collection=f"alf/{self.pname}*")
Expand Down Expand Up @@ -909,32 +925,61 @@ def _get_spike_sorting_collection(self, spike_sorter='pykilosort', revision=None
return collection

def _download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_types=None):
"""
Downloads an ALF object
:param obj: object name, str between 'spikes', 'clusters' or 'channels'
:param spike_sorter: (defaults to 'pykilosort')
:param dataset_types: list of extra dataset types
:return:
"""
if len(self.collections) == 0:
return {}, {}, {}
self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter)
_logger.debug(f"loading spike sorting from {self.collection}")
spike_attributes, cluster_attributes = self._get_attributes(dataset_types)
attributes = {'spikes': spike_attributes, 'clusters': cluster_attributes, 'channels': None}
self.files[obj] = self.one.load_object(self.eid, obj=obj, attribute=attributes[obj],
collection=self.collection, download_only=True)

def download_spike_sorting(self, **kwargs):
"""spike_sorter='pykilosort', dataset_types=None"""
"""
Downloads spikes, clusters and channels
:param spike_sorter: (defaults to 'pykilosort')
:param dataset_types: list of extra dataset types
:return:
"""
for obj in ['spikes', 'clusters', 'channels']:
self._download_spike_sorting_object(obj=obj, **kwargs)
self.spike_sorting_path = self.files['spikes'][0].parent

def load_spike_sorting(self, **kwargs):
"""spike_sorter='pykilosort', dataset_types=None"""
"""
Loads spikes, clusters and channels
There could be several spike sorting collections, by default the loader will get the pykilosort collection
The channel locations can come from several sources, it will load the most advanced version of the histology available,
regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging):
- alf: the final version of channel locations, same as resolved with the difference that data is on file
- resolved: channel locations alignments have been agreed upon
- aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate
- traced: the histology track has been recovered from microscopy, however the depths may not match, inacurate data
:param spike_sorter: (defaults to 'pykilosort')
:param dataset_types: list of extra dataset types
:return:
"""
if len(self.collections) == 0:
return {}, {}, {}
self.download_spike_sorting(**kwargs)
channels = alfio.load_object(self.files['channels'], wildcards=self.one.wildcards)
clusters = alfio.load_object(self.files['clusters'], wildcards=self.one.wildcards)
spikes = alfio.load_object(self.files['spikes'], wildcards=self.one.wildcards)
if 'brainLocationIds_ccf_2017' not in channels:
channels, self.histology = _load_channel_locations_traj(
_channels, self.histology = _load_channel_locations_traj(
self.eid, probe=self.pname, one=self.one, brain_atlas=self.atlas, return_source=True)
channels = channels[self.pname]
if _channels:
channels = _channels[self.pname]
else:
channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions)
self.histology = 'alf'
Expand Down
2 changes: 1 addition & 1 deletion ibllib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.9.0"
__version__ = "2.9.1"
import warnings

from ibllib.misc import logger_config
Expand Down
Loading

0 comments on commit 11cd0ed

Please sign in to comment.