Skip to content

Commit

Permalink
Merge pull request #522 from int-brain-lab/release/2.17.0
Browse files Browse the repository at this point in the history
Release/2.17.0
  • Loading branch information
oliche authored Oct 4, 2022
2 parents 9b3e1ba + da84714 commit 648a0bc
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 70 deletions.
35 changes: 26 additions & 9 deletions brainbox/io/one.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


SPIKES_ATTRIBUTES = ['clusters', 'times', 'amps', 'depths']
CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics']
CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics', 'uuids']


def load_lfp(eid, one=None, dataset_types=None, **kwargs):
Expand Down Expand Up @@ -952,13 +952,14 @@ def _get_spike_sorting_collection(self, spike_sorter='pykilosort'):
_logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}")
return collection

def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_types=None, collection=None):
def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_types=None, collection=None, **kwargs):
"""
Downloads an ALF object
:param obj: object name, str between 'spikes', 'clusters' or 'channels'
:param spike_sorter: (defaults to 'pykilosort')
:param dataset_types: list of extra dataset types, for example ['spikes.samples']
:param collection: string specifiying the collection, for example 'alf/probe01/pykilosort'
:param kwargs: additional arguments to be passed to one.api.One.load_object
:return:
"""
if len(self.collections) == 0:
Expand All @@ -969,7 +970,7 @@ def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_
attributes = {'spikes': spike_attributes, 'clusters': cluster_attributes, 'channels': None,
'templates': None, 'spikes_subset': None}
self.files[obj] = self.one.load_object(self.eid, obj=obj, attribute=attributes[obj],
collection=self.collection, download_only=True)
collection=self.collection, download_only=True, **kwargs)

