diff --git a/CHANGELOG.md b/CHANGELOG.md index 04a12c51..ff769b62 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,11 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. +## [0.1.0b4] - 2021-11-29 +### Added ++ Processing with Kilosort and pyKilosort for Open Ephys and SpikeGLX + + ## [0.1.0b0] - 2021-05-07 ### Added + First beta release diff --git a/element_array_ephys/__init__.py b/element_array_ephys/__init__.py index b11cbf65..3c389614 100644 --- a/element_array_ephys/__init__.py +++ b/element_array_ephys/__init__.py @@ -1,2 +1,14 @@ -# ephys_acute as default -import element_array_ephys.ephys_acute as ephys \ No newline at end of file +import datajoint as dj +import logging +import os + + +dj.config['enable_python_native_blobs'] = True + + +def get_logger(name): + log = logging.getLogger(name) + log.setLevel(os.getenv('LOGLEVEL', 'INFO')) + return log + +from . import ephys_acute as ephys diff --git a/element_array_ephys/ephys_acute.py b/element_array_ephys/ephys_acute.py index 295279fb..df7ebd73 100644 --- a/element_array_ephys/ephys_acute.py +++ b/element_array_ephys/ephys_acute.py @@ -4,10 +4,16 @@ import numpy as np import inspect import importlib +import gc +from decimal import Decimal + from element_interface.utils import find_root_directory, find_full_path, dict_to_uuid from .readers import spikeglx, kilosort, openephys -from . import probe +from . import probe, get_logger + + +log = get_logger(__name__) schema = dj.schema() @@ -26,7 +32,7 @@ def activate(ephys_schema_name, probe_schema_name=None, *, create_schema=True, :param linking_module: a module name or a module containing the required dependencies to activate the `ephys` element: Upstream tables: - + Session: parent table to ProbeInsertion, typically identifying a recording session + + Session: table referenced by EphysRecording, typically identifying a recording session + SkullReference: Reference table for InsertionLocation, specifying the skull reference used for probe insertion location (e.g. Bregma, Lambda) Functions: @@ -37,6 +43,9 @@ def activate(ephys_schema_name, probe_schema_name=None, *, create_schema=True, Retrieve the session directory containing the recorded Neuropixels data for a given Session :param session_key: a dictionary of one Session `key` :return: a string for full path to the session directory + + get_processed_root_data_dir() -> str: + Retrieves the root directory for all processed data to be found from or written to + :return: a string for full path to the root directory for processed data """ if isinstance(linking_module, str): @@ -70,7 +79,14 @@ def get_ephys_root_data_dir() -> list: :return: a string for full path to the ephys root data directory, or list of strings for possible root data directories """ - return _linking_module.get_ephys_root_data_dir() + root_directories = _linking_module.get_ephys_root_data_dir() + if isinstance(root_directories, (str, pathlib.Path)): + root_directories = [root_directories] + + if hasattr(_linking_module, 'get_processed_root_data_dir'): + root_directories.append(_linking_module.get_processed_root_data_dir()) + + return root_directories def get_session_directory(session_key: dict) -> str: @@ -84,6 +100,18 @@ def get_session_directory(session_key: dict) -> str: return _linking_module.get_session_directory(session_key) +def get_processed_root_data_dir() -> str: + """ + get_processed_root_data_dir() -> str: + Retrieves the root directory for all processed data to be found from or written to + :return: a string for full path to the root directory for processed data + """ + + if hasattr(_linking_module, 'get_processed_root_data_dir'): + return _linking_module.get_processed_root_data_dir() + else: + return get_ephys_root_data_dir()[0] + # ----------------------------- Table declarations ---------------------- @@ -105,6 +133,63 @@ class ProbeInsertion(dj.Manual): -> probe.Probe """ + @classmethod + def auto_generate_entries(cls, session_key): + """ + Method to auto-generate ProbeInsertion entries for a particular session + Probe information is inferred from the meta data found in the session data directory + """ + session_dir = find_full_path(get_ephys_root_data_dir(), + get_session_directory(session_key)) + # search session dir and determine acquisition software + for ephys_pattern, ephys_acq_type in zip(['*.ap.meta', '*.oebin'], + ['SpikeGLX', 'Open Ephys']): + ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern)) + if ephys_meta_filepaths: + acq_software = ephys_acq_type + break + else: + raise FileNotFoundError( + f'Ephys recording data not found!' + f' Neither SpikeGLX nor Open Ephys recording files found in: {session_dir}') + + probe_list, probe_insertion_list = [], [] + if acq_software == 'SpikeGLX': + for meta_fp_idx, meta_filepath in enumerate(ephys_meta_filepaths): + spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) + + probe_key = {'probe_type': spikeglx_meta.probe_model, + 'probe': spikeglx_meta.probe_SN} + if (probe_key['probe'] not in [p['probe'] for p in probe_list] + and probe_key not in probe.Probe()): + probe_list.append(probe_key) + + probe_dir = meta_filepath.parent + try: + probe_number = re.search('(imec)?\d{1}$', probe_dir.name).group() + probe_number = int(probe_number.replace('imec', '')) + except AttributeError: + probe_number = meta_fp_idx + + probe_insertion_list.append({**session_key, + 'probe': spikeglx_meta.probe_SN, + 'insertion_number': int(probe_number)}) + elif acq_software == 'Open Ephys': + loaded_oe = openephys.OpenEphys(session_dir) + for probe_idx, oe_probe in enumerate(loaded_oe.probes.values()): + probe_key = {'probe_type': oe_probe.probe_model, 'probe': oe_probe.probe_SN} + if (probe_key['probe'] not in [p['probe'] for p in probe_list] + and probe_key not in probe.Probe()): + probe_list.append(probe_key) + probe_insertion_list.append({**session_key, + 'probe': oe_probe.probe_SN, + 'insertion_number': probe_idx}) + else: + raise NotImplementedError(f'Unknown acquisition software: {acq_software}') + + probe.Probe.insert(probe_list) + cls.insert(probe_insertion_list, skip_duplicates=True) + @schema class InsertionLocation(dj.Manual): @@ -130,7 +215,7 @@ class EphysRecording(dj.Imported): --- -> probe.ElectrodeConfig -> AcquisitionSoftware - sampling_rate: float # (Hz) + sampling_rate: float # (Hz) recording_datetime: datetime # datetime of the recording from this probe recording_duration: float # (seconds) duration of the recording from this probe """ @@ -143,16 +228,14 @@ class EphysFile(dj.Part): """ def make(self, key): - - session_dir = find_full_path(get_ephys_root_data_dir(), - get_session_directory(key)) - + session_dir = find_full_path(get_ephys_root_data_dir(), + get_session_directory(key)) inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1('probe') # search session dir and determine acquisition software for ephys_pattern, ephys_acq_type in zip(['*.ap.meta', '*.oebin'], ['SpikeGLX', 'Open Ephys']): - ephys_meta_filepaths = [fp for fp in session_dir.rglob(ephys_pattern)] + ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern)) if ephys_meta_filepaths: acq_software = ephys_acq_type break @@ -162,6 +245,8 @@ def make(self, key): f' Neither SpikeGLX nor Open Ephys recording files found' f' in {session_dir}') + supported_probe_types = probe.ProbeType.fetch('probe_type') + if acq_software == 'SpikeGLX': for meta_filepath in ephys_meta_filepaths: spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) @@ -171,7 +256,7 @@ def make(self, key): raise FileNotFoundError( 'No SpikeGLX data found for probe insertion: {}'.format(key)) - if re.search('(1.0|2.0)', spikeglx_meta.probe_model): + if spikeglx_meta.probe_model in supported_probe_types: probe_type = spikeglx_meta.probe_model electrode_query = probe.ProbeType.Electrode & {'probe_type': probe_type} @@ -188,13 +273,14 @@ def make(self, key): 'Processing for neuropixels probe model' ' {} not yet implemented'.format(spikeglx_meta.probe_model)) - self.insert1({**key, - **generate_electrode_config(probe_type, electrode_group_members), - 'acq_software': acq_software, - 'sampling_rate': spikeglx_meta.meta['imSampRate'], - 'recording_datetime': spikeglx_meta.recording_time, - 'recording_duration': (spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath))}) + self.insert1({ + **key, + **generate_electrode_config(probe_type, electrode_group_members), + 'acq_software': acq_software, + 'sampling_rate': spikeglx_meta.meta['imSampRate'], + 'recording_datetime': spikeglx_meta.recording_time, + 'recording_duration': (spikeglx_meta.recording_duration + or spikeglx.retrieve_recording_duration(meta_filepath))}) root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) @@ -210,7 +296,10 @@ def make(self, key): raise FileNotFoundError( 'No Open Ephys data found for probe insertion: {}'.format(key)) - if re.search('(1.0|2.0)', probe_data.probe_model): + if not probe_data.ap_meta: + raise IOError('No analog signals found - check "structure.oebin" file or "continuous" directory') + + if probe_data.probe_model in supported_probe_types: probe_type = probe_data.probe_model electrode_query = probe.ProbeType.Electrode & {'probe_type': probe_type} @@ -219,24 +308,29 @@ def make(self, key): electrode_group_members = [ probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta['channels_ids']] + for channel_idx in probe_data.ap_meta['channels_indices']] else: raise NotImplementedError( 'Processing for neuropixels' ' probe model {} not yet implemented'.format(probe_data.probe_model)) - self.insert1({**key, - **generate_electrode_config(probe_type, electrode_group_members), - 'acq_software': acq_software, - 'sampling_rate': probe_data.ap_meta['sample_rate'], - 'recording_datetime': probe_data.recording_info['recording_datetimes'][0], - 'recording_duration': np.sum(probe_data.recording_info['recording_durations'])}) + self.insert1({ + **key, + **generate_electrode_config(probe_type, electrode_group_members), + 'acq_software': acq_software, + 'sampling_rate': probe_data.ap_meta['sample_rate'], + 'recording_datetime': probe_data.recording_info['recording_datetimes'][0], + 'recording_duration': np.sum(probe_data.recording_info['recording_durations'])}) root_dir = find_root_directory(get_ephys_root_data_dir(), probe_data.recording_info['recording_files'][0]) self.EphysFile.insert([{**key, 'file_path': fp.relative_to(root_dir).as_posix()} for fp in probe_data.recording_info['recording_files']]) + # explicitly garbage collect "dataset" + # as these may have large memory footprint and may not be cleared fast enough + del probe_data, dataset + gc.collect() else: raise NotImplementedError(f'Processing ephys files from' f' acquisition software of type {acq_software} is' @@ -267,8 +361,7 @@ class Electrode(dj.Part): _skip_channel_counts = 9 def make(self, key): - acq_software, probe_sn = (EphysRecording - * ProbeInsertion & key).fetch1('acq_software', 'probe') + acq_software = (EphysRecording * ProbeInsertion & key).fetch1('acq_software') electrode_keys, lfp = [], [] @@ -301,14 +394,10 @@ def make(self, key): shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap['data'][recorded_site] electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)]) elif acq_software == 'Open Ephys': - session_dir = find_full_path(get_ephys_root_data_dir(), - get_session_directory(key)) - - loaded_oe = openephys.OpenEphys(session_dir) - oe_probe = loaded_oe.probes[probe_sn] + oe_probe = get_openephys_probe_data(key) - lfp_channel_ind = np.arange( - len(oe_probe.lfp_meta['channels_ids']))[-1::-self._skip_channel_counts] + lfp_channel_ind = np.r_[ + len(oe_probe.lfp_meta['channels_indices'])-1:0:-self._skip_channel_counts] lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind] # (sample x channel) lfp = (lfp * np.array(oe_probe.lfp_meta['channels_gains'])[lfp_channel_ind]).T # (channel x sample) @@ -325,8 +414,8 @@ def make(self, key): probe_electrodes = {key['electrode']: key for key in electrode_query.fetch('KEY')} - for channel_idx in np.array(oe_probe.lfp_meta['channels_ids'])[lfp_channel_ind]: - electrode_keys.append(probe_electrodes[channel_idx]) + electrode_keys.extend(probe_electrodes[channel_idx] + for channel_idx in lfp_channel_ind) else: raise NotImplementedError(f'LFP extraction from acquisition software' f' of type {acq_software} is not yet implemented') @@ -347,8 +436,9 @@ class ClusteringMethod(dj.Lookup): clustering_method_desc: varchar(1000) """ - contents = [('kilosort', 'kilosort clustering method'), - ('kilosort2', 'kilosort2 clustering method')] + contents = [('kilosort2', 'kilosort2 clustering method'), + ('kilosort2.5', 'kilosort2.5 clustering method'), + ('kilosort3', 'kilosort3 clustering method')] @schema @@ -365,13 +455,18 @@ class ClusteringParamSet(dj.Lookup): """ @classmethod - def insert_new_params(cls, processing_method: str, paramset_idx: int, - paramset_desc: str, params: dict): - param_dict = {'clustering_method': processing_method, + def insert_new_params(cls, clustering_method: str, paramset_desc: str, + params: dict, paramset_idx: int = None): + if paramset_idx is None: + paramset_idx = (dj.U().aggr(cls, n='max(paramset_idx)').fetch1('n') or 0) + 1 + + param_dict = {'clustering_method': clustering_method, 'paramset_idx': paramset_idx, 'paramset_desc': paramset_desc, 'params': params, - 'param_set_hash': dict_to_uuid(params)} + 'param_set_hash': dict_to_uuid( + {**params, 'clustering_method': clustering_method}) + } param_query = cls & {'param_set_hash': param_dict['param_set_hash']} if param_query: # If the specified param-set already exists @@ -380,9 +475,13 @@ def insert_new_params(cls, processing_method: str, paramset_idx: int, return else: # If not same name: human error, trying to add the same paramset with different name raise dj.DataJointError( - 'The specified param-set' - ' already exists - paramset_idx: {}'.format(existing_paramset_idx)) + f'The specified param-set already exists' + f' - with paramset_idx: {existing_paramset_idx}') else: + if {'paramset_idx': paramset_idx} in cls.proj(): + raise dj.DataJointError( + f'The specified paramset_idx {paramset_idx} already exists,' + f' please pick a different one.') cls.insert1(param_dict) @@ -390,7 +489,7 @@ def insert_new_params(cls, processing_method: str, paramset_idx: int, class ClusterQualityLabel(dj.Lookup): definition = """ # Quality - cluster_quality_label: varchar(100) + cluster_quality_label: varchar(100) # cluster quality type - e.g. 'good', 'MUA', 'noise', etc. --- cluster_quality_description: varchar(4000) """ @@ -409,10 +508,62 @@ class ClusteringTask(dj.Manual): -> EphysRecording -> ClusteringParamSet --- - clustering_output_dir: varchar(255) # clustering output directory relative to the clustering root data directory + clustering_output_dir='': varchar(255) # clustering output directory relative to the clustering root data directory task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation """ + @classmethod + def infer_output_dir(cls, key, relative=False, mkdir=False): + """ + Given a 'key' to an entry in this table + Return the expected clustering_output_dir based on the following convention: + processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx} + e.g.: sub4/sess1/probe_2/kilosort2_0 + """ + processed_dir = pathlib.Path(get_processed_root_data_dir()) + session_dir = find_full_path(get_ephys_root_data_dir(), + get_session_directory(key)) + root_dir = find_root_directory(get_ephys_root_data_dir(), session_dir) + + method = (ClusteringParamSet * ClusteringMethod & key).fetch1( + 'clustering_method').replace(".", "-") + + output_dir = (processed_dir + / session_dir.relative_to(root_dir) + / f'probe_{key["insertion_number"]}' + / f'{method}_{key["paramset_idx"]}') + + if mkdir: + output_dir.mkdir(parents=True, exist_ok=True) + log.info(f'{output_dir} created!') + + return output_dir.relative_to(processed_dir) if relative else output_dir + + @classmethod + def auto_generate_entries(cls, ephys_recording_key, paramset_idx=0): + """ + Method to auto-generate ClusteringTask entries for a particular ephys recording + Output directory is auto-generated based on the convention + defined in `ClusteringTask.infer_output_dir()` + Default parameter set used: paramset_idx = 0 + """ + key = {**ephys_recording_key, 'paramset_idx': paramset_idx} + + processed_dir = get_processed_root_data_dir() + output_dir = ClusteringTask.infer_output_dir(key, relative=False, mkdir=True) + + try: + kilosort.Kilosort(output_dir) # check if the directory is a valid Kilosort output + except FileNotFoundError: + task_mode = 'trigger' + else: + task_mode = 'load' + + cls.insert1({ + **key, + 'clustering_output_dir': output_dir.relative_to(processed_dir).as_posix(), + 'task_mode': task_mode}) + @schema class Clustering(dj.Imported): @@ -433,17 +584,85 @@ class Clustering(dj.Imported): def make(self, key): task_mode, output_dir = (ClusteringTask & key).fetch1( 'task_mode', 'clustering_output_dir') + + if not output_dir: + output_dir = ClusteringTask.infer_output_dir(key, relative=True, mkdir=True) + # update clustering_output_dir + ClusteringTask.update1({**key, 'clustering_output_dir': output_dir.as_posix()}) + kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) if task_mode == 'load': - kilosort_dataset = kilosort.Kilosort(kilosort_dir) # check if the directory is a valid Kilosort output - creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir) + kilosort.Kilosort(kilosort_dir) # check if the directory is a valid Kilosort output elif task_mode == 'trigger': - raise NotImplementedError('Automatic triggering of' - ' clustering analysis is not yet supported') + acq_software, clustering_method, params = (ClusteringTask * EphysRecording + * ClusteringParamSet & key).fetch1( + 'acq_software', 'clustering_method', 'params') + + if 'kilosort' in clustering_method: + from element_array_ephys.readers import kilosort_triggering + + # add additional probe-recording and channels details into `params` + params = {**params, **get_recording_channels_details(key)} + params['fs'] = params['sample_rate'] + + if acq_software == 'SpikeGLX': + spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + spikeglx_recording.validate_file('ap') + + if clustering_method.startswith('pykilosort'): + kilosort_triggering.run_pykilosort( + continuous_file=spikeglx_recording.root_dir / ( + spikeglx_recording.root_name + '.ap.bin'), + kilosort_output_directory=kilosort_dir, + channel_ind=params.pop('channel_ind'), + x_coords=params.pop('x_coords'), + y_coords=params.pop('y_coords'), + shank_ind=params.pop('shank_ind'), + connected=params.pop('connected'), + sample_rate=params.pop('sample_rate'), + params=params) + else: + run_kilosort = kilosort_triggering.SGLXKilosortPipeline( + npx_input_dir=spikeglx_meta_filepath.parent, + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + run_CatGT=True) + run_kilosort.run_modules() + elif acq_software == 'Open Ephys': + oe_probe = get_openephys_probe_data(key) + + assert len(oe_probe.recording_info['recording_files']) == 1 + + # run kilosort + if clustering_method.startswith('pykilosort'): + kilosort_triggering.run_pykilosort( + continuous_file=pathlib.Path(oe_probe.recording_info['recording_files'][0]) / 'continuous.dat', + kilosort_output_directory=kilosort_dir, + channel_ind=params.pop('channel_ind'), + x_coords=params.pop('x_coords'), + y_coords=params.pop('y_coords'), + shank_ind=params.pop('shank_ind'), + connected=params.pop('connected'), + sample_rate=params.pop('sample_rate'), + params=params) + else: + run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( + npx_input_dir=oe_probe.recording_info['recording_files'][0], + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}') + run_kilosort.run_modules() + else: + raise NotImplementedError(f'Automatic triggering of {clustering_method}' + f' clustering analysis is not yet supported') + else: raise ValueError(f'Unknown task mode: {task_mode}') + creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir) self.insert1({**key, 'clustering_time': creation_time}) @@ -511,7 +730,10 @@ def make(self, key): kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) kilosort_dataset = kilosort.Kilosort(kilosort_dir) - acq_software = (EphysRecording & key).fetch1('acq_software') + acq_software, sample_rate = (EphysRecording & key).fetch1( + 'acq_software', 'sampling_rate') + + sample_rate = kilosort_dataset.data['params'].get('sample_rate', sample_rate) # ---------- Unit ---------- # -- Remove 0-spike units @@ -541,7 +763,7 @@ def make(self, key): if (kilosort_dataset.data['spike_clusters'] == unit).any(): unit_channel, _ = kilosort_dataset.get_best_channel(unit) unit_spike_times = (spike_times[kilosort_dataset.data['spike_clusters'] == unit] - / kilosort_dataset.data['params']['sample_rate']) + / sample_rate) spike_count = len(unit_spike_times) units.append({ @@ -658,8 +880,10 @@ def yield_unit_waveforms(): # insert waveform on a per-unit basis to mitigate potential memory issue self.insert1(key) for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms(): - self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True) - self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True) + if unit_peak_waveform: + self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True) + if unit_electrode_waveforms: + self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True) # ---------------- HELPER FUNCTIONS ---------------- @@ -694,6 +918,22 @@ def get_spikeglx_meta_filepath(ephys_recording_key): return spikeglx_meta_filepath +def get_openephys_probe_data(ephys_recording_key): + inserted_probe_serial_number = (ProbeInsertion * probe.Probe + & ephys_recording_key).fetch1('probe') + session_dir = find_full_path(get_ephys_root_data_dir(), + get_session_directory(ephys_recording_key)) + loaded_oe = openephys.OpenEphys(session_dir) + probe_data = loaded_oe.probes[inserted_probe_serial_number] + + # explicitly garbage collect "loaded_oe" + # as these may have large memory footprint and may not be cleared fast enough + del loaded_oe + gc.collect() + + return probe_data + + def get_neuropixels_channel2electrode_map(ephys_recording_key, acq_software): if acq_software == 'SpikeGLX': spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) @@ -714,11 +954,7 @@ def get_neuropixels_channel2electrode_map(ephys_recording_key, acq_software): for recorded_site, (shank, shank_col, shank_row, _) in enumerate( spikeglx_meta.shankmap['data'])} elif acq_software == 'Open Ephys': - session_dir = find_full_path(get_ephys_root_data_dir(), - get_session_directory(ephys_recording_key)) - openephys_dataset = openephys.OpenEphys(session_dir) - probe_serial_number = (ProbeInsertion & ephys_recording_key).fetch1('probe') - probe_dataset = openephys_dataset.probes[probe_serial_number] + probe_dataset = get_openephys_probe_data(ephys_recording_key) electrode_query = (probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode @@ -729,7 +965,7 @@ def get_neuropixels_channel2electrode_map(ephys_recording_key, acq_software): channel2electrode_map = { channel_idx: probe_electrodes[channel_idx] - for channel_idx in probe_dataset.ap_meta['channels_ids']} + for channel_idx in probe_dataset.ap_meta['channels_indices']} return channel2electrode_map @@ -763,3 +999,39 @@ def generate_electrode_config(probe_type: str, electrodes: list): return electrode_config_key + +def get_recording_channels_details(ephys_recording_key): + channels_details = {} + + acq_software, sample_rate = (EphysRecording & ephys_recording_key).fetch1('acq_software', + 'sampling_rate') + + probe_type = (ProbeInsertion * probe.Probe & ephys_recording_key).fetch1('probe_type') + channels_details['probe_type'] = {'neuropixels 1.0 - 3A': '3A', + 'neuropixels 1.0 - 3B': 'NP1', + 'neuropixels UHD': 'NP1100', + 'neuropixels 2.0 - SS': 'NP21', + 'neuropixels 2.0 - MS': 'NP24'}[probe_type] + + electrode_config_key = (probe.ElectrodeConfig * EphysRecording & ephys_recording_key).fetch1('KEY') + channels_details['channel_ind'], channels_details['x_coords'], channels_details[ + 'y_coords'], channels_details['shank_ind'] = ( + probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode + & electrode_config_key).fetch('electrode', 'x_coord', 'y_coord', 'shank') + channels_details['sample_rate'] = sample_rate + channels_details['num_channels'] = len(channels_details['channel_ind']) + + if acq_software == 'SpikeGLX': + spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + channels_details['uVPerBit'] = spikeglx_recording.get_channel_bit_volts('ap')[0] + channels_details['connected'] = np.array( + [v for *_, v in spikeglx_recording.apmeta.shankmap['data']]) + elif acq_software == 'Open Ephys': + oe_probe = get_openephys_probe_data(ephys_recording_key) + channels_details['uVPerBit'] = oe_probe.ap_meta['channels_gains'][0] + channels_details['connected'] = np.array([ + int(v == 1) for c, v in oe_probe.channels_connected.items() + if c in channels_details['channel_ind']]) + + return channels_details diff --git a/element_array_ephys/ephys_chronic.py b/element_array_ephys/ephys_chronic.py index a0ea2e11..2f6948e6 100644 --- a/element_array_ephys/ephys_chronic.py +++ b/element_array_ephys/ephys_chronic.py @@ -4,10 +4,16 @@ import numpy as np import inspect import importlib +import gc +from decimal import Decimal + from element_interface.utils import find_root_directory, find_full_path, dict_to_uuid from .readers import spikeglx, kilosort, openephys -from . import probe +from . import probe, get_logger + + +log = get_logger(__name__) schema = dj.schema() @@ -38,6 +44,9 @@ def activate(ephys_schema_name, probe_schema_name=None, *, create_schema=True, Retrieve the session directory containing the recorded Neuropixels data for a given Session :param session_key: a dictionary of one Session `key` :return: a string for full path to the session directory + + get_processed_root_data_dir() -> str: + Retrieves the root directory for all processed data to be found from or written to + :return: a string for full path to the root directory for processed data """ if isinstance(linking_module, str): @@ -70,7 +79,14 @@ def get_ephys_root_data_dir() -> list: :return: a string for full path to the ephys root data directory, or list of strings for possible root data directories """ - return _linking_module.get_ephys_root_data_dir() + root_directories = _linking_module.get_ephys_root_data_dir() + if isinstance(root_directories, (str, pathlib.Path)): + root_directories = [root_directories] + + if hasattr(_linking_module, 'get_processed_root_data_dir'): + root_directories.append(_linking_module.get_processed_root_data_dir()) + + return root_directories def get_session_directory(session_key: dict) -> str: @@ -84,6 +100,18 @@ def get_session_directory(session_key: dict) -> str: return _linking_module.get_session_directory(session_key) +def get_processed_root_data_dir() -> str: + """ + get_processed_root_data_dir() -> str: + Retrieves the root directory for all processed data to be found from or written to + :return: a string for full path to the root directory for processed data + """ + + if hasattr(_linking_module, 'get_processed_root_data_dir'): + return _linking_module.get_processed_root_data_dir() + else: + return get_ephys_root_data_dir()[0] + # ----------------------------- Table declarations ---------------------- @@ -153,7 +181,7 @@ def make(self, key): # search session dir and determine acquisition software for ephys_pattern, ephys_acq_type in zip(['*.ap.meta', '*.oebin'], ['SpikeGLX', 'Open Ephys']): - ephys_meta_filepaths = [fp for fp in session_dir.rglob(ephys_pattern)] + ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern)) if ephys_meta_filepaths: acq_software = ephys_acq_type break @@ -163,6 +191,8 @@ def make(self, key): f' Neither SpikeGLX nor Open Ephys recording files found' f' in {session_dir}') + supported_probe_types = probe.ProbeType.fetch('probe_type') + if acq_software == 'SpikeGLX': for meta_filepath in ephys_meta_filepaths: spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) @@ -173,7 +203,7 @@ def make(self, key): f'No SpikeGLX data found for probe insertion: {key}' + ' The probe serial number does not match.') - if re.search('(1.0|2.0)', spikeglx_meta.probe_model): + if spikeglx_meta.probe_model in supported_probe_types: probe_type = spikeglx_meta.probe_model electrode_query = probe.ProbeType.Electrode & {'probe_type': probe_type} @@ -190,13 +220,14 @@ def make(self, key): 'Processing for neuropixels probe model' ' {} not yet implemented'.format(spikeglx_meta.probe_model)) - self.insert1({**key, - **generate_electrode_config(probe_type, electrode_group_members), - 'acq_software': acq_software, - 'sampling_rate': spikeglx_meta.meta['imSampRate'], - 'recording_datetime': spikeglx_meta.recording_time, - 'recording_duration': (spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath))}) + self.insert1({ + **key, + **generate_electrode_config(probe_type, electrode_group_members), + 'acq_software': acq_software, + 'sampling_rate': spikeglx_meta.meta['imSampRate'], + 'recording_datetime': spikeglx_meta.recording_time, + 'recording_duration': (spikeglx_meta.recording_duration + or spikeglx.retrieve_recording_duration(meta_filepath))}) root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) @@ -212,7 +243,10 @@ def make(self, key): raise FileNotFoundError( 'No Open Ephys data found for probe insertion: {}'.format(key)) - if re.search('(1.0|2.0)', probe_data.probe_model): + if not probe_data.ap_meta: + raise IOError('No analog signals found - check "structure.oebin" file or "continuous" directory') + + if probe_data.probe_model in supported_probe_types: probe_type = probe_data.probe_model electrode_query = probe.ProbeType.Electrode & {'probe_type': probe_type} @@ -221,24 +255,29 @@ def make(self, key): electrode_group_members = [ probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta['channels_ids']] + for channel_idx in probe_data.ap_meta['channels_indices']] else: raise NotImplementedError( 'Processing for neuropixels' ' probe model {} not yet implemented'.format(probe_data.probe_model)) - self.insert1({**key, - **generate_electrode_config(probe_type, electrode_group_members), - 'acq_software': acq_software, - 'sampling_rate': probe_data.ap_meta['sample_rate'], - 'recording_datetime': probe_data.recording_info['recording_datetimes'][0], - 'recording_duration': np.sum(probe_data.recording_info['recording_durations'])}) + self.insert1({ + **key, + **generate_electrode_config(probe_type, electrode_group_members), + 'acq_software': acq_software, + 'sampling_rate': probe_data.ap_meta['sample_rate'], + 'recording_datetime': probe_data.recording_info['recording_datetimes'][0], + 'recording_duration': np.sum(probe_data.recording_info['recording_durations'])}) root_dir = find_root_directory(get_ephys_root_data_dir(), probe_data.recording_info['recording_files'][0]) self.EphysFile.insert([{**key, 'file_path': fp.relative_to(root_dir).as_posix()} for fp in probe_data.recording_info['recording_files']]) + # explicitly garbage collect "dataset" + # as these may have large memory footprint and may not be cleared fast enough + del probe_data, dataset + gc.collect() else: raise NotImplementedError(f'Processing ephys files from' f' acquisition software of type {acq_software} is' @@ -269,8 +308,7 @@ class Electrode(dj.Part): _skip_channel_counts = 9 def make(self, key): - acq_software, probe_sn = (EphysRecording - * ProbeInsertion & key).fetch1('acq_software', 'probe') + acq_software = (EphysRecording * ProbeInsertion & key).fetch1('acq_software') electrode_keys, lfp = [], [] @@ -303,13 +341,10 @@ def make(self, key): shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap['data'][recorded_site] electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)]) elif acq_software == 'Open Ephys': - session_dir = find_full_path(get_ephys_root_data_dir(), - get_session_directory(key)) - loaded_oe = openephys.OpenEphys(session_dir) - oe_probe = loaded_oe.probes[probe_sn] + oe_probe = get_openephys_probe_data(key) - lfp_channel_ind = np.arange( - len(oe_probe.lfp_meta['channels_ids']))[-1::-self._skip_channel_counts] + lfp_channel_ind = np.r_[ + len(oe_probe.lfp_meta['channels_indices'])-1:0:-self._skip_channel_counts] lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind] # (sample x channel) lfp = (lfp * np.array(oe_probe.lfp_meta['channels_gains'])[lfp_channel_ind]).T # (channel x sample) @@ -326,8 +361,8 @@ def make(self, key): probe_electrodes = {key['electrode']: key for key in electrode_query.fetch('KEY')} - for channel_idx in np.array(oe_probe.lfp_meta['channels_ids'])[lfp_channel_ind]: - electrode_keys.append(probe_electrodes[channel_idx]) + electrode_keys.extend(probe_electrodes[channel_idx] + for channel_idx in lfp_channel_ind) else: raise NotImplementedError(f'LFP extraction from acquisition software' f' of type {acq_software} is not yet implemented') @@ -348,8 +383,9 @@ class ClusteringMethod(dj.Lookup): clustering_method_desc: varchar(1000) """ - contents = [('kilosort', 'kilosort clustering method'), - ('kilosort2', 'kilosort2 clustering method')] + contents = [('kilosort2', 'kilosort2 clustering method'), + ('kilosort2.5', 'kilosort2.5 clustering method'), + ('kilosort3', 'kilosort3 clustering method')] @schema @@ -366,13 +402,18 @@ class ClusteringParamSet(dj.Lookup): """ @classmethod - def insert_new_params(cls, processing_method: str, paramset_idx: int, - paramset_desc: str, params: dict): - param_dict = {'clustering_method': processing_method, + def insert_new_params(cls, clustering_method: str, paramset_desc: str, + params: dict, paramset_idx: int = None): + if paramset_idx is None: + paramset_idx = (dj.U().aggr(cls, n='max(paramset_idx)').fetch1('n') or 0) + 1 + + param_dict = {'clustering_method': clustering_method, 'paramset_idx': paramset_idx, 'paramset_desc': paramset_desc, 'params': params, - 'param_set_hash': dict_to_uuid(params)} + 'param_set_hash': dict_to_uuid( + {**params, 'clustering_method': clustering_method}) + } param_query = cls & {'param_set_hash': param_dict['param_set_hash']} if param_query: # If the specified param-set already exists @@ -381,9 +422,13 @@ def insert_new_params(cls, processing_method: str, paramset_idx: int, return else: # If not same name: human error, trying to add the same paramset with different name raise dj.DataJointError( - 'The specified param-set' - ' already exists - paramset_idx: {}'.format(existing_paramset_idx)) + f'The specified param-set already exists' + f' - with paramset_idx: {existing_paramset_idx}') else: + if {'paramset_idx': paramset_idx} in cls.proj(): + raise dj.DataJointError( + f'The specified paramset_idx {paramset_idx} already exists,' + f' please pick a different one.') cls.insert1(param_dict) @@ -391,7 +436,7 @@ def insert_new_params(cls, processing_method: str, paramset_idx: int, class ClusterQualityLabel(dj.Lookup): definition = """ # Quality - cluster_quality_label: varchar(100) + cluster_quality_label: varchar(100) # cluster quality type - e.g. 'good', 'MUA', 'noise', etc. --- cluster_quality_description: varchar(4000) """ @@ -410,10 +455,62 @@ class ClusteringTask(dj.Manual): -> EphysRecording -> ClusteringParamSet --- - clustering_output_dir: varchar(255) # clustering output directory relative to the clustering root data directory + clustering_output_dir='': varchar(255) # clustering output directory relative to the clustering root data directory task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation """ + @classmethod + def infer_output_dir(cls, key, relative=False, mkdir=False): + """ + Given a 'key' to an entry in this table + Return the expected clustering_output_dir based on the following convention: + processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx} + e.g.: sub4/sess1/probe_2/kilosort2_0 + """ + processed_dir = pathlib.Path(get_processed_root_data_dir()) + sess_dir = find_full_path(get_ephys_root_data_dir(), + get_session_directory(key)) + root_dir = find_root_directory(get_ephys_root_data_dir(), sess_dir) + + method = (ClusteringParamSet * ClusteringMethod & key).fetch1( + 'clustering_method').replace(".", "-") + + output_dir = (processed_dir + / sess_dir.relative_to(root_dir) + / f'probe_{key["insertion_number"]}' + / f'{method}_{key["paramset_idx"]}') + + if mkdir: + output_dir.mkdir(parents=True, exist_ok=True) + log.info(f'{output_dir} created!') + + return output_dir.relative_to(processed_dir) if relative else output_dir + + @classmethod + def auto_generate_entries(cls, ephys_recording_key, paramset_idx=0): + """ + Method to auto-generate ClusteringTask entries for a particular ephys recording + Output directory is auto-generated based on the convention + defined in `ClusteringTask.infer_output_dir()` + Default parameter set used: paramset_idx = 0 + """ + key = {**ephys_recording_key, 'paramset_idx': paramset_idx} + + processed_dir = get_processed_root_data_dir() + output_dir = ClusteringTask.infer_output_dir(key, relative=False, mkdir=True) + + try: + kilosort.Kilosort(output_dir) # check if the directory is a valid Kilosort output + except FileNotFoundError: + task_mode = 'trigger' + else: + task_mode = 'load' + + cls.insert1({ + **key, + 'clustering_output_dir': output_dir.relative_to(processed_dir).as_posix(), + 'task_mode': task_mode}) + @schema class Clustering(dj.Imported): @@ -434,17 +531,85 @@ class Clustering(dj.Imported): def make(self, key): task_mode, output_dir = (ClusteringTask & key).fetch1( 'task_mode', 'clustering_output_dir') + + if not output_dir: + output_dir = ClusteringTask.infer_output_dir(key, relative=True, mkdir=True) + # update clustering_output_dir + ClusteringTask.update1({**key, 'clustering_output_dir': output_dir.as_posix()}) + kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) if task_mode == 'load': - kilosort_dataset = kilosort.Kilosort(kilosort_dir) # check if the directory is a valid Kilosort output - creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir) + kilosort.Kilosort(kilosort_dir) # check if the directory is a valid Kilosort output elif task_mode == 'trigger': - raise NotImplementedError('Automatic triggering of' - ' clustering analysis is not yet supported') + acq_software, clustering_method, params = (ClusteringTask * EphysRecording + * ClusteringParamSet & key).fetch1( + 'acq_software', 'clustering_method', 'params') + + if 'kilosort' in clustering_method: + from element_array_ephys.readers import kilosort_triggering + + # add additional probe-recording and channels details into `params` + params = {**params, **get_recording_channels_details(key)} + params['fs'] = params['sample_rate'] + + if acq_software == 'SpikeGLX': + spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + spikeglx_recording.validate_file('ap') + + if clustering_method.startswith('pykilosort'): + kilosort_triggering.run_pykilosort( + continuous_file=spikeglx_recording.root_dir / ( + spikeglx_recording.root_name + '.ap.bin'), + kilosort_output_directory=kilosort_dir, + channel_ind=params.pop('channel_ind'), + x_coords=params.pop('x_coords'), + y_coords=params.pop('y_coords'), + shank_ind=params.pop('shank_ind'), + connected=params.pop('connected'), + sample_rate=params.pop('sample_rate'), + params=params) + else: + run_kilosort = kilosort_triggering.SGLXKilosortPipeline( + npx_input_dir=spikeglx_meta_filepath.parent, + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + run_CatGT=True) + run_kilosort.run_modules() + elif acq_software == 'Open Ephys': + oe_probe = get_openephys_probe_data(key) + + assert len(oe_probe.recording_info['recording_files']) == 1 + + # run kilosort + if clustering_method.startswith('pykilosort'): + kilosort_triggering.run_pykilosort( + continuous_file=pathlib.Path(oe_probe.recording_info['recording_files'][0]) / 'continuous.dat', + kilosort_output_directory=kilosort_dir, + channel_ind=params.pop('channel_ind'), + x_coords=params.pop('x_coords'), + y_coords=params.pop('y_coords'), + shank_ind=params.pop('shank_ind'), + connected=params.pop('connected'), + sample_rate=params.pop('sample_rate'), + params=params) + else: + run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( + npx_input_dir=oe_probe.recording_info['recording_files'][0], + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}') + run_kilosort.run_modules() + else: + raise NotImplementedError(f'Automatic triggering of {clustering_method}' + f' clustering analysis is not yet supported') + else: raise ValueError(f'Unknown task mode: {task_mode}') + creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir) self.insert1({**key, 'clustering_time': creation_time}) @@ -510,7 +675,10 @@ def make(self, key): kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) kilosort_dataset = kilosort.Kilosort(kilosort_dir) - acq_software = (EphysRecording & key).fetch1('acq_software') + acq_software, sample_rate = (EphysRecording & key).fetch1( + 'acq_software', 'sampling_rate') + + sample_rate = kilosort_dataset.data['params'].get('sample_rate', sample_rate) # ---------- Unit ---------- # -- Remove 0-spike units @@ -540,7 +708,7 @@ def make(self, key): if (kilosort_dataset.data['spike_clusters'] == unit).any(): unit_channel, _ = kilosort_dataset.get_best_channel(unit) unit_spike_times = (spike_times[kilosort_dataset.data['spike_clusters'] == unit] - / kilosort_dataset.data['params']['sample_rate']) + / sample_rate) spike_count = len(unit_spike_times) units.append({ @@ -657,8 +825,10 @@ def yield_unit_waveforms(): # insert waveform on a per-unit basis to mitigate potential memory issue self.insert1(key) for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms(): - self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True) - self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True) + if unit_peak_waveform: + self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True) + if unit_electrode_waveforms: + self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True) # ---------------- HELPER FUNCTIONS ---------------- @@ -694,6 +864,22 @@ def get_spikeglx_meta_filepath(ephys_recording_key): return spikeglx_meta_filepath +def get_openephys_probe_data(ephys_recording_key): + inserted_probe_serial_number = (ProbeInsertion * probe.Probe + & ephys_recording_key).fetch1('probe') + session_dir = find_full_path(get_ephys_root_data_dir(), + get_session_directory(ephys_recording_key)) + loaded_oe = openephys.OpenEphys(session_dir) + probe_data = loaded_oe.probes[inserted_probe_serial_number] + + # explicitly garbage collect "loaded_oe" + # as these may have large memory footprint and may not be cleared fast enough + del loaded_oe + gc.collect() + + return probe_data + + def get_neuropixels_channel2electrode_map(ephys_recording_key, acq_software): if acq_software == 'SpikeGLX': spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) @@ -714,11 +900,7 @@ def get_neuropixels_channel2electrode_map(ephys_recording_key, acq_software): for recorded_site, (shank, shank_col, shank_row, _) in enumerate( spikeglx_meta.shankmap['data'])} elif acq_software == 'Open Ephys': - session_dir = find_full_path(get_ephys_root_data_dir(), - get_session_directory(ephys_recording_key)) - openephys_dataset = openephys.OpenEphys(session_dir) - probe_serial_number = (ProbeInsertion & ephys_recording_key).fetch1('probe') - probe_dataset = openephys_dataset.probes[probe_serial_number] + probe_dataset = get_openephys_probe_data(ephys_recording_key) electrode_query = (probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode @@ -729,7 +911,7 @@ def get_neuropixels_channel2electrode_map(ephys_recording_key, acq_software): channel2electrode_map = { channel_idx: probe_electrodes[channel_idx] - for channel_idx in probe_dataset.ap_meta['channels_ids']} + for channel_idx in probe_dataset.ap_meta['channels_indices']} return channel2electrode_map @@ -762,3 +944,40 @@ def generate_electrode_config(probe_type: str, electrodes: list): for electrode in electrodes) return electrode_config_key + + +def get_recording_channels_details(ephys_recording_key): + channels_details = {} + + acq_software, sample_rate = (EphysRecording & ephys_recording_key).fetch1('acq_software', + 'sampling_rate') + + probe_type = (ProbeInsertion * probe.Probe & ephys_recording_key).fetch1('probe_type') + channels_details['probe_type'] = {'neuropixels 1.0 - 3A': '3A', + 'neuropixels 1.0 - 3B': 'NP1', + 'neuropixels UHD': 'NP1100', + 'neuropixels 2.0 - SS': 'NP21', + 'neuropixels 2.0 - MS': 'NP24'}[probe_type] + + electrode_config_key = (probe.ElectrodeConfig * EphysRecording & ephys_recording_key).fetch1('KEY') + channels_details['channel_ind'], channels_details['x_coords'], channels_details[ + 'y_coords'], channels_details['shank_ind'] = ( + probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode + & electrode_config_key).fetch('electrode', 'x_coord', 'y_coord', 'shank') + channels_details['sample_rate'] = sample_rate + channels_details['num_channels'] = len(channels_details['channel_ind']) + + if acq_software == 'SpikeGLX': + spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + channels_details['uVPerBit'] = spikeglx_recording.get_channel_bit_volts('ap')[0] + channels_details['connected'] = np.array( + [v for *_, v in spikeglx_recording.apmeta.shankmap['data']]) + elif acq_software == 'Open Ephys': + oe_probe = get_openephys_probe_data(ephys_recording_key) + channels_details['uVPerBit'] = oe_probe.ap_meta['channels_gains'][0] + channels_details['connected'] = np.array([ + int(v == 1) for c, v in oe_probe.channels_connected.items() + if c in channels_details['channel_ind']]) + + return channels_details diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py new file mode 100644 index 00000000..f0988917 --- /dev/null +++ b/element_array_ephys/ephys_no_curation.py @@ -0,0 +1,995 @@ +import datajoint as dj +import pathlib +import re +import numpy as np +import inspect +import importlib +import gc +from decimal import Decimal + +from element_interface.utils import find_root_directory, find_full_path, dict_to_uuid + +from .readers import spikeglx, kilosort, openephys +from . import probe, get_logger + + +log = get_logger(__name__) + +schema = dj.schema() + +_linking_module = None + + +def activate(ephys_schema_name, probe_schema_name=None, *, create_schema=True, + create_tables=True, linking_module=None): + """ + activate(ephys_schema_name, probe_schema_name=None, *, create_schema=True, create_tables=True, linking_module=None) + :param ephys_schema_name: schema name on the database server to activate the `ephys` element + :param probe_schema_name: schema name on the database server to activate the `probe` element + - may be omitted if the `probe` element is already activated + :param create_schema: when True (default), create schema in the database if it does not yet exist. + :param create_tables: when True (default), create tables in the database if they do not yet exist. + :param linking_module: a module name or a module containing the + required dependencies to activate the `ephys` element: + Upstream tables: + + Session: table referenced by EphysRecording, typically identifying a recording session + + SkullReference: Reference table for InsertionLocation, specifying the skull reference + used for probe insertion location (e.g. Bregma, Lambda) + Functions: + + get_ephys_root_data_dir() -> list + Retrieve the root data directory - e.g. containing the raw ephys recording files for all subject/sessions. + :return: a string for full path to the root data directory + + get_session_directory(session_key: dict) -> str + Retrieve the session directory containing the recorded Neuropixels data for a given Session + :param session_key: a dictionary of one Session `key` + :return: a string for full path to the session directory + + get_processed_root_data_dir() -> str: + Retrieves the root directory for all processed data to be found from or written to + :return: a string for full path to the root directory for processed data + """ + + if isinstance(linking_module, str): + linking_module = importlib.import_module(linking_module) + assert inspect.ismodule(linking_module),\ + "The argument 'dependency' must be a module's name or a module" + + global _linking_module + _linking_module = linking_module + + # activate + probe.activate(probe_schema_name, create_schema=create_schema, + create_tables=create_tables) + schema.activate(ephys_schema_name, create_schema=create_schema, + create_tables=create_tables, add_objects=_linking_module.__dict__) + + +# -------------- Functions required by the elements-ephys --------------- + +def get_ephys_root_data_dir() -> list: + """ + All data paths, directories in DataJoint Elements are recommended to be stored as + relative paths, with respect to some user-configured "root" directory, + which varies from machine to machine (e.g. different mounted drive locations) + + get_ephys_root_data_dir() -> list + This user-provided function retrieves the possible root data directories + containing the ephys data for all subjects/sessions + (e.g. acquired SpikeGLX or Open Ephys raw files, + output files from spike sorting routines, etc.) + :return: a string for full path to the ephys root data directory, + or list of strings for possible root data directories + """ + root_directories = _linking_module.get_ephys_root_data_dir() + if isinstance(root_directories, (str, pathlib.Path)): + root_directories = [root_directories] + + if hasattr(_linking_module, 'get_processed_root_data_dir'): + root_directories.append(_linking_module.get_processed_root_data_dir()) + + return root_directories + + +def get_session_directory(session_key: dict) -> str: + """ + get_session_directory(session_key: dict) -> str + Retrieve the session directory containing the + recorded Neuropixels data for a given Session + :param session_key: a dictionary of one Session `key` + :return: a string for full path to the session directory + """ + return _linking_module.get_session_directory(session_key) + + +def get_processed_root_data_dir() -> str: + """ + get_processed_root_data_dir() -> str: + Retrieves the root directory for all processed data to be found from or written to + :return: a string for full path to the root directory for processed data + """ + + if hasattr(_linking_module, 'get_processed_root_data_dir'): + return _linking_module.get_processed_root_data_dir() + else: + return get_ephys_root_data_dir()[0] + +# ----------------------------- Table declarations ---------------------- + + +@schema +class AcquisitionSoftware(dj.Lookup): + definition = """ # Name of software used for recording of neuropixels probes - SpikeGLX or Open Ephys + acq_software: varchar(24) + """ + contents = zip(['SpikeGLX', 'Open Ephys']) + + +@schema +class ProbeInsertion(dj.Manual): + definition = """ + # Probe insertion implanted into an animal for a given session. + -> Session + insertion_number: tinyint unsigned + --- + -> probe.Probe + """ + + @classmethod + def auto_generate_entries(cls, session_key): + """ + Method to auto-generate ProbeInsertion entries for a particular session + Probe information is inferred from the meta data found in the session data directory + """ + session_dir = find_full_path(get_ephys_root_data_dir(), + get_session_directory(session_key)) + # search session dir and determine acquisition software + for ephys_pattern, ephys_acq_type in zip(['*.ap.meta', '*.oebin'], + ['SpikeGLX', 'Open Ephys']): + ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern)) + if ephys_meta_filepaths: + acq_software = ephys_acq_type + break + else: + raise FileNotFoundError( + f'Ephys recording data not found!' + f' Neither SpikeGLX nor Open Ephys recording files found in: {session_dir}') + + probe_list, probe_insertion_list = [], [] + if acq_software == 'SpikeGLX': + for meta_fp_idx, meta_filepath in enumerate(ephys_meta_filepaths): + spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) + + probe_key = {'probe_type': spikeglx_meta.probe_model, + 'probe': spikeglx_meta.probe_SN} + if (probe_key['probe'] not in [p['probe'] for p in probe_list] + and probe_key not in probe.Probe()): + probe_list.append(probe_key) + + probe_dir = meta_filepath.parent + try: + probe_number = re.search('(imec)?\d{1}$', probe_dir.name).group() + probe_number = int(probe_number.replace('imec', '')) + except AttributeError: + probe_number = meta_fp_idx + + probe_insertion_list.append({**session_key, + 'probe': spikeglx_meta.probe_SN, + 'insertion_number': int(probe_number)}) + elif acq_software == 'Open Ephys': + loaded_oe = openephys.OpenEphys(session_dir) + for probe_idx, oe_probe in enumerate(loaded_oe.probes.values()): + probe_key = {'probe_type': oe_probe.probe_model, 'probe': oe_probe.probe_SN} + if (probe_key['probe'] not in [p['probe'] for p in probe_list] + and probe_key not in probe.Probe()): + probe_list.append(probe_key) + probe_insertion_list.append({**session_key, + 'probe': oe_probe.probe_SN, + 'insertion_number': probe_idx}) + else: + raise NotImplementedError(f'Unknown acquisition software: {acq_software}') + + probe.Probe.insert(probe_list) + cls.insert(probe_insertion_list, skip_duplicates=True) + + +@schema +class InsertionLocation(dj.Manual): + definition = """ + # Brain Location of a given probe insertion. + -> ProbeInsertion + --- + -> SkullReference + ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive + ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive + depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative + theta=null: decimal(5, 2) # (deg) - elevation - rotation about the ml-axis [0, 180] - w.r.t the z+ axis + phi=null: decimal(5, 2) # (deg) - azimuth - rotation about the dv-axis [0, 360] - w.r.t the x+ axis + beta=null: decimal(5, 2) # (deg) rotation about the shank of the probe [-180, 180] - clockwise is increasing in degree - 0 is the probe-front facing anterior + """ + + +@schema +class EphysRecording(dj.Imported): + definition = """ + # Ephys recording from a probe insertion for a given session. + -> ProbeInsertion + --- + -> probe.ElectrodeConfig + -> AcquisitionSoftware + sampling_rate: float # (Hz) + recording_datetime: datetime # datetime of the recording from this probe + recording_duration: float # (seconds) duration of the recording from this probe + """ + + class EphysFile(dj.Part): + definition = """ + # Paths of files of a given EphysRecording round. + -> master + file_path: varchar(255) # filepath relative to root data directory + """ + + def make(self, key): + session_dir = find_full_path(get_ephys_root_data_dir(), + get_session_directory(key)) + inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1('probe') + + # search session dir and determine acquisition software + for ephys_pattern, ephys_acq_type in zip(['*.ap.meta', '*.oebin'], + ['SpikeGLX', 'Open Ephys']): + ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern)) + if ephys_meta_filepaths: + acq_software = ephys_acq_type + break + else: + raise FileNotFoundError( + f'Ephys recording data not found!' + f' Neither SpikeGLX nor Open Ephys recording files found') + + supported_probe_types = probe.ProbeType.fetch('probe_type') + + if acq_software == 'SpikeGLX': + for meta_filepath in ephys_meta_filepaths: + spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) + if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: + break + else: + raise FileNotFoundError( + 'No SpikeGLX data found for probe insertion: {}'.format(key)) + + if spikeglx_meta.probe_model in supported_probe_types: + probe_type = spikeglx_meta.probe_model + electrode_query = probe.ProbeType.Electrode & {'probe_type': probe_type} + + probe_electrodes = { + (shank, shank_col, shank_row): key + for key, shank, shank_col, shank_row in zip(*electrode_query.fetch( + 'KEY', 'shank', 'shank_col', 'shank_row'))} + + electrode_group_members = [ + probe_electrodes[(shank, shank_col, shank_row)] + for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap['data']] + else: + raise NotImplementedError( + 'Processing for neuropixels probe model' + ' {} not yet implemented'.format(spikeglx_meta.probe_model)) + + self.insert1({ + **key, + **generate_electrode_config(probe_type, electrode_group_members), + 'acq_software': acq_software, + 'sampling_rate': spikeglx_meta.meta['imSampRate'], + 'recording_datetime': spikeglx_meta.recording_time, + 'recording_duration': (spikeglx_meta.recording_duration + or spikeglx.retrieve_recording_duration(meta_filepath))}) + + root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) + self.EphysFile.insert1({ + **key, + 'file_path': meta_filepath.relative_to(root_dir).as_posix()}) + elif acq_software == 'Open Ephys': + dataset = openephys.OpenEphys(session_dir) + for serial_number, probe_data in dataset.probes.items(): + if str(serial_number) == inserted_probe_serial_number: + break + else: + raise FileNotFoundError( + 'No Open Ephys data found for probe insertion: {}'.format(key)) + + if not probe_data.ap_meta: + raise IOError('No analog signals found - check "structure.oebin" file or "continuous" directory') + + if probe_data.probe_model in supported_probe_types: + probe_type = probe_data.probe_model + electrode_query = probe.ProbeType.Electrode & {'probe_type': probe_type} + + probe_electrodes = {key['electrode']: key + for key in electrode_query.fetch('KEY')} + + electrode_group_members = [ + probe_electrodes[channel_idx] + for channel_idx in probe_data.ap_meta['channels_indices']] + else: + raise NotImplementedError( + 'Processing for neuropixels' + ' probe model {} not yet implemented'.format(probe_data.probe_model)) + + self.insert1({ + **key, + **generate_electrode_config(probe_type, electrode_group_members), + 'acq_software': acq_software, + 'sampling_rate': probe_data.ap_meta['sample_rate'], + 'recording_datetime': probe_data.recording_info['recording_datetimes'][0], + 'recording_duration': np.sum(probe_data.recording_info['recording_durations'])}) + + root_dir = find_root_directory( + get_ephys_root_data_dir(), + probe_data.recording_info['recording_files'][0]) + self.EphysFile.insert([{**key, + 'file_path': fp.relative_to(root_dir).as_posix()} + for fp in probe_data.recording_info['recording_files']]) + # explicitly garbage collect "dataset" + # as these may have large memory footprint and may not be cleared fast enough + del probe_data, dataset + gc.collect() + else: + raise NotImplementedError(f'Processing ephys files from' + f' acquisition software of type {acq_software} is' + f' not yet implemented') + + +@schema +class LFP(dj.Imported): + definition = """ + # Acquired local field potential (LFP) from a given Ephys recording. + -> EphysRecording + --- + lfp_sampling_rate: float # (Hz) + lfp_time_stamps: longblob # (s) timestamps with respect to the start of the recording (recording_timestamp) + lfp_mean: longblob # (uV) mean of LFP across electrodes - shape (time,) + """ + + class Electrode(dj.Part): + definition = """ + -> master + -> probe.ElectrodeConfig.Electrode + --- + lfp: longblob # (uV) recorded lfp at this electrode + """ + + # Only store LFP for every 9th channel, due to high channel density, + # close-by channels exhibit highly similar LFP + _skip_channel_counts = 9 + + def make(self, key): + acq_software = (EphysRecording * ProbeInsertion & key).fetch1('acq_software') + + electrode_keys, lfp = [], [] + + if acq_software == 'SpikeGLX': + spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + + lfp_channel_ind = spikeglx_recording.lfmeta.recording_channels[ + -1::-self._skip_channel_counts] + + # Extract LFP data at specified channels and convert to uV + lfp = spikeglx_recording.lf_timeseries[:, lfp_channel_ind] # (sample x channel) + lfp = (lfp * spikeglx_recording.get_channel_bit_volts('lf')[lfp_channel_ind]).T # (channel x sample) + + self.insert1(dict(key, + lfp_sampling_rate=spikeglx_recording.lfmeta.meta['imSampRate'], + lfp_time_stamps=(np.arange(lfp.shape[1]) + / spikeglx_recording.lfmeta.meta['imSampRate']), + lfp_mean=lfp.mean(axis=0))) + + electrode_query = (probe.ProbeType.Electrode + * probe.ElectrodeConfig.Electrode + * EphysRecording & key) + probe_electrodes = { + (shank, shank_col, shank_row): key + for key, shank, shank_col, shank_row in zip(*electrode_query.fetch( + 'KEY', 'shank', 'shank_col', 'shank_row'))} + + for recorded_site in lfp_channel_ind: + shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap['data'][recorded_site] + electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)]) + elif acq_software == 'Open Ephys': + oe_probe = get_openephys_probe_data(key) + + lfp_channel_ind = np.r_[ + len(oe_probe.lfp_meta['channels_indices'])-1:0:-self._skip_channel_counts] + + lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind] # (sample x channel) + lfp = (lfp * np.array(oe_probe.lfp_meta['channels_gains'])[lfp_channel_ind]).T # (channel x sample) + lfp_timestamps = oe_probe.lfp_timestamps + + self.insert1(dict(key, + lfp_sampling_rate=oe_probe.lfp_meta['sample_rate'], + lfp_time_stamps=lfp_timestamps, + lfp_mean=lfp.mean(axis=0))) + + electrode_query = (probe.ProbeType.Electrode + * probe.ElectrodeConfig.Electrode + * EphysRecording & key) + probe_electrodes = {key['electrode']: key + for key in electrode_query.fetch('KEY')} + + electrode_keys.extend(probe_electrodes[channel_idx] + for channel_idx in lfp_channel_ind) + else: + raise NotImplementedError(f'LFP extraction from acquisition software' + f' of type {acq_software} is not yet implemented') + + # single insert in loop to mitigate potential memory issue + for electrode_key, lfp_trace in zip(electrode_keys, lfp): + self.Electrode.insert1({**key, **electrode_key, 'lfp': lfp_trace}) + + +# ------------ Clustering -------------- + +@schema +class ClusteringMethod(dj.Lookup): + definition = """ + # Method for clustering + clustering_method: varchar(16) + --- + clustering_method_desc: varchar(1000) + """ + + contents = [('kilosort2', 'kilosort2 clustering method'), + ('kilosort2.5', 'kilosort2.5 clustering method'), + ('kilosort3', 'kilosort3 clustering method')] + + +@schema +class ClusteringParamSet(dj.Lookup): + definition = """ + # Parameter set to be used in a clustering procedure + paramset_idx: smallint + --- + -> ClusteringMethod + paramset_desc: varchar(128) + param_set_hash: uuid + unique index (param_set_hash) + params: longblob # dictionary of all applicable parameters + """ + + @classmethod + def insert_new_params(cls, clustering_method: str, paramset_desc: str, + params: dict, paramset_idx: int = None): + if paramset_idx is None: + paramset_idx = (dj.U().aggr(cls, n='max(paramset_idx)').fetch1('n') or 0) + 1 + + param_dict = {'clustering_method': clustering_method, + 'paramset_idx': paramset_idx, + 'paramset_desc': paramset_desc, + 'params': params, + 'param_set_hash': dict_to_uuid( + {**params, 'clustering_method': clustering_method}) + } + param_query = cls & {'param_set_hash': param_dict['param_set_hash']} + + if param_query: # If the specified param-set already exists + existing_paramset_idx = param_query.fetch1('paramset_idx') + if existing_paramset_idx == paramset_idx: # If the existing set has the same paramset_idx: job done + return + else: # If not same name: human error, trying to add the same paramset with different name + raise dj.DataJointError( + f'The specified param-set already exists' + f' - with paramset_idx: {existing_paramset_idx}') + else: + if {'paramset_idx': paramset_idx} in cls.proj(): + raise dj.DataJointError( + f'The specified paramset_idx {paramset_idx} already exists,' + f' please pick a different one.') + cls.insert1(param_dict) + + +@schema +class ClusterQualityLabel(dj.Lookup): + definition = """ + # Quality + cluster_quality_label: varchar(100) # cluster quality type - e.g. 'good', 'MUA', 'noise', etc. + --- + cluster_quality_description: varchar(4000) + """ + contents = [ + ('good', 'single unit'), + ('ok', 'probably a single unit, but could be contaminated'), + ('mua', 'multi-unit activity'), + ('noise', 'bad unit') + ] + + +@schema +class ClusteringTask(dj.Manual): + definition = """ + # Manual table for defining a clustering task ready to be run + -> EphysRecording + -> ClusteringParamSet + --- + clustering_output_dir='': varchar(255) # clustering output directory relative to the clustering root data directory + task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation + """ + + @classmethod + def infer_output_dir(cls, key, relative=False, mkdir=False): + """ + Given a 'key' to an entry in this table + Return the expected clustering_output_dir based on the following convention: + processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx} + e.g.: sub4/sess1/probe_2/kilosort2_0 + """ + processed_dir = pathlib.Path(get_processed_root_data_dir()) + session_dir = find_full_path(get_ephys_root_data_dir(), + get_session_directory(key)) + root_dir = find_root_directory(get_ephys_root_data_dir(), session_dir) + + method = (ClusteringParamSet * ClusteringMethod & key).fetch1( + 'clustering_method').replace(".", "-") + + output_dir = (processed_dir + / session_dir.relative_to(root_dir) + / f'probe_{key["insertion_number"]}' + / f'{method}_{key["paramset_idx"]}') + + if mkdir: + output_dir.mkdir(parents=True, exist_ok=True) + log.info(f'{output_dir} created!') + + return output_dir.relative_to(processed_dir) if relative else output_dir + + @classmethod + def auto_generate_entries(cls, ephys_recording_key, paramset_idx=0): + """ + Method to auto-generate ClusteringTask entries for a particular ephys recording + Output directory is auto-generated based on the convention + defined in `ClusteringTask.infer_output_dir()` + Default parameter set used: paramset_idx = 0 + """ + key = {**ephys_recording_key, 'paramset_idx': paramset_idx} + + processed_dir = get_processed_root_data_dir() + output_dir = ClusteringTask.infer_output_dir(key, relative=False, mkdir=True) + + try: + kilosort.Kilosort(output_dir) # check if the directory is a valid Kilosort output + except FileNotFoundError: + task_mode = 'trigger' + else: + task_mode = 'load' + + cls.insert1({ + **key, + 'clustering_output_dir': output_dir.relative_to(processed_dir).as_posix(), + 'task_mode': task_mode}) + + +@schema +class Clustering(dj.Imported): + """ + A processing table to handle each ClusteringTask: + + If `task_mode == "trigger"`: trigger clustering analysis + according to the ClusteringParamSet (e.g. launch a kilosort job) + + If `task_mode == "load"`: verify output + """ + definition = """ + # Clustering Procedure + -> ClusteringTask + --- + clustering_time: datetime # time of generation of this set of clustering results + package_version='': varchar(16) + """ + + def make(self, key): + task_mode, output_dir = (ClusteringTask & key).fetch1( + 'task_mode', 'clustering_output_dir') + + if not output_dir: + output_dir = ClusteringTask.infer_output_dir(key, relative=True, mkdir=True) + # update clustering_output_dir + ClusteringTask.update1({**key, 'clustering_output_dir': output_dir.as_posix()}) + + kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + + if task_mode == 'load': + kilosort.Kilosort(kilosort_dir) # check if the directory is a valid Kilosort output + elif task_mode == 'trigger': + acq_software, clustering_method, params = (ClusteringTask * EphysRecording + * ClusteringParamSet & key).fetch1( + 'acq_software', 'clustering_method', 'params') + + if 'kilosort' in clustering_method: + from element_array_ephys.readers import kilosort_triggering + + # add additional probe-recording and channels details into `params` + params = {**params, **get_recording_channels_details(key)} + params['fs'] = params['sample_rate'] + + if acq_software == 'SpikeGLX': + spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + spikeglx_recording.validate_file('ap') + + if clustering_method.startswith('pykilosort'): + kilosort_triggering.run_pykilosort( + continuous_file=spikeglx_recording.root_dir / ( + spikeglx_recording.root_name + '.ap.bin'), + kilosort_output_directory=kilosort_dir, + channel_ind=params.pop('channel_ind'), + x_coords=params.pop('x_coords'), + y_coords=params.pop('y_coords'), + shank_ind=params.pop('shank_ind'), + connected=params.pop('connected'), + sample_rate=params.pop('sample_rate'), + params=params) + else: + run_kilosort = kilosort_triggering.SGLXKilosortPipeline( + npx_input_dir=spikeglx_meta_filepath.parent, + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + run_CatGT=True) + run_kilosort.run_modules() + elif acq_software == 'Open Ephys': + oe_probe = get_openephys_probe_data(key) + + assert len(oe_probe.recording_info['recording_files']) == 1 + + # run kilosort + if clustering_method.startswith('pykilosort'): + kilosort_triggering.run_pykilosort( + continuous_file=pathlib.Path(oe_probe.recording_info['recording_files'][0]) / 'continuous.dat', + kilosort_output_directory=kilosort_dir, + channel_ind=params.pop('channel_ind'), + x_coords=params.pop('x_coords'), + y_coords=params.pop('y_coords'), + shank_ind=params.pop('shank_ind'), + connected=params.pop('connected'), + sample_rate=params.pop('sample_rate'), + params=params) + else: + run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( + npx_input_dir=oe_probe.recording_info['recording_files'][0], + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}') + run_kilosort.run_modules() + else: + raise NotImplementedError(f'Automatic triggering of {clustering_method}' + f' clustering analysis is not yet supported') + + else: + raise ValueError(f'Unknown task mode: {task_mode}') + + creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir) + self.insert1({**key, 'clustering_time': creation_time}) + + +@schema +class CuratedClustering(dj.Imported): + definition = """ + # Clustering results of the spike sorting step. + -> Clustering + """ + + class Unit(dj.Part): + definition = """ + # Properties of a given unit from a round of clustering (and curation) + -> master + unit: int + --- + -> probe.ElectrodeConfig.Electrode # electrode with highest waveform amplitude for this unit + -> ClusterQualityLabel + spike_count: int # how many spikes in this recording for this unit + spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording + spike_sites : longblob # array of electrode associated with each spike + spike_depths : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe + """ + + def make(self, key): + output_dir = (ClusteringTask & key).fetch1('clustering_output_dir') + kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + + kilosort_dataset = kilosort.Kilosort(kilosort_dir) + acq_software, sample_rate = (EphysRecording & key).fetch1( + 'acq_software', 'sampling_rate') + + sample_rate = kilosort_dataset.data['params'].get('sample_rate', sample_rate) + + # ---------- Unit ---------- + # -- Remove 0-spike units + withspike_idx = [i for i, u in enumerate(kilosort_dataset.data['cluster_ids']) + if (kilosort_dataset.data['spike_clusters'] == u).any()] + valid_units = kilosort_dataset.data['cluster_ids'][withspike_idx] + valid_unit_labels = kilosort_dataset.data['cluster_groups'][withspike_idx] + # -- Get channel and electrode-site mapping + channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software) + + # -- Spike-times -- + # spike_times_sec_adj > spike_times_sec > spike_times + spike_time_key = ('spike_times_sec_adj' if 'spike_times_sec_adj' in kilosort_dataset.data + else 'spike_times_sec' if 'spike_times_sec' + in kilosort_dataset.data else 'spike_times') + spike_times = kilosort_dataset.data[spike_time_key] + kilosort_dataset.extract_spike_depths() + + # -- Spike-sites and Spike-depths -- + spike_sites = np.array([channel2electrodes[s]['electrode'] + for s in kilosort_dataset.data['spike_sites']]) + spike_depths = kilosort_dataset.data['spike_depths'] + + # -- Insert unit, label, peak-chn + units = [] + for unit, unit_lbl in zip(valid_units, valid_unit_labels): + if (kilosort_dataset.data['spike_clusters'] == unit).any(): + unit_channel, _ = kilosort_dataset.get_best_channel(unit) + unit_spike_times = (spike_times[kilosort_dataset.data['spike_clusters'] == unit] + / sample_rate) + spike_count = len(unit_spike_times) + + units.append({ + 'unit': unit, + 'cluster_quality_label': unit_lbl, + **channel2electrodes[unit_channel], + 'spike_times': unit_spike_times, + 'spike_count': spike_count, + 'spike_sites': spike_sites[kilosort_dataset.data['spike_clusters'] == unit], + 'spike_depths': spike_depths[kilosort_dataset.data['spike_clusters'] == unit]}) + + self.insert1(key) + self.Unit.insert([{**key, **u} for u in units]) + + +@schema +class WaveformSet(dj.Imported): + definition = """ + # A set of spike waveforms for units out of a given CuratedClustering + -> CuratedClustering + """ + + class PeakWaveform(dj.Part): + definition = """ + # Mean waveform across spikes for a given unit at its representative electrode + -> master + -> CuratedClustering.Unit + --- + peak_electrode_waveform: longblob # (uV) mean waveform for a given unit at its representative electrode + """ + + class Waveform(dj.Part): + definition = """ + # Spike waveforms and their mean across spikes for the given unit + -> master + -> CuratedClustering.Unit + -> probe.ElectrodeConfig.Electrode + --- + waveform_mean: longblob # (uV) mean waveform across spikes of the given unit + waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit + """ + + def make(self, key): + output_dir = (ClusteringTask & key).fetch1('clustering_output_dir') + kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + + kilosort_dataset = kilosort.Kilosort(kilosort_dir) + + acq_software, probe_serial_number = (EphysRecording * ProbeInsertion & key).fetch1( + 'acq_software', 'probe') + + # -- Get channel and electrode-site mapping + recording_key = (EphysRecording & key).fetch1('KEY') + channel2electrodes = get_neuropixels_channel2electrode_map(recording_key, acq_software) + + # Get all units + units = {u['unit']: u for u in (CuratedClustering.Unit & key).fetch( + as_dict=True, order_by='unit')} + + if (kilosort_dir / 'mean_waveforms.npy').exists(): + unit_waveforms = np.load(kilosort_dir / 'mean_waveforms.npy') # unit x channel x sample + + def yield_unit_waveforms(): + for unit_no, unit_waveform in zip(kilosort_dataset.data['cluster_ids'], + unit_waveforms): + unit_peak_waveform = {} + unit_electrode_waveforms = [] + if unit_no in units: + for channel, channel_waveform in zip( + kilosort_dataset.data['channel_map'], + unit_waveform): + unit_electrode_waveforms.append({ + **units[unit_no], **channel2electrodes[channel], + 'waveform_mean': channel_waveform}) + if channel2electrodes[channel]['electrode'] == units[unit_no]['electrode']: + unit_peak_waveform = { + **units[unit_no], + 'peak_electrode_waveform': channel_waveform} + yield unit_peak_waveform, unit_electrode_waveforms + else: + if acq_software == 'SpikeGLX': + spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) + neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + elif acq_software == 'Open Ephys': + session_dir = find_full_path(get_ephys_root_data_dir(), + get_session_directory(key)) + openephys_dataset = openephys.OpenEphys(session_dir) + neuropixels_recording = openephys_dataset.probes[probe_serial_number] + + def yield_unit_waveforms(): + for unit_dict in units.values(): + unit_peak_waveform = {} + unit_electrode_waveforms = [] + + spikes = unit_dict['spike_times'] + waveforms = neuropixels_recording.extract_spike_waveforms( + spikes, kilosort_dataset.data['channel_map']) # (sample x channel x spike) + waveforms = waveforms.transpose((1, 2, 0)) # (channel x spike x sample) + for channel, channel_waveform in zip( + kilosort_dataset.data['channel_map'], waveforms): + unit_electrode_waveforms.append({ + **unit_dict, **channel2electrodes[channel], + 'waveform_mean': channel_waveform.mean(axis=0), + 'waveforms': channel_waveform}) + if channel2electrodes[channel]['electrode'] == unit_dict['electrode']: + unit_peak_waveform = { + **unit_dict, + 'peak_electrode_waveform': channel_waveform.mean(axis=0)} + + yield unit_peak_waveform, unit_electrode_waveforms + + # insert waveform on a per-unit basis to mitigate potential memory issue + self.insert1(key) + for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms(): + if unit_peak_waveform: + self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True) + if unit_electrode_waveforms: + self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True) + + +# ---------------- HELPER FUNCTIONS ---------------- + +def get_spikeglx_meta_filepath(ephys_recording_key): + # attempt to retrieve from EphysRecording.EphysFile + spikeglx_meta_filepath = (EphysRecording.EphysFile & ephys_recording_key + & 'file_path LIKE "%.ap.meta"').fetch1('file_path') + + try: + spikeglx_meta_filepath = find_full_path(get_ephys_root_data_dir(), + spikeglx_meta_filepath) + except FileNotFoundError: + # if not found, search in session_dir again + if not spikeglx_meta_filepath.exists(): + session_dir = find_full_path(get_ephys_root_data_dir(), + get_session_directory(ephys_recording_key)) + inserted_probe_serial_number = (ProbeInsertion * probe.Probe + & ephys_recording_key).fetch1('probe') + + spikeglx_meta_filepaths = [fp for fp in session_dir.rglob('*.ap.meta')] + for meta_filepath in spikeglx_meta_filepaths: + spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) + if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: + spikeglx_meta_filepath = meta_filepath + break + else: + raise FileNotFoundError( + 'No SpikeGLX data found for probe insertion: {}'.format(ephys_recording_key)) + + return spikeglx_meta_filepath + + +def get_openephys_probe_data(ephys_recording_key): + inserted_probe_serial_number = (ProbeInsertion * probe.Probe + & ephys_recording_key).fetch1('probe') + session_dir = find_full_path(get_ephys_root_data_dir(), + get_session_directory(ephys_recording_key)) + loaded_oe = openephys.OpenEphys(session_dir) + probe_data = loaded_oe.probes[inserted_probe_serial_number] + + # explicitly garbage collect "loaded_oe" + # as these may have large memory footprint and may not be cleared fast enough + del loaded_oe + gc.collect() + + return probe_data + + +def get_neuropixels_channel2electrode_map(ephys_recording_key, acq_software): + if acq_software == 'SpikeGLX': + spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) + spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath) + electrode_config_key = (EphysRecording * probe.ElectrodeConfig + & ephys_recording_key).fetch1('KEY') + + electrode_query = (probe.ProbeType.Electrode + * probe.ElectrodeConfig.Electrode & electrode_config_key) + + probe_electrodes = { + (shank, shank_col, shank_row): key + for key, shank, shank_col, shank_row in zip(*electrode_query.fetch( + 'KEY', 'shank', 'shank_col', 'shank_row'))} + + channel2electrode_map = { + recorded_site: probe_electrodes[(shank, shank_col, shank_row)] + for recorded_site, (shank, shank_col, shank_row, _) in enumerate( + spikeglx_meta.shankmap['data'])} + elif acq_software == 'Open Ephys': + probe_dataset = get_openephys_probe_data(ephys_recording_key) + + electrode_query = (probe.ProbeType.Electrode + * probe.ElectrodeConfig.Electrode + * EphysRecording & ephys_recording_key) + + probe_electrodes = {key['electrode']: key + for key in electrode_query.fetch('KEY')} + + channel2electrode_map = { + channel_idx: probe_electrodes[channel_idx] + for channel_idx in probe_dataset.ap_meta['channels_indices']} + + return channel2electrode_map + + +def generate_electrode_config(probe_type: str, electrodes: list): + """ + Generate and insert new ElectrodeConfig + :param probe_type: probe type (e.g. neuropixels 2.0 - SS) + :param electrodes: list of the electrode dict (keys of the probe.ProbeType.Electrode table) + :return: a dict representing a key of the probe.ElectrodeConfig table + """ + # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) + electrode_config_hash = dict_to_uuid({k['electrode']: k for k in electrodes}) + + electrode_list = sorted([k['electrode'] for k in electrodes]) + electrode_gaps = ([-1] + + np.where(np.diff(electrode_list) > 1)[0].tolist() + + [len(electrode_list) - 1]) + electrode_config_name = '; '.join([ + f'{electrode_list[start + 1]}-{electrode_list[end]}' + for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:])]) + + electrode_config_key = {'electrode_config_hash': electrode_config_hash} + + # ---- make new ElectrodeConfig if needed ---- + if not probe.ElectrodeConfig & electrode_config_key: + probe.ElectrodeConfig.insert1({**electrode_config_key, 'probe_type': probe_type, + 'electrode_config_name': electrode_config_name}) + probe.ElectrodeConfig.Electrode.insert({**electrode_config_key, **electrode} + for electrode in electrodes) + + return electrode_config_key + + +def get_recording_channels_details(ephys_recording_key): + channels_details = {} + + acq_software, sample_rate = (EphysRecording & ephys_recording_key).fetch1('acq_software', + 'sampling_rate') + + probe_type = (ProbeInsertion * probe.Probe & ephys_recording_key).fetch1('probe_type') + channels_details['probe_type'] = {'neuropixels 1.0 - 3A': '3A', + 'neuropixels 1.0 - 3B': 'NP1', + 'neuropixels UHD': 'NP1100', + 'neuropixels 2.0 - SS': 'NP21', + 'neuropixels 2.0 - MS': 'NP24'}[probe_type] + + electrode_config_key = (probe.ElectrodeConfig * EphysRecording & ephys_recording_key).fetch1('KEY') + channels_details['channel_ind'], channels_details['x_coords'], channels_details[ + 'y_coords'], channels_details['shank_ind'] = ( + probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode + & electrode_config_key).fetch('electrode', 'x_coord', 'y_coord', 'shank') + channels_details['sample_rate'] = sample_rate + channels_details['num_channels'] = len(channels_details['channel_ind']) + + if acq_software == 'SpikeGLX': + spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + channels_details['uVPerBit'] = spikeglx_recording.get_channel_bit_volts('ap')[0] + channels_details['connected'] = np.array( + [v for *_, v in spikeglx_recording.apmeta.shankmap['data']]) + elif acq_software == 'Open Ephys': + oe_probe = get_openephys_probe_data(ephys_recording_key) + channels_details['uVPerBit'] = oe_probe.ap_meta['channels_gains'][0] + channels_details['connected'] = np.array([ + int(v == 1) for c, v in oe_probe.channels_connected.items() + if c in channels_details['channel_ind']]) + + return channels_details diff --git a/element_array_ephys/probe.py b/element_array_ephys/probe.py index ca281bfe..8c26b8d1 100644 --- a/element_array_ephys/probe.py +++ b/element_array_ephys/probe.py @@ -18,7 +18,7 @@ def activate(schema_name, *, create_schema=True, create_tables=True): schema.activate(schema_name, create_schema=create_schema, create_tables=create_tables) # Add neuropixels probes - for probe_type in ('neuropixels 1.0 - 3A', 'neuropixels 1.0 - 3B', + for probe_type in ('neuropixels 1.0 - 3A', 'neuropixels 1.0 - 3B', 'neuropixels UHD', 'neuropixels 2.0 - SS', 'neuropixels 2.0 - MS'): ProbeType.create_neuropixels_probe(probe_type) @@ -46,15 +46,38 @@ class Electrode(dj.Part): def create_neuropixels_probe(probe_type='neuropixels 1.0 - 3A'): """ Create `ProbeType` and `Electrode` for neuropixels probes: - 1.0 (3A and 3B), 2.0 (SS and MS) + + neuropixels 1.0 - 3A + + neuropixels 1.0 - 3B + + neuropixels UHD + + neuropixels 2.0 - SS + + neuropixels 2.0 - MS + For electrode location, the (0, 0) is the bottom left corner of the probe (ignore the tip portion) Electrode numbering is 1-indexing """ + neuropixels_probes_config = { + 'neuropixels 1.0 - 3A': dict(site_count=960, col_spacing=32, row_spacing=20, + white_spacing=16, col_count=2, + shank_count=1, shank_spacing=0), + 'neuropixels 1.0 - 3B': dict(site_count=960, col_spacing=32, row_spacing=20, + white_spacing=16, col_count=2, + shank_count=1, shank_spacing=0), + 'neuropixels UHD': dict(site_count=384, col_spacing=6, row_spacing=6, + white_spacing=0, col_count=8, + shank_count=1, shank_spacing=0), + 'neuropixels 2.0 - SS': dict(site_count=1280, col_spacing=32, row_spacing=15, + white_spacing=0, col_count=2, + shank_count=1, shank_spacing=250), + 'neuropixels 2.0 - MS': dict(site_count=1280, col_spacing=32, row_spacing=15, + white_spacing=0, col_count=2, + shank_count=4, shank_spacing=250) + } + def build_electrodes(site_count, col_spacing, row_spacing, - white_spacing, col_count=2, - shank_count=1, shank_spacing=250): + white_spacing, col_count, + shank_count, shank_spacing): """ :param site_count: site count per shank :param col_spacing: (um) horrizontal spacing between sites @@ -66,14 +89,15 @@ def build_electrodes(site_count, col_spacing, row_spacing, :return: """ row_count = int(site_count / col_count) - x_coords = np.tile([0, 0 + col_spacing], row_count) - x_white_spaces = np.tile([white_spacing, white_spacing, 0, 0], int(row_count / 2)) + x_coords = np.tile(np.arange(0, col_spacing * col_count, col_spacing), row_count) + y_coords = np.repeat(np.arange(row_count) * row_spacing, col_count) - x_coords = x_coords + x_white_spaces - y_coords = np.repeat(np.arange(row_count) * row_spacing, 2) + if white_spacing: + x_white_spaces = np.tile([white_spacing, white_spacing, 0, 0], int(row_count / 2)) + x_coords = x_coords + x_white_spaces - shank_cols = np.tile([0, 1], row_count) - shank_rows = np.repeat(range(row_count), 2) + shank_cols = np.tile(range(col_count), row_count) + shank_rows = np.repeat(range(row_count), col_count) npx_electrodes = [] for shank_no in range(shank_count): @@ -88,51 +112,12 @@ def build_electrodes(site_count, col_spacing, row_spacing, return npx_electrodes - # ---- 1.0 3A ---- - if probe_type == 'neuropixels 1.0 - 3A': - electrodes = build_electrodes(site_count=960, col_spacing=32, row_spacing=20, - white_spacing=16, col_count=2) - - probe_type = {'probe_type': 'neuropixels 1.0 - 3A'} - with ProbeType.connection.transaction: - ProbeType.insert1(probe_type, skip_duplicates=True) - ProbeType.Electrode.insert([{**probe_type, **e} for e in electrodes], - skip_duplicates=True) - - # ---- 1.0 3B ---- - if probe_type == 'neuropixels 1.0 - 3B': - electrodes = build_electrodes(site_count=960, col_spacing=32, row_spacing=20, - white_spacing=16, col_count=2) - - probe_type = {'probe_type': 'neuropixels 1.0 - 3B'} - with ProbeType.connection.transaction: - ProbeType.insert1(probe_type, skip_duplicates=True) - ProbeType.Electrode.insert([{**probe_type, **e} for e in electrodes], - skip_duplicates=True) - - # ---- 2.0 Single shank ---- - if probe_type == 'neuropixels 2.0 - SS': - electrodes = build_electrodes(site_count=1280, col_spacing=32, row_spacing=15, - white_spacing=0, col_count=2, - shank_count=1, shank_spacing=250) - - probe_type = {'probe_type': 'neuropixels 2.0 - SS'} - with ProbeType.connection.transaction: - ProbeType.insert1(probe_type, skip_duplicates=True) - ProbeType.Electrode.insert([{**probe_type, **e} for e in electrodes], - skip_duplicates=True) - - # ---- 2.0 Multi shank ---- - if probe_type == 'neuropixels 2.0 - MS': - electrodes = build_electrodes(site_count=1280, col_spacing=32, row_spacing=15, - white_spacing=0, col_count=2, - shank_count=4, shank_spacing=250) - - probe_type = {'probe_type': 'neuropixels 2.0 - MS'} - with ProbeType.connection.transaction: - ProbeType.insert1(probe_type, skip_duplicates=True) - ProbeType.Electrode.insert([{**probe_type, **e} for e in electrodes], - skip_duplicates=True) + electrodes = build_electrodes(**neuropixels_probes_config[probe_type]) + probe_type = {'probe_type': probe_type} + with ProbeType.connection.transaction: + ProbeType.insert1(probe_type, skip_duplicates=True) + ProbeType.Electrode.insert([{**probe_type, **e} for e in electrodes], + skip_duplicates=True) @schema diff --git a/element_array_ephys/readers/kilosort.py b/element_array_ephys/readers/kilosort.py index a4e257f9..cddd7656 100644 --- a/element_array_ephys/readers/kilosort.py +++ b/element_array_ephys/readers/kilosort.py @@ -12,7 +12,7 @@ class Kilosort: - kilosort_files = [ + _kilosort_core_files = [ 'params.py', 'amplitudes.npy', 'channel_map.npy', @@ -22,21 +22,23 @@ class Kilosort: 'similar_templates.npy', 'spike_templates.npy', 'spike_times.npy', - 'spike_times_sec.npy', - 'spike_times_sec_adj.npy', 'template_features.npy', 'template_feature_ind.npy', 'templates.npy', 'templates_ind.npy', 'whitening_mat.npy', 'whitening_mat_inv.npy', - 'spike_clusters.npy', + 'spike_clusters.npy' + ] + + _kilosort_additional_files = [ + 'spike_times_sec.npy', + 'spike_times_sec_adj.npy', 'cluster_groups.csv', 'cluster_KSLabel.tsv' ] - # keys to self.files, .data are file name e.g. self.data['params'], etc. - kilosort_keys = [path.splitext(kilosort_file)[0] for kilosort_file in kilosort_files] + kilosort_files = _kilosort_core_files + _kilosort_additional_files def __init__(self, kilosort_dir): self._kilosort_dir = pathlib.Path(kilosort_dir) @@ -44,25 +46,36 @@ def __init__(self, kilosort_dir): self._data = None self._clusters = None - params_filepath = kilosort_dir / 'params.py' - - if not params_filepath.exists(): - raise FileNotFoundError(f'No Kilosort output found in: {kilosort_dir}') + self.validate() + params_filepath = kilosort_dir / 'params.py' self._info = {'time_created': datetime.fromtimestamp(params_filepath.stat().st_ctime), 'time_modified': datetime.fromtimestamp(params_filepath.stat().st_mtime)} @property def data(self): if self._data is None: - self._stat() + self._load() return self._data @property def info(self): return self._info - def _stat(self): + def validate(self): + """ + Check if this is a valid set of kilosort outputs - i.e. all crucial files exist + """ + missing_files = [] + for f in Kilosort._kilosort_core_files: + full_path = self._kilosort_dir / f + if not full_path.exists(): + missing_files.append(f) + if missing_files: + raise FileNotFoundError(f'Kilosort files missing in ({self._kilosort_dir}):' + f' {missing_files}') + + def _load(self): self._data = {} for kilosort_filename in Kilosort.kilosort_files: kilosort_filepath = self._kilosort_dir / kilosort_filename @@ -91,8 +104,10 @@ def _stat(self): self._data[base] = (np.reshape(d, d.shape[0]) if d.ndim == 2 and d.shape[1] == 1 else d) + self._data['channel_map'] = self._data['channel_map'].flatten() + # Read the Cluster Groups - for cluster_pattern, cluster_col_name in zip(['cluster_groups.*', 'cluster_KSLabel.*'], + for cluster_pattern, cluster_col_name in zip(['cluster_group.*', 'cluster_KSLabel.*'], ['group', 'KSLabel']): try: cluster_file = next(self._kilosort_dir.glob(cluster_pattern)) @@ -100,7 +115,7 @@ def _stat(self): pass else: cluster_file_suffix = cluster_file.suffix - assert cluster_file_suffix in ('.csv', '.tsv', '.xlsx') + assert cluster_file_suffix in ('.tsv', '.xlsx') break else: raise FileNotFoundError( diff --git a/element_array_ephys/readers/kilosort_triggering.py b/element_array_ephys/readers/kilosort_triggering.py new file mode 100644 index 00000000..2d988ba1 --- /dev/null +++ b/element_array_ephys/readers/kilosort_triggering.py @@ -0,0 +1,577 @@ +import subprocess +import sys +import pathlib +import json +import re +import inspect +import os +import scipy.io +import numpy as np +from datetime import datetime, timedelta + +from element_interface.utils import dict_to_uuid + + +# import the spike sorting packages +try: + from ecephys_spike_sorting.scripts.create_input_json import createInputJson + from ecephys_spike_sorting.scripts.helpers import SpikeGLX_utils, log_from_json + from ecephys_spike_sorting.modules.kilosort_helper.__main__ import get_noise_channels +except Exception as e: + print(f'Error in loading "ecephys_spike_sorting" package - {str(e)}') + +# import pykilosort package +try: + import pykilosort +except Exception as e: + print(f'Error in loading "pykilosort" package - {str(e)}') + + +class SGLXKilosortPipeline: + """ + An object of SGLXKilosortPipeline manages the state of the Kilosort data processing pipeline + for one Neuropixels probe in one recording session using the Spike GLX acquisition software. + + Primarily calling routines specified from: + https://github.com/jenniferColonell/ecephys_spike_sorting + """ + + _modules = ['kilosort_helper', + 'kilosort_postprocessing', + 'noise_templates', + 'mean_waveforms', + 'quality_metrics'] + + _default_catgt_params = { + 'catGT_car_mode': 'gblcar', + 'catGT_loccar_min_um': 40, + 'catGT_loccar_max_um': 160, + 'catGT_cmd_string': '-prb_fld -out_prb_fld -gfix=0.4,0.10,0.02', + 'ni_present': False, + 'ni_extract_string': '-XA=0,1,3,500 -iXA=1,3,3,0 -XD=-1,1,50 -XD=-1,2,1.7 -XD=-1,3,5 -iXD=-1,3,5' + } + + _input_json_args = list(inspect.signature(createInputJson).parameters) + + def __init__(self, npx_input_dir: str, ks_output_dir: str, + params: dict, KS2ver: str, + run_CatGT=False, + ni_present=False, + ni_extract_string=None): + + self._npx_input_dir = pathlib.Path(npx_input_dir) + + self._ks_output_dir = pathlib.Path(ks_output_dir) + self._ks_output_dir.mkdir(parents=True, exist_ok=True) + + self._params = params + self._KS2ver = KS2ver + self._run_CatGT = run_CatGT + self._run_CatGT = run_CatGT + self._default_catgt_params['ni_present'] = ni_present + self._default_catgt_params['ni_extract_string'] = ni_extract_string or self._default_catgt_params['ni_extract_string'] + + self._json_directory = self._ks_output_dir / 'json_configs' + self._json_directory.mkdir(parents=True, exist_ok=True) + + self._CatGT_finished = False + self.ks_input_params = None + self._modules_input_hash = None + self._modules_input_hash_fp = None + + def parse_input_filename(self): + meta_filename = next(self._npx_input_dir.glob('*.ap.meta')).name + match = re.search('(.*)_g(\d)_t(\d+|cat)\.imec(\d?)\.ap\.meta', meta_filename) + session_str, gate_str, trigger_str, probe_str = match.groups() + return session_str, gate_str, trigger_str, probe_str or '0' + + def generate_CatGT_input_json(self): + if not self._run_CatGT: + print('run_CatGT is set to False, skipping...') + return + + session_str, gate_str, trig_str, probe_str = self.parse_input_filename() + + first_trig, last_trig = SpikeGLX_utils.ParseTrigStr( + 'start,end', probe_str, gate_str, self._npx_input_dir.as_posix()) + trigger_str = repr(first_trig) + ',' + repr(last_trig) + + self._catGT_input_json = self._json_directory / f'{session_str}{probe_str}_CatGT-input.json' + + catgt_params = {k: self._params.get(k, v) + for k, v in self._default_catgt_params.items()} + + ni_present = catgt_params.pop('ni_present') + ni_extract_string = catgt_params.pop('ni_extract_string') + + catgt_params['catGT_stream_string'] = '-ap -ni' if ni_present else '-ap' + sync_extract = '-SY=' + probe_str + ',-1,6,500' + extract_string = sync_extract + (f' {ni_extract_string}' if ni_present else '') + catgt_params['catGT_cmd_string'] += f' {extract_string}' + + input_meta_fullpath, continuous_file = self._get_raw_data_filepaths() + + # create symbolic link to the actual data files - as CatGT expects files to follow a certain naming convention + continuous_file_symlink = (continuous_file.parent / f'{session_str}_g{gate_str}' + / f'{session_str}_g{gate_str}_imec{probe_str}' + / f'{session_str}_g{gate_str}_t{trig_str}.imec{probe_str}.ap.bin') + continuous_file_symlink.parent.mkdir(parents=True, exist_ok=True) + if continuous_file_symlink.exists(): + continuous_file_symlink.unlink() + continuous_file_symlink.symlink_to(continuous_file) + input_meta_fullpath_symlink = (input_meta_fullpath.parent / f'{session_str}_g{gate_str}' + / f'{session_str}_g{gate_str}_imec{probe_str}' + / f'{session_str}_g{gate_str}_t{trig_str}.imec{probe_str}.ap.meta') + input_meta_fullpath_symlink.parent.mkdir(parents=True, exist_ok=True) + if input_meta_fullpath_symlink.exists(): + input_meta_fullpath_symlink.unlink() + input_meta_fullpath_symlink.symlink_to(input_meta_fullpath) + + createInputJson(self._catGT_input_json.as_posix(), + KS2ver=self._KS2ver, + npx_directory=self._npx_input_dir.as_posix(), + spikeGLX_data=True, + catGT_run_name=session_str, + gate_string=gate_str, + trigger_string=trigger_str, + probe_string=probe_str, + continuous_file=continuous_file.as_posix(), + input_meta_path=input_meta_fullpath.as_posix(), + extracted_data_directory=self._ks_output_dir.as_posix(), + kilosort_output_directory=self._ks_output_dir.as_posix(), + kilosort_output_tmp=self._ks_output_dir.as_posix(), + kilosort_repository=_get_kilosort_repository(self._KS2ver), + **{k: v for k, v in catgt_params.items() if k in self._input_json_args} + ) + + def run_CatGT(self, force_rerun=False): + if self._run_CatGT and (not self._CatGT_finished or force_rerun): + self.generate_CatGT_input_json() + + print('---- Running CatGT ----') + catGT_input_json = self._catGT_input_json.as_posix() + catGT_output_json = catGT_input_json.replace('CatGT-input.json', 'CatGT-output.json') + + command = (sys.executable + + " -W ignore -m ecephys_spike_sorting.modules." + + 'catGT_helper' + " --input_json " + catGT_input_json + + " --output_json " + catGT_output_json) + subprocess.check_call(command.split(' ')) + + self._CatGT_finished = True + + def generate_modules_input_json(self): + session_str, _, _, probe_str = self.parse_input_filename() + self._module_input_json = self._json_directory / f'{session_str}_imec{probe_str}-input.json' + + input_meta_fullpath, continuous_file = self._get_raw_data_filepaths() + + params = {} + for k, v in self._params.items(): + value = str(v) if isinstance(v, list) else v + if f'ks_{k}' in self._input_json_args: + params[f'ks_{k}'] = value + if k in self._input_json_args: + params[k] = value + + self.ks_input_params = createInputJson( + self._module_input_json.as_posix(), + KS2ver=self._KS2ver, + npx_directory=self._npx_input_dir.as_posix(), + spikeGLX_data=True, + continuous_file=continuous_file.as_posix(), + input_meta_path=input_meta_fullpath.as_posix(), + extracted_data_directory=self._ks_output_dir.parent.as_posix(), + kilosort_output_directory=self._ks_output_dir.as_posix(), + kilosort_output_tmp=self._ks_output_dir.as_posix(), + ks_make_copy=True, + noise_template_use_rf=self._params.get('noise_template_use_rf', False), + c_Waves_snr_um=self._params.get('c_Waves_snr_um', 160), + qm_isi_thresh=self._params.get('refPerMS', 2.0) / 1000, + kilosort_repository=_get_kilosort_repository(self._KS2ver), + **params + ) + + self._modules_input_hash = dict_to_uuid(self.ks_input_params) + + def run_modules(self): + if self._run_CatGT and not self._CatGT_finished: + self.run_CatGT() + + print('---- Running Modules ----') + self.generate_modules_input_json() + module_input_json = self._module_input_json.as_posix() + module_logfile = module_input_json.replace('-input.json', '-run_modules-log.txt') + + for module in self._modules: + module_status = self._get_module_status(module) + if module_status['completion_time'] is not None: + continue + + module_output_json = self._get_module_output_json_filename(module) + command = (sys.executable + + " -W ignore -m ecephys_spike_sorting.modules." + module + + " --input_json " + module_input_json + + " --output_json " + module_output_json) + + start_time = datetime.utcnow() + self._update_module_status( + {module: {'start_time': start_time, + 'completion_time': None, + 'duration': None}}) + with open(module_logfile, "a") as f: + subprocess.check_call(command.split(' '), stdout=f) + completion_time = datetime.utcnow() + self._update_module_status( + {module: {'start_time': start_time, + 'completion_time': completion_time, + 'duration': (completion_time - start_time).total_seconds()}}) + + self._update_total_duration() + + def _get_raw_data_filepaths(self): + session_str, gate_str, _, probe_str = self.parse_input_filename() + + if self._CatGT_finished: + catGT_dest = self._ks_output_dir + run_str = session_str + '_g' + gate_str + run_folder = 'catgt_' + run_str + prb_folder = run_str + '_imec' + probe_str + data_directory = catGT_dest / run_folder / prb_folder + else: + data_directory = self._npx_input_dir + try: + meta_fp = next(data_directory.glob(f'{session_str}*.ap.meta')) + bin_fp = next(data_directory.glob(f'{session_str}*.ap.bin')) + except StopIteration: + raise RuntimeError(f'No ap meta/bin files found in {data_directory} - CatGT error?') + + return meta_fp, bin_fp + + def _update_module_status(self, updated_module_status={}): + if self._modules_input_hash is None: + raise RuntimeError('"generate_modules_input_json()" not yet performed!') + + self._modules_input_hash_fp = self._json_directory / f'.{self._modules_input_hash}.json' + if self._modules_input_hash_fp.exists(): + with open(self._modules_input_hash_fp) as f: + modules_status = json.load(f) + modules_status = {**modules_status, **updated_module_status} + else: + modules_status = {module: {'start_time': None, + 'completion_time': None, + 'duration': None} + for module in self._modules} + with open(self._modules_input_hash_fp, 'w') as f: + json.dump(modules_status, f, default=str) + + def _get_module_status(self, module): + if self._modules_input_hash_fp is None: + self._update_module_status() + + if self._modules_input_hash_fp.exists(): + with open(self._modules_input_hash_fp) as f: + modules_status = json.load(f) + if modules_status[module]['completion_time'] is None: + # additional logic to read from the "-output.json" file for this module as well + # handle cases where the module has finished successfully, + # but the "_modules_input_hash_fp" is not updated (for whatever reason), + # resulting in this module not registered as completed in the "_modules_input_hash_fp" + module_output_json_fp = pathlib.Path(self._get_module_output_json_filename(module)) + if module_output_json_fp.exists(): + with open(module_output_json_fp) as f: + module_run_output = json.load(f) + modules_status[module]['duration'] = module_run_output['execution_time'] + modules_status[module]['completion_time'] = ( + datetime.strptime(modules_status[module]['start_time'], '%Y-%m-%d %H:%M:%S.%f') + + timedelta(seconds=module_run_output['execution_time'])) + return modules_status[module] + + return {'start_time': None, 'completion_time': None, 'duration': None} + + def _get_module_output_json_filename(self, module): + module_input_json = self._module_input_json.as_posix() + module_output_json = module_input_json.replace( + '-input.json', + '-' + module + '-' + str(self._modules_input_hash) + '-output.json') + return module_output_json + + def _update_total_duration(self): + with open(self._modules_input_hash_fp) as f: + modules_status = json.load(f) + cumulative_execution_duration = sum( + v['duration'] or 0 for k, v in modules_status.items() + if k not in ('cumulative_execution_duration', 'total_duration')) + total_duration = ( + datetime.strptime(modules_status[self._modules[-1]]['completion_time'], '%Y-%m-%d %H:%M:%S.%f') + - datetime.strptime(modules_status[self._modules[0]]['start_time'], '%Y-%m-%d %H:%M:%S.%f') + ).total_seconds() + self._update_module_status( + {'cumulative_execution_duration': cumulative_execution_duration, + 'total_duration': total_duration}) + + +class OpenEphysKilosortPipeline: + """ + An object of OpenEphysKilosortPipeline manages the state of the Kilosort data processing pipeline + for one Neuropixels probe in one recording session using the Open Ephys acquisition software. + + Primarily calling routines specified from: + https://github.com/jenniferColonell/ecephys_spike_sorting + Which is based on `ecephys_spike_sorting` routines from Allen Institute + https://github.com/AllenInstitute/ecephys_spike_sorting + """ + + _modules = ['depth_estimation', + 'median_subtraction', + 'kilosort_helper', + 'kilosort_postprocessing', + 'noise_templates', + 'mean_waveforms', + 'quality_metrics'] + + _input_json_args = list(inspect.signature(createInputJson).parameters) + + def __init__(self, npx_input_dir: str, ks_output_dir: str, + params: dict, KS2ver: str): + + self._npx_input_dir = pathlib.Path(npx_input_dir) + + self._ks_output_dir = pathlib.Path(ks_output_dir) + self._ks_output_dir.mkdir(parents=True, exist_ok=True) + + self._params = params + self._KS2ver = KS2ver + + self._json_directory = self._ks_output_dir / 'json_configs' + self._json_directory.mkdir(parents=True, exist_ok=True) + + self.ks_input_params = None + self._modules_input_hash = None + self._modules_input_hash_fp = None + + def make_chanmap_file(self): + continuous_file = self._npx_input_dir / 'continuous.dat' + self._chanmap_filepath = self._ks_output_dir / 'chanMap.mat' + + _write_channel_map_file(channel_ind=self._params['channel_ind'], + x_coords=self._params['x_coords'], + y_coords=self._params['y_coords'], + shank_ind=self._params['shank_ind'], + connected=self._params['connected'], + probe_name=self._params['probe_type'], + ap_band_file=continuous_file.as_posix(), + bit_volts=self._params['uVPerBit'], + sample_rate=self._params['sample_rate'], + save_path=self._chanmap_filepath.as_posix(), + is_0_based=True) + + def generate_modules_input_json(self): + self.make_chanmap_file() + + self._module_input_json = self._json_directory / f'{self._npx_input_dir.name}-input.json' + + continuous_file = self._npx_input_dir / 'continuous.dat' + + params = {} + for k, v in self._params.items(): + value = str(v) if isinstance(v, list) else v + if f'ks_{k}' in self._input_json_args: + params[f'ks_{k}'] = value + if k in self._input_json_args: + params[k] = value + + self.ks_input_params = createInputJson( + self._module_input_json.as_posix(), + KS2ver=self._KS2ver, + npx_directory=self._npx_input_dir.as_posix(), + spikeGLX_data=False, + continuous_file=continuous_file.as_posix(), + extracted_data_directory=self._ks_output_dir.parent.as_posix(), + kilosort_output_directory=self._ks_output_dir.as_posix(), + kilosort_output_tmp=self._ks_output_dir.as_posix(), + ks_make_copy=True, + noise_template_use_rf=self._params.get('noise_template_use_rf', False), + use_C_Waves=False, + c_Waves_snr_um=self._params.get('c_Waves_snr_um', 160), + qm_isi_thresh=self._params.get('refPerMS', 2.0) / 1000, + kilosort_repository=_get_kilosort_repository(self._KS2ver), + chanMap_path=self._chanmap_filepath.as_posix(), + **params + ) + + self._modules_input_hash = dict_to_uuid(self.ks_input_params) + + def run_modules(self): + print('---- Running Modules ----') + self.generate_modules_input_json() + module_input_json = self._module_input_json.as_posix() + module_logfile = module_input_json.replace('-input.json', '-run_modules-log.txt') + + for module in self._modules: + module_status = self._get_module_status(module) + if module_status['completion_time'] is not None: + continue + + module_output_json = self._get_module_output_json_filename(module) + command = [sys.executable, + '-W', 'ignore', '-m', 'ecephys_spike_sorting.modules.' + module, + '--input_json', module_input_json, + '--output_json', module_output_json] + + start_time = datetime.utcnow() + self._update_module_status( + {module: {'start_time': start_time, + 'completion_time': None, + 'duration': None}}) + with open(module_logfile, "a") as f: + subprocess.check_call(command, stdout=f) + completion_time = datetime.utcnow() + self._update_module_status( + {module: {'start_time': start_time, + 'completion_time': completion_time, + 'duration': (completion_time - start_time).total_seconds()}}) + + self._update_total_duration() + + def _update_module_status(self, updated_module_status={}): + if self._modules_input_hash is None: + raise RuntimeError('"generate_modules_input_json()" not yet performed!') + + self._modules_input_hash_fp = self._json_directory / f'.{self._modules_input_hash}.json' + if self._modules_input_hash_fp.exists(): + with open(self._modules_input_hash_fp) as f: + modules_status = json.load(f) + modules_status = {**modules_status, **updated_module_status} + else: + modules_status = {module: {'start_time': None, + 'completion_time': None, + 'duration': None} + for module in self._modules} + with open(self._modules_input_hash_fp, 'w') as f: + json.dump(modules_status, f, default=str) + + def _get_module_status(self, module): + if self._modules_input_hash_fp is None: + self._update_module_status() + + if self._modules_input_hash_fp.exists(): + with open(self._modules_input_hash_fp) as f: + modules_status = json.load(f) + if modules_status[module]['completion_time'] is None: + # additional logic to read from the "-output.json" file for this module as well + # handle cases where the module has finished successfully, + # but the "_modules_input_hash_fp" is not updated (for whatever reason), + # resulting in this module not registered as completed in the "_modules_input_hash_fp" + module_output_json_fp = pathlib.Path(self._get_module_output_json_filename(module)) + if module_output_json_fp.exists(): + with open(module_output_json_fp) as f: + module_run_output = json.load(f) + modules_status[module]['duration'] = module_run_output['execution_time'] + modules_status[module]['completion_time'] = ( + datetime.strptime(modules_status[module]['start_time'], '%Y-%m-%d %H:%M:%S.%f') + + timedelta(seconds=module_run_output['execution_time'])) + return modules_status[module] + + return {'start_time': None, 'completion_time': None, 'duration': None} + + def _get_module_output_json_filename(self, module): + module_input_json = self._module_input_json.as_posix() + module_output_json = module_input_json.replace( + '-input.json', + '-' + module + '-' + str(self._modules_input_hash) + '-output.json') + return module_output_json + + def _update_total_duration(self): + with open(self._modules_input_hash_fp) as f: + modules_status = json.load(f) + cumulative_execution_duration = sum( + v['duration'] or 0 for k, v in modules_status.items() + if k not in ('cumulative_execution_duration', 'total_duration')) + total_duration = ( + datetime.strptime(modules_status[self._modules[-1]]['completion_time'], '%Y-%m-%d %H:%M:%S.%f') + - datetime.strptime(modules_status[self._modules[0]]['start_time'], '%Y-%m-%d %H:%M:%S.%f') + ).total_seconds() + self._update_module_status( + {'cumulative_execution_duration': cumulative_execution_duration, + 'total_duration': total_duration}) + + +def run_pykilosort(continuous_file, kilosort_output_directory, params, + channel_ind, x_coords, y_coords, shank_ind, connected, sample_rate): + dat_path = pathlib.Path(continuous_file) + + probe = pykilosort.Bunch() + channel_count = len(channel_ind) + probe.Nchan = channel_count + probe.NchanTOT = channel_count + probe.chanMap = np.arange(0, channel_count, dtype='int') + probe.xc = x_coords + probe.yc = y_coords + probe.kcoords = shank_ind + + pykilosort.run(dat_path=continuous_file, + dir_path=dat_path.parent, + output_dir=kilosort_output_directory, + probe=probe, + params=params, + n_channels=probe.Nchan, + dtype=np.int16, + sample_rate=sample_rate) + + +def _get_kilosort_repository(KS2ver): + """ + Get the path to where the kilosort package is installed at, assuming it can be found + as environment variable named "kilosort_repository" + Modify this path according to the KSVer used + """ + ks_repo = pathlib.Path(os.getenv('kilosort_repository')) + assert ks_repo.exists() + assert ks_repo.stem.startswith('Kilosort') + + ks_repo = ks_repo.parent / f'Kilosort-{KS2ver}' + assert ks_repo.exists() + + return ks_repo.as_posix() + + +def _write_channel_map_file(*, channel_ind, x_coords, y_coords, shank_ind, connected, + probe_name, ap_band_file, bit_volts, sample_rate, + save_path, is_0_based=True): + """ + Write channel map into .mat file in 1-based indexing format (MATLAB style) + """ + + assert len(channel_ind) == len(x_coords) == len(y_coords) == len(shank_ind) == len(connected) + + if is_0_based: + channel_ind += 1 + shank_ind += 1 + + channel_count = len(channel_ind) + chanMap0ind = np.arange(0, channel_count, dtype='float64') + chanMap0ind = chanMap0ind.reshape((channel_count, 1)) + chanMap = chanMap0ind + 1 + + # channels to exclude + mask = get_noise_channels(ap_band_file, + channel_count, + sample_rate, + bit_volts) + bad_channel_ind = np.where(mask is False)[0] + connected[bad_channel_ind] = 0 + + mdict = { + 'chanMap': chanMap, + 'chanMap0ind': chanMap0ind, + 'connected': connected, + 'name': probe_name, + 'xcoords': x_coords, + 'ycoords': y_coords, + 'shankInd': shank_ind, + 'kcoords': shank_ind, + 'fs': sample_rate + } + + scipy.io.savemat(save_path, mdict) diff --git a/element_array_ephys/readers/openephys.py b/element_array_ephys/readers/openephys.py index 7e8b240d..00718cf0 100644 --- a/element_array_ephys/readers/openephys.py +++ b/element_array_ephys/readers/openephys.py @@ -1,6 +1,11 @@ import pathlib import pyopenephys import numpy as np +import re +import datetime +import logging + +logger = logging.getLogger(__name__) """ @@ -8,7 +13,7 @@ (https://open-ephys.github.io/gui-docs/User-Manual/Recording-data/Binary-format.html) Record Node 102 --- experiment1 (equivalent to a Session) +-- experiment1 (equivalent to one experimental session - multi probes, multi recordings per probe) -- recording1 -- recording2 -- continuous @@ -33,20 +38,37 @@ class OpenEphys: def __init__(self, experiment_dir): self.session_dir = pathlib.Path(experiment_dir) - openephys_file = pyopenephys.File(self.session_dir.parent) # this is on the Record Node level + if self.session_dir.name.startswith('recording'): + openephys_file = pyopenephys.File(self.session_dir.parent.parent) # this is on the Record Node level + self._is_recording_folder = True + else: + openephys_file = pyopenephys.File(self.session_dir.parent) # this is on the Record Node level + self._is_recording_folder = False # extract the "recordings" for this session self.experiment = next(experiment for experiment in openephys_file.experiments - if pathlib.Path(experiment.absolute_foldername) == self.session_dir) - - self.recording_time = self.experiment.datetime + if pathlib.Path(experiment.absolute_foldername) == ( + self.session_dir.parent if self._is_recording_folder else self.session_dir) + ) # extract probe data self.probes = self.load_probe_data() + # + self._recording_time = None + + @property + def recording_time(self): + if self._recording_time is None: + recording_datetimes = [] + for probe in self.probes.values(): + recording_datetimes.extend(probe.recording_info['recording_datetimes']) + self._recording_time = sorted(recording_datetimes)[0] + return self._recording_time + def load_probe_data(self): """ - Loop through all Open Ephys "processors", identify the processor for + Loop through all Open Ephys "signalchains/processors", identify the processor for the Neuropixels probe(s), extract probe info Loop through all recordings, associate recordings to the matching probes, extract recording info @@ -56,46 +78,98 @@ def load_probe_data(self): """ probes = {} - for processor in self.experiment.settings['SIGNALCHAIN']['PROCESSOR']: - if processor['@pluginName'] in ('Neuropix-PXI', 'Neuropix-3a'): - if (processor['@pluginName'] == 'Neuropix-3a' - or 'NP_PROBE' not in processor['EDITOR']): - probe = Probe(processor) + sigchain_iter = (self.experiment.settings['SIGNALCHAIN'] + if isinstance(self.experiment.settings['SIGNALCHAIN'], list) + else [self.experiment.settings['SIGNALCHAIN']]) + for sigchain in sigchain_iter: + processor_iter = (sigchain['PROCESSOR'] + if isinstance(sigchain['PROCESSOR'], list) + else [sigchain['PROCESSOR']]) + for processor in processor_iter: + if processor['@pluginName'] in ('Neuropix-3a', 'Neuropix-PXI'): + if 'STREAM' in processor: # only on version >= 0.6.0 + ap_streams = [stream for stream in processor['STREAM'] if not stream['@name'].endswith('LFP')] + else: + ap_streams = None + + if (processor['@pluginName'] == 'Neuropix-3a' + or 'NP_PROBE' not in processor['EDITOR']): + if isinstance(processor['EDITOR']['PROBE'], dict): + probe_indices = (0,) + else: + probe_indices = range(len(processor['EDITOR']['PROBE'])) + elif processor['@pluginName'] == 'Neuropix-PXI': + probe_indices = range(len(processor['EDITOR']['NP_PROBE'])) + else: + raise NotImplementedError + else: # not a processor for Neuropixels probe + continue + + for probe_index in probe_indices: + probe = Probe(processor, probe_index) + if ap_streams: + probe.probe_info['ap_stream'] = ap_streams[probe_index] probes[probe.probe_SN] = probe - else: - for probe_index in range(len(processor['EDITOR']['NP_PROBE'])): - probe = Probe(processor, probe_index) - probes[probe.probe_SN] = probe - + for probe_index, probe_SN in enumerate(probes): probe = probes[probe_SN] for rec in self.experiment.recordings: + if self._is_recording_folder and rec.absolute_foldername != self.session_dir: + continue + + assert len(rec._oebin['continuous']) == len(rec.analog_signals), \ + f'Mismatch in the number of continuous data' \ + f' - expecting {len(rec._oebin["continuous"])} (from structure.oebin file),' \ + f' found {len(rec.analog_signals)} (in continuous folder)' + for continuous_info, analog_signal in zip(rec._oebin['continuous'], rec.analog_signals): if continuous_info['source_processor_id'] != probe.processor_id: continue - if continuous_info['source_processor_sub_idx'] == probe_index * 2: # ap data - assert continuous_info['sample_rate'] == analog_signal.sample_rate == 30000 - continuous_type = 'ap' - + # determine if this is continuous data for AP or LFP for the current probe + if 'ap_stream' in probe.probe_info: + if probe.probe_info['ap_stream']['@name'].split('-')[0] != continuous_info['stream_name'].split('-')[0]: + continue # not continuous data for the current probe + match = re.search('-(AP|LFP)$', continuous_info['stream_name']) + if match: + continuous_type = match.groups()[0].lower() + else: + continuous_type = 'ap' + elif 'source_processor_sub_idx' in continuous_info: + if continuous_info['source_processor_sub_idx'] == probe_index * 2: # ap data + assert continuous_info['sample_rate'] == analog_signal.sample_rate == 30000 + continuous_type = 'ap' + elif continuous_info['source_processor_sub_idx'] == probe_index * 2 + 1: # lfp data + assert continuous_info['sample_rate'] == analog_signal.sample_rate == 2500 + continuous_type = 'lfp' + else: + continue # not continuous data for the current probe + else: + raise ValueError(f'Unable to infer type (AP or LFP) for the continuous data from:\n\t{continuous_info["folder_name"]}') + + if continuous_type == 'ap': probe.recording_info['recording_count'] += 1 probe.recording_info['recording_datetimes'].append( - rec.datetime) + rec.datetime + datetime.timedelta(seconds=float(rec.start_time))) probe.recording_info['recording_durations'].append( float(rec.duration)) probe.recording_info['recording_files'].append( rec.absolute_foldername / 'continuous' / continuous_info['folder_name']) - - elif continuous_info['source_processor_sub_idx'] == probe_index * 2 + 1: # lfp data - assert continuous_info['sample_rate'] == analog_signal.sample_rate == 2500 - continuous_type = 'lfp' + elif continuous_type == 'lfp': + probe.recording_info['recording_lfp_files'].append( + rec.absolute_foldername / 'continuous' / continuous_info['folder_name']) meta = getattr(probe, continuous_type + '_meta') if not meta: + # channel indices - 0-based indexing + channels_indices = [int(re.search(r'\d+$', chn_name).group()) - 1 + for chn_name in analog_signal.channel_names] + meta.update(**continuous_info, + channels_indices=channels_indices, channels_ids=analog_signal.channel_ids, channels_names=analog_signal.channel_names, channels_gains=analog_signal.gains) @@ -106,21 +180,51 @@ def load_probe_data(self): return probes +# For more details on supported probes, +# see: https://open-ephys.github.io/gui-docs/User-Manual/Plugins/Neuropixels-PXI.html +_probe_model_name_mapping = { + "Neuropix-PXI": "neuropixels 1.0 - 3B", + "Neuropix-3a": "neuropixels 1.0 - 3A", + "Neuropixels 1.0": "neuropixels 1.0 - 3B", + "Neuropixels Ultra": "neuropixels UHD", + "Neuropixels Ultra (Switchable)": "neuropixels UHD", + "Neuropixels 21": "neuropixels 2.0 - SS", + "Neuropixels 24": "neuropixels 2.0 - MS", + "Neuropixels 2.0 - Single Shank": "neuropixels 2.0 - SS", + "Neuropixels 2.0 - Four Shank": "neuropixels 2.0 - MS" +} + + class Probe: def __init__(self, processor, probe_index=0): - self.processor_id = int(processor['@NodeId']) + processor_node_id = processor.get("@nodeId", processor.get("@NodeId")) + if processor_node_id is None: + raise KeyError('Neither "@nodeId" nor "@NodeId" key found') + + self.processor_id = int(processor_node_id) if processor['@pluginName'] == 'Neuropix-3a' or 'NP_PROBE' not in processor['EDITOR']: - self.probe_info = processor['EDITOR']['PROBE'] + self.probe_info = processor['EDITOR']['PROBE'] if isinstance(processor['EDITOR']['PROBE'], dict) else processor['EDITOR']['PROBE'][probe_index] self.probe_SN = self.probe_info['@probe_serial_number'] - self.probe_model = { - "Neuropix-PXI": "neuropixels 1.0 - 3B", - "Neuropix-3a": "neuropixels 1.0 - 3A"}[processor['@pluginName']] - else: + self.probe_model = _probe_model_name_mapping[processor['@pluginName']] + self._channels_connected = {int(re.search(r'\d+$', k).group()): int(v) + for k, v in self.probe_info.pop('CHANNELSTATUS').items()} + else: # Neuropix-PXI self.probe_info = processor['EDITOR']['NP_PROBE'][probe_index] self.probe_SN = self.probe_info['@probe_serial_number'] - self.probe_model = self.probe_info['@probe_name'] + self.probe_model = _probe_model_name_mapping[self.probe_info['@probe_name']] + + if 'ELECTRODE_XPOS' in self.probe_info: + self.probe_info['ELECTRODE_XPOS'] = {int(re.search(r'\d+$', k).group()): int(v) + for k, v in self.probe_info.pop('ELECTRODE_XPOS').items()} + self.probe_info['ELECTRODE_YPOS'] = {int(re.search(r'\d+$', k).group()): int(v) + for k, v in self.probe_info.pop('ELECTRODE_YPOS').items()} + self.probe_info['ELECTRODE_SHANK'] = {int(re.search(r'\d+$', k).group()): int(v) + for k, v in self.probe_info['CHANNELS'].items()} + + self._channels_connected = {int(re.search(r'\d+$', k).group()): 1 + for k in self.probe_info.pop('CHANNELS')} self.ap_meta = {} self.lfp_meta = {} @@ -131,13 +235,19 @@ def __init__(self, processor, probe_index=0): self.recording_info = {'recording_count': 0, 'recording_datetimes': [], 'recording_durations': [], - 'recording_files': []} + 'recording_files': [], + 'recording_lfp_files': []} self._ap_timeseries = None self._ap_timestamps = None self._lfp_timeseries = None self._lfp_timestamps = None + @property + def channels_connected(self): + return {chn_idx: self._channels_connected.get(chn_idx, 0) + for chn_idx in self.ap_meta['channels_indices']} + @property def ap_timeseries(self): """ @@ -200,3 +310,73 @@ def extract_spike_waveforms(self, spikes, channel_ind, n_wf=500, wf_win=(-32, 32 return spike_wfs else: # if no spike found, return NaN of size (sample x channel x 1) return np.full((len(range(*wf_win)), len(channel_ind), 1), np.nan) + + def compress(self): + from mtscomp import compress as mts_compress + + ap_dirs = self.recording_info['recording_files'] + lfp_dirs = self.recording_info['recording_lfp_files'] + + meta_mapping = {'ap': self.ap_meta, 'lfp': self.lfp_meta} + + compressed_files = [] + for continuous_dir, continuous_type in zip( + ap_dirs + lfp_dirs, + ['ap'] * len(ap_dirs) + ['lfp'] * len(lfp_dirs)): + dat_fp = continuous_dir / 'continuous.dat' + if not dat_fp.exists(): + raise FileNotFoundError(f'Compression error - "{dat_fp}" does not exist') + cdat_fp = continuous_dir / 'continuous.cdat' + ch_fp = continuous_dir / 'continuous.ch' + + if cdat_fp.exists(): + assert ch_fp.exists() + logger.info(f'Compressed file exists ({cdat_fp}), skipping...') + continue + + try: + mts_compress(dat_fp, cdat_fp, ch_fp, + sample_rate=meta_mapping[continuous_type]['sample_rate'], + n_channels=meta_mapping[continuous_type]['num_channels'], + dtype=np.memmap(dat_fp).dtype) + except Exception as e: + cdat_fp.unlink(missing_ok=True) + ch_fp.unlink(missing_ok=True) + raise e + else: + compressed_files.append((cdat_fp, ch_fp)) + + return compressed_files + + def decompress(self): + from mtscomp import decompress as mts_decompress + + ap_dirs = self.recording_info['recording_files'] + lfp_dirs = self.recording_info['recording_lfp_files'] + + decompressed_files = [] + for continuous_dir, continuous_type in zip( + ap_dirs + lfp_dirs, + ['ap'] * len(ap_dirs) + ['lfp'] * len(lfp_dirs)): + dat_fp = continuous_dir / 'continuous.dat' + + if dat_fp.exists(): + logger.info(f'Decompressed file exists ({dat_fp}), skipping...') + continue + + cdat_fp = continuous_dir / 'continuous.cdat' + ch_fp = continuous_dir / 'continuous.ch' + + if not cdat_fp.exists(): + raise FileNotFoundError(f'Decompression error - "{cdat_fp}" does not exist') + + try: + decomp_arr = mts_decompress(cdat_fp, ch_fp) + decomp_arr.tofile(dat_fp) + except Exception as e: + dat_fp.unlink(missing_ok=True) + raise e + else: + decompressed_files.append(dat_fp) + + return decompressed_files diff --git a/element_array_ephys/readers/spikeglx.py b/element_array_ephys/readers/spikeglx.py index 67569989..d877f52c 100644 --- a/element_array_ephys/readers/spikeglx.py +++ b/element_array_ephys/readers/spikeglx.py @@ -1,8 +1,10 @@ from datetime import datetime import numpy as np import pathlib +import logging from .utils import convert_to_number +logger = logging.getLogger(__name__) AP_GAIN = 80 # For NP 2.0 probes; APGain = 80 for all AP (LF is computed from AP) @@ -58,6 +60,7 @@ def ap_timeseries(self): - to convert to microvolts, multiply with self.get_channel_bit_volts('ap') """ if self._ap_timeseries is None: + self.validate_file('ap') self._ap_timeseries = self._read_bin(self.root_dir / (self.root_name + '.ap.bin')) return self._ap_timeseries @@ -75,6 +78,7 @@ def lf_timeseries(self): - to convert to microvolts, multiply with self.get_channel_bit_volts('lf') """ if self._lf_timeseries is None: + self.validate_file('lf') self._lf_timeseries = self._read_bin(self.root_dir / (self.root_name + '.lf.bin')) return self._lf_timeseries @@ -145,6 +149,81 @@ def extract_spike_waveforms(self, spikes, channel_ind, n_wf=500, wf_win=(-32, 32 else: # if no spike found, return NaN of size (sample x channel x 1) return np.full((len(range(*wf_win)), len(channel_ind), 1), np.nan) + def validate_file(self, file_type='ap'): + file_path = self.root_dir / (self.root_name + f'.{file_type}.bin') + file_size = file_path.stat().st_size + + meta_mapping = { + 'ap': self.apmeta, + 'lf': self.lfmeta} + meta = meta_mapping[file_type] + + if file_size != meta.meta['fileSizeBytes']: + raise IOError(f'File size error! {file_path} may be corrupted or in transfer?') + + def compress(self): + from mtscomp import compress as mts_compress + + ap_file = self.root_dir / (self.root_name + '.ap.bin') + lfp_file = self.root_dir / (self.root_name + '.lf.bin') + + meta_mapping = {'ap': self.apmeta, 'lfp': self.lfmeta} + + compressed_files = [] + for bin_fp, band_type in zip([ap_file, lfp_file], ['ap', 'lfp']): + if not bin_fp.exists(): + raise FileNotFoundError(f'Compression error - "{bin_fp}" does not exist') + cbin_fp = bin_fp.parent / f'{bin_fp.stem}.cbin' + ch_fp = bin_fp.parent / f'{bin_fp.stem}.ch' + + if cbin_fp.exists(): + assert ch_fp.exists() + logger.info(f'Compressed file exists ({cbin_fp}), skipping...') + continue + + try: + mts_compress(bin_fp, cbin_fp, ch_fp, + sample_rate=meta_mapping[band_type]['sample_rate'], + n_channels=meta_mapping[band_type]['num_channels'], + dtype=np.memmap(bin_fp).dtype) + except Exception as e: + cbin_fp.unlink(missing_ok=True) + ch_fp.unlink(missing_ok=True) + raise e + else: + compressed_files.append((cbin_fp, ch_fp)) + + return compressed_files + + def decompress(self): + from mtscomp import decompress as mts_decompress + + ap_file = self.root_dir / (self.root_name + '.ap.bin') + lfp_file = self.root_dir / (self.root_name + '.lf.bin') + + decompressed_files = [] + for bin_fp, band_type in zip([ap_file, lfp_file], ['ap', 'lfp']): + if bin_fp.exists(): + logger.info(f'Decompressed file exists ({bin_fp}), skipping...') + continue + + cbin_fp = bin_fp.parent / f'{bin_fp.stem}.cbin' + ch_fp = bin_fp.parent / f'{bin_fp.stem}.ch' + + if not cbin_fp.exists(): + raise FileNotFoundError(f'Decompression error - "{cbin_fp}" does not exist') + + try: + decomp_arr = mts_decompress(cbin_fp, ch_fp) + decomp_arr.tofile(bin_fp) + except Exception as e: + bin_fp.unlink(missing_ok=True) + raise e + else: + decompressed_files.append(bin_fp) + + return decompressed_files + class SpikeGLXMeta: @@ -165,6 +244,8 @@ def __init__(self, meta_filepath): self.probe_model = 'neuropixels 1.0 - 3A' elif 'typeImEnabled' in self.meta: self.probe_model = 'neuropixels 1.0 - 3B' + elif probe_model == 1100: + self.probe_model = 'neuropixels UHD' elif probe_model == 21: self.probe_model = 'neuropixels 2.0 - SS' elif probe_model == 24: @@ -339,3 +420,10 @@ def _read_meta(meta_filepath): except ValueError: pass return res + + +def retrieve_recording_duration(meta_filepath): + root_dir = pathlib.Path(meta_filepath).parent + spike_glx = SpikeGLX(root_dir) + return (spike_glx.apmeta.recording_duration + or spike_glx.ap_timeseries.shape[0] / spike_glx.apmeta.meta['imSampRate']) diff --git a/element_array_ephys/version.py b/element_array_ephys/version.py index 5d6e909e..60162e4c 100644 --- a/element_array_ephys/version.py +++ b/element_array_ephys/version.py @@ -1,2 +1,2 @@ """Package metadata.""" -__version__ = '0.1.0b0' \ No newline at end of file +__version__ = '0.1.0b4' \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 6ed6fb1b..d484bae7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ datajoint>=0.13 -pyopenephys -openpyxl \ No newline at end of file +pyopenephys @ git+https://github.com/datajoint-company/pyopenephys.git +openpyxl +pynwb==1.4.0 +element-interface @ git+https://github.com/datajoint/element-interface.git