Skip to content

Commit

Permalink
Merge pull request datajoint#142 from ttngu207/main
Browse files Browse the repository at this point in the history
Update kilosort_triggering.py
  • Loading branch information
kabilar authored Jun 26, 2023
2 parents 47dea95 + 5e1f055 commit 1d30cb8
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 69 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.

## [Unreleased] - 2023-06-23

+ Update - Improve kilosort triggering routine - better logging, remove temporary files, robust resumable processing

## [0.2.10] - 2023-05-26

+ Add - Kilosort, NWB, and DANDI citations
Expand Down
173 changes: 104 additions & 69 deletions element_array_ephys/readers/kilosort_triggering.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
from ecephys_spike_sorting.scripts.create_input_json import createInputJson
from ecephys_spike_sorting.scripts.helpers import SpikeGLX_utils
except Exception as e:
print(f'Error in loading "ecephys_spike_sorting" package - {str(e)}')
print(f'Warning: Failed loading "ecephys_spike_sorting" package - {str(e)}')

# import pykilosort package
try:
import pykilosort
except Exception as e:
print(f'Error in loading "pykilosort" package - {str(e)}')
print(f'Warning: Failed loading "pykilosort" package - {str(e)}')


class SGLXKilosortPipeline:
Expand Down Expand Up @@ -67,7 +67,6 @@ def __init__(
ni_present=False,
ni_extract_string=None,
):

self._npx_input_dir = pathlib.Path(npx_input_dir)

self._ks_output_dir = pathlib.Path(ks_output_dir)
Expand All @@ -85,6 +84,13 @@ def __init__(
self._json_directory = self._ks_output_dir / "json_configs"
self._json_directory.mkdir(parents=True, exist_ok=True)

self._module_input_json = (
self._json_directory / f"{self._npx_input_dir.name}-input.json"
)
self._module_logfile = (
self._json_directory / f"{self._npx_input_dir.name}-run_modules-log.txt"
)

self._CatGT_finished = False
self.ks_input_params = None
self._modules_input_hash = None
Expand Down Expand Up @@ -223,20 +229,20 @@ def generate_modules_input_json(self):
**params,
)

self._modules_input_hash = dict_to_uuid(self.ks_input_params)
self._modules_input_hash = dict_to_uuid(dict(self._params, KS2ver=self._KS2ver))

def run_modules(self):
def run_modules(self, modules_to_run=None):
if self._run_CatGT and not self._CatGT_finished:
self.run_CatGT()

print("---- Running Modules ----")
self.generate_modules_input_json()
module_input_json = self._module_input_json.as_posix()
module_logfile = module_input_json.replace(
"-input.json", "-run_modules-log.txt"
)
module_logfile = self._module_logfile.as_posix()

modules = modules_to_run or self._modules

