diff --git a/element_array_ephys/ephys_acute.py b/element_array_ephys/ephys_acute.py index 320db517..568d6323 100644 --- a/element_array_ephys/ephys_acute.py +++ b/element_array_ephys/ephys_acute.py @@ -4,6 +4,7 @@ import numpy as np import inspect import importlib +import gc from decimal import Decimal from element_interface.utils import find_root_directory, find_full_path, dict_to_uuid @@ -326,6 +327,10 @@ def make(self, key): self.EphysFile.insert([{**key, 'file_path': fp.relative_to(root_dir).as_posix()} for fp in probe_data.recording_info['recording_files']]) + # explicitly garbage collect "dataset" + # as these may have large memory footprint and may not be cleared fast enough + del probe_data, dataset + gc.collect() else: raise NotImplementedError(f'Processing ephys files from' f' acquisition software of type {acq_software} is' @@ -919,7 +924,14 @@ def get_openephys_probe_data(ephys_recording_key): session_dir = find_full_path(get_ephys_root_data_dir(), get_session_directory(ephys_recording_key)) loaded_oe = openephys.OpenEphys(session_dir) - return loaded_oe.probes[inserted_probe_serial_number] + probe_data = loaded_oe.probes[inserted_probe_serial_number] + + # explicitly garbage collect "loaded_oe" + # as these may have large memory footprint and may not be cleared fast enough + del loaded_oe + gc.collect() + + return probe_data def get_neuropixels_channel2electrode_map(ephys_recording_key, acq_software): diff --git a/element_array_ephys/ephys_chronic.py b/element_array_ephys/ephys_chronic.py index d8162126..cd599f89 100644 --- a/element_array_ephys/ephys_chronic.py +++ b/element_array_ephys/ephys_chronic.py @@ -4,6 +4,7 @@ import numpy as np import inspect import importlib +import gc from decimal import Decimal from element_interface.utils import find_root_directory, find_full_path, dict_to_uuid @@ -273,6 +274,10 @@ def make(self, key): self.EphysFile.insert([{**key, 'file_path': fp.relative_to(root_dir).as_posix()} for fp in probe_data.recording_info['recording_files']]) + # explicitly garbage collect "dataset" + # as these may have large memory footprint and may not be cleared fast enough + del probe_data, dataset + gc.collect() else: raise NotImplementedError(f'Processing ephys files from' f' acquisition software of type {acq_software} is' @@ -862,10 +867,17 @@ def get_spikeglx_meta_filepath(ephys_recording_key): def get_openephys_probe_data(ephys_recording_key): inserted_probe_serial_number = (ProbeInsertion * probe.Probe & ephys_recording_key).fetch1('probe') - sess_dir = find_full_path(get_ephys_root_data_dir(), + session_dir = find_full_path(get_ephys_root_data_dir(), get_session_directory(ephys_recording_key)) - loaded_oe = openephys.OpenEphys(sess_dir) - return loaded_oe.probes[inserted_probe_serial_number] + loaded_oe = openephys.OpenEphys(session_dir) + probe_data = loaded_oe.probes[inserted_probe_serial_number] + + # explicitly garbage collect "loaded_oe" + # as these may have large memory footprint and may not be cleared fast enough + del loaded_oe + gc.collect() + + return probe_data def get_neuropixels_channel2electrode_map(ephys_recording_key, acq_software): diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index bbd00fa1..1e4fdd5c 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -4,6 +4,7 @@ import numpy as np import inspect import importlib +import gc from decimal import Decimal from element_interface.utils import find_root_directory, find_full_path, dict_to_uuid @@ -325,6 +326,10 @@ def make(self, key): self.EphysFile.insert([{**key, 'file_path': fp.relative_to(root_dir).as_posix()} for fp in probe_data.recording_info['recording_files']]) + # explicitly garbage collect "dataset" + # as these may have large memory footprint and may not be cleared fast enough + del probe_data, dataset + gc.collect() else: raise NotImplementedError(f'Processing ephys files from' f' acquisition software of type {acq_software} is' @@ -877,7 +882,14 @@ def get_openephys_probe_data(ephys_recording_key): session_dir = find_full_path(get_ephys_root_data_dir(), get_session_directory(ephys_recording_key)) loaded_oe = openephys.OpenEphys(session_dir) - return loaded_oe.probes[inserted_probe_serial_number] + probe_data = loaded_oe.probes[inserted_probe_serial_number] + + # explicitly garbage collect "loaded_oe" + # as these may have large memory footprint and may not be cleared fast enough + del loaded_oe + gc.collect() + + return probe_data def get_neuropixels_channel2electrode_map(ephys_recording_key, acq_software): diff --git a/element_array_ephys/readers/openephys.py b/element_array_ephys/readers/openephys.py index 0d39dd55..00718cf0 100644 --- a/element_array_ephys/readers/openephys.py +++ b/element_array_ephys/readers/openephys.py @@ -3,6 +3,9 @@ import numpy as np import re import datetime +import logging + +logger = logging.getLogger(__name__) """ @@ -155,6 +158,9 @@ def load_probe_data(self): float(rec.duration)) probe.recording_info['recording_files'].append( rec.absolute_foldername / 'continuous' / continuous_info['folder_name']) + elif continuous_type == 'lfp': + probe.recording_info['recording_lfp_files'].append( + rec.absolute_foldername / 'continuous' / continuous_info['folder_name']) meta = getattr(probe, continuous_type + '_meta') if not meta: @@ -229,7 +235,8 @@ def __init__(self, processor, probe_index=0): self.recording_info = {'recording_count': 0, 'recording_datetimes': [], 'recording_durations': [], - 'recording_files': []} + 'recording_files': [], + 'recording_lfp_files': []} self._ap_timeseries = None self._ap_timestamps = None @@ -303,3 +310,73 @@ def extract_spike_waveforms(self, spikes, channel_ind, n_wf=500, wf_win=(-32, 32 return spike_wfs else: # if no spike found, return NaN of size (sample x channel x 1) return np.full((len(range(*wf_win)), len(channel_ind), 1), np.nan) + + def compress(self): + from mtscomp import compress as mts_compress + + ap_dirs = self.recording_info['recording_files'] + lfp_dirs = self.recording_info['recording_lfp_files'] + + meta_mapping = {'ap': self.ap_meta, 'lfp': self.lfp_meta} + + compressed_files = [] + for continuous_dir, continuous_type in zip( + ap_dirs + lfp_dirs, + ['ap'] * len(ap_dirs) + ['lfp'] * len(lfp_dirs)): + dat_fp = continuous_dir / 'continuous.dat' + if not dat_fp.exists(): + raise FileNotFoundError(f'Compression error - "{dat_fp}" does not exist') + cdat_fp = continuous_dir / 'continuous.cdat' + ch_fp = continuous_dir / 'continuous.ch' + + if cdat_fp.exists(): + assert ch_fp.exists() + logger.info(f'Compressed file exists ({cdat_fp}), skipping...') + continue + + try: + mts_compress(dat_fp, cdat_fp, ch_fp, + sample_rate=meta_mapping[continuous_type]['sample_rate'], + n_channels=meta_mapping[continuous_type]['num_channels'], + dtype=np.memmap(dat_fp).dtype) + except Exception as e: + cdat_fp.unlink(missing_ok=True) + ch_fp.unlink(missing_ok=True) + raise e + else: + compressed_files.append((cdat_fp, ch_fp)) + + return compressed_files + + def decompress(self): + from mtscomp import decompress as mts_decompress + + ap_dirs = self.recording_info['recording_files'] + lfp_dirs = self.recording_info['recording_lfp_files'] + + decompressed_files = [] + for continuous_dir, continuous_type in zip( + ap_dirs + lfp_dirs, + ['ap'] * len(ap_dirs) + ['lfp'] * len(lfp_dirs)): + dat_fp = continuous_dir / 'continuous.dat' + + if dat_fp.exists(): + logger.info(f'Decompressed file exists ({dat_fp}), skipping...') + continue + + cdat_fp = continuous_dir / 'continuous.cdat' + ch_fp = continuous_dir / 'continuous.ch' + + if not cdat_fp.exists(): + raise FileNotFoundError(f'Decompression error - "{cdat_fp}" does not exist') + + try: + decomp_arr = mts_decompress(cdat_fp, ch_fp) + decomp_arr.tofile(dat_fp) + except Exception as e: + dat_fp.unlink(missing_ok=True) + raise e + else: + decompressed_files.append(dat_fp) + + return decompressed_files diff --git a/element_array_ephys/readers/spikeglx.py b/element_array_ephys/readers/spikeglx.py index a73103de..d877f52c 100644 --- a/element_array_ephys/readers/spikeglx.py +++ b/element_array_ephys/readers/spikeglx.py @@ -1,8 +1,10 @@ from datetime import datetime import numpy as np import pathlib +import logging from .utils import convert_to_number +logger = logging.getLogger(__name__) AP_GAIN = 80 # For NP 2.0 probes; APGain = 80 for all AP (LF is computed from AP) @@ -159,6 +161,69 @@ def validate_file(self, file_type='ap'): if file_size != meta.meta['fileSizeBytes']: raise IOError(f'File size error! {file_path} may be corrupted or in transfer?') + def compress(self): + from mtscomp import compress as mts_compress + + ap_file = self.root_dir / (self.root_name + '.ap.bin') + lfp_file = self.root_dir / (self.root_name + '.lf.bin') + + meta_mapping = {'ap': self.apmeta, 'lfp': self.lfmeta} + + compressed_files = [] + for bin_fp, band_type in zip([ap_file, lfp_file], ['ap', 'lfp']): + if not bin_fp.exists(): + raise FileNotFoundError(f'Compression error - "{bin_fp}" does not exist') + cbin_fp = bin_fp.parent / f'{bin_fp.stem}.cbin' + ch_fp = bin_fp.parent / f'{bin_fp.stem}.ch' + + if cbin_fp.exists(): + assert ch_fp.exists() + logger.info(f'Compressed file exists ({cbin_fp}), skipping...') + continue + + try: + mts_compress(bin_fp, cbin_fp, ch_fp, + sample_rate=meta_mapping[band_type]['sample_rate'], + n_channels=meta_mapping[band_type]['num_channels'], + dtype=np.memmap(bin_fp).dtype) + except Exception as e: + cbin_fp.unlink(missing_ok=True) + ch_fp.unlink(missing_ok=True) + raise e + else: + compressed_files.append((cbin_fp, ch_fp)) + + return compressed_files + + def decompress(self): + from mtscomp import decompress as mts_decompress + + ap_file = self.root_dir / (self.root_name + '.ap.bin') + lfp_file = self.root_dir / (self.root_name + '.lf.bin') + + decompressed_files = [] + for bin_fp, band_type in zip([ap_file, lfp_file], ['ap', 'lfp']): + if bin_fp.exists(): + logger.info(f'Decompressed file exists ({bin_fp}), skipping...') + continue + + cbin_fp = bin_fp.parent / f'{bin_fp.stem}.cbin' + ch_fp = bin_fp.parent / f'{bin_fp.stem}.ch' + + if not cbin_fp.exists(): + raise FileNotFoundError(f'Decompression error - "{cbin_fp}" does not exist') + + try: + decomp_arr = mts_decompress(cbin_fp, ch_fp) + decomp_arr.tofile(bin_fp) + except Exception as e: + bin_fp.unlink(missing_ok=True) + raise e + else: + decompressed_files.append(bin_fp) + + return decompressed_files + class SpikeGLXMeta: