Skip to content

Commit

Permalink
Merge branch 'release/2.5.0'
Browse files Browse the repository at this point in the history
  • Loading branch information
juhuntenburg committed Nov 24, 2021
2 parents cab64bd + 850c3c4 commit e657db6
Show file tree
Hide file tree
Showing 22 changed files with 901 additions and 149 deletions.
2 changes: 1 addition & 1 deletion ibllib/dsp/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def convolve(x, w, mode='full'):
ns = ns_optim_fft(nsx + nsw)
x_ = np.concatenate((x, np.zeros([*x.shape[:-1], ns - nsx], dtype=x.dtype)), axis=-1)
w_ = np.concatenate((w, np.zeros([*w.shape[:-1], ns - nsw], dtype=w.dtype)), axis=-1)
xw = np.fft.irfft(np.fft.rfft(x_, axis=-1) * np.fft.rfft(w_, axis=-1), axis=-1)
xw = np.real(np.fft.irfft(np.fft.rfft(x_, axis=-1) * np.fft.rfft(w_, axis=-1), axis=-1))
xw = xw[..., :(nsx + nsw)] # remove 0 padding
if mode == 'full':
return xw
Expand Down
391 changes: 329 additions & 62 deletions ibllib/dsp/voltage.py

Large diffs are not rendered by default.

59 changes: 37 additions & 22 deletions ibllib/ephys/ephysqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import pandas as pd
from scipy import signal
from scipy import signal, stats
from tqdm import tqdm
import one.alf.io as alfio
from iblutil.util import Bunch
Expand Down Expand Up @@ -105,7 +105,10 @@ def _compute_metrics_array(raw, fs, h):
rms_pre_proc = dsp.rms(destripe)
detections = spikes.detection(data=destripe.T, fs=fs, h=h, detect_threshold=SPIKE_THRESHOLD_UV * 1e-6)
spike_rate = np.bincount(detections.trace, minlength=raw.shape[0]).astype(np.float32)
return rms_raw, rms_pre_proc, spike_rate
channel_labels, _ = dsp.voltage.detect_bad_channels(raw, fs=fs)
_, psd = signal.welch(destripe, fs=fs, window='hanning', nperseg=WELCH_WIN_LENGTH_SAMPLES,
detrend='constant', return_onesided=True, scaling='density', axis=-1)
return rms_raw, rms_pre_proc, spike_rate, channel_labels, psd