for module in self._modules:
for module in modules:
module_status = self._get_module_status(module)
if module_status["completion_time"] is not None:
continue
Expand Down Expand Up @@ -312,13 +318,11 @@ def _update_module_status(self, updated_module_status={}):
else:
# handle cases of processing rerun on different parameters (the hash changes)
# delete outdated files
outdated_files = [
f
[
f.unlink()
for f in self._json_directory.glob("*")
if f.is_file() and f.name != self._module_input_json.name
]
for f in outdated_files:
f.unlink()

modules_status = {
module: {"start_time": None, "completion_time": None, "duration": None}
Expand Down Expand Up @@ -371,14 +375,26 @@ def _update_total_duration(self):
for k, v in modules_status.items()
if k not in ("cumulative_execution_duration", "total_duration")
)

for m in self._modules:
first_start_time = modules_status[m]["start_time"]
if first_start_time is not None:
break

for m in self._modules[::-1]:
last_completion_time = modules_status[m]["completion_time"]
if last_completion_time is not None:
break

if first_start_time is None or last_completion_time is None:
return

total_duration = (
datetime.strptime(
modules_status[self._modules[-1]]["completion_time"],
last_completion_time,
"%Y-%m-%d %H:%M:%S.%f",
)
- datetime.strptime(
modules_status[self._modules[0]]["start_time"], "%Y-%m-%d %H:%M:%S.%f"
)
- datetime.strptime(first_start_time, "%Y-%m-%d %H:%M:%S.%f")
).total_seconds()
self._update_module_status(
{
Expand Down Expand Up @@ -414,7 +430,6 @@ class OpenEphysKilosortPipeline:
def __init__(
self, npx_input_dir: str, ks_output_dir: str, params: dict, KS2ver: str
):

self._npx_input_dir = pathlib.Path(npx_input_dir)

self._ks_output_dir = pathlib.Path(ks_output_dir)
Expand All @@ -426,7 +441,13 @@ def __init__(
self._json_directory = self._ks_output_dir / "json_configs"
self._json_directory.mkdir(parents=True, exist_ok=True)

self._median_subtraction_status = {}
self._module_input_json = (
self._json_directory / f"{self._npx_input_dir.name}-input.json"
)
self._module_logfile = (
self._json_directory / f"{self._npx_input_dir.name}-run_modules-log.txt"
)

self.ks_input_params = None
self._modules_input_hash = None
self._modules_input_hash_fp = None
Expand All @@ -451,9 +472,6 @@ def make_chanmap_file(self):

def generate_modules_input_json(self):
self.make_chanmap_file()
self._module_input_json = (
self._json_directory / f"{self._npx_input_dir.name}-input.json"
)

continuous_file = self._get_raw_data_filepaths()

Expand Down Expand Up @@ -497,35 +515,37 @@ def generate_modules_input_json(self):
**params,
)

self._modules_input_hash = dict_to_uuid(self.ks_input_params)
self._modules_input_hash = dict_to_uuid(dict(self._params, KS2ver=self._KS2ver))

def run_modules(self):
def run_modules(self, modules_to_run=None):
print("---- Running Modules ----")
self.generate_modules_input_json()
module_input_json = self._module_input_json.as_posix()
module_logfile = module_input_json.replace(
"-input.json", "-run_modules-log.txt"
)
module_logfile = self._module_logfile.as_posix()

for module in self._modules:
modules = modules_to_run or self._modules

for module in modules:
module_status = self._get_module_status(module)
if module_status["completion_time"] is not None:
continue

if module == "median_subtraction" and self._median_subtraction_status:
median_subtraction_status = self._get_module_status(
"median_subtraction"
)
median_subtraction_status["duration"] = self._median_subtraction_status[
"duration"
]
median_subtraction_status["completion_time"] = datetime.strptime(
median_subtraction_status["start_time"], "%Y-%m-%d %H:%M:%S.%f"
) + timedelta(seconds=median_subtraction_status["duration"])
self._update_module_status(
{"median_subtraction": median_subtraction_status}
if module == "median_subtraction":
median_subtraction_duration = (
self._get_median_subtraction_duration_from_log()
)
continue
if median_subtraction_duration is not None:
median_subtraction_status = self._get_module_status(
"median_subtraction"
)
median_subtraction_status["duration"] = median_subtraction_duration
median_subtraction_status["completion_time"] = datetime.strptime(
median_subtraction_status["start_time"], "%Y-%m-%d %H:%M:%S.%f"
) + timedelta(seconds=median_subtraction_status["duration"])
self._update_module_status(
{"median_subtraction": median_subtraction_status}
)
continue

module_output_json = self._get_module_output_json_filename(module)
command = [
Expand Down Expand Up @@ -576,26 +596,11 @@ def _get_raw_data_filepaths(self):
assert "depth_estimation" in self._modules
continuous_file = self._ks_output_dir / "continuous.dat"
if continuous_file.exists():
if raw_ap_fp.stat().st_mtime < continuous_file.stat().st_mtime:
# if the copied continuous.dat was actually modified,
# median_subtraction may have been completed - let's check
module_input_json = self._module_input_json.as_posix()
module_logfile = module_input_json.replace(
"-input.json", "-run_modules-log.txt"
)
with open(module_logfile, "r") as f:
previous_line = ""
for line in f.readlines():
if line.startswith(
"ecephys spike sorting: median subtraction module"
) and previous_line.startswith("Total processing time:"):
# regex to search for the processing duration - a float value
duration = int(
re.search("\d+\.?\d+", previous_line).group()
)
self._median_subtraction_status["duration"] = duration
return continuous_file
previous_line = line
if raw_ap_fp.stat().st_mtime == continuous_file.stat().st_mtime:
return continuous_file
else:
if self._module_logfile.exists():
return continuous_file

shutil.copy2(raw_ap_fp, continuous_file)
return continuous_file
Expand All @@ -614,13 +619,11 @@ def _update_module_status(self, updated_module_status={}):
else:
# handle cases of processing rerun on different parameters (the hash changes)
# delete outdated files
outdated_files = [
f
[
f.unlink()
for f in self._json_directory.glob("*")
if f.is_file() and f.name != self._module_input_json.name
]
for f in outdated_files:
f.unlink()

modules_status = {
module: {"start_time": None, "completion_time": None, "duration": None}
Expand Down Expand Up @@ -673,14 +676,26 @@ def _update_total_duration(self):
for k, v in modules_status.items()
if k not in ("cumulative_execution_duration", "total_duration")
)

for m in self._modules:
first_start_time = modules_status[m]["start_time"]
if first_start_time is not None:
break

for m in self._modules[::-1]:
last_completion_time = modules_status[m]["completion_time"]
if last_completion_time is not None:
break

if first_start_time is None or last_completion_time is None:
return

total_duration = (
datetime.strptime(
modules_status[self._modules[-1]]["completion_time"],
last_completion_time,
"%Y-%m-%d %H:%M:%S.%f",
)
- datetime.strptime(
modules_status[self._modules[0]]["start_time"], "%Y-%m-%d %H:%M:%S.%f"
)
- datetime.strptime(first_start_time, "%Y-%m-%d %H:%M:%S.%f")
).total_seconds()
self._update_module_status(
{
Expand All @@ -689,6 +704,26 @@ def _update_total_duration(self):
}
)

def _get_median_subtraction_duration_from_log(self):
raw_ap_fp = self._npx_input_dir / "continuous.dat"
continuous_file = self._ks_output_dir / "continuous.dat"
if raw_ap_fp.stat().st_mtime < continuous_file.stat().st_mtime:
# if the copied continuous.dat was actually modified,
# median_subtraction may have been completed - let's check
if self._module_logfile.exists():
with open(self._module_logfile, "r") as f:
previous_line = ""
for line in f.readlines():
if line.startswith(
"ecephys spike sorting: median subtraction module"
) and previous_line.startswith("Total processing time:"):
# regex to search for the processing duration - a float value
duration = int(
re.search("\d+\.?\d+", previous_line).group()
)
return duration
previous_line = line


def run_pykilosort(
continuous_file,
Expand Down

0 comments on commit 1d30cb8

Please sign in to comment.