diff --git a/CHANGELOG.md b/CHANGELOG.md index 1bed6341..11579ee1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,10 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. +## [Unreleased] - 2023-06-23 + ++ Update - Improve kilosort triggering routine - better logging, remove temporary files, robust resumable processing + ## [0.2.10] - 2023-05-26 + Add - Kilosort, NWB, and DANDI citations diff --git a/element_array_ephys/readers/kilosort_triggering.py b/element_array_ephys/readers/kilosort_triggering.py index b78af832..8e180be9 100644 --- a/element_array_ephys/readers/kilosort_triggering.py +++ b/element_array_ephys/readers/kilosort_triggering.py @@ -20,13 +20,13 @@ from ecephys_spike_sorting.scripts.create_input_json import createInputJson from ecephys_spike_sorting.scripts.helpers import SpikeGLX_utils except Exception as e: - print(f'Error in loading "ecephys_spike_sorting" package - {str(e)}') + print(f'Warning: Failed loading "ecephys_spike_sorting" package - {str(e)}') # import pykilosort package try: import pykilosort except Exception as e: - print(f'Error in loading "pykilosort" package - {str(e)}') + print(f'Warning: Failed loading "pykilosort" package - {str(e)}') class SGLXKilosortPipeline: @@ -67,7 +67,6 @@ def __init__( ni_present=False, ni_extract_string=None, ): - self._npx_input_dir = pathlib.Path(npx_input_dir) self._ks_output_dir = pathlib.Path(ks_output_dir) @@ -85,6 +84,13 @@ def __init__( self._json_directory = self._ks_output_dir / "json_configs" self._json_directory.mkdir(parents=True, exist_ok=True) + self._module_input_json = ( + self._json_directory / f"{self._npx_input_dir.name}-input.json" + ) + self._module_logfile = ( + self._json_directory / f"{self._npx_input_dir.name}-run_modules-log.txt" + ) + self._CatGT_finished = False self.ks_input_params = None self._modules_input_hash = None @@ -223,20 +229,20 @@ def generate_modules_input_json(self): **params, ) - self._modules_input_hash = dict_to_uuid(self.ks_input_params) + self._modules_input_hash = dict_to_uuid(dict(self._params, KS2ver=self._KS2ver)) - def run_modules(self): + def run_modules(self, modules_to_run=None): if self._run_CatGT and not self._CatGT_finished: self.run_CatGT() print("---- Running Modules ----") self.generate_modules_input_json() module_input_json = self._module_input_json.as_posix() - module_logfile = module_input_json.replace( - "-input.json", "-run_modules-log.txt" - ) + module_logfile = self._module_logfile.as_posix() + + modules = modules_to_run or self._modules - for module in self._modules: + for module in modules: module_status = self._get_module_status(module) if module_status["completion_time"] is not None: continue @@ -312,13 +318,11 @@ def _update_module_status(self, updated_module_status={}): else: # handle cases of processing rerun on different parameters (the hash changes) # delete outdated files - outdated_files = [ - f + [ + f.unlink() for f in self._json_directory.glob("*") if f.is_file() and f.name != self._module_input_json.name ] - for f in outdated_files: - f.unlink() modules_status = { module: {"start_time": None, "completion_time": None, "duration": None} @@ -371,14 +375,26 @@ def _update_total_duration(self): for k, v in modules_status.items() if k not in ("cumulative_execution_duration", "total_duration") ) + + for m in self._modules: + first_start_time = modules_status[m]["start_time"] + if first_start_time is not None: + break + + for m in self._modules[::-1]: + last_completion_time = modules_status[m]["completion_time"] + if last_completion_time is not None: + break + + if first_start_time is None or last_completion_time is None: + return + total_duration = ( datetime.strptime( - modules_status[self._modules[-1]]["completion_time"], + last_completion_time, "%Y-%m-%d %H:%M:%S.%f", ) - - datetime.strptime( - modules_status[self._modules[0]]["start_time"], "%Y-%m-%d %H:%M:%S.%f" - ) + - datetime.strptime(first_start_time, "%Y-%m-%d %H:%M:%S.%f") ).total_seconds() self._update_module_status( { @@ -414,7 +430,6 @@ class OpenEphysKilosortPipeline: def __init__( self, npx_input_dir: str, ks_output_dir: str, params: dict, KS2ver: str ): - self._npx_input_dir = pathlib.Path(npx_input_dir) self._ks_output_dir = pathlib.Path(ks_output_dir) @@ -426,7 +441,13 @@ def __init__( self._json_directory = self._ks_output_dir / "json_configs" self._json_directory.mkdir(parents=True, exist_ok=True) - self._median_subtraction_status = {} + self._module_input_json = ( + self._json_directory / f"{self._npx_input_dir.name}-input.json" + ) + self._module_logfile = ( + self._json_directory / f"{self._npx_input_dir.name}-run_modules-log.txt" + ) + self.ks_input_params = None self._modules_input_hash = None self._modules_input_hash_fp = None @@ -451,9 +472,6 @@ def make_chanmap_file(self): def generate_modules_input_json(self): self.make_chanmap_file() - self._module_input_json = ( - self._json_directory / f"{self._npx_input_dir.name}-input.json" - ) continuous_file = self._get_raw_data_filepaths() @@ -497,35 +515,37 @@ def generate_modules_input_json(self): **params, ) - self._modules_input_hash = dict_to_uuid(self.ks_input_params) + self._modules_input_hash = dict_to_uuid(dict(self._params, KS2ver=self._KS2ver)) - def run_modules(self): + def run_modules(self, modules_to_run=None): print("---- Running Modules ----") self.generate_modules_input_json() module_input_json = self._module_input_json.as_posix() - module_logfile = module_input_json.replace( - "-input.json", "-run_modules-log.txt" - ) + module_logfile = self._module_logfile.as_posix() - for module in self._modules: + modules = modules_to_run or self._modules + + for module in modules: module_status = self._get_module_status(module) if module_status["completion_time"] is not None: continue - if module == "median_subtraction" and self._median_subtraction_status: - median_subtraction_status = self._get_module_status( - "median_subtraction" - ) - median_subtraction_status["duration"] = self._median_subtraction_status[ - "duration" - ] - median_subtraction_status["completion_time"] = datetime.strptime( - median_subtraction_status["start_time"], "%Y-%m-%d %H:%M:%S.%f" - ) + timedelta(seconds=median_subtraction_status["duration"]) - self._update_module_status( - {"median_subtraction": median_subtraction_status} + if module == "median_subtraction": + median_subtraction_duration = ( + self._get_median_subtraction_duration_from_log() ) - continue + if median_subtraction_duration is not None: + median_subtraction_status = self._get_module_status( + "median_subtraction" + ) + median_subtraction_status["duration"] = median_subtraction_duration + median_subtraction_status["completion_time"] = datetime.strptime( + median_subtraction_status["start_time"], "%Y-%m-%d %H:%M:%S.%f" + ) + timedelta(seconds=median_subtraction_status["duration"]) + self._update_module_status( + {"median_subtraction": median_subtraction_status} + ) + continue module_output_json = self._get_module_output_json_filename(module) command = [ @@ -576,26 +596,11 @@ def _get_raw_data_filepaths(self): assert "depth_estimation" in self._modules continuous_file = self._ks_output_dir / "continuous.dat" if continuous_file.exists(): - if raw_ap_fp.stat().st_mtime < continuous_file.stat().st_mtime: - # if the copied continuous.dat was actually modified, - # median_subtraction may have been completed - let's check - module_input_json = self._module_input_json.as_posix() - module_logfile = module_input_json.replace( - "-input.json", "-run_modules-log.txt" - ) - with open(module_logfile, "r") as f: - previous_line = "" - for line in f.readlines(): - if line.startswith( - "ecephys spike sorting: median subtraction module" - ) and previous_line.startswith("Total processing time:"): - # regex to search for the processing duration - a float value - duration = int( - re.search("\d+\.?\d+", previous_line).group() - ) - self._median_subtraction_status["duration"] = duration - return continuous_file - previous_line = line + if raw_ap_fp.stat().st_mtime == continuous_file.stat().st_mtime: + return continuous_file + else: + if self._module_logfile.exists(): + return continuous_file shutil.copy2(raw_ap_fp, continuous_file) return continuous_file @@ -614,13 +619,11 @@ def _update_module_status(self, updated_module_status={}): else: # handle cases of processing rerun on different parameters (the hash changes) # delete outdated files - outdated_files = [ - f + [ + f.unlink() for f in self._json_directory.glob("*") if f.is_file() and f.name != self._module_input_json.name ] - for f in outdated_files: - f.unlink() modules_status = { module: {"start_time": None, "completion_time": None, "duration": None} @@ -673,14 +676,26 @@ def _update_total_duration(self): for k, v in modules_status.items() if k not in ("cumulative_execution_duration", "total_duration") ) + + for m in self._modules: + first_start_time = modules_status[m]["start_time"] + if first_start_time is not None: + break + + for m in self._modules[::-1]: + last_completion_time = modules_status[m]["completion_time"] + if last_completion_time is not None: + break + + if first_start_time is None or last_completion_time is None: + return + total_duration = ( datetime.strptime( - modules_status[self._modules[-1]]["completion_time"], + last_completion_time, "%Y-%m-%d %H:%M:%S.%f", ) - - datetime.strptime( - modules_status[self._modules[0]]["start_time"], "%Y-%m-%d %H:%M:%S.%f" - ) + - datetime.strptime(first_start_time, "%Y-%m-%d %H:%M:%S.%f") ).total_seconds() self._update_module_status( { @@ -689,6 +704,26 @@ def _update_total_duration(self): } ) + def _get_median_subtraction_duration_from_log(self): + raw_ap_fp = self._npx_input_dir / "continuous.dat" + continuous_file = self._ks_output_dir / "continuous.dat" + if raw_ap_fp.stat().st_mtime < continuous_file.stat().st_mtime: + # if the copied continuous.dat was actually modified, + # median_subtraction may have been completed - let's check + if self._module_logfile.exists(): + with open(self._module_logfile, "r") as f: + previous_line = "" + for line in f.readlines(): + if line.startswith( + "ecephys spike sorting: median subtraction module" + ) and previous_line.startswith("Total processing time:"): + # regex to search for the processing duration - a float value + duration = int( + re.search("\d+\.?\d+", previous_line).group() + ) + return duration + previous_line = line + def run_pykilosort( continuous_file,