def run(self, update: bool = False, overwrite: bool = True, stream: bool = None, **kwargs) -> (str, dict):
"""
Expand All @@ -124,14 +127,18 @@ def run(self, update: bool = False, overwrite: bool = True, stream: bool = None,
self.load_data()
qc_files = []
# If ap meta file present, calculate median RMS per channel before and after destriping
# TODO: This should go a a separate function once we have a spikeglx.Streamer that behaves like the Reader
# NB: ideally this should go a a separate function once we have a spikeglx.Streamer that behaves like the Reader
if self.data.ap_meta:
rms_file = self.probe_path.joinpath("_iblqc_ephysChannels.apRMS.npy")
spike_rate_file = self.probe_path.joinpath("_iblqc_ephysChannels.rawSpikeRates.npy")
if all([rms_file.exists(), spike_rate_file.exists()]) and not overwrite:
files = {'rms': self.probe_path.joinpath("_iblqc_ephysChannels.apRMS.npy"),
'spike_rate': self.probe_path.joinpath("_iblqc_ephysChannels.rawSpikeRates.npy"),
'channel_labels': self.probe_path.joinpath("_iblqc_ephysChannels.labels.npy"),
'ap_freqs': self.probe_path.joinpath("_iblqc_ephysSpectralDensityAP.freqs.npy"),
'ap_power': self.probe_path.joinpath("_iblqc_ephysSpectralDensityAP.power.npy"),
}
if all([files[k].exists() for k in files]) and not overwrite:
_logger.warning(f'RMS map already exists for .ap data in {self.probe_path}, skipping. '
f'Use overwrite option.')
median_rms = np.load(rms_file)
results = {k: np.load(files[k]) for k in files}
else:
rl = self.data.ap_meta.fileTimeSecs
nsync = len(spikeglx._get_sync_trace_indices_from_meta(self.data.ap_meta))
Expand All @@ -145,14 +152,18 @@ def run(self, update: bool = False, overwrite: bool = True, stream: bool = None,
raise ValueError("Wrong Neuropixel channel mapping used - ABORT")
t0s = np.arange(TMIN, rl - SAMPLE_LENGTH, BATCHES_SPACING)
all_rms = np.zeros((2, nc, t0s.shape[0]))
all_srs = np.zeros((nc, t0s.shape[0]))
all_srs, channel_ok = (np.zeros((nc, t0s.shape[0])) for _ in range(2))
psds = np.zeros((nc, dsp.fscale(WELCH_WIN_LENGTH_SAMPLES, 1, one_sided=True).size))
# If the ap.bin file is not present locally, stream it
if self.data.ap is None and self.stream is True:
_logger.warning(f'Streaming .ap data to compute RMS samples for probe {self.pid}')
for i, t0 in enumerate(tqdm(t0s)):
sr, _ = sglx_streamer(self.pid, t0=t0, nsecs=1, one=self.one, remove_cached=True)
raw = sr[:, :-nsync].T
all_rms[0, :, i], all_rms[1, :, i], all_srs[:, i] = self._compute_metrics_array(raw, sr.fs, h)
all_rms[0, :, i], all_rms[1, :, i], all_srs[:, i], channel_ok[:, i], psd =\
self._compute_metrics_array(raw, sr.fs, h)
psds += psd
fs = sr.fs
elif self.data.ap is None and self.stream is not True:
_logger.warning('Raw .ap data is not available locally. Run with stream=True in order to stream '
'data for calculating RMS samples.')
Expand All @@ -161,23 +172,27 @@ def run(self, update: bool = False, overwrite: bool = True, stream: bool = None,
for i, t0 in enumerate(t0s):
sl = slice(int(t0 * self.data.ap.fs), int((t0 + SAMPLE_LENGTH) * self.data.ap.fs))
raw = self.data.ap[sl, :-nsync].T
all_rms[0, :, i], all_rms[1, :, i], all_srs[:, i] = self._compute_metrics_array(raw, self.data.ap.fs, h)
all_rms[0, :, i], all_rms[1, :, i], all_srs[:, i], channel_ok[:, i], psd =\
self._compute_metrics_array(raw, self.data.ap.fs, h)
fs = self.data.ap.fs
psds += psd
# Calculate the median RMS across all samples per channel
median_rms = np.median(all_rms, axis=-1)
median_spike_rate = np.median(all_srs, axis=-1)
np.save(rms_file, median_rms)
np.save(spike_rate_file, median_spike_rate)
qc_files.extend([rms_file, spike_rate_file])

results = {'rms': np.median(all_rms, axis=-1),
'spike_rate': np.median(all_srs, axis=-1),
'channel_labels': stats.mode(channel_ok, axis=1)[0],
'ap_freqs': dsp.fscale(WELCH_WIN_LENGTH_SAMPLES, 1 / fs, one_sided=True),
'ap_power': psds.T / len(t0s), # shape: (nfreqs, nchannels)
}
for k in files:
np.save(files[k], results[k])
qc_files.extend([files[k] for k in files])
for p in [10, 90]:
self.metrics[f'apRms_p{p}_raw'] = np.format_float_scientific(np.percentile(median_rms[0, :], p),
precision=2)
self.metrics[f'apRms_p{p}_proc'] = np.format_float_scientific(np.percentile(median_rms[1, :], p),
precision=2)
self.metrics[f'apRms_p{p}_raw'] = np.format_float_scientific(
np.percentile(results['rms'][0, :], p), precision=2)
self.metrics[f'apRms_p{p}_proc'] = np.format_float_scientific(
np.percentile(results['rms'][1, :], p), precision=2)
if update:
self.update_extended_qc(self.metrics)
# self.update(outcome)

# If lf meta and bin file present, run the old qc on LF data
if self.data.lf_meta and self.data.lf:
qc_files.extend(extract_rmsmap(self.data.lf, out_folder=self.probe_path, overwrite=overwrite))
Expand Down
6 changes: 4 additions & 2 deletions ibllib/ephys/spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def _sr(ap_file):
out_files.extend([f for f in out_path.glob("*.*") if
f.name.startswith(('channels.', 'drift', 'clusters.', 'spikes.', 'templates.',
'_kilosort_', '_phy_spikes_subset', '_ibl_log.info'))])
# the QC files computed during spike sorting stay within the raw ephys data folder
out_files.extend(list(ap_file.parent.glob('_iblqc_*AP.*.npy')))
return out_files, 0


Expand All @@ -159,7 +161,7 @@ def ks2_to_alf(ks_path, bin_path, out_path, bin_file=None, ampfactor=1, label=No
ac.convert(out_path, label=label, force=force, ampfactor=ampfactor)


def ks2_to_tar(ks_path, out_path):
def ks2_to_tar(ks_path, out_path, force=False):
"""
Compress output from kilosort 2 into tar file in order to register to flatiron and move to
spikesorters/ks2_matlab/probexx path. Output file to register
Expand Down Expand Up @@ -199,7 +201,7 @@ def ks2_to_tar(ks_path, out_path):
'whitening_mat_inv.npy']

out_file = Path(out_path).joinpath('_kilosort_raw.output.tar')
if out_file.exists():
if out_file.exists() and not force:
_logger.info(f"Already converted ks2 to tar: for {ks_path}, skipping.")
return [out_file]

Expand Down
3 changes: 2 additions & 1 deletion ibllib/ephys/sync_probes.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def version3B(ses_path, display=True, type=None, tol=2.5):
sync_probe = get_sync_fronts(ef.sync, ef.sync_map['imec_sync'])
sr = _get_sr(ef)
try:
assert(sync_nidq.times.size == sync_probe.times.size)
# we say that the number of pulses should be within 10 %
assert(np.isclose(sync_nidq.times.size, sync_probe.times.size, rtol=0.1))
except AssertionError:
raise Neuropixel3BSyncFrontsNonMatching(f"{ses_path}")
# if the qc of the diff finds anomalies, do not attempt to smooth the interp function
Expand Down
4 changes: 4 additions & 0 deletions ibllib/oneibl/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, s3_bucket_name=None, one=None):

def _download_datasets(self, datasets):

files = []
for _, d in datasets.iterrows():
rel_file_path = Path(d['session_path']).joinpath(d['rel_path'])
file_path = Path(self.one.cache_dir).joinpath(rel_file_path)
Expand All @@ -54,5 +55,8 @@ def _download_datasets(self, datasets):
_logger.info(f'Downloading {aws_path} to {file_path}')
self.bucket.download_file(aws_path, file_path.as_posix())
_logger.debug(f'Complete. Time elapsed {time() - ts} for {file_path}')
files.append(file_path)
else:
_logger.warning(f'{aws_path} not found on s3 bucket: {self.bucket.name}')

return files
99 changes: 87 additions & 12 deletions ibllib/oneibl/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import abc
from time import time

from one.api import ONE
from one.util import filter_datasets
from one.alf.files import add_uuid_string
from one.alf.files import add_uuid_string, session_path_parts
from iblutil.io.parquet import np2str
from ibllib.oneibl.registration import register_dataset
from ibllib.oneibl.patcher import FTPPatcher, SDSCPatcher, SDSC_ROOT_PATH, SDSC_PATCH_PATH
Expand Down Expand Up @@ -36,15 +37,17 @@ def setUp(self):
"""
pass

def getData(self):
def getData(self, one=None):
"""
Finds the datasets required for task based on input signatures
:return:
"""
if self.one is None:
if self.one is None and one is None:
return
session_datasets = self.one.list_datasets(self.one.path2eid(self.session_path), details=True)
df = pd.DataFrame(columns=self.one._cache.datasets.columns)

one = one or self.one
session_datasets = one.list_datasets(one.path2eid(self.session_path), details=True)
df = pd.DataFrame(columns=one._cache.datasets.columns)
for file in self.signature['input_files']:
df = df.append(filter_datasets(session_datasets, filename=file[0], collection=file[1],
wildcards=True, assert_unique=False))
Expand Down Expand Up @@ -131,14 +134,21 @@ def __init__(self, session_path, signatures, one=None):
else:
self.lab = labs[0]

self.globus.add_endpoint(f'flatiron_{self.lab}')
# For cortex lab we need to get the endpoint from the ibl alyx
if self.lab == 'cortexlab':
self.globus.add_endpoint(f'flatiron_{self.lab}', one=ONE(base_url='https://alyx.internationalbrainlab.org'))
else:
self.globus.add_endpoint(f'flatiron_{self.lab}')

def setUp(self):
"""
Function to download necessary data to run tasks using globus-sdk
:return:
"""
df = super().getData()
if self.lab == 'cortexlab':
df = super().getData(one=ONE(base_url='https://alyx.internationalbrainlab.org'))
else:
df = super().getData()

if len(df) == 0:
# If no datasets found in the cache only work off local file system do not attempt to download any missing data
Expand Down Expand Up @@ -225,24 +235,29 @@ def uploadData(self, outputs, version, **kwargs):


class RemoteAwsDataHandler(DataHandler):
def __init__(self, session_path, signature, one=None):
def __init__(self, task, session_path, signature, one=None):
"""
Data handler for running tasks on remote compute node. Will download missing data from private ibl s3 AWS data bucket
:param session_path: path to session
:param signature: input and output file signatures
:param one: ONE instance
"""
from one.globus import Globus # noqa
super().__init__(session_path, signature, one=one)
self.task = task
self.aws = AWS(one=self.one)
self.globus = Globus(client_name='server')
self.lab = session_path_parts(self.session_path, as_dict=True)['lab']
self.globus.add_endpoint(f'flatiron_{self.lab}')

def setUp(self):
"""
Function to download necessary data to run tasks using AWS boto3
:return:
"""
df = super().getData()
self.aws._download_datasets(df)
self.local_paths = self.aws._download_datasets(df)

def uploadData(self, outputs, version, **kwargs):
"""
Expand All @@ -251,10 +266,70 @@ def uploadData(self, outputs, version, **kwargs):
:param version: ibllib version
:return: output info of registered datasets
"""

# register datasets
versions = super().uploadData(outputs, version)
ftp_patcher = FTPPatcher(one=self.one)
return ftp_patcher.create_dataset(path=outputs, created_by=self.one.alyx.user,
versions=versions, **kwargs)
response = register_dataset(outputs, one=self.one, server_only=True, versions=versions, **kwargs)

# upload directly via globus
source_paths = []
target_paths = []
collections = {}

for dset, out in zip(response, outputs):
assert (Path(out).name == dset['name'])
# set flag to false
fr = next(fr for fr in dset['file_records'] if 'flatiron' in fr['data_repository'])
collection = '/'.join(fr['relative_path'].split('/')[:-1])
if collection in collections.keys():
collections[collection].update({f'{dset["name"]}': {'fr_id': fr['id'], 'size': dset['file_size']}})
else:
collections[collection] = {f'{dset["name"]}': {'fr_id': fr['id'], 'size': dset['file_size']}}

# Set all exists status to false for server file records
self.one.alyx.rest('files', 'partial_update', id=fr['id'], data={'exists': False})

source_paths.append(out)
target_paths.append(add_uuid_string(fr['relative_path'], dset['id']))

if len(target_paths) != 0:
ts = time()
for sp, tp in zip(source_paths, target_paths):
_logger.info(f'Uploading {sp} to {tp}')
self.globus.mv('local', f'flatiron_{self.lab}', source_paths, target_paths)
_logger.debug(f'Complete. Time elapsed {time() - ts}')

for collection, files in collections.items():
globus_files = self.globus.ls(f'flatiron_{self.lab}', collection, remove_uuid=True, return_size=True)
file_names = [gl[0] for gl in globus_files]
file_sizes = [gl[1] for gl in globus_files]

for name, details in files.items():
try:
idx = file_names.index(name)
size = file_sizes[idx]
if size == details['size']:
# update the file record if sizes match
self.one.alyx.rest('files', 'partial_update', id=details['fr_id'], data={'exists': True})
else:
_logger.warning(f'File {name} found on SDSC but sizes do not match')
except ValueError:
_logger.warning(f'File {name} not found on SDSC')

return response

# ftp_patcher = FTPPatcher(one=self.one)
# return ftp_patcher.create_dataset(path=outputs, created_by=self.one.alyx.user,
# versions=versions, **kwargs)

def cleanUp(self):
"""
Clean up, remove the files that were downloaded from globus once task has completed
:return:
"""
if self.task.status == 0:
for file in self.local_paths:
os.unlink(file)


class RemoteGlobusDataHandler(DataHandler):
Expand Down
Loading

0 comments on commit e657db6

Please sign in to comment.