Skip to content

Commit

Permalink
Merge pull request #88 from ttngu207/no-curation
Browse files Browse the repository at this point in the history
overall code cleanup/improvement for more robust and optimal kilosort run
  • Loading branch information
kabilar authored Jul 25, 2022
2 parents d9c3887 + fd331bd commit ad8436e
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 11 deletions.
3 changes: 3 additions & 0 deletions element_array_ephys/ephys_acute.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,9 @@ def make(self, key):
raise FileNotFoundError(
'No Open Ephys data found for probe insertion: {}'.format(key))

if not probe_data.ap_meta:
raise IOError('No analog signals found - check "structure.oebin" file or "continuous" directory')

if probe_data.probe_model in supported_probe_types:
probe_type = probe_data.probe_model
electrode_query = probe.ProbeType.Electrode & {'probe_type': probe_type}
Expand Down
3 changes: 3 additions & 0 deletions element_array_ephys/ephys_chronic.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ def make(self, key):
raise FileNotFoundError(
'No Open Ephys data found for probe insertion: {}'.format(key))

if not probe_data.ap_meta:
raise IOError('No analog signals found - check "structure.oebin" file or "continuous" directory')

if probe_data.probe_model in supported_probe_types:
probe_type = probe_data.probe_model
electrode_query = probe.ProbeType.Electrode & {'probe_type': probe_type}
Expand Down
3 changes: 3 additions & 0 deletions element_array_ephys/ephys_no_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ def make(self, key):
raise FileNotFoundError(
'No Open Ephys data found for probe insertion: {}'.format(key))

if not probe_data.ap_meta:
raise IOError('No analog signals found - check "structure.oebin" file or "continuous" directory')

if probe_data.probe_model in supported_probe_types:
probe_type = probe_data.probe_model
electrode_query = probe.ProbeType.Electrode & {'probe_type': probe_type}
Expand Down
98 changes: 88 additions & 10 deletions element_array_ephys/readers/kilosort_triggering.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import scipy.io
import numpy as np
from datetime import datetime
from datetime import datetime, timedelta

from element_interface.utils import dict_to_uuid

Expand Down Expand Up @@ -191,14 +191,17 @@ def run_modules(self):
if module_status['completion_time'] is not None:
continue

module_output_json = module_input_json.replace('-input.json',
'-' + module + '-output.json')
module_output_json = self._get_module_output_json_filename(module)
command = (sys.executable
+ " -W ignore -m ecephys_spike_sorting.modules." + module
+ " --input_json " + module_input_json
+ " --output_json " + module_output_json)

start_time = datetime.utcnow()
self._update_module_status(
{module: {'start_time': start_time,
'completion_time': None,
'duration': None}})
with open(module_logfile, "a") as f:
subprocess.check_call(command.split(' '), stdout=f)
completion_time = datetime.utcnow()
Expand All @@ -207,6 +210,8 @@ def run_modules(self):
'completion_time': completion_time,
'duration': (completion_time - start_time).total_seconds()}})

self._update_total_duration()

def _get_raw_data_filepaths(self):
session_str, gate_str, _, probe_str = self.parse_input_filename()

Expand Down Expand Up @@ -248,10 +253,44 @@ def _get_module_status(self, module):
if self._modules_input_hash_fp.exists():
with open(self._modules_input_hash_fp) as f:
modules_status = json.load(f)
if modules_status[module]['completion_time'] is None:
# additional logic to read from the "-output.json" file for this module as well
# handle cases where the module has finished successfully,
# but the "_modules_input_hash_fp" is not updated (for whatever reason),
# resulting in this module not registered as completed in the "_modules_input_hash_fp"
module_output_json_fp = pathlib.Path(self._get_module_output_json_filename(module))
if module_output_json_fp.exists():
with open(module_output_json_fp) as f:
module_run_output = json.load(f)
modules_status[module]['duration'] = module_run_output['execution_time']
modules_status[module]['completion_time'] = (
datetime.strptime(modules_status[module]['start_time'], '%Y-%m-%d %H:%M:%S.%f')
+ timedelta(seconds=module_run_output['execution_time']))
return modules_status[module]

return {'start_time': None, 'completion_time': None, 'duration': None}

def _get_module_output_json_filename(self, module):
module_input_json = self._module_input_json.as_posix()
module_output_json = module_input_json.replace(
'-input.json',
'-' + module + '-' + str(self._modules_input_hash) + '-output.json')
return module_output_json

def _update_total_duration(self):
with open(self._modules_input_hash_fp) as f:
modules_status = json.load(f)
cumulative_execution_duration = sum(
v['duration'] or 0 for k, v in modules_status.items()
if k not in ('cumulative_execution_duration', 'total_duration'))
total_duration = (
datetime.strptime(modules_status[self._modules[-1]]['completion_time'], '%Y-%m-%d %H:%M:%S.%f')
- datetime.strptime(modules_status[self._modules[0]]['start_time'], '%Y-%m-%d %H:%M:%S.%f')
).total_seconds()
self._update_module_status(
{'cumulative_execution_duration': cumulative_execution_duration,
'total_duration': total_duration})


class OpenEphysKilosortPipeline:
"""
Expand Down Expand Up @@ -353,22 +392,27 @@ def run_modules(self):
if module_status['completion_time'] is not None:
continue

module_output_json = module_input_json.replace('-input.json',
'-' + module + '-output.json')
command = (sys.executable
+ " -W ignore -m ecephys_spike_sorting.modules." + module
+ " --input_json " + module_input_json
+ " --output_json " + module_output_json)
module_output_json = self._get_module_output_json_filename(module)
command = [sys.executable,
'-W', 'ignore', '-m', 'ecephys_spike_sorting.modules.' + module,
'--input_json', module_input_json,
'--output_json', module_output_json]

start_time = datetime.utcnow()
self._update_module_status(
{module: {'start_time': start_time,
'completion_time': None,
'duration': None}})
with open(module_logfile, "a") as f:
subprocess.check_call(command.split(' '), stdout=f)
subprocess.check_call(command, stdout=f)
completion_time = datetime.utcnow()
self._update_module_status(
{module: {'start_time': start_time,
'completion_time': completion_time,
'duration': (completion_time - start_time).total_seconds()}})

self._update_total_duration()

def _update_module_status(self, updated_module_status={}):
if self._modules_input_hash is None:
raise RuntimeError('"generate_modules_input_json()" not yet performed!')
Expand All @@ -393,10 +437,44 @@ def _get_module_status(self, module):
if self._modules_input_hash_fp.exists():
with open(self._modules_input_hash_fp) as f:
modules_status = json.load(f)
if modules_status[module]['completion_time'] is None:
# additional logic to read from the "-output.json" file for this module as well
# handle cases where the module has finished successfully,
# but the "_modules_input_hash_fp" is not updated (for whatever reason),
# resulting in this module not registered as completed in the "_modules_input_hash_fp"
module_output_json_fp = pathlib.Path(self._get_module_output_json_filename(module))
if module_output_json_fp.exists():
with open(module_output_json_fp) as f:
module_run_output = json.load(f)
modules_status[module]['duration'] = module_run_output['execution_time']
modules_status[module]['completion_time'] = (
datetime.strptime(modules_status[module]['start_time'], '%Y-%m-%d %H:%M:%S.%f')
+ timedelta(seconds=module_run_output['execution_time']))
return modules_status[module]

return {'start_time': None, 'completion_time': None, 'duration': None}

def _get_module_output_json_filename(self, module):
module_input_json = self._module_input_json.as_posix()
module_output_json = module_input_json.replace(
'-input.json',
'-' + module + '-' + str(self._modules_input_hash) + '-output.json')
return module_output_json

def _update_total_duration(self):
with open(self._modules_input_hash_fp) as f:
modules_status = json.load(f)
cumulative_execution_duration = sum(
v['duration'] or 0 for k, v in modules_status.items()
if k not in ('cumulative_execution_duration', 'total_duration'))
total_duration = (
datetime.strptime(modules_status[self._modules[-1]]['completion_time'], '%Y-%m-%d %H:%M:%S.%f')
- datetime.strptime(modules_status[self._modules[0]]['start_time'], '%Y-%m-%d %H:%M:%S.%f')
).total_seconds()
self._update_module_status(
{'cumulative_execution_duration': cumulative_execution_duration,
'total_duration': total_duration})


def run_pykilosort(continuous_file, kilosort_output_directory, params,
channel_ind, x_coords, y_coords, shank_ind, connected, sample_rate):
Expand Down
2 changes: 1 addition & 1 deletion element_array_ephys/readers/openephys.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def load_probe_data(self):
else:
continue # not continuous data for the current probe
else:
raise ValueError(f'Unable to infer type (AP or LFP) for the continuous data from:\n\t{continuous_info}')
raise ValueError(f'Unable to infer type (AP or LFP) for the continuous data from:\n\t{continuous_info["folder_name"]}')

if continuous_type == 'ap':
probe.recording_info['recording_count'] += 1
Expand Down

0 comments on commit ad8436e

Please sign in to comment.