def download_spike_sorting(self, **kwargs):
"""
Expand Down Expand Up @@ -1006,6 +1007,7 @@ def load_spike_sorting(self, **kwargs):
clusters = alfio.load_object(self.files['clusters'], wildcards=self.one.wildcards)
spikes = alfio.load_object(self.files['spikes'], wildcards=self.one.wildcards)
if 'brainLocationIds_ccf_2017' not in channels:
_logger.debug(f"loading channels from alyx for {self.files['channels']}")
_channels, self.histology = _load_channel_locations_traj(
self.eid, probe=self.pname, one=self.one, brain_atlas=self.atlas, return_source=True, aligned=True)
if _channels:
Expand All @@ -1016,8 +1018,24 @@ def load_spike_sorting(self, **kwargs):
return spikes, clusters, channels

@staticmethod
def merge_clusters(spikes, clusters, channels, cache_dir=None):
"""merge metrics and channels info - optionally saves a clusters.pqt dataframe"""
def compute_metrics(spikes, clusters=None):
nc = clusters['channels'].size if clusters else np.unique(spikes['clusters']).size
metrics = pd.DataFrame(quick_unit_metrics(
spikes['clusters'], spikes['times'], spikes['amps'], spikes['depths'], cluster_ids=np.arange(nc)))
return metrics

@staticmethod
def merge_clusters(spikes, clusters, channels, cache_dir=None, compute_metrics=False):
"""
Merge the metrics and the channel information into the clusters dictionary
:param spikes:
:param clusters:
:param channels:
:param cache_dir: if specified, will look for a cached parquet file to speed up. This is to be used
for clusters or analysis applications (defaults to None).
:param compute_metrics: if True, will explicitly recompute metrics (defaults to false)
:return: cluster dictionary containing metrics and histology
"""
if spikes == {}:
return
nc = clusters['channels'].size
Expand All @@ -1027,18 +1045,17 @@ def merge_clusters(spikes, clusters, channels, cache_dir=None):
metrics = clusters.pop('metrics')
if metrics.shape[0] != nc:
metrics = None
if metrics is None:
if metrics is None or compute_metrics is True:
_logger.debug("recompute clusters metrics")
metrics = pd.DataFrame(quick_unit_metrics(
spikes['clusters'], spikes['times'], spikes['amps'], spikes['depths'], cluster_ids=np.arange(nc)))
metrics = SpikeSortingLoader.compute_metrics(spikes, clusters)
if isinstance(cache_dir, Path):
metrics.to_parquet(Path(cache_dir).joinpath('clusters.metrics.pqt'))
for k in metrics.keys():
clusters[k] = metrics[k].to_numpy()

for k in channels.keys():
clusters[k] = channels[k][clusters['channels']]
if cache_dir:
_logger.debug(f'caching clusters metrics in {cache_dir}')
pd.DataFrame(clusters).to_parquet(Path(cache_dir).joinpath('clusters.pqt'))
return clusters

Expand Down
108 changes: 51 additions & 57 deletions brainbox/metrics/single_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Run the following to set-up the workspace to run the docstring examples:
>>> import brainbox as bb
>>> import alf.io as aio
>>> import one.alf.io as aio
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> import ibllib.ephys.spikes as e_spks
Expand All @@ -27,6 +27,8 @@
from phylib.stats import correlograms
from iblutil.util import Bunch
from iblutil.numerical import ismember, between_sorted
from slidingRP import metrics

from brainbox import singlecell
from brainbox.io.spikeglx import extract_waveforms
from brainbox.processing import bincount2D
Expand All @@ -37,20 +39,15 @@

# Parameters to be used in `quick_unit_metrics`
METRICS_PARAMS = {
'noise_cutoff': dict(quantile_length=.25, n_bins=100, nc_threshold=5, percent_threshold=0.10),
'missed_spikes_est': dict(spks_per_bin=10, sigma=4, min_num_bins=50),
'acceptable_contamination': 0.1,
'bin_size': 0.25,
'med_amp_thresh_uv': 50,
'min_isi': 0.0001,
'min_num_bins_for_missed_spks_est': 50,
'nc_bins': 100,
'nc_n_low_bins': 2,
'nc_quartile_length': 0.2,
'nc_thresh': 20,
'presence_window': 10,
'refractory_period': 0.0015,
'RPslide_thresh': 0.1,
'spks_per_bin_for_missed_spks_est': 10,
'std_smoothing_kernel_for_missed_spks_est': 4,
}


Expand Down Expand Up @@ -757,69 +754,75 @@ def slidingRP_viol(ts, bin_size=0.25, thresh=0.1, acceptThresh=0.1):
return didpass


def noise_cutoff(amps, quartile_length=.2, n_bins=100, n_low_bins=2):
def noise_cutoff(amps, quantile_length=.25, n_bins=100, nc_threshold=5, percent_threshold=0.10):
"""
A metric to determine whether a unit's amplitude distribution is cut off
A new metric to determine whether a unit's amplitude distribution is cut off
(at floor), without assuming a Gaussian distribution.
This metric takes the amplitude distribution, computes the mean and std
of an upper quartile of the distribution, and determines how many standard
deviations away from that mean a lower quartile lies.
Parameters
----------
amps : ndarray_like
The amplitudes (in uV) of the spikes.
quartile_length : float
quantile_length : float
The size of the upper quartile of the amplitude distribution.
n_bins : int
The number of bins used to compute a histogram of the amplitude
distribution.
n_low_bins : int
The number of bins used in the lower part of the distribution (where
cutoff is determined).
nc_threshold: float
the noise cutoff result has to be lower than this for a neuron to fail
percent_threshold: float
the first bin has to be greater than percent_threshold for neuron the to fail
Returns
-------
cutoff : float
Number of standard deviations that the lower mean is outside of the
mean of the upper quartile.
See Also
--------
missed_spikes_est
Examples
--------
1) Compute whether a unit's amplitude distribution is cut off
>>> amps = spks_b['amps'][unit_idxs]
>>> cutoff = bb.metrics.noise_cutoff(amps, quartile_length=.2,
n_bins=100, n_low_bins=2)
>>> cutoff = bb.metrics.noise_cutoff(amps, quantile_length=.25, n_bins=100)
"""

if amps.size > 1:
bins_list = np.linspace(0, np.max(amps), n_bins)
n, bins = np.histogram(amps, bins=bins_list)
dx = np.diff(n)
idx_nz = np.nonzero(dx) # indices of nonzeros
idx_peak = np.argmax(n)
length_top_half = idx_nz[0][-1] - idx_peak
high_quartile = 1 - (2 * quartile_length)

high_quartile_start_ind = int(np.ceil(high_quartile * length_top_half + idx_peak))
xx = idx_nz[0][idx_nz[0] > high_quartile_start_ind]
if len(n[xx]) > 0:
mean_high_quartile = np.mean(n[xx])
std_high_quartile = np.std(n[xx])
first_low_quartile = np.mean(n[idx_nz[0][1:n_low_bins]])
if std_high_quartile > 0:
cutoff = (first_low_quartile - mean_high_quartile) / std_high_quartile
else:
cutoff = np.float64(np.nan)
else:
cutoff = np.float64(np.nan)
else:
cutoff = np.float64(np.nan)
return cutoff
cutoff = np.float64(np.nan)
first_low_quantile = np.float64(np.nan)
fail_criteria = np.ones(1).astype(bool)[0]

if amps.size > 1: # ensure there are amplitudes available to analyze
bins_list = np.linspace(0, np.max(amps), n_bins) # list of bins to compute the amplitude histogram
n, bins = np.histogram(amps, bins=bins_list) # construct amplitude histogram
idx_peak = np.argmax(n) # peak of amplitude distribution
# don't count zeros #len(n) - idx_peak, compute the length of the top half of the distribution -- ignoring zero bins
length_top_half = len(np.where(n[idx_peak:-1] > 0)[0])
# the remaining part of the distribution, which we will compare the low quantile to
high_quantile = 2 * quantile_length
# the first bin (index) of the high quantile part of the distribution
high_quantile_start_ind = int(np.ceil(high_quantile * length_top_half + idx_peak))
# bins to consider in the high quantile (of all non-zero bins)
indices_bins_high_quantile = np.arange(high_quantile_start_ind, len(n))
idx_use = np.where(n[indices_bins_high_quantile] >= 1)[0]

if len(n[indices_bins_high_quantile]) > 0: # ensure there are amplitudes in these bins
# mean of all amp values in high quantile bins
mean_high_quantile = np.mean(n[indices_bins_high_quantile][idx_use])
std_high_quantile = np.std(n[indices_bins_high_quantile][idx_use])
if std_high_quantile > 0:
first_low_quantile = n[(n != 0)][1] # take the second bin
cutoff = (first_low_quantile - mean_high_quantile) / std_high_quantile
peak_bin_height = np.max(n)
percent_of_peak = percent_threshold * peak_bin_height

fail_criteria = (cutoff > nc_threshold) & (first_low_quantile > percent_of_peak)

nc_pass = ~fail_criteria
return nc_pass, cutoff, first_low_quantile


def spike_sorting_metrics(times, clusters, amps, depths, cluster_ids=None, params=METRICS_PARAMS):
Expand Down Expand Up @@ -977,6 +980,9 @@ def quick_unit_metrics(spike_clusters, spike_times, spike_amps, spike_depths,
# this is the geometric median
r.amp_median[ir] = np.array(10 ** (camp['log_amps'].median() / 20))
r.amp_std_dB[ir] = np.array(camp['log_amps'].std())
srp = metrics.slidingRP_all(spikeTimes=spike_times, spikeClusters=spike_clusters,
**{'sampleRate': 30000, 'binSizeCorr': 1 / 30000})
r.slidingRP_viol[srp['cidx']] = srp['value']

# loop over each cluster to compute the rest of the metrics
for ic in np.arange(nclust):
Expand All @@ -987,24 +993,12 @@ def quick_unit_metrics(spike_clusters, spike_times, spike_amps, spike_depths,
ts = spike_times[ispikes]
amps = spike_amps[ispikes]
depths = spike_depths[ispikes]

# compute metrics
r.contamination_alt[ic] = contamination_alt(ts, rp=params['refractory_period'])
r.contamination[ic], _ = contamination(
ts, tmin, tmax, rp=params['refractory_period'], min_isi=params['min_isi'])
r.slidingRP_viol[ic] = slidingRP_viol(ts,
bin_size=params['bin_size'],
thresh=params['RPslide_thresh'],
acceptThresh=params['acceptable_contamination'])
r.noise_cutoff[ic] = noise_cutoff(amps,
quartile_length=params['nc_quartile_length'],
n_bins=params['nc_bins'],
n_low_bins=params['nc_n_low_bins'])
r.missed_spikes_est[ic], _, _ = missed_spikes_est(
amps, spks_per_bin=params['spks_per_bin_for_missed_spks_est'],
sigma=params['std_smoothing_kernel_for_missed_spks_est'],
min_num_bins=params['min_num_bins_for_missed_spks_est'])

_, r.noise_cutoff[ic], _ = noise_cutoff(amps, **params['noise_cutoff'])
r.missed_spikes_est[ic], _, _ = missed_spikes_est(amps, **params['missed_spikes_est'])
# wonder if there is a need to low-cut this
r.drift[ic] = np.sum(np.abs(np.diff(depths))) / (tmax - tmin) * 3600

Expand All @@ -1023,7 +1017,7 @@ def compute_labels(r, params=METRICS_PARAMS, return_details=False):
# we could eventually do a bitwise qc
labels = np.c_[
r.slidingRP_viol,
r.noise_cutoff < params['nc_thresh'],
r.noise_cutoff < params['noise_cutoff']['nc_threshold'],
r.amp_median > params['med_amp_thresh_uv'] / 1e6,
]
if not return_details:
Expand Down
2 changes: 1 addition & 1 deletion ibllib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Library implementing the International Brain Laboratory data pipeline."""
__version__ = "2.16.1"
__version__ = "2.17.0"
import warnings

from iblutil.util import get_logger
Expand Down
4 changes: 4 additions & 0 deletions release_notes.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## Release 2.17.0 (unreleased)
### features
- units quality metrics use latest algorithms for refractory period violations and noise cut-off

## Release 2.16.1
### Release Notes 2.16.1 2022-09-28
### bugfixes
Expand Down
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ scipy>=1.3.0
seaborn>=0.9.0
tqdm>=4.32.1
# ibl libraries
iblutil>=1.3.0
wfield>0.2.2 # widefield extractor
labcams # widefield extractor
ibl-neuropixel>=0.3.1
iblutil>=1.3.0
labcams # widefield extractor
ONE-api>=1.8.1
slidingRP # steinmetz lab refractory period metrics
wfield>0.2.2 # widefield extractor

0 comments on commit 648a0bc

Please sign in to comment.