From 2fd57b712a0ec4bdf6e2d680a243c4723ea64aaf Mon Sep 17 00:00:00 2001 From: alejoe91 Date: Tue, 11 Jul 2023 13:07:50 +0000 Subject: [PATCH 01/84] Use spikeinterface functions for DeepInteprolation training and inference and cleanup full analysis --- scripts/run_full_analysis.py | 505 +++++++++++++++++++---------- scripts/utils.py | 181 ----------- src/deepinterpolation_recording.py | 150 +++++---- src/spikeinterface_generator.py | 117 +++---- 4 files changed, 496 insertions(+), 457 deletions(-) delete mode 100644 scripts/utils.py diff --git a/scripts/run_full_analysis.py b/scripts/run_full_analysis.py index ceaab85..905775c 100644 --- a/scripts/run_full_analysis.py +++ b/scripts/run_full_analysis.py @@ -1,9 +1,17 @@ +import warnings + +warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", category=DeprecationWarning) + + #### IMPORTS ####### import os +import sys import numpy as np from pathlib import Path -from numba import cuda +from numba import cuda import pandas as pd +import time # SpikeInterface @@ -15,13 +23,11 @@ import spikeinterface.comparison as sc import spikeinterface.qualitymetrics as sqm -from utils import train_di_model +base_path = Path("../../..") ##### DEFINE DATASETS AND FOLDERS ####### -DATASET_BUCKET = "s3://aind-benchmark-data/ephys-compression/aind-np2/" - sessions = [ "595262_2022-02-21_15-18-07_ProbeA", "602454_2022-03-22_16-30-03_ProbeB", @@ -32,39 +38,38 @@ "618384_2022-04-14_15-11-00_ProbeB", "621362_2022-07-14_11-19-36_ProbeA", ] -sessions = sessions[:1] n_jobs = 16 -data_folder = Path("../data") -results_folder = Path("../results") +job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") + +data_folder = base_path / "data" +scratch_folder = base_path / "scratch" +results_folder = base_path / "results" + -DEBUG = True +# DATASET_BUCKET = "s3://aind-benchmark-data/ephys-compression/aind-np2/" +DATASET_BUCKET = data_folder / "ephys-compression-benchmark" / "aind-np2" + +DEBUG = False +NUM_DEBUG_SESSIONS = 2 DEBUG_DURATION = 20 ##### DEFINE PARAMS ##### OVERWRITE = False +USE_GPU = True FULL_INFERENCE = True # Define training and testing constants (@Jad you can gradually increase this) -if DEBUG: - TRAINING_START_S = 0 - TRAINING_END_S = 0.5 - TESTING_START_S = 10 - TESTING_END_S = 10.05 -else: - TRAINING_START_S = 0 - TRAINING_END_S = 20 - TESTING_START_S = 70 - TESTING_END_S = 70.5 - -FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" + + +FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" # DI params pre_frame = 30 post_frame = 30 pre_post_omission = 1 desired_shape = (192, 2) -inference_n_jobs = 1 # TODO: Jad - try more jobs +inference_n_jobs = 8 inference_chunk_duration = "50ms" di_kwargs = dict( @@ -72,157 +77,329 @@ post_frame=post_frame, pre_post_omission=pre_post_omission, desired_shape=desired_shape, - inference_n_jobs=inference_n_jobs, - inference_chunk_duration=inference_chunk_duration ) -sorter_name = "kilosort2_5" +sorter_name = "pykilosort" singularity_image = False match_score = 0.7 -#### START #### -raw_data_folder = data_folder / "raw" -raw_data_folder.mkdir(exist_ok=True) - -session_level_results = pd.DataFrame(columns=['session', 'filter_option', 'di', "num_units", "sorting_path"]) - -unit_level_results = pd.DataFrame(columns=['session', 'filter_option', 'di', "unit_index", - "unit_id_no_di", "unit_id_di"]) -for session in sessions: - print(f"Analyzing session {session}") - - # download dataset - dst_folder = (raw_data_folder / session) - dst_folder.mkdir(exist_ok=True) - - src_folder = f"{DATASET_BUCKET}{session}" - - cmd = f"aws s3 sync {src_folder} {dst_folder}" - # aws command to download - os.system(cmd) - - recording_folder = dst_folder - recording = si.load_extractor(recording_folder) - if DEBUG: - recording = recording.frame_slice(start_frame=0, end_frame=int(DEBUG_DURATION * recording.sampling_frequency)) - - results_dict = {} - for filter_option in FILTER_OPTIONS: - print(f"\tFilter option: {filter_option}") - results_dict[filter_option] = {} - # train DI models - print(f"\t\tTraning DI") - training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) - testing_time = np.round(TESTING_END_S - TESTING_START_S, 3) - model_name = f"{filter_option}_t{training_time}s_v{testing_time}s" - recording_no_di, recording_di = train_di_model(recording, session, filter_option, - TRAINING_START_S, TRAINING_END_S, - TESTING_START_S, TESTING_END_S, - data_folder, FULL_INFERENCE, model_name, - di_kwargs, overwrite=OVERWRITE) - results_dict[filter_option]["recording_no_di"] = recording_no_di - results_dict[filter_option]["recording_di"] = recording_di - - # release GPU memory - device = cuda.get_current_device() - device.reset() - - # run spike sorting - sorting_output_folder = data_folder / "sortings" / session - sorting_output_folder.mkdir(parents=True, exist_ok=True) - - recording_no_di = results_dict[filter_option]["recording_no_di"] - if (sorting_output_folder / f"no_di_{model_name}").is_dir() and not OVERWRITE: - print("\t\tLoading NO DI sorting") - sorting_no_di = si.load_extractor(sorting_output_folder / f"no_di_{model_name}") - else: - print(f"\t\tSpike sorting NO DI with {sorter_name}") - sorting_no_di = ss.run_sorter(sorter_name, recording=recording_no_di, - n_jobs=n_jobs, verbose=True, singularity_image=singularity_image) - sorting_no_di = sorting_no_di.save(folder=sorting_output_folder / f"no_di_{model_name}") - results_dict[filter_option]["sorting_no_di"] = sorting_no_di - - recording_di = results_dict[filter_option]["recording_di"] - if (sorting_output_folder / f"di_{model_name}").is_dir() and not OVERWRITE: - print("\t\tLoading DI sorting") - sorting_di = si.load_extractor(sorting_output_folder / f"di_{model_name}") +if __name__ == "__main__": + if len(sys.argv) == 2: + if sys.argv[1] == "true": + DEBUG = True else: - print(f"\t\tSpike sorting DI with {sorter_name}") - sorting_di = ss.run_sorter(sorter_name, recording=recording_di, - n_jobs=n_jobs, verbose=True, singularity_image=singularity_image) - sorting_di = sorting_di.save(folder=sorting_output_folder / f"di_{model_name}") - results_dict[filter_option]["sorting_di"] = sorting_di - - # TODO: Jad - compute waveforms and quality metrics (https://spikeinterface.readthedocs.io/en/latest/how_to/get_started.html) - - ## add entries to session-level results - session_level_results.append({"session": session, "filter_option": filter_option, - "di": False, "num_units": len(sorting_no_di.unit_ids), - "sorting_path": str((sorting_output_folder / f"no_di_{model_name}").absolute())}, - ignore_index=True) - session_level_results.append({"session": session, "filter_option": filter_option, - "di": True, "num_units": len(sorting_di.unit_ids), - "sorting_path": str((sorting_output_folder / f"di_{model_name}").absolute())}, - ignore_index=True) - - # compare outputs - print("\t\tComparing sortings") - cmp = sc.compare_two_sorters(sorting1=sorting_no_di, sorting2=sorting_di, - sorting1_name="no_di", sorting2_name="di", - match_score=match_score) - matched_units = cmp.get_matching()[0] - matched_unit_ids_no_di = matched_units.index.values.astype(int) - matched_unit_ids_di = matched_units.values.astype(int) - matched_units_valid = matched_unit_ids_di != -1 - matched_unit_ids_no_di = matched_unit_ids_no_di[matched_units_valid] - matched_unit_ids_di = matched_unit_ids_di[matched_units_valid] - sorting_no_di_matched = sorting_no_di.select_units(unit_ids=matched_unit_ids_no_di) - sorting_di_matched = sorting_di.select_units(unit_ids=matched_unit_ids_di) - - waveforms_folder = data_folder / "waveforms" / session - waveforms_folder.mkdir(exist_ok=True, parents=True) - - if (waveforms_folder / f"no_di_{model_name}").is_dir() and not OVERWRITE: - print("\t\tLoad NO DI waveforms") - we_no_di = si.load_waveforms(waveforms_folder / f"no_di_{model_name}") - else: - print("\t\tCompute NO DI waveforms") - we_no_di = si.extract_waveforms(recording_no_di, sorting_no_di_matched, - folder=waveforms_folder / f"no_di_{model_name}", - n_jobs=n_jobs, overwrite=True) - results_dict[filter_option]["we_no_di"] = we_no_di - - if (waveforms_folder / f"di_{model_name}").is_dir() and not OVERWRITE: - print("\t\tLoad DI waveforms") - we_di = si.load_waveforms(waveforms_folder / f"di_{model_name}") - else: - print("\t\tCompute DI waveforms") - we_di = si.extract_waveforms(recording_di, sorting_di_matched, - folder=waveforms_folder / f"di_{model_name}", - n_jobs=n_jobs, overwrite=True) - results_dict[filter_option]["we_di"] = we_di - - # compute metrics - if we_no_di.is_extension("quality_metrics") and not OVERWRITE: - print("\t\tLoad NO DI metrics") - qm_no_di = we_no_di.load_extension("quality_metrics").get_data() + DEBUG = False + if DEBUG: + TRAINING_START_S = 0 + TRAINING_END_S = 0.2 + TESTING_START_S = 10 + TESTING_END_S = 10.05 + sessions = sessions[:NUM_DEBUG_SESSIONS] + OVERWRITE = True + else: + TRAINING_START_S = 0 + TRAINING_END_S = 20 + TESTING_START_S = 70 + TESTING_END_S = 70.5 + OVERWRITE = False + + si.set_global_job_kwargs(**job_kwargs) + + #### START #### + session_level_results = pd.DataFrame( + columns=[ + "session", + "probe", + "filter_option", + "num_units", + "num_units_di", + "sorting_path", + "sorting_path_di", + "num_match", + ] + ) + + unit_level_results_columns = [ + "session", + "probe", + "filter_option", + "unit_index", + "unit_id", + "unit_id_di", + ] + unit_level_results = None + + for session in sessions: + if str(DATASET_BUCKET).startswith("s3"): + raw_data_folder = scratch_folder / "raw" + raw_data_folder.mkdir(exist_ok=True) + print(f"Analyzing session {session}") + + # download dataset + dst_folder.mkdir(exist_ok=True) + + src_folder = f"{DATASET_BUCKET}{session}" + + cmd = f"aws s3 sync {src_folder} {dst_folder}" + # aws command to download + os.system(cmd) else: - print("\t\tCompute NO DI metrics") - qm_no_di = sqm.compute_quality_metrics(we_no_di) - results_dict[filter_option]["qm_no_di"] = qm_no_di + raw_data_folder = DATASET_BUCKET + dst_folder = raw_data_folder / session - if we_di.is_extension("quality_metrics") and not OVERWRITE: - print("\t\tLoad DI metrics") - qm_di = we_di.load_extension("quality_metrics").get_data() + if "np1" in dst_folder.name: + probe = "NP1" else: - print("\t\tCompute DI metrics") - qm_di = sqm.compute_quality_metrics(we_di) - results_dict[filter_option]["qm_di"] = qm_di - - ## add entries to unit-level results - -results_folder.mkdir(exist_ok=True) -session_level_results.to_csv(results_folder / "session-results.csv") -unit_level_results.to_csv(results_folder / "unit-results.csv") + probe = "NP2" + + recording_folder = dst_folder + recording = si.load_extractor(recording_folder) + if DEBUG: + recording = recording.frame_slice( + start_frame=0, + end_frame=int(DEBUG_DURATION * recording.sampling_frequency), + ) + + results_dict = {} + for filter_option in FILTER_OPTIONS: + print(f"\tFilter option: {filter_option}") + results_dict[filter_option] = {} + # train DI models + print(f"\t\tTraning DI") + training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) + testing_time = np.round(TESTING_END_S - TESTING_START_S, 3) + model_name = f"{filter_option}_t{training_time}s_v{testing_time}s" + + # apply filter and zscore + if filter_option == "hp": + recording_processed = spre.highpass_filter(recording) + elif filter_option == "bp": + recording_processed = spre.bandpass_filter(recording) + else: + recording_processed = recording + recording_zscore = spre.zscore(recording_processed) + + # train model + model_folder = results_folder / "models" / session / filter_option + model_folder.parent.mkdir(parents=True, exist_ok=True) + # Use SI function + t_start_training = time.perf_counter() + model_path = spre.train_deepinterpolation( + recording_zscore, + model_folder=model_folder, + model_name=model_name, + train_start_s=TRAINING_START_S, + train_end_s=TRAINING_END_S, + test_start_s=TESTING_START_S, + test_end_s=TESTING_END_S, + **di_kwargs, + ) + t_stop_training = time.perf_counter() + elapsed_time_training = np.round(t_stop_training - t_start_training, 2) + print(f"\t\tElapsed time TRAINING: {elapsed_time_training}s") + # full inference + output_folder = ( + results_folder / "deepinterpolated" / session / filter_option + ) + if OVERWRITE and output_folder.is_dir(): + shutil.rmtree(output_folder) + + if not output_folder.is_dir(): + t_start_inference = time.perf_counter() + output_folder.parent.mkdir(exist_ok=True, parents=True) + recording_di = spre.deepinterpolate( + recording_zscore, + model_path=model_path, + pre_frame=pre_frame, + post_frame=post_frame, + pre_post_omission=pre_post_omission, + use_gpu=USE_GPU, + ) + recording_di = recording_di.save( + folder=output_folder, + n_jobs=inference_n_jobs, + chunk_duration=inference_chunk_duration, + ) + t_stop_inference = time.perf_counter() + elapsed_time_inference = np.round( + t_stop_inference - t_start_inference, 2 + ) + print(f"\t\tElapsed time INFERENCE: {elapsed_time_inference}s") + else: + print("\t\tLoading existing folder") + recording_di = si.load_extractor(output_folder) + # apply inverse z-scoring + inverse_gains = 1 / recording_zscore.gain + inverse_offset = -recording_zscore.offset * inverse_gains + recording_di_inverse_zscore = spre.scale( + recording_di, gain=inverse_gains, offset=inverse_offset, dtype="float" + ) + + results_dict[filter_option]["recording_no_di"] = recording_processed + results_dict[filter_option]["recording_di"] = recording_di_inverse_zscore + + # run spike sorting + sorting_output_folder = ( + results_folder / "sortings" / session / filter_option + ) + sorting_output_folder.mkdir(parents=True, exist_ok=True) + + recording_no_di = results_dict[filter_option]["recording_no_di"] + if ( + sorting_output_folder / f"no_di_{model_name}" + ).is_dir() and not OVERWRITE: + print("\t\tLoading NO DI sorting") + sorting_no_di = si.load_extractor(sorting_output_folder / "sorting") + else: + print(f"\t\tSpike sorting NO DI with {sorter_name}") + sorting_no_di = ss.run_sorter( + sorter_name, + recording=recording_no_di, + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting_no_di = sorting_no_di.save( + folder=sorting_output_folder / "sorting" + ) + results_dict[filter_option]["sorting_no_di"] = sorting_no_di + + recording_di = results_dict[filter_option]["recording_di"] + if (sorting_output_folder / f"di_{model_name}").is_dir() and not OVERWRITE: + print("\t\tLoading DI sorting") + sorting_di = si.load_extractor(sorting_output_folder / "sorting_di") + else: + print(f"\t\tSpike sorting DI with {sorter_name}") + sorting_di = ss.run_sorter( + sorter_name, + recording=recording_di, + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting_di = sorting_di.save( + folder=sorting_output_folder / "sorting_di" + ) + results_dict[filter_option]["sorting_di"] = sorting_di + + # compare outputs + print("\t\tComparing sortings") + comp = sc.compare_two_sorters( + sorting1=sorting_no_di, + sorting2=sorting_di, + sorting1_name="no_di", + sorting2_name="di", + match_score=match_score, + ) + matched_units = comp.get_matching()[0] + matched_unit_ids_no_di = matched_units.index.values.astype(int) + matched_unit_ids_di = matched_units.values.astype(int) + matched_units_valid = matched_unit_ids_di != -1 + matched_unit_ids_no_di = matched_unit_ids_no_di[matched_units_valid] + matched_unit_ids_di = matched_unit_ids_di[matched_units_valid] + sorting_no_di_matched = sorting_no_di.select_units( + unit_ids=matched_unit_ids_no_di + ) + sorting_di_matched = sorting_di.select_units(unit_ids=matched_unit_ids_di) + + ## add entries to session-level results + new_row = { + "session": session, + "filter_option": filter_option, + "probe": probe, + "num_units": len(sorting_no_di.unit_ids), + "num_units_di": len(sorting_di.unit_ids), + "num_match": len(sorting_no_di_matched.unit_ids), + "sorting_path": str( + (sorting_output_folder / "sorting").relative_to(results_folder) + ), + "sorting_path_di": str( + (sorting_output_folder / "sorting_di_{model_name}").relative_to( + results_folder + ) + ), + } + session_level_results = pd.concat( + [session_level_results, pd.DataFrame([new_row])], ignore_index=True + ) + + print( + f"\n\t\tNum units: {new_row['num_units']} - Num units DI: {new_row['num_units_di']} - Num match: {new_row['num_match']}" + ) + + # waveforms + waveforms_folder = results_folder / "waveforms" / session / filter_option + waveforms_folder.mkdir(exist_ok=True, parents=True) + + if (waveforms_folder / f"no_di_{model_name}").is_dir() and not OVERWRITE: + print("\t\tLoad NO DI waveforms") + we_no_di = si.load_waveforms(waveforms_folder / f"no_di_{model_name}") + else: + print("\t\tCompute NO DI waveforms") + we_no_di = si.extract_waveforms( + recording_no_di, + sorting_no_di_matched, + folder=waveforms_folder / f"no_di_{model_name}", + n_jobs=n_jobs, + overwrite=True, + ) + results_dict[filter_option]["we_no_di"] = we_no_di + + if (waveforms_folder / f"di_{model_name}").is_dir() and not OVERWRITE: + print("\t\tLoad DI waveforms") + we_di = si.load_waveforms(waveforms_folder / f"di_{model_name}") + else: + print("\t\tCompute DI waveforms") + we_di = si.extract_waveforms( + recording_di, + sorting_di_matched, + folder=waveforms_folder / f"di_{model_name}", + n_jobs=n_jobs, + overwrite=True, + ) + results_dict[filter_option]["we_di"] = we_di + + # compute metrics + if we_no_di.is_extension("quality_metrics") and not OVERWRITE: + print("\t\tLoad NO DI metrics") + qm_no_di = we_no_di.load_extension("quality_metrics").get_data() + else: + print("\t\tCompute NO DI metrics") + qm_no_di = sqm.compute_quality_metrics(we_no_di) + results_dict[filter_option]["qm_no_di"] = qm_no_di + + if we_di.is_extension("quality_metrics") and not OVERWRITE: + print("\t\tLoad DI metrics") + qm_di = we_di.load_extension("quality_metrics").get_data() + else: + print("\t\tCompute DI metrics") + qm_di = sqm.compute_quality_metrics(we_di) + results_dict[filter_option]["qm_di"] = qm_di + + ## add entries to unit-level results + if unit_level_results is None: + for metric in qm_no_di.columns: + unit_level_results_columns.append(metric) + unit_level_results_columns.append(f"{metric}_di") + unit_level_results = pd.DataFrame(columns=unit_level_results_columns) + + new_rows = { + "session": [session] * len(qm_no_di), + "probe": [probe] * len(qm_no_di), + "filter_option": [filter_option] * len(qm_no_di), + "unit_id": we_no_di.unit_ids, + "unit_id_di": we_di.unit_ids, + } + for metric in qm_no_di.columns: + new_rows[metric] = qm_no_di[metric].values + new_rows[f"{metric}_di"] = qm_di[metric].values + # append new entries + unit_level_results = pd.concat( + [unit_level_results, pd.DataFrame(new_rows)], ignore_index=True + ) + + results_folder.mkdir(exist_ok=True) + session_level_results.to_csv(results_folder / "session-results.csv") + unit_level_results.to_csv(results_folder / "unit-results.csv") diff --git a/scripts/utils.py b/scripts/utils.py deleted file mode 100644 index 4e97a04..0000000 --- a/scripts/utils.py +++ /dev/null @@ -1,181 +0,0 @@ - -import json -import sys -import shutil - -import spikeinterface as si -import spikeinterface.preprocessing as spre - -# DeepInterpolation -from deepinterpolation.trainor_collection import core_trainer -from deepinterpolation.network_collection import unet_single_ephys_1024 -from deepinterpolation.generic import ClassLoader - -# Import local classes for DI+SI -sys.path.append("../src") - -# the generator is in the "spikeinterface_generator.py" -from spikeinterface_generator import SpikeInterfaceGenerator -from deepinterpolation_recording import DeepInterpolatedRecording - - - - -def train_di_model(recording, session, filter_option, train_start_s, train_end_s, - test_start_s, test_end_s, data_folder, full_inference, model_name, di_kwargs, - overwrite=False, use_gpu=True): - """_summary_ - - Parameters - ---------- - recording : _type_ - _description_ - session : _type_ - _description_ - filter_option : _type_ - _description_ - train_start_s : _type_ - _description_ - train_end_s : _type_ - _description_ - test_start_s : _type_ - _description_ - test_end_s : _type_ - _description_ - data_folder : _type_ - _description_ - full_inference : _type_ - _description_ - di_kwargs : _type_ - _description_ - """ - model_folder = data_folder / "models" / session - model_folder.mkdir(exist_ok=True, parents=True) - trained_model_folder = model_folder / model_name - - pre_frame = di_kwargs["pre_frame"] - post_frame = di_kwargs["post_frame"] - pre_post_omission = di_kwargs["pre_post_omission"] - desired_shape = di_kwargs["desired_shape"] - inference_n_jobs = di_kwargs["inference_n_jobs"] - inference_chunk_duration = di_kwargs["inference_chunk_duration"] - - # pre-process - assert filter_option in ("no", "hp", "bp"), "Wrong filter option!" - if filter_option == "hp": - rec_f = spre.highpass_filter(recording) - elif filter_option == "bp": - rec_f = spre.bandpass_filter(recording) - else: - rec_f = recording - - rec_processed = rec_f - rec_norm = spre.zscore(rec_processed) - - ### Define params - start_frame_training = int(train_start_s * rec_norm.sampling_frequency) - end_frame_training = int(train_end_s * rec_norm.sampling_frequency) - start_frame_test = int(test_start_s * rec_norm.sampling_frequency) - end_frame_test = int(test_end_s * rec_norm.sampling_frequency) - - # Those are parameters used for the network topology - network_params = dict() - network_params["type"] = "network" - # Name of network topology in the collection - network_params["name"] = "unet_single_ephys_1024" - training_params = dict() - training_params["output_dir"] = str(trained_model_folder) - # We pass on the uid - training_params["run_uid"] = "first_test" - - # We convert to old schema - training_params["nb_gpus"] = 1 - training_params["type"] = "trainer" - training_params["steps_per_epoch"] = 10 - training_params["period_save"] = 100 - training_params["apply_learning_decay"] = 0 - training_params["nb_times_through_data"] = 1 - training_params["learning_rate"] = 0.0001 - training_params["pre_post_frame"] = 1 - training_params["loss"] = "mean_absolute_error" - training_params["nb_workers"] = 2 - training_params["caching_validation"] = False - training_params["model_string"] = f"{network_params['name']}_{training_params['loss']}" - - - if not trained_model_folder.is_dir() or overwrite: - trained_model_folder.mkdir(exist_ok=True) - - # Training (from core_trainor class) - training_data_generator = SpikeInterfaceGenerator(rec_norm, zscore=False, - pre_frame=pre_frame, post_frame=post_frame, - pre_post_omission=pre_post_omission, - start_frame=start_frame_training, - end_frame=end_frame_training, - desired_shape=desired_shape) - test_data_generator = SpikeInterfaceGenerator(rec_norm, zscore=False, - pre_frame=pre_frame, post_frame=post_frame, - pre_post_omission=pre_post_omission, - start_frame=start_frame_test, - end_frame=end_frame_test, - steps_per_epoch=-1, - desired_shape=desired_shape) - - - network_json_path = trained_model_folder / "network_params.json" - with open(network_json_path, "w") as f: - json.dump(network_params, f) - - network_obj = ClassLoader(network_json_path) - data_network = network_obj.find_and_build()(network_json_path) - - training_json_path = trained_model_folder / "training_params.json" - with open(training_json_path, "w") as f: - json.dump(training_params, f) - - - training_class = core_trainer( - training_data_generator, test_data_generator, data_network, - training_json_path - ) - - print("created objects for training") - training_class.run() - - print("training job finished - finalizing output model") - training_class.finalize() - else: - print("Loading pre-trained model") - - ### Test inference - - # Re-load model from output folder - model_path = trained_model_folder / f"{training_params['run_uid']}_{training_params['model_string']}_model.h5" - - rec_di = DeepInterpolatedRecording(rec_norm, model_path=model_path, pre_frames=pre_frame, - post_frames=post_frame, pre_post_omission=pre_post_omission, - disable_tf_logger=True, - use_gpu=use_gpu) - - if full_inference: - deepinterpolated_folder = data_folder / "deepinterpolated" / session - deepinterpolated_folder.mkdir(exist_ok=True, parents=True) - output_folder = deepinterpolated_folder / model_name - if output_folder.is_dir() and overwrite: - shutil.rmtree(output_folder) - rec_di = DeepInterpolatedRecording(rec_norm, model_path=model_path, pre_frames=pre_frame, - post_frames=post_frame, pre_post_omission=pre_post_omission, - disable_tf_logger=True, - use_gpu=use_gpu) - rec_di_saved = rec_di.save(folder=output_folder, n_jobs=inference_n_jobs, - chunk_duration=inference_chunk_duration) - else: - rec_di_saved = si.load_extractor(output_folder) - else: - rec_di_saved = rec_di - - # apply inverse z-scoring - inverse_gains = 1 / rec_norm.gain - inverse_offset = - rec_norm.offset * inverse_gains - rec_di_inverse_zscore = spre.scale(rec_di_saved, gain=inverse_gains, offset=inverse_offset, dtype='float') - return rec_processed, rec_di_inverse_zscore diff --git a/src/deepinterpolation_recording.py b/src/deepinterpolation_recording.py index 3d85158..8657bea 100644 --- a/src/deepinterpolation_recording.py +++ b/src/deepinterpolation_recording.py @@ -21,19 +21,20 @@ def has_tf(use_gpu=True, disable_tf_logger=True, memory_gpu=None): return True except ImportError: return False - + + def import_tf(use_gpu=True, disable_tf_logger=True, memory_gpu=None): import tensorflow as tf if not use_gpu: - os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" if disable_tf_logger: - os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' - tf.get_logger().setLevel('ERROR') + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + tf.get_logger().setLevel("ERROR") tf.compat.v1.disable_eager_execution() - gpus = tf.config.list_physical_devices('GPU') + gpus = tf.config.list_physical_devices("GPU") if gpus: if memory_gpu is None: try: @@ -47,24 +48,32 @@ def import_tf(use_gpu=True, disable_tf_logger=True, memory_gpu=None): else: for gpu in gpus: tf.config.set_logical_device_configuration( - gpus[0], - [tf.config.LogicalDeviceConfiguration(memory_limit=memory_gpu)]) + gpus[0], [tf.config.LogicalDeviceConfiguration(memory_limit=memory_gpu)] + ) return tf - class DeepInterpolatedRecording(BasePreprocessor): - name = 'deepinterpolate' - - def __init__(self, recording, model_path: str, - pre_frames: int = 30, post_frames: int = 30, pre_post_omission: int = 1, - batch_size=128, use_gpu: bool = True, disable_tf_logger: bool = True, - memory_gpu=None): + name = "deepinterpolate" + + def __init__( + self, + recording, + model_path: str, + pre_frames: int = 30, + post_frames: int = 30, + pre_post_omission: int = 1, + batch_size=128, + use_gpu: bool = True, + disable_tf_logger: bool = True, + memory_gpu=None, + ): assert has_tf( - use_gpu, disable_tf_logger, memory_gpu), "To use DeepInterpolation, you first need to install `tensorflow`." - + use_gpu, disable_tf_logger, memory_gpu + ), "To use DeepInterpolation, you first need to install `tensorflow`." + self.tf = import_tf(use_gpu, disable_tf_logger, memory_gpu=memory_gpu) - + # try move model load here with spawn BasePreprocessor.__init__(self, recording) @@ -75,46 +84,71 @@ def __init__(self, recording, model_path: str, # check shape (this will need to be done at inference) network_input_shape = model.get_config()["layers"][0]["config"]["batch_input_shape"] desired_shape = network_input_shape[1:3] - assert desired_shape[0]*desired_shape[1] == recording.get_num_channels(), "text" + assert desired_shape[0] * desired_shape[1] == recording.get_num_channels(), "text" assert network_input_shape[-1] == pre_frames + post_frames - self.model = model # add segment for segment in recording._recording_segments: - recording_segment = DeepInterpolatedRecordingSegment(segment, self.model, - pre_frames, post_frames, pre_post_omission, - desired_shape, batch_size, use_gpu, - disable_tf_logger, memory_gpu) + recording_segment = DeepInterpolatedRecordingSegment( + segment, + self.model, + pre_frames, + post_frames, + pre_post_omission, + desired_shape, + batch_size, + use_gpu, + disable_tf_logger, + memory_gpu, + ) self.add_recording_segment(recording_segment) self._preferred_mp_context = "spawn" - self._kwargs = dict(recording=recording.to_dict(), model_path=str(model_path), - pre_frames=pre_frames, post_frames=post_frames, pre_post_omission=pre_post_omission, - batch_size=batch_size, use_gpu=use_gpu, disable_tf_logger=disable_tf_logger, - memory_gpu=memory_gpu) - self.extra_requirements.extend(['tensorflow']) + self._kwargs = dict( + recording=recording.to_dict(), + model_path=str(model_path), + pre_frames=pre_frames, + post_frames=post_frames, + pre_post_omission=pre_post_omission, + batch_size=batch_size, + use_gpu=use_gpu, + disable_tf_logger=disable_tf_logger, + memory_gpu=memory_gpu, + ) + self.extra_requirements.extend(["tensorflow"]) class DeepInterpolatedRecordingSegment(BasePreprocessorSegment): - - def __init__(self, recording_segment, model, - pre_frames, post_frames, pre_post_omission, - desired_shape, batch_size, use_gpu, disable_tf_logger, memory_gpu): + def __init__( + self, + recording_segment, + model, + pre_frames, + post_frames, + pre_post_omission, + desired_shape, + batch_size, + use_gpu, + disable_tf_logger, + memory_gpu, + ): from spikeinterface_generator import SpikeInterfaceRecordingSegmentGenerator BasePreprocessorSegment.__init__(self, recording_segment) - + self.model = model self.pre_frames = pre_frames self.post_frames = post_frames self.pre_post_omission = pre_post_omission self.batch_size = batch_size self.use_gpu = use_gpu - self.desired_shape=desired_shape + self.desired_shape = desired_shape # creating class dynamically to use the imported TF with GPU enabled/disabled based on the use_gpu flag - self.SpikeInterfaceGenerator = SpikeInterfaceRecordingSegmentGenerator #define_input_generator_class( use_gpu, disable_tf_logger) + self.SpikeInterfaceGenerator = ( + SpikeInterfaceRecordingSegmentGenerator # define_input_generator_class( use_gpu, disable_tf_logger) + ) def get_traces(self, start_frame, end_frame, channel_indices): n_frames = self.parent_recording_segment.get_num_samples() @@ -127,42 +161,44 @@ def get_traces(self, start_frame, end_frame, channel_indices): # for frames that lack full training data (i.e. pre and post frames including omissinos), # just return uninterpolated - if start_frame < self.pre_frames+self.pre_post_omission: - true_start_frame = self.pre_frames+self.pre_post_omission - array_to_append_front = self.parent_recording_segment.get_traces(start_frame=0, - end_frame=true_start_frame, - channel_indices=channel_indices) + if start_frame < self.pre_frames + self.pre_post_omission: + true_start_frame = self.pre_frames + self.pre_post_omission + array_to_append_front = self.parent_recording_segment.get_traces( + start_frame=0, end_frame=true_start_frame, channel_indices=channel_indices + ) else: true_start_frame = start_frame - if end_frame > n_frames-self.post_frames-self.pre_post_omission: - true_end_frame = n_frames-self.post_frames-self.pre_post_omission - array_to_append_back = self.parent_recording_segment.get_traces(start_frame=true_end_frame, - end_frame=n_frames, - channel_indices=channel_indices) + if end_frame > n_frames - self.post_frames - self.pre_post_omission: + true_end_frame = n_frames - self.post_frames - self.pre_post_omission + array_to_append_back = self.parent_recording_segment.get_traces( + start_frame=true_end_frame, end_frame=n_frames, channel_indices=channel_indices + ) else: true_end_frame = end_frame # instantiate an input generator that can be passed directly to model.predict - input_generator = self.SpikeInterfaceGenerator(recording_segment=self.parent_recording_segment, - start_frame=true_start_frame, - end_frame=true_end_frame, - pre_frame=self.pre_frames, - post_frame=self.post_frames, - pre_post_omission=self.pre_post_omission, - batch_size=self.batch_size) + input_generator = self.SpikeInterfaceGenerator( + recording_segment=self.parent_recording_segment, + start_frame=true_start_frame, + end_frame=true_end_frame, + pre_frame=self.pre_frames, + post_frame=self.post_frames, + pre_post_omission=self.pre_post_omission, + batch_size=self.batch_size, + ) input_generator.randomize = False input_generator._calculate_list_samples(input_generator.total_samples) di_output = self.model.predict(input_generator, verbose=2) out_traces = input_generator.reshape_output(di_output) - if true_start_frame != start_frame: # related to the restriction to be applied from the start and end frames around 0 and end - out_traces = np.concatenate( - (array_to_append_front, out_traces), axis=0) + if ( + true_start_frame != start_frame + ): # related to the restriction to be applied from the start and end frames around 0 and end + out_traces = np.concatenate((array_to_append_front, out_traces), axis=0) if true_end_frame != end_frame: - out_traces = np.concatenate( - (out_traces, array_to_append_back), axis=0) + out_traces = np.concatenate((out_traces, array_to_append_back), axis=0) return out_traces[:, channel_indices] diff --git a/src/spikeinterface_generator.py b/src/spikeinterface_generator.py index 48be09b..126e2e8 100644 --- a/src/spikeinterface_generator.py +++ b/src/spikeinterface_generator.py @@ -7,16 +7,28 @@ from deepinterpolation.generator_collection import SequentialGenerator + # TODO: rename to SpikeInterfaceRecordingGenerator class SpikeInterfaceGenerator(SequentialGenerator): """This generator is used when dealing with a SpikeInterface recording. The desired shape controls the reshaping of the input data before convolutions.""" - def __init__(self, recording, pre_frame=30, post_frame=30, pre_post_omission=1, desired_shape=(192, 2), - batch_size=100, steps_per_epoch=10, zscore=True, start_frame=None, end_frame=None): + def __init__( + self, + recording, + pre_frame=30, + post_frame=30, + pre_post_omission=1, + desired_shape=(192, 2), + batch_size=100, + steps_per_epoch=10, + zscore=True, + start_frame=None, + end_frame=None, + ): "Initialization" assert recording.get_num_segments() == 1, "Only supported for mon-segment recordings" - + if zscore: recording_z = spre.zscore(recording) else: @@ -25,15 +37,16 @@ def __init__(self, recording, pre_frame=30, post_frame=30, pre_post_omission=1, self.recording = recording_z self.total_samples = recording.get_num_samples() assert len(desired_shape) == 2, "desired_shape should be 2D" - assert desired_shape[0] * desired_shape [1] == recording.get_num_channels(), \ - f"The product of desired_shape dimensions should be the number of channels: {recording.get_num_channels()}" + assert ( + desired_shape[0] * desired_shape[1] == recording.get_num_channels() + ), f"The product of desired_shape dimensions should be the number of channels: {recording.get_num_channels()}" self.desired_shape = desired_shape - + start_frame = start_frame if start_frame is not None else 0 end_frame = end_frame if end_frame is not None else self.total_samples - + assert end_frame > start_frame, "end_frame must be greater than start_frame" - + sequential_generator_params = dict() sequential_generator_params["steps_per_epoch"] = steps_per_epoch sequential_generator_params["pre_frame"] = pre_frame @@ -53,44 +66,39 @@ def __init__(self, recording, pre_frame=30, post_frame=30, pre_post_omission=1, self._calculate_list_samples(self.total_samples) self.last_batch_size = np.mod(self.end_frame - self.start_frame, self.batch_size) - def __len__(self): "Denotes the total number of batches" - return int(len(self.list_samples) / self.batch_size) + 1 + return int(np.floor(len(self.list_samples) / self.batch_size) + 1) def generate_batch_indexes(self, index): # This is to ensure we are going through # the entire data when steps_per_epoch 0: index = index + self.steps_per_epoch * self.epoch_index # Generate indexes of the batch - indexes = np.arange(index * self.batch_size, - (index + 1) * self.batch_size) - + indexes = np.arange(index * self.batch_size, (index + 1) * self.batch_size) + if max(indexes) > len(self.list_samples): + print(index, min(indexes), max(indexes), len(self.list_samples)) shuffle_indexes = self.list_samples[indexes] elif index == len(self) - 1: - shuffle_indexes = self.list_samples[-self.last_batch_size:] + shuffle_indexes = self.list_samples[-self.last_batch_size :] else: raise Exception(f"Exceeding number of available batches: {len(self)}") return shuffle_indexes - def __getitem__(self, index): # This is to ensure we are going through # the entire data when steps_per_epoch start_frame, "end_frame must be greater than start_frame" - + sequential_generator_params = dict() sequential_generator_params["steps_per_epoch"] = steps_per_epoch sequential_generator_params["pre_frame"] = pre_frame @@ -177,7 +192,6 @@ def __init__(self, recording_segment, start_frame, end_frame, pre_frame=30, post self._calculate_list_samples(self.total_samples) self.last_batch_size = np.mod(self.end_frame - self.start_frame, self.batch_size) - def __len__(self): "Denotes the total number of batches" return int(len(self.list_samples) / self.batch_size) + 1 @@ -185,36 +199,31 @@ def __len__(self): def generate_batch_indexes(self, index): # This is to ensure we are going through # the entire data when steps_per_epoch 0: index = index + self.steps_per_epoch * self.epoch_index # Generate indexes of the batch - indexes = np.arange(index * self.batch_size, - (index + 1) * self.batch_size) + indexes = np.arange(index * self.batch_size, (index + 1) * self.batch_size) shuffle_indexes = self.list_samples[indexes] elif index == len(self) - 1: - shuffle_indexes = self.list_samples[-self.last_batch_size:] + shuffle_indexes = self.list_samples[-self.last_batch_size :] else: raise Exception(f"Exceeding number of available batches: {len(self)}") return shuffle_indexes - def __getitem__(self, index): # This is to ensure we are going through # the entire data when steps_per_epoch Date: Tue, 11 Jul 2023 13:17:44 +0000 Subject: [PATCH 02/84] Add minimal requirements --- generate_image_inference.py | 28 ---------------------------- requirements.txt | 6 ++++++ 2 files changed, 6 insertions(+), 28 deletions(-) delete mode 100644 generate_image_inference.py create mode 100644 requirements.txt diff --git a/generate_image_inference.py b/generate_image_inference.py deleted file mode 100644 index 2ba171a..0000000 --- a/generate_image_inference.py +++ /dev/null @@ -1,28 +0,0 @@ -import os -import matplotlib -import matplotlib.pyplot as plt -from scipy import ndimage -import scipy -import h5py -import numpy as np -os.listdir('.') -fn= 'ephys_tiny_continuous_deep_interpolation.h5' -f= h5py.File(fn) -tsurf = f['data'] -tsurf -di_traces = tsurf[:,:,0] -di_traces.shape -plot= plt.imshow(di_traces.T, - origin='lower', - vmin=-50, - vmax=50, - cmap='RdGy', - aspect='auto') -plt.xlabel('Sample Index') -plt.ylabel('Acquisition Channels') -#plt.ylim(0,386) -#plt.xlim(0,400) -#matplotlib.pyplot.annotate('alpha', xy= (1,1), xytext=(3,3), xycoords='data', textcoords=None, arrowprops=None, annotation_clip=None) -matplotlib.pyplot.colorbar(label= 'μV', shrink=0.25) -plt.title("1000samples_28Epoch") -matplotlib.pyplot.savefig('1000samples_28Epoch') diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4a63580 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +numpy==1.24 +protobuf==3.20.* +tensorflow==2.7.0 +pip install git+https://github.com/AllenInstitute/deepinterpolation.git +# we need a specific branch for SpikeInterface +pip install git+https://github.com/alejoe91/spikeinterface.git@deepinterp#egg=spikeinterface \ No newline at end of file From 38e4128c77bf97e1f559fbf80fe9b2d1ea5ad407 Mon Sep 17 00:00:00 2001 From: alejoe91 Date: Fri, 14 Jul 2023 18:05:03 +0000 Subject: [PATCH 03/84] Add separate scripts --- scripts/run_full_analysis.py | 27 ++- scripts/run_inference.py | 409 +++++++++++++++++++++++++++++++++++ scripts/run_training.py | 188 ++++++++++++++++ 3 files changed, 617 insertions(+), 7 deletions(-) create mode 100644 scripts/run_inference.py create mode 100644 scripts/run_training.py diff --git a/scripts/run_full_analysis.py b/scripts/run_full_analysis.py index 905775c..37d7ab2 100644 --- a/scripts/run_full_analysis.py +++ b/scripts/run_full_analysis.py @@ -23,6 +23,9 @@ import spikeinterface.comparison as sc import spikeinterface.qualitymetrics as sqm +# Tensorflow +import tensorflow as tf + base_path = Path("../../..") @@ -69,8 +72,10 @@ post_frame = 30 pre_post_omission = 1 desired_shape = (192, 2) -inference_n_jobs = 8 -inference_chunk_duration = "50ms" +# play around with these +inference_n_jobs = 4 +inference_chunk_duration = "100ms" +inference_memory_gpu = 2000 #MB di_kwargs = dict( pre_frame=pre_frame, @@ -86,17 +91,21 @@ if __name__ == "__main__": - if len(sys.argv) == 2: + if len(sys.argv) == 3: if sys.argv[1] == "true": DEBUG = True else: DEBUG = False + if sys.argv[2] != "all": + sessions = [sys.argv[2]] + if DEBUG: TRAINING_START_S = 0 TRAINING_END_S = 0.2 TESTING_START_S = 10 TESTING_END_S = 10.05 - sessions = sessions[:NUM_DEBUG_SESSIONS] + if len(sessions) > NUM_DEBUG_SESSIONS: + sessions = sessions[:NUM_DEBUG_SESSIONS] OVERWRITE = True else: TRAINING_START_S = 0 @@ -107,6 +116,8 @@ si.set_global_job_kwargs(**job_kwargs) + print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") + #### START #### session_level_results = pd.DataFrame( columns=[ @@ -125,17 +136,18 @@ "session", "probe", "filter_option", - "unit_index", "unit_id", "unit_id_di", ] unit_level_results = None + sessions = sessions[:2] + for session in sessions: + print(f"\nAnalyzing session {session}\n") if str(DATASET_BUCKET).startswith("s3"): raw_data_folder = scratch_folder / "raw" raw_data_folder.mkdir(exist_ok=True) - print(f"Analyzing session {session}") # download dataset dst_folder.mkdir(exist_ok=True) @@ -215,7 +227,8 @@ pre_frame=pre_frame, post_frame=post_frame, pre_post_omission=pre_post_omission, - use_gpu=USE_GPU, + memory_gpu=inference_memory_gpu, + use_gpu=USE_GPU ) recording_di = recording_di.save( folder=output_folder, diff --git a/scripts/run_inference.py b/scripts/run_inference.py new file mode 100644 index 0000000..1bc182e --- /dev/null +++ b/scripts/run_inference.py @@ -0,0 +1,409 @@ +import warnings + +warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +#### IMPORTS ####### +import os +import sys +import numpy as np +from pathlib import Path +from numba import cuda +import pandas as pd +import time + + +# SpikeInterface +import spikeinterface as si +import spikeinterface.extractors as se +import spikeinterface.preprocessing as spre +import spikeinterface.sorters as ss +import spikeinterface.postprocessing as spost +import spikeinterface.comparison as sc +import spikeinterface.qualitymetrics as sqm + +# Tensorflow +import tensorflow as tf + + +base_path = Path("../../..") + +##### DEFINE DATASETS AND FOLDERS ####### + +sessions = [ + "595262_2022-02-21_15-18-07_ProbeA", + "602454_2022-03-22_16-30-03_ProbeB", + "612962_2022-04-13_19-18-04_ProbeB", + "612962_2022-04-14_17-17-10_ProbeC", + "618197_2022-06-21_14-08-06_ProbeC", + "618318_2022-04-13_14-59-07_ProbeB", + "618384_2022-04-14_15-11-00_ProbeB", + "621362_2022-07-14_11-19-36_ProbeA", +] +n_jobs = 16 + +job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") + +data_folder = base_path / "data" +scratch_folder = base_path / "scratch" +results_folder = base_path / "results" + + +# DATASET_BUCKET = "s3://aind-benchmark-data/ephys-compression/aind-np2/" +DATASET_BUCKET = data_folder / "ephys-compression-benchmark" / "aind-np2" + +DEBUG = False +NUM_DEBUG_SESSIONS = 2 +DEBUG_DURATION = 20 + +##### DEFINE PARAMS ##### +OVERWRITE = False +USE_GPU = True +FULL_INFERENCE = True + +# Define training and testing constants (@Jad you can gradually increase this) + + +FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" + +# DI params +pre_frame = 30 +post_frame = 30 +pre_post_omission = 1 +desired_shape = (192, 2) +# play around with these +inference_n_jobs = 4 +inference_chunk_duration = "100ms" +inference_memory_gpu = 2000 #MB + +di_kwargs = dict( + pre_frame=pre_frame, + post_frame=post_frame, + pre_post_omission=pre_post_omission, + desired_shape=desired_shape, +) + +sorter_name = "pykilosort" +singularity_image = False + +match_score = 0.7 + + +if __name__ == "__main__": + if len(sys.argv) == 3: + if sys.argv[1] == "true": + DEBUG = True + else: + DEBUG = False + if sys.argv[2] != "all": + sessions = [sys.argv[2]] + + if DEBUG: + TRAINING_START_S = 0 + TRAINING_END_S = 0.2 + TESTING_START_S = 10 + TESTING_END_S = 10.05 + if len(sessions) > NUM_DEBUG_SESSIONS: + sessions = sessions[:NUM_DEBUG_SESSIONS] + OVERWRITE = True + else: + TRAINING_START_S = 0 + TRAINING_END_S = 20 + TESTING_START_S = 70 + TESTING_END_S = 70.5 + OVERWRITE = False + + si.set_global_job_kwargs(**job_kwargs) + + print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") + + #### START #### + session_level_results = pd.DataFrame( + columns=[ + "session", + "probe", + "filter_option", + "num_units", + "num_units_di", + "sorting_path", + "sorting_path_di", + "num_match", + ] + ) + + unit_level_results_columns = [ + "session", + "probe", + "filter_option", + "unit_id", + "unit_id_di", + ] + unit_level_results = None + + if (data_folder / "models").is_dir(): + data_model_folder = data_folder + else: + data_subfolders = [p for p in data_folder.iterdir()] + assert len(data_subfolders) == 1 + data_model_folder = data_subfolders[0] + + for session in sessions: + print(f"\nAnalyzing session {session}\n") + if str(DATASET_BUCKET).startswith("s3"): + raw_data_folder = scratch_folder / "raw" + raw_data_folder.mkdir(exist_ok=True) + dst_folder = raw_data_folder / session + + # download dataset + dst_folder.mkdir(exist_ok=True) + + src_folder = f"{DATASET_BUCKET}{session}" + + cmd = f"aws s3 sync --no-sign-request {src_folder} {dst_folder}" + # aws command to download + os.system(cmd) + else: + raw_data_folder = DATASET_BUCKET + dst_folder = raw_data_folder / session + + if "np1" in dst_folder.name: + probe = "NP1" + else: + probe = "NP2" + + recording_folder = dst_folder + recording = si.load_extractor(recording_folder) + if DEBUG: + recording = recording.frame_slice( + start_frame=0, + end_frame=int(DEBUG_DURATION * recording.sampling_frequency), + ) + + results_dict = {} + for filter_option in FILTER_OPTIONS: + print(f"\tFilter option: {filter_option}") + results_dict[filter_option] = {} + # train DI models + print(f"\t\tTraning DI") + training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) + testing_time = np.round(TESTING_END_S - TESTING_START_S, 3) + model_name = f"{filter_option}_t{training_time}s_v{testing_time}s" + + # apply filter and zscore + if filter_option == "hp": + recording_processed = spre.highpass_filter(recording) + elif filter_option == "bp": + recording_processed = spre.bandpass_filter(recording) + else: + recording_processed = recording + recording_zscore = spre.zscore(recording_processed) + + # train model + model_folder = data_model_folder / "models" / session / filter_option + model_path = [p for p in model_folder if p.name.endswith("model.h5") and filter_option in p.name][0] + # full inference + output_folder = ( + results_folder / "deepinterpolated" / session / filter_option + ) + if OVERWRITE and output_folder.is_dir(): + shutil.rmtree(output_folder) + + if not output_folder.is_dir(): + t_start_inference = time.perf_counter() + output_folder.parent.mkdir(exist_ok=True, parents=True) + recording_di = spre.deepinterpolate( + recording_zscore, + model_path=model_path, + pre_frame=pre_frame, + post_frame=post_frame, + pre_post_omission=pre_post_omission, + memory_gpu=inference_memory_gpu, + use_gpu=USE_GPU + ) + recording_di = recording_di.save( + folder=output_folder, + n_jobs=inference_n_jobs, + chunk_duration=inference_chunk_duration, + ) + t_stop_inference = time.perf_counter() + elapsed_time_inference = np.round( + t_stop_inference - t_start_inference, 2 + ) + print(f"\t\tElapsed time INFERENCE: {elapsed_time_inference}s") + else: + print("\t\tLoading existing folder") + recording_di = si.load_extractor(output_folder) + # apply inverse z-scoring + inverse_gains = 1 / recording_zscore.gain + inverse_offset = -recording_zscore.offset * inverse_gains + recording_di_inverse_zscore = spre.scale( + recording_di, gain=inverse_gains, offset=inverse_offset, dtype="float" + ) + + results_dict[filter_option]["recording_no_di"] = recording_processed + results_dict[filter_option]["recording_di"] = recording_di_inverse_zscore + + # run spike sorting + sorting_output_folder = ( + results_folder / "sortings" / session / filter_option + ) + sorting_output_folder.mkdir(parents=True, exist_ok=True) + + recording_no_di = results_dict[filter_option]["recording_no_di"] + if ( + sorting_output_folder / f"no_di_{model_name}" + ).is_dir() and not OVERWRITE: + print("\t\tLoading NO DI sorting") + sorting_no_di = si.load_extractor(sorting_output_folder / "sorting") + else: + print(f"\t\tSpike sorting NO DI with {sorter_name}") + sorting_no_di = ss.run_sorter( + sorter_name, + recording=recording_no_di, + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting_no_di = sorting_no_di.save( + folder=sorting_output_folder / "sorting" + ) + results_dict[filter_option]["sorting_no_di"] = sorting_no_di + + recording_di = results_dict[filter_option]["recording_di"] + if (sorting_output_folder / f"di_{model_name}").is_dir() and not OVERWRITE: + print("\t\tLoading DI sorting") + sorting_di = si.load_extractor(sorting_output_folder / "sorting_di") + else: + print(f"\t\tSpike sorting DI with {sorter_name}") + sorting_di = ss.run_sorter( + sorter_name, + recording=recording_di, + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting_di = sorting_di.save( + folder=sorting_output_folder / "sorting_di" + ) + results_dict[filter_option]["sorting_di"] = sorting_di + + # compare outputs + print("\t\tComparing sortings") + comp = sc.compare_two_sorters( + sorting1=sorting_no_di, + sorting2=sorting_di, + sorting1_name="no_di", + sorting2_name="di", + match_score=match_score, + ) + matched_units = comp.get_matching()[0] + matched_unit_ids_no_di = matched_units.index.values.astype(int) + matched_unit_ids_di = matched_units.values.astype(int) + matched_units_valid = matched_unit_ids_di != -1 + matched_unit_ids_no_di = matched_unit_ids_no_di[matched_units_valid] + matched_unit_ids_di = matched_unit_ids_di[matched_units_valid] + sorting_no_di_matched = sorting_no_di.select_units( + unit_ids=matched_unit_ids_no_di + ) + sorting_di_matched = sorting_di.select_units(unit_ids=matched_unit_ids_di) + + ## add entries to session-level results + new_row = { + "session": session, + "filter_option": filter_option, + "probe": probe, + "num_units": len(sorting_no_di.unit_ids), + "num_units_di": len(sorting_di.unit_ids), + "num_match": len(sorting_no_di_matched.unit_ids), + "sorting_path": str( + (sorting_output_folder / "sorting").relative_to(results_folder) + ), + "sorting_path_di": str( + (sorting_output_folder / "sorting_di_{model_name}").relative_to( + results_folder + ) + ), + } + session_level_results = pd.concat( + [session_level_results, pd.DataFrame([new_row])], ignore_index=True + ) + + print( + f"\n\t\tNum units: {new_row['num_units']} - Num units DI: {new_row['num_units_di']} - Num match: {new_row['num_match']}" + ) + + # waveforms + waveforms_folder = results_folder / "waveforms" / session / filter_option + waveforms_folder.mkdir(exist_ok=True, parents=True) + + if (waveforms_folder / f"no_di_{model_name}").is_dir() and not OVERWRITE: + print("\t\tLoad NO DI waveforms") + we_no_di = si.load_waveforms(waveforms_folder / f"no_di_{model_name}") + else: + print("\t\tCompute NO DI waveforms") + we_no_di = si.extract_waveforms( + recording_no_di, + sorting_no_di_matched, + folder=waveforms_folder / f"no_di_{model_name}", + n_jobs=n_jobs, + overwrite=True, + ) + results_dict[filter_option]["we_no_di"] = we_no_di + + if (waveforms_folder / f"di_{model_name}").is_dir() and not OVERWRITE: + print("\t\tLoad DI waveforms") + we_di = si.load_waveforms(waveforms_folder / f"di_{model_name}") + else: + print("\t\tCompute DI waveforms") + we_di = si.extract_waveforms( + recording_di, + sorting_di_matched, + folder=waveforms_folder / f"di_{model_name}", + n_jobs=n_jobs, + overwrite=True, + ) + results_dict[filter_option]["we_di"] = we_di + + # compute metrics + if we_no_di.is_extension("quality_metrics") and not OVERWRITE: + print("\t\tLoad NO DI metrics") + qm_no_di = we_no_di.load_extension("quality_metrics").get_data() + else: + print("\t\tCompute NO DI metrics") + qm_no_di = sqm.compute_quality_metrics(we_no_di) + results_dict[filter_option]["qm_no_di"] = qm_no_di + + if we_di.is_extension("quality_metrics") and not OVERWRITE: + print("\t\tLoad DI metrics") + qm_di = we_di.load_extension("quality_metrics").get_data() + else: + print("\t\tCompute DI metrics") + qm_di = sqm.compute_quality_metrics(we_di) + results_dict[filter_option]["qm_di"] = qm_di + + ## add entries to unit-level results + if unit_level_results is None: + for metric in qm_no_di.columns: + unit_level_results_columns.append(metric) + unit_level_results_columns.append(f"{metric}_di") + unit_level_results = pd.DataFrame(columns=unit_level_results_columns) + + new_rows = { + "session": [session] * len(qm_no_di), + "probe": [probe] * len(qm_no_di), + "filter_option": [filter_option] * len(qm_no_di), + "unit_id": we_no_di.unit_ids, + "unit_id_di": we_di.unit_ids, + } + for metric in qm_no_di.columns: + new_rows[metric] = qm_no_di[metric].values + new_rows[f"{metric}_di"] = qm_di[metric].values + # append new entries + unit_level_results = pd.concat( + [unit_level_results, pd.DataFrame(new_rows)], ignore_index=True + ) + + results_folder.mkdir(exist_ok=True) + session_level_results.to_csv(results_folder / "session-results.csv") + unit_level_results.to_csv(results_folder / "unit-results.csv") diff --git a/scripts/run_training.py b/scripts/run_training.py new file mode 100644 index 0000000..15f5275 --- /dev/null +++ b/scripts/run_training.py @@ -0,0 +1,188 @@ +import warnings + +warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +#### IMPORTS ####### +import os +import sys +import numpy as np +from pathlib import Path +from numba import cuda +import pandas as pd +import time + + +# SpikeInterface +import spikeinterface as si +import spikeinterface.extractors as se +import spikeinterface.preprocessing as spre +import spikeinterface.sorters as ss +import spikeinterface.postprocessing as spost +import spikeinterface.comparison as sc +import spikeinterface.qualitymetrics as sqm + +# Tensorflow +import tensorflow as tf + + +base_path = Path("../../..") + +##### DEFINE DATASETS AND FOLDERS ####### + +sessions = [ + "595262_2022-02-21_15-18-07_ProbeA", + "602454_2022-03-22_16-30-03_ProbeB", + "612962_2022-04-13_19-18-04_ProbeB", + "612962_2022-04-14_17-17-10_ProbeC", + "618197_2022-06-21_14-08-06_ProbeC", + "618318_2022-04-13_14-59-07_ProbeB", + "618384_2022-04-14_15-11-00_ProbeB", + "621362_2022-07-14_11-19-36_ProbeA", +] +n_jobs = 16 + +job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") + +data_folder = base_path / "data" +scratch_folder = base_path / "scratch" +results_folder = base_path / "results" + + +# DATASET_BUCKET = "s3://aind-benchmark-data/ephys-compression/aind-np2/" +DATASET_BUCKET = data_folder / "ephys-compression-benchmark" / "aind-np2" + +DEBUG = False +NUM_DEBUG_SESSIONS = 2 +DEBUG_DURATION = 20 + +##### DEFINE PARAMS ##### +OVERWRITE = False +USE_GPU = True +FULL_INFERENCE = True + +# Define training and testing constants (@Jad you can gradually increase this) + + +FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" + +# DI params +pre_frame = 30 +post_frame = 30 +pre_post_omission = 1 +desired_shape = (192, 2) +# play around with these +inference_n_jobs = 4 +inference_chunk_duration = "100ms" +inference_memory_gpu = 2000 #MB + +di_kwargs = dict( + pre_frame=pre_frame, + post_frame=post_frame, + pre_post_omission=pre_post_omission, + desired_shape=desired_shape, +) + +sorter_name = "pykilosort" +singularity_image = False + +match_score = 0.7 + + +if __name__ == "__main__": + if len(sys.argv) == 3: + if sys.argv[1] == "true": + DEBUG = True + else: + DEBUG = False + if sys.argv[2] != "all": + sessions = [sys.argv[2]] + + if DEBUG: + TRAINING_START_S = 0 + TRAINING_END_S = 0.2 + TESTING_START_S = 10 + TESTING_END_S = 10.05 + if len(sessions) > NUM_DEBUG_SESSIONS: + sessions = sessions[:NUM_DEBUG_SESSIONS] + OVERWRITE = True + else: + TRAINING_START_S = 0 + TRAINING_END_S = 20 + TESTING_START_S = 70 + TESTING_END_S = 70.5 + OVERWRITE = False + + si.set_global_job_kwargs(**job_kwargs) + + print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") + + for session in sessions: + print(f"\nAnalyzing session {session}\n") + if str(DATASET_BUCKET).startswith("s3"): + raw_data_folder = scratch_folder / "raw" + raw_data_folder.mkdir(exist_ok=True) + + # download dataset + dst_folder.mkdir(exist_ok=True) + + src_folder = f"{DATASET_BUCKET}{session}" + + cmd = f"aws s3 sync {src_folder} {dst_folder}" + # aws command to download + os.system(cmd) + else: + raw_data_folder = DATASET_BUCKET + dst_folder = raw_data_folder / session + + if "np1" in dst_folder.name: + probe = "NP1" + else: + probe = "NP2" + + recording_folder = dst_folder + recording = si.load_extractor(recording_folder) + if DEBUG: + recording = recording.frame_slice( + start_frame=0, + end_frame=int(DEBUG_DURATION * recording.sampling_frequency), + ) + + results_dict = {} + for filter_option in FILTER_OPTIONS: + print(f"\tFilter option: {filter_option}") + results_dict[filter_option] = {} + # train DI models + print(f"\t\tTraning DI") + training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) + testing_time = np.round(TESTING_END_S - TESTING_START_S, 3) + model_name = f"{filter_option}_t{training_time}s_v{testing_time}s" + + # apply filter and zscore + if filter_option == "hp": + recording_processed = spre.highpass_filter(recording) + elif filter_option == "bp": + recording_processed = spre.bandpass_filter(recording) + else: + recording_processed = recording + recording_zscore = spre.zscore(recording_processed) + + # train model + model_folder = results_folder / "models" / session / filter_option + model_folder.parent.mkdir(parents=True, exist_ok=True) + # Use SI function + t_start_training = time.perf_counter() + model_path = spre.train_deepinterpolation( + recording_zscore, + model_folder=model_folder, + model_name=model_name, + train_start_s=TRAINING_START_S, + train_end_s=TRAINING_END_S, + test_start_s=TESTING_START_S, + test_end_s=TESTING_END_S, + **di_kwargs, + ) + t_stop_training = time.perf_counter() + elapsed_time_training = np.round(t_stop_training - t_start_training, 2) + print(f"\t\tElapsed time TRAINING: {elapsed_time_training}s") \ No newline at end of file From 976a73ccb5d16fd8fba25782fadd3852ac9bb633 Mon Sep 17 00:00:00 2001 From: alejoe91 Date: Tue, 18 Jul 2023 09:06:03 +0000 Subject: [PATCH 04/84] Create pipeline folder --- pipeline/run_inference.py | 235 ++++++++++++++++++ .../run_spike_sorting.py | 29 ++- pipeline/run_training.py | 192 ++++++++++++++ pipeline/sessions.py | 26 ++ scripts/run_training.py | 188 -------------- 5 files changed, 472 insertions(+), 198 deletions(-) create mode 100644 pipeline/run_inference.py rename scripts/run_inference.py => pipeline/run_spike_sorting.py (94%) create mode 100644 pipeline/run_training.py create mode 100644 pipeline/sessions.py delete mode 100644 scripts/run_training.py diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py new file mode 100644 index 0000000..c56318d --- /dev/null +++ b/pipeline/run_inference.py @@ -0,0 +1,235 @@ +import warnings + + + +warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +#### IMPORTS ####### +import os +import sys +import json +import numpy as np +from pathlib import Path +import pandas as pd +import time + + +# SpikeInterface +import spikeinterface as si +import spikeinterface.extractors as se +import spikeinterface.preprocessing as spre +import spikeinterface.sorters as ss +import spikeinterface.postprocessing as spost +import spikeinterface.comparison as sc +import spikeinterface.qualitymetrics as sqm + +# Tensorflow +import tensorflow as tf + + +os.environ['OPENBLAS_NUM_THREADS'] = '1' + + +base_path = Path("../../..") + +##### DEFINE DATASETS AND FOLDERS ####### +from sessions import all_sessions + +n_jobs = 16 +job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") + +data_folder = base_path / "data" +scratch_folder = base_path / "scratch" +results_folder = base_path / "results" + + +# DATASET_BUCKET = "s3://aind-benchmark-data/ephys-compression/aind-np2/" +DATASET_BUCKET = data_folder / "ephys-compression-benchmark" + +DEBUG = False +NUM_DEBUG_SESSIONS = 2 +DEBUG_DURATION = 20 + +##### DEFINE PARAMS ##### +OVERWRITE = False +USE_GPU = True +FULL_INFERENCE = True + +# Define training and testing constants (@Jad you can gradually increase this) + + +FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" + +# DI params +pre_frame = 30 +post_frame = 30 +pre_post_omission = 1 +desired_shape = (192, 2) +# play around with these +inference_n_jobs = 16 +inference_chunk_duration = "500ms" +inference_memory_gpu = 2000 #MB + +di_kwargs = dict( + pre_frame=pre_frame, + post_frame=post_frame, + pre_post_omission=pre_post_omission, + desired_shape=desired_shape, +) + +if __name__ == "__main__": + if len(sys.argv) == 2: + if sys.argv[1] == "true": + DEBUG = True + else: + DEBUG = False + + json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] + + if len(json_files) > 0: + print(f"Found {len(json_files)} JSON config") + session_dict = {} + # each json file contains a session to run + for json_file in json_files: + with open(json_file, "r") as f: + d = json.load(f) + probe = d["probe"] + if probe not in session_dict: + session_dict[probe] = [] + session = d["session"] + assert session in all_sessions[probe], f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" + session_dict[probe].append(session) + else: + session_dict = all_sessions + + print(session_dict) + + if DEBUG: + TRAINING_START_S = 0 + TRAINING_END_S = 0.2 + TESTING_START_S = 10 + TESTING_END_S = 10.05 + OVERWRITE = True + else: + TRAINING_START_S = 0 + TRAINING_END_S = 20 + TESTING_START_S = 70 + TESTING_END_S = 70.5 + OVERWRITE = False + + si.set_global_job_kwargs(**job_kwargs) + + print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") + + #### START #### + + if (data_folder / "models").is_dir(): + data_model_folder = data_folder / "models" + else: + data_subfolders = [p for p in data_folder.iterdir() if (p / "models").is_dir()] + assert len(data_subfolders) == 1 + data_model_folder = data_subfolders[0] / "models" + + for probe, sessions in session_dict.items(): + if DEBUG and len(sessions) > NUM_DEBUG_SESSIONS: + sessions_to_run = sessions[:NUM_DEBUG_SESSIONS] + else: + sessions_to_run = sessions + print(f"Dataset {probe}") + for session in sessions_to_run: + print(f"\nAnalyzing session {session}\n") + dataset_name, session_name = session.split("/") + + if str(DATASET_BUCKET).startswith("s3"): + raw_data_folder = scratch_folder / "raw" + raw_data_folder.mkdir(exist_ok=True) + dst_folder = raw_data_folder / session + + # download dataset + dst_folder.mkdir(exist_ok=True) + + src_folder = f"{DATASET_BUCKET}{session}" + + cmd = f"aws s3 sync --no-sign-request {src_folder} {dst_folder}" + # aws command to download + os.system(cmd) + else: + raw_data_folder = DATASET_BUCKET + dst_folder = raw_data_folder / session + + recording_folder = dst_folder + recording = si.load_extractor(recording_folder) + if DEBUG: + recording = recording.frame_slice( + start_frame=0, + end_frame=int(DEBUG_DURATION * recording.sampling_frequency), + ) + + results_dict = {} + for filter_option in FILTER_OPTIONS: + print(f"\tFilter option: {filter_option}") + results_dict[filter_option] = {} + # train DI models + print(f"\t\tTraning DI") + training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) + testing_time = np.round(TESTING_END_S - TESTING_START_S, 3) + model_name = f"{filter_option}_t{training_time}s_v{testing_time}s" + + # apply filter and zscore + if filter_option == "hp": + recording_processed = spre.highpass_filter(recording) + elif filter_option == "bp": + recording_processed = spre.bandpass_filter(recording) + else: + recording_processed = recording + recording_zscore = spre.zscore(recording_processed) + + # train model + model_folder = data_model_folder / session / filter_option + model_path = [p for p in model_folder.iterdir() if p.name.endswith("model.h5") and filter_option in p.name][0] + # full inference + output_folder = ( + results_folder / "deepinterpolated" / session / filter_option + ) + if OVERWRITE and output_folder.is_dir(): + shutil.rmtree(output_folder) + + if not output_folder.is_dir(): + t_start_inference = time.perf_counter() + output_folder.parent.mkdir(exist_ok=True, parents=True) + recording_di = spre.deepinterpolate( + recording_zscore, + model_path=model_path, + pre_frame=pre_frame, + post_frame=post_frame, + pre_post_omission=pre_post_omission, + memory_gpu=inference_memory_gpu, + use_gpu=USE_GPU + ) + recording_di = recording_di.save( + folder=output_folder, + n_jobs=inference_n_jobs, + chunk_duration=inference_chunk_duration, + ) + t_stop_inference = time.perf_counter() + elapsed_time_inference = np.round( + t_stop_inference - t_start_inference, 2 + ) + print(f"\t\tElapsed time INFERENCE: {elapsed_time_inference}s") + else: + print("\t\tLoading existing folder") + recording_di = si.load_extractor(output_folder) + # apply inverse z-scoring + inverse_gains = 1 / recording_zscore.gain + inverse_offset = -recording_zscore.offset * inverse_gains + recording_di = spre.scale( + recording_di, gain=inverse_gains, offset=inverse_offset, dtype="float" + ) + + # save processed json + processed_folder = results_folder / "processed" / session / filter_option + processed_folder.mkdir(exist_ok=True, parents=True) + recording_processed.dump_to_json(processed_folder / "processed.json", relative_to=results_folder) + recording_di.dump_to_json(processed_folder / f"deepinterpolated.json", relative_to=results_folder) diff --git a/scripts/run_inference.py b/pipeline/run_spike_sorting.py similarity index 94% rename from scripts/run_inference.py rename to pipeline/run_spike_sorting.py index 1bc182e..37d7ab2 100644 --- a/scripts/run_inference.py +++ b/pipeline/run_spike_sorting.py @@ -141,26 +141,20 @@ ] unit_level_results = None - if (data_folder / "models").is_dir(): - data_model_folder = data_folder - else: - data_subfolders = [p for p in data_folder.iterdir()] - assert len(data_subfolders) == 1 - data_model_folder = data_subfolders[0] + sessions = sessions[:2] for session in sessions: print(f"\nAnalyzing session {session}\n") if str(DATASET_BUCKET).startswith("s3"): raw_data_folder = scratch_folder / "raw" raw_data_folder.mkdir(exist_ok=True) - dst_folder = raw_data_folder / session # download dataset dst_folder.mkdir(exist_ok=True) src_folder = f"{DATASET_BUCKET}{session}" - cmd = f"aws s3 sync --no-sign-request {src_folder} {dst_folder}" + cmd = f"aws s3 sync {src_folder} {dst_folder}" # aws command to download os.system(cmd) else: @@ -200,8 +194,23 @@ recording_zscore = spre.zscore(recording_processed) # train model - model_folder = data_model_folder / "models" / session / filter_option - model_path = [p for p in model_folder if p.name.endswith("model.h5") and filter_option in p.name][0] + model_folder = results_folder / "models" / session / filter_option + model_folder.parent.mkdir(parents=True, exist_ok=True) + # Use SI function + t_start_training = time.perf_counter() + model_path = spre.train_deepinterpolation( + recording_zscore, + model_folder=model_folder, + model_name=model_name, + train_start_s=TRAINING_START_S, + train_end_s=TRAINING_END_S, + test_start_s=TESTING_START_S, + test_end_s=TESTING_END_S, + **di_kwargs, + ) + t_stop_training = time.perf_counter() + elapsed_time_training = np.round(t_stop_training - t_start_training, 2) + print(f"\t\tElapsed time TRAINING: {elapsed_time_training}s") # full inference output_folder = ( results_folder / "deepinterpolated" / session / filter_option diff --git a/pipeline/run_training.py b/pipeline/run_training.py new file mode 100644 index 0000000..03dc066 --- /dev/null +++ b/pipeline/run_training.py @@ -0,0 +1,192 @@ +import warnings + +warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +#### IMPORTS ####### +import os +import sys +import json +import numpy as np +from pathlib import Path +import pandas as pd +import time + + +# SpikeInterface +import spikeinterface as si +import spikeinterface.extractors as se +import spikeinterface.preprocessing as spre +import spikeinterface.sorters as ss +import spikeinterface.postprocessing as spost +import spikeinterface.comparison as sc +import spikeinterface.qualitymetrics as sqm + +# Tensorflow +import tensorflow as tf + + +base_path = Path("../../..") + +##### DEFINE DATASETS AND FOLDERS ####### +from sessions import all_sessions + +n_jobs = 16 + +job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") + +data_folder = base_path / "data" +scratch_folder = base_path / "scratch" +results_folder = base_path / "results" + + +# DATASET_BUCKET = "s3://aind-benchmark-data/ephys-compression/aind-np2/" +DATASET_BUCKET = data_folder / "ephys-compression-benchmark" + +DEBUG = False +NUM_DEBUG_SESSIONS = 2 +DEBUG_DURATION = 20 + +##### DEFINE PARAMS ##### +OVERWRITE = False +USE_GPU = True +FULL_INFERENCE = True + +# Define training and testing constants (@Jad you can gradually increase this) + + +FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" + +# DI params +pre_frame = 30 +post_frame = 30 +pre_post_omission = 1 +desired_shape = (192, 2) + +di_kwargs = dict( + pre_frame=pre_frame, + post_frame=post_frame, + pre_post_omission=pre_post_omission, + desired_shape=desired_shape, +) + + +if __name__ == "__main__": + if len(sys.argv) == 2: + if sys.argv[1] == "true": + DEBUG = True + else: + DEBUG = False + + json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] + + if len(json_files) > 0: + print(f"Found {len(json_files)} JSON config") + session_dict = {} + # each json file contains a session to run + for json_file in json_files: + with open(json_file, "r") as f: + d = json.load(f) + probe = d["probe"] + if probe not in session_dict: + session_dict[probe] = [] + session = d["session"] + assert session in all_sessions[probe], f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" + session_dict[probe].append(session) + else: + session_dict = all_sessions + + print(session_dict) + + if DEBUG: + TRAINING_START_S = 0 + TRAINING_END_S = 0.2 + TESTING_START_S = 10 + TESTING_END_S = 10.05 + OVERWRITE = True + else: + TRAINING_START_S = 0 + TRAINING_END_S = 20 + TESTING_START_S = 70 + TESTING_END_S = 70.5 + OVERWRITE = False + + si.set_global_job_kwargs(**job_kwargs) + + print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") + + for probe, sessions in session_dict.items(): + print(f"Dataset {probe}") + if DEBUG and len(sessions) > NUM_DEBUG_SESSIONS: + sessions_to_run = sessions[:NUM_DEBUG_SESSIONS] + else: + sessions_to_run = sessions + for session in sessions_to_run: + print(f"\nAnalyzing session {session}\n") + if str(DATASET_BUCKET).startswith("s3"): + raw_data_folder = scratch_folder / "raw" + raw_data_folder.mkdir(exist_ok=True) + + # download dataset + dst_folder.mkdir(exist_ok=True) + + src_folder = f"{DATASET_BUCKET}{session}" + + cmd = f"aws s3 sync {src_folder} {dst_folder}" + # aws command to download + os.system(cmd) + else: + raw_data_folder = DATASET_BUCKET + dst_folder = raw_data_folder / session + + if "np1" in dst_folder.name: + probe = "NP1" + else: + probe = "NP2" + + recording_folder = dst_folder + recording = si.load_extractor(recording_folder) + if DEBUG: + recording = recording.frame_slice( + start_frame=0, + end_frame=int(DEBUG_DURATION * recording.sampling_frequency), + ) + + results_dict = {} + for filter_option in FILTER_OPTIONS: + print(f"\tFilter option: {filter_option}") + results_dict[filter_option] = {} + # train DI models + print(f"\t\tTraning DI") + training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) + testing_time = np.round(TESTING_END_S - TESTING_START_S, 3) + model_name = f"{filter_option}_t{training_time}s_v{testing_time}s" + + # apply filter and zscore + if filter_option == "hp": + recording_processed = spre.highpass_filter(recording) + elif filter_option == "bp": + recording_processed = spre.bandpass_filter(recording) + else: + recording_processed = recording + recording_zscore = spre.zscore(recording_processed) + + # train model + model_folder = results_folder / "models" / session / filter_option + model_folder.parent.mkdir(parents=True, exist_ok=True) + # Use SI function + t_start_training = time.perf_counter() + model_path = spre.train_deepinterpolation( + recording_zscore, + model_folder=model_folder, + model_name=model_name, + train_start_s=TRAINING_START_S, + train_end_s=TRAINING_END_S, + test_start_s=TESTING_START_S, + test_end_s=TESTING_END_S, + **di_kwargs, + ) + t_stop_training = time.perf_counter() + elapsed_time_training = np.round(t_stop_training - t_start_training, 2) + print(f"\t\tElapsed time TRAINING: {elapsed_time_training}s") diff --git a/pipeline/sessions.py b/pipeline/sessions.py new file mode 100644 index 0000000..4373e9b --- /dev/null +++ b/pipeline/sessions.py @@ -0,0 +1,26 @@ + + +all_sessions = { + "NP1": + [ + "aind-np1/625749_2022-08-03_15-15-06_ProbeA", + "aind-np1/634568_2022-08-05_15-59-46_ProbeA", + "aind-np1/634569_2022-08-09_16-14-38_ProbeA", + "aind-np1/634571_2022-08-04_14-27-05_ProbeA", + "ibl-np1/CSHZAD026_2020-09-04_probe00", + "ibl-np1/CSHZAD029_2020-09-09_probe00", + "ibl-np1/SWC054_2020-10-05_probe00", + "ibl-np1/SWC054_2020-10-05_probe01", + ], + "NP2": + [ + "aind-np2/595262_2022-02-21_15-18-07_ProbeA", + "aind-np2/602454_2022-03-22_16-30-03_ProbeB", + "aind-np2/612962_2022-04-13_19-18-04_ProbeB", + "aind-np2/612962_2022-04-14_17-17-10_ProbeC", + "aind-np2/618197_2022-06-21_14-08-06_ProbeC", + "aind-np2/618318_2022-04-13_14-59-07_ProbeB", + "aind-np2/618384_2022-04-14_15-11-00_ProbeB", + "aind-np2/621362_2022-07-14_11-19-36_ProbeA", + ] +} \ No newline at end of file diff --git a/scripts/run_training.py b/scripts/run_training.py deleted file mode 100644 index 15f5275..0000000 --- a/scripts/run_training.py +++ /dev/null @@ -1,188 +0,0 @@ -import warnings - -warnings.filterwarnings("ignore") -warnings.filterwarnings("ignore", category=DeprecationWarning) - - -#### IMPORTS ####### -import os -import sys -import numpy as np -from pathlib import Path -from numba import cuda -import pandas as pd -import time - - -# SpikeInterface -import spikeinterface as si -import spikeinterface.extractors as se -import spikeinterface.preprocessing as spre -import spikeinterface.sorters as ss -import spikeinterface.postprocessing as spost -import spikeinterface.comparison as sc -import spikeinterface.qualitymetrics as sqm - -# Tensorflow -import tensorflow as tf - - -base_path = Path("../../..") - -##### DEFINE DATASETS AND FOLDERS ####### - -sessions = [ - "595262_2022-02-21_15-18-07_ProbeA", - "602454_2022-03-22_16-30-03_ProbeB", - "612962_2022-04-13_19-18-04_ProbeB", - "612962_2022-04-14_17-17-10_ProbeC", - "618197_2022-06-21_14-08-06_ProbeC", - "618318_2022-04-13_14-59-07_ProbeB", - "618384_2022-04-14_15-11-00_ProbeB", - "621362_2022-07-14_11-19-36_ProbeA", -] -n_jobs = 16 - -job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") - -data_folder = base_path / "data" -scratch_folder = base_path / "scratch" -results_folder = base_path / "results" - - -# DATASET_BUCKET = "s3://aind-benchmark-data/ephys-compression/aind-np2/" -DATASET_BUCKET = data_folder / "ephys-compression-benchmark" / "aind-np2" - -DEBUG = False -NUM_DEBUG_SESSIONS = 2 -DEBUG_DURATION = 20 - -##### DEFINE PARAMS ##### -OVERWRITE = False -USE_GPU = True -FULL_INFERENCE = True - -# Define training and testing constants (@Jad you can gradually increase this) - - -FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" - -# DI params -pre_frame = 30 -post_frame = 30 -pre_post_omission = 1 -desired_shape = (192, 2) -# play around with these -inference_n_jobs = 4 -inference_chunk_duration = "100ms" -inference_memory_gpu = 2000 #MB - -di_kwargs = dict( - pre_frame=pre_frame, - post_frame=post_frame, - pre_post_omission=pre_post_omission, - desired_shape=desired_shape, -) - -sorter_name = "pykilosort" -singularity_image = False - -match_score = 0.7 - - -if __name__ == "__main__": - if len(sys.argv) == 3: - if sys.argv[1] == "true": - DEBUG = True - else: - DEBUG = False - if sys.argv[2] != "all": - sessions = [sys.argv[2]] - - if DEBUG: - TRAINING_START_S = 0 - TRAINING_END_S = 0.2 - TESTING_START_S = 10 - TESTING_END_S = 10.05 - if len(sessions) > NUM_DEBUG_SESSIONS: - sessions = sessions[:NUM_DEBUG_SESSIONS] - OVERWRITE = True - else: - TRAINING_START_S = 0 - TRAINING_END_S = 20 - TESTING_START_S = 70 - TESTING_END_S = 70.5 - OVERWRITE = False - - si.set_global_job_kwargs(**job_kwargs) - - print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") - - for session in sessions: - print(f"\nAnalyzing session {session}\n") - if str(DATASET_BUCKET).startswith("s3"): - raw_data_folder = scratch_folder / "raw" - raw_data_folder.mkdir(exist_ok=True) - - # download dataset - dst_folder.mkdir(exist_ok=True) - - src_folder = f"{DATASET_BUCKET}{session}" - - cmd = f"aws s3 sync {src_folder} {dst_folder}" - # aws command to download - os.system(cmd) - else: - raw_data_folder = DATASET_BUCKET - dst_folder = raw_data_folder / session - - if "np1" in dst_folder.name: - probe = "NP1" - else: - probe = "NP2" - - recording_folder = dst_folder - recording = si.load_extractor(recording_folder) - if DEBUG: - recording = recording.frame_slice( - start_frame=0, - end_frame=int(DEBUG_DURATION * recording.sampling_frequency), - ) - - results_dict = {} - for filter_option in FILTER_OPTIONS: - print(f"\tFilter option: {filter_option}") - results_dict[filter_option] = {} - # train DI models - print(f"\t\tTraning DI") - training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) - testing_time = np.round(TESTING_END_S - TESTING_START_S, 3) - model_name = f"{filter_option}_t{training_time}s_v{testing_time}s" - - # apply filter and zscore - if filter_option == "hp": - recording_processed = spre.highpass_filter(recording) - elif filter_option == "bp": - recording_processed = spre.bandpass_filter(recording) - else: - recording_processed = recording - recording_zscore = spre.zscore(recording_processed) - - # train model - model_folder = results_folder / "models" / session / filter_option - model_folder.parent.mkdir(parents=True, exist_ok=True) - # Use SI function - t_start_training = time.perf_counter() - model_path = spre.train_deepinterpolation( - recording_zscore, - model_folder=model_folder, - model_name=model_name, - train_start_s=TRAINING_START_S, - train_end_s=TRAINING_END_S, - test_start_s=TESTING_START_S, - test_end_s=TESTING_END_S, - **di_kwargs, - ) - t_stop_training = time.perf_counter() - elapsed_time_training = np.round(t_stop_training - t_start_training, 2) - print(f"\t\tElapsed time TRAINING: {elapsed_time_training}s") \ No newline at end of file From ead234374b1caa763a2daffd0a3258d4b1177db9 Mon Sep 17 00:00:00 2001 From: alejoe91 Date: Tue, 18 Jul 2023 12:42:26 +0000 Subject: [PATCH 05/84] update run_spike_sorting --- pipeline/run_inference.py | 4 +- pipeline/run_spike_sorting.py | 556 ++++++++++++++-------------------- 2 files changed, 225 insertions(+), 335 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index c56318d..d268696 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -1,7 +1,5 @@ import warnings - - warnings.filterwarnings("ignore") warnings.filterwarnings("ignore", category=DeprecationWarning) @@ -70,6 +68,7 @@ # play around with these inference_n_jobs = 16 inference_chunk_duration = "500ms" +inference_predict_workers = 8 inference_memory_gpu = 2000 #MB di_kwargs = dict( @@ -206,6 +205,7 @@ post_frame=post_frame, pre_post_omission=pre_post_omission, memory_gpu=inference_memory_gpu, + predict_workers=inference_predict_workers, use_gpu=USE_GPU ) recording_di = recording_di.save( diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 37d7ab2..4c8d17c 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -7,6 +7,7 @@ #### IMPORTS ####### import os import sys +import json import numpy as np from pathlib import Path from numba import cuda @@ -23,24 +24,11 @@ import spikeinterface.comparison as sc import spikeinterface.qualitymetrics as sqm -# Tensorflow -import tensorflow as tf - base_path = Path("../../..") ##### DEFINE DATASETS AND FOLDERS ####### - -sessions = [ - "595262_2022-02-21_15-18-07_ProbeA", - "602454_2022-03-22_16-30-03_ProbeB", - "612962_2022-04-13_19-18-04_ProbeB", - "612962_2022-04-14_17-17-10_ProbeC", - "618197_2022-06-21_14-08-06_ProbeC", - "618318_2022-04-13_14-59-07_ProbeB", - "618384_2022-04-14_15-11-00_ProbeB", - "621362_2022-07-14_11-19-36_ProbeA", -] +from sessions import all_sessions n_jobs = 16 job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") @@ -57,362 +45,264 @@ NUM_DEBUG_SESSIONS = 2 DEBUG_DURATION = 20 -##### DEFINE PARAMS ##### -OVERWRITE = False -USE_GPU = True -FULL_INFERENCE = True - -# Define training and testing constants (@Jad you can gradually increase this) - +# Define training and testing constants FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" -# DI params -pre_frame = 30 -post_frame = 30 -pre_post_omission = 1 -desired_shape = (192, 2) -# play around with these -inference_n_jobs = 4 -inference_chunk_duration = "100ms" -inference_memory_gpu = 2000 #MB - -di_kwargs = dict( - pre_frame=pre_frame, - post_frame=post_frame, - pre_post_omission=pre_post_omission, - desired_shape=desired_shape, -) sorter_name = "pykilosort" singularity_image = False - match_score = 0.7 if __name__ == "__main__": - if len(sys.argv) == 3: + if len(sys.argv) == 2: if sys.argv[1] == "true": DEBUG = True else: DEBUG = False - if sys.argv[2] != "all": - sessions = [sys.argv[2]] + + json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] + + if len(json_files) > 0: + print(f"Found {len(json_files)} JSON config") + session_dict = {} + # each json file contains a session to run + for json_file in json_files: + with open(json_file, "r") as f: + d = json.load(f) + probe = d["probe"] + if probe not in session_dict: + session_dict[probe] = [] + session = d["session"] + assert session in all_sessions[probe], f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" + session_dict[probe].append(session) + else: + session_dict = all_sessions + + print(session_dict) if DEBUG: - TRAINING_START_S = 0 - TRAINING_END_S = 0.2 - TESTING_START_S = 10 - TESTING_END_S = 10.05 if len(sessions) > NUM_DEBUG_SESSIONS: sessions = sessions[:NUM_DEBUG_SESSIONS] OVERWRITE = True else: - TRAINING_START_S = 0 - TRAINING_END_S = 20 - TESTING_START_S = 70 - TESTING_END_S = 70.5 OVERWRITE = False si.set_global_job_kwargs(**job_kwargs) - print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") - #### START #### - session_level_results = pd.DataFrame( - columns=[ - "session", - "probe", - "filter_option", - "num_units", - "num_units_di", - "sorting_path", - "sorting_path_di", - "num_match", - ] - ) - - unit_level_results_columns = [ - "session", - "probe", - "filter_option", - "unit_id", - "unit_id_di", - ] - unit_level_results = None - - sessions = sessions[:2] - - for session in sessions: - print(f"\nAnalyzing session {session}\n") - if str(DATASET_BUCKET).startswith("s3"): - raw_data_folder = scratch_folder / "raw" - raw_data_folder.mkdir(exist_ok=True) - - # download dataset - dst_folder.mkdir(exist_ok=True) - - src_folder = f"{DATASET_BUCKET}{session}" - - cmd = f"aws s3 sync {src_folder} {dst_folder}" - # aws command to download - os.system(cmd) - else: - raw_data_folder = DATASET_BUCKET - dst_folder = raw_data_folder / session - - if "np1" in dst_folder.name: - probe = "NP1" + if (data_folder / "processed").is_dir(): + processed_folder = data_folder / "processed" + deepinterpolated_folder = data_folder / "deepinterpolated" + base_folder = data_folder + else: + data_subfolders = [p for p in data_folder.iterdir() if (p / "processed").is_dir()] + assert len(data_subfolders) == 1 + processed_folder = data_subfolders[0] / "processed" + deepinterpolated_folder = data_subfolders[0] / "deepinterpolated" + base_folder = data_subfolders[0] + + for probe, sessions in session_dict.items(): + if DEBUG and len(sessions) > NUM_DEBUG_SESSIONS: + sessions_to_run = sessions[:NUM_DEBUG_SESSIONS] else: - probe = "NP2" - - recording_folder = dst_folder - recording = si.load_extractor(recording_folder) - if DEBUG: - recording = recording.frame_slice( - start_frame=0, - end_frame=int(DEBUG_DURATION * recording.sampling_frequency), + sessions_to_run = sessions + + print(f"Dataset {probe}") + for session in sessions_to_run: + print(f"\nAnalyzing session {session}\n") + dataset_name, session_name = session.split("/") + + session_level_results = pd.DataFrame( + columns=[ + "dataset", + "session", + "probe", + "filter_option", + "num_units", + "num_units_di", + "sorting_path", + "sorting_path_di", + "num_match", + ] ) - results_dict = {} - for filter_option in FILTER_OPTIONS: - print(f"\tFilter option: {filter_option}") - results_dict[filter_option] = {} - # train DI models - print(f"\t\tTraning DI") - training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) - testing_time = np.round(TESTING_END_S - TESTING_START_S, 3) - model_name = f"{filter_option}_t{training_time}s_v{testing_time}s" - - # apply filter and zscore - if filter_option == "hp": - recording_processed = spre.highpass_filter(recording) - elif filter_option == "bp": - recording_processed = spre.bandpass_filter(recording) - else: - recording_processed = recording - recording_zscore = spre.zscore(recording_processed) - - # train model - model_folder = results_folder / "models" / session / filter_option - model_folder.parent.mkdir(parents=True, exist_ok=True) - # Use SI function - t_start_training = time.perf_counter() - model_path = spre.train_deepinterpolation( - recording_zscore, - model_folder=model_folder, - model_name=model_name, - train_start_s=TRAINING_START_S, - train_end_s=TRAINING_END_S, - test_start_s=TESTING_START_S, - test_end_s=TESTING_END_S, - **di_kwargs, - ) - t_stop_training = time.perf_counter() - elapsed_time_training = np.round(t_stop_training - t_start_training, 2) - print(f"\t\tElapsed time TRAINING: {elapsed_time_training}s") - # full inference - output_folder = ( - results_folder / "deepinterpolated" / session / filter_option - ) - if OVERWRITE and output_folder.is_dir(): - shutil.rmtree(output_folder) - - if not output_folder.is_dir(): - t_start_inference = time.perf_counter() - output_folder.parent.mkdir(exist_ok=True, parents=True) - recording_di = spre.deepinterpolate( - recording_zscore, - model_path=model_path, - pre_frame=pre_frame, - post_frame=post_frame, - pre_post_omission=pre_post_omission, - memory_gpu=inference_memory_gpu, - use_gpu=USE_GPU - ) - recording_di = recording_di.save( - folder=output_folder, - n_jobs=inference_n_jobs, - chunk_duration=inference_chunk_duration, - ) - t_stop_inference = time.perf_counter() - elapsed_time_inference = np.round( - t_stop_inference - t_start_inference, 2 + unit_level_results_columns = [ + "dataset", + "session", + "probe", + "filter_option", + "unit_id", + "unit_id_di", + "agreement_score" + ] + unit_level_results = None + + for filter_option in FILTER_OPTIONS: + print(f"\tFilter option: {filter_option}") + + # load recordings + # save processed json + processed_json_folder = processed_folder / session / filter_option + recording = si.load_extractor(processed_json_folder / "processed.json", base_folder=data_folder) + recording_di = si.load_extractor(processed_json_folder / "deepinterpolated.json", base_folder=base_folder) + + # run spike sorting + sorting_output_folder = ( + results_folder / "sortings" / session / filter_option ) - print(f"\t\tElapsed time INFERENCE: {elapsed_time_inference}s") - else: - print("\t\tLoading existing folder") - recording_di = si.load_extractor(output_folder) - # apply inverse z-scoring - inverse_gains = 1 / recording_zscore.gain - inverse_offset = -recording_zscore.offset * inverse_gains - recording_di_inverse_zscore = spre.scale( - recording_di, gain=inverse_gains, offset=inverse_offset, dtype="float" - ) + sorting_output_folder.mkdir(parents=True, exist_ok=True) + + if (sorting_output_folder / "sorting").is_dir() and not OVERWRITE: + print("\t\tLoading NO DI sorting") + sorting = si.load_extractor(sorting_output_folder / "sorting") + else: + print(f"\t\tSpike sorting NO DI with {sorter_name}") + sorting = ss.run_sorter( + sorter_name, + recording=recording, + output_folder=scratch_folder / session / filter_option / "no_di", + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting = sorting.save( + folder=sorting_output_folder / "sorting" + ) - results_dict[filter_option]["recording_no_di"] = recording_processed - results_dict[filter_option]["recording_di"] = recording_di_inverse_zscore + if (sorting_output_folder / "sorting_di").is_dir() and not OVERWRITE: + print("\t\tLoading DI sorting") + sorting_di = si.load_extractor(sorting_output_folder / "sorting_di") + else: + print(f"\t\tSpike sorting DI with {sorter_name}") + sorting_di = ss.run_sorter( + sorter_name, + recording=recording_di, + output_folder=scratch_folder / session / filter_option / "di", + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting_di = sorting_di.save( + folder=sorting_output_folder / "sorting_di" + ) - # run spike sorting - sorting_output_folder = ( - results_folder / "sortings" / session / filter_option - ) - sorting_output_folder.mkdir(parents=True, exist_ok=True) - - recording_no_di = results_dict[filter_option]["recording_no_di"] - if ( - sorting_output_folder / f"no_di_{model_name}" - ).is_dir() and not OVERWRITE: - print("\t\tLoading NO DI sorting") - sorting_no_di = si.load_extractor(sorting_output_folder / "sorting") - else: - print(f"\t\tSpike sorting NO DI with {sorter_name}") - sorting_no_di = ss.run_sorter( - sorter_name, - recording=recording_no_di, - n_jobs=n_jobs, - verbose=True, - singularity_image=singularity_image, + # compare outputs + print("\t\tComparing sortings") + comp = sc.compare_two_sorters( + sorting1=sorting, + sorting2=sorting_di, + sorting1_name="no_di", + sorting2_name="di", + match_score=match_score, ) - sorting_no_di = sorting_no_di.save( - folder=sorting_output_folder / "sorting" + matched_units = comp.get_matching()[0] + matched_unit_ids = matched_units.index.values.astype(int) + matched_unit_ids_di = matched_units.values.astype(int) + matched_units_valid = matched_unit_ids_di != -1 + matched_unit_ids = matched_unit_ids[matched_units_valid] + matched_unit_ids_di = matched_unit_ids_di[matched_units_valid] + sorting_matched = sorting.select_units( + unit_ids=matched_unit_ids ) - results_dict[filter_option]["sorting_no_di"] = sorting_no_di - - recording_di = results_dict[filter_option]["recording_di"] - if (sorting_output_folder / f"di_{model_name}").is_dir() and not OVERWRITE: - print("\t\tLoading DI sorting") - sorting_di = si.load_extractor(sorting_output_folder / "sorting_di") - else: - print(f"\t\tSpike sorting DI with {sorter_name}") - sorting_di = ss.run_sorter( - sorter_name, - recording=recording_di, - n_jobs=n_jobs, - verbose=True, - singularity_image=singularity_image, + sorting_di_matched = sorting_di.select_units(unit_ids=matched_unit_ids_di) + + ## add entries to session-level results + new_row = { + "dataset": dataset_name, + "session": session_name, + "filter_option": filter_option, + "probe": probe, + "num_units": len(sorting.unit_ids), + "num_units_di": len(sorting_di.unit_ids), + "num_match": len(sorting_matched.unit_ids), + "sorting_path": str( + (sorting_output_folder / "sorting").relative_to(results_folder) + ), + "sorting_path_di": str( + (sorting_output_folder / "sorting_di_").relative_to(results_folder) + ), + } + session_level_results = pd.concat( + [session_level_results, pd.DataFrame([new_row])], ignore_index=True ) - sorting_di = sorting_di.save( - folder=sorting_output_folder / "sorting_di" + + print( + f"\n\t\tNum units: {new_row['num_units']} - Num units DI: {new_row['num_units_di']} - Num match: {new_row['num_match']}" ) - results_dict[filter_option]["sorting_di"] = sorting_di - - # compare outputs - print("\t\tComparing sortings") - comp = sc.compare_two_sorters( - sorting1=sorting_no_di, - sorting2=sorting_di, - sorting1_name="no_di", - sorting2_name="di", - match_score=match_score, - ) - matched_units = comp.get_matching()[0] - matched_unit_ids_no_di = matched_units.index.values.astype(int) - matched_unit_ids_di = matched_units.values.astype(int) - matched_units_valid = matched_unit_ids_di != -1 - matched_unit_ids_no_di = matched_unit_ids_no_di[matched_units_valid] - matched_unit_ids_di = matched_unit_ids_di[matched_units_valid] - sorting_no_di_matched = sorting_no_di.select_units( - unit_ids=matched_unit_ids_no_di - ) - sorting_di_matched = sorting_di.select_units(unit_ids=matched_unit_ids_di) - - ## add entries to session-level results - new_row = { - "session": session, - "filter_option": filter_option, - "probe": probe, - "num_units": len(sorting_no_di.unit_ids), - "num_units_di": len(sorting_di.unit_ids), - "num_match": len(sorting_no_di_matched.unit_ids), - "sorting_path": str( - (sorting_output_folder / "sorting").relative_to(results_folder) - ), - "sorting_path_di": str( - (sorting_output_folder / "sorting_di_{model_name}").relative_to( - results_folder + + # waveforms + waveforms_folder = results_folder / "waveforms" / session / filter_option + waveforms_folder.mkdir(exist_ok=True, parents=True) + + if (waveforms_folder / "waveforms").is_dir() and not OVERWRITE: + print("\t\tLoad NO DI waveforms") + we = si.load_waveforms(waveforms_folder / "waveforms") + else: + print("\t\tCompute NO DI waveforms") + we = si.extract_waveforms( + recording, + sorting_matched, + folder=waveforms_folder / "waveforms", + n_jobs=n_jobs, + overwrite=True, ) - ), - } - session_level_results = pd.concat( - [session_level_results, pd.DataFrame([new_row])], ignore_index=True - ) - print( - f"\n\t\tNum units: {new_row['num_units']} - Num units DI: {new_row['num_units_di']} - Num match: {new_row['num_match']}" - ) + if (waveforms_folder / "waveforms_di").is_dir() and not OVERWRITE: + print("\t\tLoad DI waveforms") + we_di = si.load_waveforms(waveforms_folder / "waveforms_di") + else: + print("\t\tCompute DI waveforms") + we_di = si.extract_waveforms( + recording_di, + sorting_di_matched, + folder=waveforms_folder / "waveforms_di", + n_jobs=n_jobs, + overwrite=True, + ) - # waveforms - waveforms_folder = results_folder / "waveforms" / session / filter_option - waveforms_folder.mkdir(exist_ok=True, parents=True) - - if (waveforms_folder / f"no_di_{model_name}").is_dir() and not OVERWRITE: - print("\t\tLoad NO DI waveforms") - we_no_di = si.load_waveforms(waveforms_folder / f"no_di_{model_name}") - else: - print("\t\tCompute NO DI waveforms") - we_no_di = si.extract_waveforms( - recording_no_di, - sorting_no_di_matched, - folder=waveforms_folder / f"no_di_{model_name}", - n_jobs=n_jobs, - overwrite=True, - ) - results_dict[filter_option]["we_no_di"] = we_no_di - - if (waveforms_folder / f"di_{model_name}").is_dir() and not OVERWRITE: - print("\t\tLoad DI waveforms") - we_di = si.load_waveforms(waveforms_folder / f"di_{model_name}") - else: - print("\t\tCompute DI waveforms") - we_di = si.extract_waveforms( - recording_di, - sorting_di_matched, - folder=waveforms_folder / f"di_{model_name}", - n_jobs=n_jobs, - overwrite=True, + # compute metrics + if we.is_extension("quality_metrics") and not OVERWRITE: + print("\t\tLoad NO DI metrics") + qm = we.load_extension("quality_metrics").get_data() + else: + print("\t\tCompute NO DI metrics") + qm = sqm.compute_quality_metrics(we) + + if we_di.is_extension("quality_metrics") and not OVERWRITE: + print("\t\tLoad DI metrics") + qm_di = we_di.load_extension("quality_metrics").get_data() + else: + print("\t\tCompute DI metrics") + qm_di = sqm.compute_quality_metrics(we_di) + + ## add entries to unit-level results + if unit_level_results is None: + for metric in qm.columns: + unit_level_results_columns.append(metric) + unit_level_results_columns.append(f"{metric}_di") + unit_level_results = pd.DataFrame(columns=unit_level_results_columns) + + new_rows = { + "dataset": [dataset_name] * len(qm), + "session": [session_name] * len(qm), + "probe": [probe] * len(qm), + "filter_option": [filter_option] * len(qm), + "unit_id": we.unit_ids, + "unit_id_di": we_di.unit_ids, + } + agreement_scores = [] + for i in range(len(we.unit_ids)): + agreement_scores.append(comp.agreement_scores.at[we.unit_ids[i], we_di.unit_ids[i]]) + new_rows["agreement_score"] = agreement_scores + for metric in qm.columns: + new_rows[metric] = qm[metric].values + new_rows[f"{metric}_di"] = qm_di[metric].values + # append new entries + unit_level_results = pd.concat( + [unit_level_results, pd.DataFrame(new_rows)], ignore_index=True ) - results_dict[filter_option]["we_di"] = we_di - - # compute metrics - if we_no_di.is_extension("quality_metrics") and not OVERWRITE: - print("\t\tLoad NO DI metrics") - qm_no_di = we_no_di.load_extension("quality_metrics").get_data() - else: - print("\t\tCompute NO DI metrics") - qm_no_di = sqm.compute_quality_metrics(we_no_di) - results_dict[filter_option]["qm_no_di"] = qm_no_di - - if we_di.is_extension("quality_metrics") and not OVERWRITE: - print("\t\tLoad DI metrics") - qm_di = we_di.load_extension("quality_metrics").get_data() - else: - print("\t\tCompute DI metrics") - qm_di = sqm.compute_quality_metrics(we_di) - results_dict[filter_option]["qm_di"] = qm_di - - ## add entries to unit-level results - if unit_level_results is None: - for metric in qm_no_di.columns: - unit_level_results_columns.append(metric) - unit_level_results_columns.append(f"{metric}_di") - unit_level_results = pd.DataFrame(columns=unit_level_results_columns) - - new_rows = { - "session": [session] * len(qm_no_di), - "probe": [probe] * len(qm_no_di), - "filter_option": [filter_option] * len(qm_no_di), - "unit_id": we_no_di.unit_ids, - "unit_id_di": we_di.unit_ids, - } - for metric in qm_no_di.columns: - new_rows[metric] = qm_no_di[metric].values - new_rows[f"{metric}_di"] = qm_di[metric].values - # append new entries - unit_level_results = pd.concat( - [unit_level_results, pd.DataFrame(new_rows)], ignore_index=True - ) - results_folder.mkdir(exist_ok=True) - session_level_results.to_csv(results_folder / "session-results.csv") - unit_level_results.to_csv(results_folder / "unit-results.csv") + session_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-sessions.csv", index=False) + unit_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-units.csv", index=False) From 8594f49d81d0984e4f39d5014e9f1b10732c9132 Mon Sep 17 00:00:00 2001 From: alejoe91 Date: Tue, 18 Jul 2023 12:50:09 +0000 Subject: [PATCH 06/84] Add generate_job_config function --- pipeline/run_inference.py | 26 +++++++++--------- pipeline/run_spike_sorting.py | 51 ++++++++++++++--------------------- pipeline/run_training.py | 4 ++- pipeline/sessions.py | 35 +++++++++++++++++++----- 4 files changed, 63 insertions(+), 53 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index d268696..4338f89 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -27,7 +27,7 @@ import tensorflow as tf -os.environ['OPENBLAS_NUM_THREADS'] = '1' +os.environ["OPENBLAS_NUM_THREADS"] = "1" base_path = Path("../../..") @@ -69,7 +69,7 @@ inference_n_jobs = 16 inference_chunk_duration = "500ms" inference_predict_workers = 8 -inference_memory_gpu = 2000 #MB +inference_memory_gpu = 2000 # MB di_kwargs = dict( pre_frame=pre_frame, @@ -98,7 +98,9 @@ if probe not in session_dict: session_dict[probe] = [] session = d["session"] - assert session in all_sessions[probe], f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" + assert ( + session in all_sessions[probe] + ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" session_dict[probe].append(session) else: session_dict = all_sessions @@ -187,11 +189,11 @@ # train model model_folder = data_model_folder / session / filter_option - model_path = [p for p in model_folder.iterdir() if p.name.endswith("model.h5") and filter_option in p.name][0] + model_path = [ + p for p in model_folder.iterdir() if p.name.endswith("model.h5") and filter_option in p.name + ][0] # full inference - output_folder = ( - results_folder / "deepinterpolated" / session / filter_option - ) + output_folder = results_folder / "deepinterpolated" / session / filter_option if OVERWRITE and output_folder.is_dir(): shutil.rmtree(output_folder) @@ -206,7 +208,7 @@ pre_post_omission=pre_post_omission, memory_gpu=inference_memory_gpu, predict_workers=inference_predict_workers, - use_gpu=USE_GPU + use_gpu=USE_GPU, ) recording_di = recording_di.save( folder=output_folder, @@ -214,9 +216,7 @@ chunk_duration=inference_chunk_duration, ) t_stop_inference = time.perf_counter() - elapsed_time_inference = np.round( - t_stop_inference - t_start_inference, 2 - ) + elapsed_time_inference = np.round(t_stop_inference - t_start_inference, 2) print(f"\t\tElapsed time INFERENCE: {elapsed_time_inference}s") else: print("\t\tLoading existing folder") @@ -224,9 +224,7 @@ # apply inverse z-scoring inverse_gains = 1 / recording_zscore.gain inverse_offset = -recording_zscore.offset * inverse_gains - recording_di = spre.scale( - recording_di, gain=inverse_gains, offset=inverse_offset, dtype="float" - ) + recording_di = spre.scale(recording_di, gain=inverse_gains, offset=inverse_offset, dtype="float") # save processed json processed_folder = results_folder / "processed" / session / filter_option diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 4c8d17c..ef3b261 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -29,6 +29,7 @@ ##### DEFINE DATASETS AND FOLDERS ####### from sessions import all_sessions + n_jobs = 16 job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") @@ -75,7 +76,9 @@ if probe not in session_dict: session_dict[probe] = [] session = d["session"] - assert session in all_sessions[probe], f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" + assert ( + session in all_sessions[probe] + ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" session_dict[probe].append(session) else: session_dict = all_sessions @@ -135,26 +138,26 @@ "filter_option", "unit_id", "unit_id_di", - "agreement_score" + "agreement_score", ] unit_level_results = None for filter_option in FILTER_OPTIONS: print(f"\tFilter option: {filter_option}") - + # load recordings # save processed json processed_json_folder = processed_folder / session / filter_option recording = si.load_extractor(processed_json_folder / "processed.json", base_folder=data_folder) - recording_di = si.load_extractor(processed_json_folder / "deepinterpolated.json", base_folder=base_folder) + recording_di = si.load_extractor( + processed_json_folder / "deepinterpolated.json", base_folder=base_folder + ) # run spike sorting - sorting_output_folder = ( - results_folder / "sortings" / session / filter_option - ) + sorting_output_folder = results_folder / "sortings" / session / filter_option sorting_output_folder.mkdir(parents=True, exist_ok=True) - if (sorting_output_folder / "sorting").is_dir() and not OVERWRITE: + if (sorting_output_folder / "sorting").is_dir() and not OVERWRITE: print("\t\tLoading NO DI sorting") sorting = si.load_extractor(sorting_output_folder / "sorting") else: @@ -167,11 +170,9 @@ verbose=True, singularity_image=singularity_image, ) - sorting = sorting.save( - folder=sorting_output_folder / "sorting" - ) + sorting = sorting.save(folder=sorting_output_folder / "sorting") - if (sorting_output_folder / "sorting_di").is_dir() and not OVERWRITE: + if (sorting_output_folder / "sorting_di").is_dir() and not OVERWRITE: print("\t\tLoading DI sorting") sorting_di = si.load_extractor(sorting_output_folder / "sorting_di") else: @@ -184,9 +185,7 @@ verbose=True, singularity_image=singularity_image, ) - sorting_di = sorting_di.save( - folder=sorting_output_folder / "sorting_di" - ) + sorting_di = sorting_di.save(folder=sorting_output_folder / "sorting_di") # compare outputs print("\t\tComparing sortings") @@ -203,9 +202,7 @@ matched_units_valid = matched_unit_ids_di != -1 matched_unit_ids = matched_unit_ids[matched_units_valid] matched_unit_ids_di = matched_unit_ids_di[matched_units_valid] - sorting_matched = sorting.select_units( - unit_ids=matched_unit_ids - ) + sorting_matched = sorting.select_units(unit_ids=matched_unit_ids) sorting_di_matched = sorting_di.select_units(unit_ids=matched_unit_ids_di) ## add entries to session-level results @@ -217,16 +214,10 @@ "num_units": len(sorting.unit_ids), "num_units_di": len(sorting_di.unit_ids), "num_match": len(sorting_matched.unit_ids), - "sorting_path": str( - (sorting_output_folder / "sorting").relative_to(results_folder) - ), - "sorting_path_di": str( - (sorting_output_folder / "sorting_di_").relative_to(results_folder) - ), + "sorting_path": str((sorting_output_folder / "sorting").relative_to(results_folder)), + "sorting_path_di": str((sorting_output_folder / "sorting_di_").relative_to(results_folder)), } - session_level_results = pd.concat( - [session_level_results, pd.DataFrame([new_row])], ignore_index=True - ) + session_level_results = pd.concat([session_level_results, pd.DataFrame([new_row])], ignore_index=True) print( f"\n\t\tNum units: {new_row['num_units']} - Num units DI: {new_row['num_units_di']} - Num match: {new_row['num_match']}" @@ -293,16 +284,14 @@ "unit_id_di": we_di.unit_ids, } agreement_scores = [] - for i in range(len(we.unit_ids)): + for i in range(len(we.unit_ids)): agreement_scores.append(comp.agreement_scores.at[we.unit_ids[i], we_di.unit_ids[i]]) new_rows["agreement_score"] = agreement_scores for metric in qm.columns: new_rows[metric] = qm[metric].values new_rows[f"{metric}_di"] = qm_di[metric].values # append new entries - unit_level_results = pd.concat( - [unit_level_results, pd.DataFrame(new_rows)], ignore_index=True - ) + unit_level_results = pd.concat([unit_level_results, pd.DataFrame(new_rows)], ignore_index=True) session_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-sessions.csv", index=False) unit_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-units.csv", index=False) diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 03dc066..00a3e4c 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -92,7 +92,9 @@ if probe not in session_dict: session_dict[probe] = [] session = d["session"] - assert session in all_sessions[probe], f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" + assert ( + session in all_sessions[probe] + ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" session_dict[probe].append(session) else: session_dict = all_sessions diff --git a/pipeline/sessions.py b/pipeline/sessions.py index 4373e9b..6797a51 100644 --- a/pipeline/sessions.py +++ b/pipeline/sessions.py @@ -1,8 +1,8 @@ - +from pathlib import Path +import json all_sessions = { - "NP1": - [ + "NP1": [ "aind-np1/625749_2022-08-03_15-15-06_ProbeA", "aind-np1/634568_2022-08-05_15-59-46_ProbeA", "aind-np1/634569_2022-08-09_16-14-38_ProbeA", @@ -12,8 +12,7 @@ "ibl-np1/SWC054_2020-10-05_probe00", "ibl-np1/SWC054_2020-10-05_probe01", ], - "NP2": - [ + "NP2": [ "aind-np2/595262_2022-02-21_15-18-07_ProbeA", "aind-np2/602454_2022-03-22_16-30-03_ProbeB", "aind-np2/612962_2022-04-13_19-18-04_ProbeB", @@ -22,5 +21,27 @@ "aind-np2/618318_2022-04-13_14-59-07_ProbeB", "aind-np2/618384_2022-04-14_15-11-00_ProbeB", "aind-np2/621362_2022-07-14_11-19-36_ProbeA", - ] -} \ No newline at end of file + ], +} + + +def generate_job_config_list(output_folder, split_probes=True): + output_folder = Path(output_folder) + output_folder.mkdir(exist_ok=True, parents=True) + + i = 0 + for probe, sessions in all_sessions.items(): + if split_probes: + i = 0 + probe_folder = output_folder / probe + probe_folder.mkdir(exist_ok=True) + else: + probe_folder = output_folder + + for session in sessions: + d = dict(session=session, probe=probe) + + with open(probe_folder / f"job{i}.json", "w") as f: + json.dump(d, f) + + i += 1 From 694eaa0f1d4383508c0fbe0f7a33beed9b6fabbf Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 15:33:35 +0200 Subject: [PATCH 07/84] Propagate JSON files to output --- pipeline/run_inference.py | 4 ++++ pipeline/run_training.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 4338f89..358c55d 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -7,6 +7,7 @@ #### IMPORTS ####### import os import sys +import shutil import json import numpy as np from pathlib import Path @@ -231,3 +232,6 @@ processed_folder.mkdir(exist_ok=True, parents=True) recording_processed.dump_to_json(processed_folder / "processed.json", relative_to=results_folder) recording_di.dump_to_json(processed_folder / f"deepinterpolated.json", relative_to=results_folder) + + for json_file in json_files: + shutil.copy(json_file, results_folder) diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 00a3e4c..ec1d9fb 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -7,6 +7,7 @@ #### IMPORTS ####### import os import sys +import shutil import json import numpy as np from pathlib import Path @@ -192,3 +193,6 @@ t_stop_training = time.perf_counter() elapsed_time_training = np.round(t_stop_training - t_start_training, 2) print(f"\t\tElapsed time TRAINING: {elapsed_time_training}s") + + for json_file in json_files: + shutil.copy(json_file, results_folder) From 92e91f09ba020980dfd1e11c75e6371c51af8d7a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 15:53:46 +0200 Subject: [PATCH 08/84] Reduce validation interval to 100ms --- pipeline/run_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipeline/run_training.py b/pipeline/run_training.py index ec1d9fb..f6c776c 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -112,7 +112,7 @@ TRAINING_START_S = 0 TRAINING_END_S = 20 TESTING_START_S = 70 - TESTING_END_S = 70.5 + TESTING_END_S = 70.1 OVERWRITE = False si.set_global_job_kwargs(**job_kwargs) From ece9cf3f33ae16d4bd33810fdfd58bbbc61f706b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 16:17:38 +0200 Subject: [PATCH 09/84] Added collect_results script --- pipeline/run_collect_results.py | 67 +++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 pipeline/run_collect_results.py diff --git a/pipeline/run_collect_results.py b/pipeline/run_collect_results.py new file mode 100644 index 0000000..b4ff433 --- /dev/null +++ b/pipeline/run_collect_results.py @@ -0,0 +1,67 @@ +import warnings + +warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +#### IMPORTS ####### +import os +import sys +import json +import numpy as np +from pathlib import Path +from numba import cuda +import pandas as pd +import time + + +# SpikeInterface +import spikeinterface as si +import spikeinterface.extractors as se +import spikeinterface.preprocessing as spre +import spikeinterface.sorters as ss +import spikeinterface.postprocessing as spost +import spikeinterface.comparison as sc +import spikeinterface.qualitymetrics as sqm + + +base_path = Path("../../..") + + +data_folder = base_path / "data" +scratch_folder = base_path / "scratch" +results_folder = base_path / "results" + + +if __name__ == "__main__": + + # concatenate dataframes + df_session = None + df_units = None + + if (data_folder / "sortings").is_dir(): + data_base_folder = data_folder + else: + data_subfolders = [p for p in data_folder.iterdir() if (p / "sortings").is_dir()] + data_base_folder = data_subfolders[0] + + session_csvs = [p for p in data_base_folder.iterdir() if "session" in p.name and p.suffix == ".csv"] + unit_csvs = [p for p in data_base_folder.iterdir() if "unit" in p.name and p.suffix == ".csv"] + + for session_csv in session_csvs: + if df_session is None: + df_session = pd.read_csv(session_csv) + else: + df_session = pd.concat([df_session, pd.read_csv(session_csv)]) + + for unit_csv in unit_csvs: + if df_units is None: + df_units = pd.read_csv(unit_csv) + else: + df_units = pd.concat([df_units, pd.read_csv(unit_csv)]) + + # copy sortings to results folder + + # save concatenated dataframes + df_session.to_csv(results_folder / "sessions.csv", index=False) + df_units.to_csv(results_folder / "units.csv", index=False) From 44162b34bcae12bf058f6f51c6f69cecf3d3ffc4 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 16:38:52 +0200 Subject: [PATCH 10/84] Change spike sorting output flders --- pipeline/run_collect_results.py | 24 ++++++++++++++++-------- pipeline/run_spike_sorting.py | 14 +++++--------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/pipeline/run_collect_results.py b/pipeline/run_collect_results.py index b4ff433..94e8efb 100644 --- a/pipeline/run_collect_results.py +++ b/pipeline/run_collect_results.py @@ -5,14 +5,11 @@ #### IMPORTS ####### -import os -import sys -import json +import shutil import numpy as np from pathlib import Path from numba import cuda import pandas as pd -import time # SpikeInterface @@ -39,10 +36,12 @@ df_session = None df_units = None - if (data_folder / "sortings").is_dir(): + probe_sortings_folders = [p for p in data_folder.iterdir() if "sortings_" in p.name and p.is_dir()] + + if len(probe_sortings_folders) > 0: data_base_folder = data_folder else: - data_subfolders = [p for p in data_folder.iterdir() if (p / "sortings").is_dir()] + data_subfolders = [p for p in data_folder.iterdir() if p.is_dir()] data_base_folder = data_subfolders[0] session_csvs = [p for p in data_base_folder.iterdir() if "session" in p.name and p.suffix == ".csv"] @@ -60,8 +59,17 @@ else: df_units = pd.concat([df_units, pd.read_csv(unit_csv)]) - # copy sortings to results folder - # save concatenated dataframes df_session.to_csv(results_folder / "sessions.csv", index=False) df_units.to_csv(results_folder / "units.csv", index=False) + + # copy sortings to results folder + sortings_folders = [p for p in data_base_folder.iterdir() if "sortings_" in p.name and p.is_dir()] + sortings_output_base_folder = results_folder / "sortings" + sortings_folders.mkdir(exist_ok=True) + + for sorting_folder in sortings_folders: + _, dataset_name, session_name, filter_option = sorting_folder.name.split("_") + sorting_output_folder = sortings_output_base_folder / dataset_name / session_name / filter_option + sorting_output_folder.mkdir(exist_ok=True, parents=True) + shutil.copytree(sorting_folder, sorting_output_folder) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index ef3b261..9370d30 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -45,6 +45,8 @@ DEBUG = False NUM_DEBUG_SESSIONS = 2 DEBUG_DURATION = 20 +OVERWRITE = False + # Define training and testing constants @@ -60,6 +62,7 @@ if len(sys.argv) == 2: if sys.argv[1] == "true": DEBUG = True + OVERWRITE = True else: DEBUG = False @@ -85,13 +88,6 @@ print(session_dict) - if DEBUG: - if len(sessions) > NUM_DEBUG_SESSIONS: - sessions = sessions[:NUM_DEBUG_SESSIONS] - OVERWRITE = True - else: - OVERWRITE = False - si.set_global_job_kwargs(**job_kwargs) #### START #### @@ -154,7 +150,7 @@ ) # run spike sorting - sorting_output_folder = results_folder / "sortings" / session / filter_option + sorting_output_folder = results_folder / f"sortings_{dataset_name}_{session_name}_{filter_option}" sorting_output_folder.mkdir(parents=True, exist_ok=True) if (sorting_output_folder / "sorting").is_dir() and not OVERWRITE: @@ -224,7 +220,7 @@ ) # waveforms - waveforms_folder = results_folder / "waveforms" / session / filter_option + waveforms_folder = results_folder / f"waveforms_{dataset_name}_{session_name}_{filter_option}" waveforms_folder.mkdir(exist_ok=True, parents=True) if (waveforms_folder / "waveforms").is_dir() and not OVERWRITE: From 555d9a9e48f7811cd7970f3f8226885f21354025 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 16:57:33 +0200 Subject: [PATCH 11/84] Remove unused imports --- pipeline/run_collect_results.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/pipeline/run_collect_results.py b/pipeline/run_collect_results.py index 94e8efb..ce32136 100644 --- a/pipeline/run_collect_results.py +++ b/pipeline/run_collect_results.py @@ -6,20 +6,10 @@ #### IMPORTS ####### import shutil -import numpy as np -from pathlib import Path -from numba import cuda -import pandas as pd -# SpikeInterface -import spikeinterface as si -import spikeinterface.extractors as se -import spikeinterface.preprocessing as spre -import spikeinterface.sorters as ss -import spikeinterface.postprocessing as spost -import spikeinterface.comparison as sc -import spikeinterface.qualitymetrics as sqm +from pathlib import Path +import pandas as pd base_path = Path("../../..") From 67d207e40128c6a0db44bed22577fa1d4d293455 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 17:03:23 +0200 Subject: [PATCH 12/84] Oups --- pipeline/run_collect_results.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipeline/run_collect_results.py b/pipeline/run_collect_results.py index ce32136..735b570 100644 --- a/pipeline/run_collect_results.py +++ b/pipeline/run_collect_results.py @@ -56,7 +56,7 @@ # copy sortings to results folder sortings_folders = [p for p in data_base_folder.iterdir() if "sortings_" in p.name and p.is_dir()] sortings_output_base_folder = results_folder / "sortings" - sortings_folders.mkdir(exist_ok=True) + sortings_output_base_folder.mkdir(exist_ok=True) for sorting_folder in sortings_folders: _, dataset_name, session_name, filter_option = sorting_folder.name.split("_") From 3b150430118cb9260e2cb65afde378fa8db8ef86 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 17:07:05 +0200 Subject: [PATCH 13/84] fix sorting paths --- pipeline/run_collect_results.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pipeline/run_collect_results.py b/pipeline/run_collect_results.py index 735b570..1b75813 100644 --- a/pipeline/run_collect_results.py +++ b/pipeline/run_collect_results.py @@ -59,7 +59,10 @@ sortings_output_base_folder.mkdir(exist_ok=True) for sorting_folder in sortings_folders: - _, dataset_name, session_name, filter_option = sorting_folder.name.split("_") + sorting_folder_split = sorting_folder.name.split("_") + dataset_name = sorting_folder_split[1] + session_name = "_".join(sorting_folder_split[2:-1]) + filter_option = sorting_folder_split[-1] sorting_output_folder = sortings_output_base_folder / dataset_name / session_name / filter_option sorting_output_folder.mkdir(exist_ok=True, parents=True) shutil.copytree(sorting_folder, sorting_output_folder) From 4cd35c11bb77e794cc67e96cfce317558be56ebb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 17:08:40 +0200 Subject: [PATCH 14/84] fix sorting paths 1 --- pipeline/run_collect_results.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pipeline/run_collect_results.py b/pipeline/run_collect_results.py index 1b75813..e019e2d 100644 --- a/pipeline/run_collect_results.py +++ b/pipeline/run_collect_results.py @@ -65,4 +65,5 @@ filter_option = sorting_folder_split[-1] sorting_output_folder = sortings_output_base_folder / dataset_name / session_name / filter_option sorting_output_folder.mkdir(exist_ok=True, parents=True) - shutil.copytree(sorting_folder, sorting_output_folder) + for sorting_subfolder in sorting_folder.iterdir(): + shutil.copytree(sorting_subfolder, sorting_output_folder) From 0a8390aa3e5c8797e7138eca36242eba831a3b06 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 17:14:58 +0200 Subject: [PATCH 15/84] fix sorting paths 2 --- pipeline/run_collect_results.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipeline/run_collect_results.py b/pipeline/run_collect_results.py index e019e2d..cac869b 100644 --- a/pipeline/run_collect_results.py +++ b/pipeline/run_collect_results.py @@ -66,4 +66,4 @@ sorting_output_folder = sortings_output_base_folder / dataset_name / session_name / filter_option sorting_output_folder.mkdir(exist_ok=True, parents=True) for sorting_subfolder in sorting_folder.iterdir(): - shutil.copytree(sorting_subfolder, sorting_output_folder) + shutil.copytree(sorting_subfolder, sorting_output_folder / sorting_subfolder.name) From 4eb9e4522a4bdcf8b294cbdc5eca786753d6d703 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 18:40:06 +0200 Subject: [PATCH 16/84] Debug --- pipeline/run_spike_sorting.py | 6 +----- pipeline/run_training.py | 20 ++++++++------------ 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 9370d30..9baee5b 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -103,13 +103,9 @@ base_folder = data_subfolders[0] for probe, sessions in session_dict.items(): - if DEBUG and len(sessions) > NUM_DEBUG_SESSIONS: - sessions_to_run = sessions[:NUM_DEBUG_SESSIONS] - else: - sessions_to_run = sessions print(f"Dataset {probe}") - for session in sessions_to_run: + for session in sessions: print(f"\nAnalyzing session {session}\n") dataset_name, session_name = session.split("/") diff --git a/pipeline/run_training.py b/pipeline/run_training.py index f6c776c..81e2d70 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -81,9 +81,8 @@ DEBUG = False json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] - + print(f"Found {len(json_files)} JSON config") if len(json_files) > 0: - print(f"Found {len(json_files)} JSON config") session_dict = {} # each json file contains a session to run for json_file in json_files: @@ -121,11 +120,8 @@ for probe, sessions in session_dict.items(): print(f"Dataset {probe}") - if DEBUG and len(sessions) > NUM_DEBUG_SESSIONS: - sessions_to_run = sessions[:NUM_DEBUG_SESSIONS] - else: - sessions_to_run = sessions - for session in sessions_to_run: + + for session in sessions: print(f"\nAnalyzing session {session}\n") if str(DATASET_BUCKET).startswith("s3"): raw_data_folder = scratch_folder / "raw" @@ -143,11 +139,6 @@ raw_data_folder = DATASET_BUCKET dst_folder = raw_data_folder / session - if "np1" in dst_folder.name: - probe = "NP1" - else: - probe = "NP2" - recording_folder = dst_folder recording = si.load_extractor(recording_folder) if DEBUG: @@ -155,6 +146,7 @@ start_frame=0, end_frame=int(DEBUG_DURATION * recording.sampling_frequency), ) + print(recording) results_dict = {} for filter_option in FILTER_OPTIONS: @@ -196,3 +188,7 @@ for json_file in json_files: shutil.copy(json_file, results_folder) + + print("Results folder content:") + for p in results_folder.iterdir(): + print(p.name) From 28ed0e7348d9933a2e454066bc7abdda33c1e98b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 18:41:16 +0200 Subject: [PATCH 17/84] Debug1 --- pipeline/run_inference.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 358c55d..7f2e689 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -135,12 +135,8 @@ data_model_folder = data_subfolders[0] / "models" for probe, sessions in session_dict.items(): - if DEBUG and len(sessions) > NUM_DEBUG_SESSIONS: - sessions_to_run = sessions[:NUM_DEBUG_SESSIONS] - else: - sessions_to_run = sessions print(f"Dataset {probe}") - for session in sessions_to_run: + for session in sessions: print(f"\nAnalyzing session {session}\n") dataset_name, session_name = session.split("/") From 4445737cb023b107dc8b4165cdc17fb0981b4966 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 18:48:39 +0200 Subject: [PATCH 18/84] Remove results dict --- pipeline/run_inference.py | 2 -- pipeline/run_training.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 7f2e689..a7e809d 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -165,10 +165,8 @@ end_frame=int(DEBUG_DURATION * recording.sampling_frequency), ) - results_dict = {} for filter_option in FILTER_OPTIONS: print(f"\tFilter option: {filter_option}") - results_dict[filter_option] = {} # train DI models print(f"\t\tTraning DI") training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 81e2d70..6bf5a4e 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -148,10 +148,8 @@ ) print(recording) - results_dict = {} for filter_option in FILTER_OPTIONS: print(f"\tFilter option: {filter_option}") - results_dict[filter_option] = {} # train DI models print(f"\t\tTraning DI") training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) From 26636f4573abc0f1ce57e1c3e1a2397c5f1ba939 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 19 Jul 2023 10:39:14 +0200 Subject: [PATCH 19/84] Try to resolve base_path --- pipeline/run_collect_results.py | 2 +- pipeline/run_inference.py | 2 +- pipeline/run_spike_sorting.py | 2 +- pipeline/run_training.py | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pipeline/run_collect_results.py b/pipeline/run_collect_results.py index cac869b..7843829 100644 --- a/pipeline/run_collect_results.py +++ b/pipeline/run_collect_results.py @@ -12,7 +12,7 @@ import pandas as pd -base_path = Path("../../..") +base_path = Path("../../..").resolve() data_folder = base_path / "data" diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index a7e809d..3e9e67c 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -31,7 +31,7 @@ os.environ["OPENBLAS_NUM_THREADS"] = "1" -base_path = Path("../../..") +base_path = Path("../../..").resolve() ##### DEFINE DATASETS AND FOLDERS ####### from sessions import all_sessions diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 9baee5b..a8fca63 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -25,7 +25,7 @@ import spikeinterface.qualitymetrics as sqm -base_path = Path("../../..") +base_path = Path("../../..").resolve() ##### DEFINE DATASETS AND FOLDERS ####### from sessions import all_sessions diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 6bf5a4e..4dcb2ae 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -28,7 +28,7 @@ import tensorflow as tf -base_path = Path("../../..") +base_path = Path("../../..").resolve() ##### DEFINE DATASETS AND FOLDERS ####### from sessions import all_sessions @@ -185,6 +185,7 @@ print(f"\t\tElapsed time TRAINING: {elapsed_time_training}s") for json_file in json_files: + print(f"Copying JSON file: {json_file.name} to {results_folder}") shutil.copy(json_file, results_folder) print("Results folder content:") From 60b7f2115b6e7da524da65fe510138f225ca1e9c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 19 Jul 2023 11:28:42 +0200 Subject: [PATCH 20/84] Add session to print --- pipeline/run_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 4dcb2ae..56ce772 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -182,7 +182,7 @@ ) t_stop_training = time.perf_counter() elapsed_time_training = np.round(t_stop_training - t_start_training, 2) - print(f"\t\tElapsed time TRAINING: {elapsed_time_training}s") + print(f"\t\tElapsed time TRAINING {session}-{filter_option}: {elapsed_time_training}s") for json_file in json_files: print(f"Copying JSON file: {json_file.name} to {results_folder}") From cdde356e7cec15bcc9b349d609133f3978f867e3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 19 Jul 2023 13:11:38 +0200 Subject: [PATCH 21/84] Change output folders and add super training --- pipeline/run_inference.py | 20 ++-- pipeline/run_spike_sorting.py | 19 ++-- pipeline/run_super_training.py | 181 +++++++++++++++++++++++++++++++++ pipeline/run_training.py | 4 +- 4 files changed, 201 insertions(+), 23 deletions(-) create mode 100644 pipeline/run_super_training.py diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 3e9e67c..5310363 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -126,13 +126,13 @@ print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") #### START #### + probe_models_folders = [p for p in data_folder.iterdir() if "models_" in p.name and p.is_dir()] - if (data_folder / "models").is_dir(): - data_model_folder = data_folder / "models" + if len(probe_models_folders) > 0: + data_model_folder = data_folder else: - data_subfolders = [p for p in data_folder.iterdir() if (p / "models").is_dir()] - assert len(data_subfolders) == 1 - data_model_folder = data_subfolders[0] / "models" + data_subfolders = [p for p in data_folder.iterdir() if p.is_dir()] + data_model_folder = data_subfolders[0] for probe, sessions in session_dict.items(): print(f"Dataset {probe}") @@ -183,12 +183,10 @@ recording_zscore = spre.zscore(recording_processed) # train model - model_folder = data_model_folder / session / filter_option - model_path = [ - p for p in model_folder.iterdir() if p.name.endswith("model.h5") and filter_option in p.name - ][0] + model_folder = data_model_folder / f"model_{dataset_name}_{session_name}_{filter_option}" + model_path = [p for p in model_folder.iterdir() if p.name.endswith("model.h5")][0] # full inference - output_folder = results_folder / "deepinterpolated" / session / filter_option + output_folder = results_folder / f"deepinterpolatedf_{dataset_name}_{session_name}_{filter_option}" if OVERWRITE and output_folder.is_dir(): shutil.rmtree(output_folder) @@ -222,7 +220,7 @@ recording_di = spre.scale(recording_di, gain=inverse_gains, offset=inverse_offset, dtype="float") # save processed json - processed_folder = results_folder / "processed" / session / filter_option + processed_folder = results_folder / f"processed_{dataset_name}_{session_name}_{filter_option}" processed_folder.mkdir(exist_ok=True, parents=True) recording_processed.dump_to_json(processed_folder / "processed.json", relative_to=results_folder) recording_di.dump_to_json(processed_folder / f"deepinterpolated.json", relative_to=results_folder) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index a8fca63..7a5f45f 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -91,16 +91,13 @@ si.set_global_job_kwargs(**job_kwargs) #### START #### - if (data_folder / "processed").is_dir(): - processed_folder = data_folder / "processed" - deepinterpolated_folder = data_folder / "deepinterpolated" - base_folder = data_folder + probe_processed_folders = [p for p in data_folder.iterdir() if "processed_" in p.name and p.is_dir()] + + if len(probe_processed_folders) > 0: + processed_folder = data_folder else: - data_subfolders = [p for p in data_folder.iterdir() if (p / "processed").is_dir()] - assert len(data_subfolders) == 1 - processed_folder = data_subfolders[0] / "processed" - deepinterpolated_folder = data_subfolders[0] / "deepinterpolated" - base_folder = data_subfolders[0] + data_subfolders = [p for p in data_folder.iterdir() if p.is_dir()] + processed_folder = data_subfolders[0] for probe, sessions in session_dict.items(): @@ -139,10 +136,10 @@ # load recordings # save processed json - processed_json_folder = processed_folder / session / filter_option + processed_json_folder = processed_folder / f"processed_{dataset_name}_{session_name}_{filter_option}" recording = si.load_extractor(processed_json_folder / "processed.json", base_folder=data_folder) recording_di = si.load_extractor( - processed_json_folder / "deepinterpolated.json", base_folder=base_folder + processed_json_folder / "deepinterpolated.json", base_folder=data_folder ) # run spike sorting diff --git a/pipeline/run_super_training.py b/pipeline/run_super_training.py new file mode 100644 index 0000000..45f2474 --- /dev/null +++ b/pipeline/run_super_training.py @@ -0,0 +1,181 @@ +import warnings + +warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +#### IMPORTS ####### +import os +import sys +import shutil +import json +import numpy as np +from pathlib import Path +import pandas as pd +import time + + +# SpikeInterface +import spikeinterface as si +import spikeinterface.extractors as se +import spikeinterface.preprocessing as spre +import spikeinterface.sorters as ss +import spikeinterface.postprocessing as spost +import spikeinterface.comparison as sc +import spikeinterface.qualitymetrics as sqm + +# Tensorflow +import tensorflow as tf + + +base_path = Path("../../..").resolve() + +##### DEFINE DATASETS AND FOLDERS ####### +from sessions import all_sessions + +n_jobs = 16 + +job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") + +data_folder = base_path / "data" +scratch_folder = base_path / "scratch" +results_folder = base_path / "results" + + +# DATASET_BUCKET = "s3://aind-benchmark-data/ephys-compression/aind-np2/" +DATASET_BUCKET = data_folder / "ephys-compression-benchmark" + +DEBUG = False +NUM_DEBUG_SESSIONS = 2 +DEBUG_DURATION = 20 + +##### DEFINE PARAMS ##### +OVERWRITE = False +USE_GPU = True +FULL_INFERENCE = True + +# Define training and testing constants (@Jad you can gradually increase this) + + +FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" + +# DI params +pre_frame = 30 +post_frame = 30 +pre_post_omission = 1 +desired_shape = (192, 2) + +di_kwargs = dict( + pre_frame=pre_frame, + post_frame=post_frame, + pre_post_omission=pre_post_omission, + desired_shape=desired_shape, +) + + +if __name__ == "__main__": + if len(sys.argv) == 2: + if sys.argv[1] == "true": + DEBUG = True + else: + DEBUG = False + + session_dict = all_sessions + + print(session_dict) + + if DEBUG: + TRAINING_START_S = 0 + TRAINING_END_S = 0.2 + TESTING_START_S = 10 + TESTING_END_S = 10.05 + OVERWRITE = True + else: + TRAINING_START_S = 0 + TRAINING_END_S = 20 + TESTING_START_S = 70 + TESTING_END_S = 70.1 + OVERWRITE = False + + si.set_global_job_kwargs(**job_kwargs) + + print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") + + + model_path = None + for filter_option in FILTER_OPTIONS: + print(f"\tFilter option: {filter_option}") + + for probe, sessions in session_dict.items(): + print(f"Dataset {probe}") + if DEBUG: + sessions_to_use = sessions[:NUM_DEBUG_SESSIONS] + else: + sessions_to_use = sessions + print(f"Running super training with {sessions_to_use} sessions") + for session in sessions_to_use: + print(f"\nAnalyzing session {session}\n") + if str(DATASET_BUCKET).startswith("s3"): + raw_data_folder = scratch_folder / "raw" + raw_data_folder.mkdir(exist_ok=True) + + # download dataset + dst_folder.mkdir(exist_ok=True) + + src_folder = f"{DATASET_BUCKET}{session}" + + cmd = f"aws s3 sync {src_folder} {dst_folder}" + # aws command to download + os.system(cmd) + else: + raw_data_folder = DATASET_BUCKET + dst_folder = raw_data_folder / session + + recording_folder = dst_folder + recording = si.load_extractor(recording_folder) + if DEBUG: + recording = recording.frame_slice( + start_frame=0, + end_frame=int(DEBUG_DURATION * recording.sampling_frequency), + ) + print(recording) + + # train DI models + print(f"\t\tTraning DI") + training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) + testing_time = np.round(TESTING_END_S - TESTING_START_S, 3) + model_name = f"{filter_option}_t{training_time}s_v{testing_time}s" + + # apply filter and zscore + if filter_option == "hp": + recording_processed = spre.highpass_filter(recording) + elif filter_option == "bp": + recording_processed = spre.bandpass_filter(recording) + else: + recording_processed = recording + recording_zscore = spre.zscore(recording_processed) + + # train model + model_folder = results_folder / f"model_{filter_option}" + model_folder.parent.mkdir(parents=True, exist_ok=True) + + if model_path is None: + print(f"\t\t\tFirst training, no model to load") + else: + print(f"\t\t\Refining training with new session") + # Use SI function + t_start_training = time.perf_counter() + model_path = spre.train_deepinterpolation( + recording_zscore, + model_folder=model_folder, + model_name=model_name, + existing_model_path=model_path, + train_start_s=TRAINING_START_S, + train_end_s=TRAINING_END_S, + test_start_s=TESTING_START_S, + test_end_s=TESTING_END_S, + **di_kwargs, + ) + t_stop_training = time.perf_counter() + elapsed_time_training = np.round(t_stop_training - t_start_training, 2) + print(f"\t\tElapsed time TRAINING {session}-{filter_option}: {elapsed_time_training}s") diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 56ce772..7eb029e 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -123,6 +123,8 @@ for session in sessions: print(f"\nAnalyzing session {session}\n") + dataset_name, session_name = session.split("/") + if str(DATASET_BUCKET).startswith("s3"): raw_data_folder = scratch_folder / "raw" raw_data_folder.mkdir(exist_ok=True) @@ -166,7 +168,7 @@ recording_zscore = spre.zscore(recording_processed) # train model - model_folder = results_folder / "models" / session / filter_option + model_folder = results_folder / f"model_{dataset_name}_{session_name}_{filter_option}" model_folder.parent.mkdir(parents=True, exist_ok=True) # Use SI function t_start_training = time.perf_counter() From 576ee38c2d17ba98c34d237c448601d29f18bcab Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 19 Jul 2023 13:26:30 +0200 Subject: [PATCH 22/84] Improve prints --- pipeline/run_super_training.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pipeline/run_super_training.py b/pipeline/run_super_training.py index 45f2474..2ac1bf9 100644 --- a/pipeline/run_super_training.py +++ b/pipeline/run_super_training.py @@ -46,7 +46,7 @@ DATASET_BUCKET = data_folder / "ephys-compression-benchmark" DEBUG = False -NUM_DEBUG_SESSIONS = 2 +NUM_DEBUG_SESSIONS = 4 DEBUG_DURATION = 20 ##### DEFINE PARAMS ##### @@ -104,17 +104,17 @@ model_path = None for filter_option in FILTER_OPTIONS: - print(f"\tFilter option: {filter_option}") + print(f"Filter option: {filter_option}") for probe, sessions in session_dict.items(): - print(f"Dataset {probe}") + print(f"\tDataset {probe}") if DEBUG: sessions_to_use = sessions[:NUM_DEBUG_SESSIONS] else: sessions_to_use = sessions - print(f"Running super training with {sessions_to_use} sessions") + print(f"\tRunning super training with {sessions_to_use} sessions") for session in sessions_to_use: - print(f"\nAnalyzing session {session}\n") + print(f"\t\tSession {session}\n") if str(DATASET_BUCKET).startswith("s3"): raw_data_folder = scratch_folder / "raw" raw_data_folder.mkdir(exist_ok=True) @@ -141,7 +141,6 @@ print(recording) # train DI models - print(f"\t\tTraning DI") training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) testing_time = np.round(TESTING_END_S - TESTING_START_S, 3) model_name = f"{filter_option}_t{training_time}s_v{testing_time}s" @@ -162,7 +161,7 @@ if model_path is None: print(f"\t\t\tFirst training, no model to load") else: - print(f"\t\t\Refining training with new session") + print(f"\t\t\tRefining training with new session") # Use SI function t_start_training = time.perf_counter() model_path = spre.train_deepinterpolation( From e1e62e1af73dcab3fe6884c19106771c68399c2a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 19 Jul 2023 14:01:24 +0200 Subject: [PATCH 23/84] finalize super training --- pipeline/run_spike_sorting.py | 1 - pipeline/run_super_training.py | 52 ++++++++++++++++++++++++++-------- pipeline/run_training.py | 2 +- 3 files changed, 41 insertions(+), 14 deletions(-) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 7a5f45f..9da2db8 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -48,7 +48,6 @@ OVERWRITE = False - # Define training and testing constants FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" diff --git a/pipeline/run_super_training.py b/pipeline/run_super_training.py index 2ac1bf9..a5801b0 100644 --- a/pipeline/run_super_training.py +++ b/pipeline/run_super_training.py @@ -14,6 +14,7 @@ import pandas as pd import time +import matplotlib.pyplot as plt # SpikeInterface import spikeinterface as si @@ -92,7 +93,7 @@ OVERWRITE = True else: TRAINING_START_S = 0 - TRAINING_END_S = 20 + TRAINING_END_S = 10 TESTING_START_S = 70 TESTING_END_S = 70.1 OVERWRITE = False @@ -101,11 +102,10 @@ print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") - - model_path = None + pretrained_model_path = None for filter_option in FILTER_OPTIONS: print(f"Filter option: {filter_option}") - + for probe, sessions in session_dict.items(): print(f"\tDataset {probe}") if DEBUG: @@ -113,8 +113,8 @@ else: sessions_to_use = sessions print(f"\tRunning super training with {sessions_to_use} sessions") - for session in sessions_to_use: - print(f"\t\tSession {session}\n") + for i, session in enumerate(sessions_to_use): + print(f"\t\tSession {session} - Iteration {i}\n") if str(DATASET_BUCKET).startswith("s3"): raw_data_folder = scratch_folder / "raw" raw_data_folder.mkdir(exist_ok=True) @@ -155,26 +155,54 @@ recording_zscore = spre.zscore(recording_processed) # train model - model_folder = results_folder / f"model_{filter_option}" + model_folder = results_folder / f"model_{filter_option}-iter{i}" model_folder.parent.mkdir(parents=True, exist_ok=True) - if model_path is None: - print(f"\t\t\tFirst training, no model to load") - else: - print(f"\t\t\tRefining training with new session") # Use SI function t_start_training = time.perf_counter() model_path = spre.train_deepinterpolation( recording_zscore, model_folder=model_folder, model_name=model_name, - existing_model_path=model_path, + existing_model_path=pretrained_model_path, train_start_s=TRAINING_START_S, train_end_s=TRAINING_END_S, test_start_s=TESTING_START_S, test_end_s=TESTING_END_S, **di_kwargs, ) + pretrained_model_path = model_path t_stop_training = time.perf_counter() elapsed_time_training = np.round(t_stop_training - t_start_training, 2) print(f"\t\tElapsed time TRAINING {session}-{filter_option}: {elapsed_time_training}s") + + # aggregate results + print(f"Aggregating results for {filter_option}") + final_model_folder = results_folder / f"model_{filter_option}" + shutil.copytree(model_folder, final_model_folder) + final_model_name = [p.name for p in final_model_folder.iterdir() if "_model" in p.name][0] + final_model_stem = final_model_name.split("_model")[0] + + # concatenate loss and val loss + loss_accuracies = np.array([]) + val_accuracies = np.array([]) + + for i in range(len(sessions_to_use)): + model_folder = results_folder / f"model_{filter_option}-iter{i}" + loss_file = [p for p in model_folder.iterdir() if "_loss.npy" in p.name and "val" not in p.name][0] + val_loss_file = [p for p in model_folder.iterdir() if "val_loss.npy" in p.name][0] + loss = np.load(loss_file) + val_loss = np.load(val_loss_file) + loss_accuracies = np.concatenate((loss_accuracies, loss)) + val_accuracies = np.concatenate((val_accuracies, val_loss)) + np.save(final_model_folder / f"{final_model_stem}_loss.npy", loss_accuracies) + np.save(final_model_folder / f"{final_model_stem}_val_loss.npy", val_accuracies) + + # plot losses + fig, ax = plt.subplots() + ax.plot(loss_accuracies, color="C0", label="loss") + ax.plot(val_accuracies, color="C0", label="loss") + ax.set_xlabel("number of epochs") + ax.set_ylabel("training loss") + ax.legend() + fig.savefig(final_model_folder / f"{final_model_stem}_losses.png", dpi=300) diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 7eb029e..0ffc40a 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -192,4 +192,4 @@ print("Results folder content:") for p in results_folder.iterdir(): - print(p.name) + print(p.name) From 20016565472291eb486bae8fd71e7c4018e3debb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 19 Jul 2023 15:03:41 +0200 Subject: [PATCH 24/84] Fixes --- pipeline/run_super_training.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pipeline/run_super_training.py b/pipeline/run_super_training.py index a5801b0..45aa2ed 100644 --- a/pipeline/run_super_training.py +++ b/pipeline/run_super_training.py @@ -112,7 +112,7 @@ sessions_to_use = sessions[:NUM_DEBUG_SESSIONS] else: sessions_to_use = sessions - print(f"\tRunning super training with {sessions_to_use} sessions") + print(f"\tRunning super training with {len(sessions_to_use)} sessions") for i, session in enumerate(sessions_to_use): print(f"\t\tSession {session} - Iteration {i}\n") if str(DATASET_BUCKET).startswith("s3"): @@ -138,7 +138,6 @@ start_frame=0, end_frame=int(DEBUG_DURATION * recording.sampling_frequency), ) - print(recording) # train DI models training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) @@ -155,7 +154,7 @@ recording_zscore = spre.zscore(recording_processed) # train model - model_folder = results_folder / f"model_{filter_option}-iter{i}" + model_folder = results_folder / f"models_{filter_option}" / f"iter{i}" model_folder.parent.mkdir(parents=True, exist_ok=True) # Use SI function @@ -188,7 +187,7 @@ val_accuracies = np.array([]) for i in range(len(sessions_to_use)): - model_folder = results_folder / f"model_{filter_option}-iter{i}" + model_folder = results_folder / f"models_{filter_option}" / f"iter{i}" loss_file = [p for p in model_folder.iterdir() if "_loss.npy" in p.name and "val" not in p.name][0] val_loss_file = [p for p in model_folder.iterdir() if "val_loss.npy" in p.name][0] loss = np.load(loss_file) @@ -201,7 +200,7 @@ # plot losses fig, ax = plt.subplots() ax.plot(loss_accuracies, color="C0", label="loss") - ax.plot(val_accuracies, color="C0", label="loss") + ax.plot(val_accuracies, color="C1", label="loss") ax.set_xlabel("number of epochs") ax.set_ylabel("training loss") ax.legend() From ded56d121504a12e7ec910c0145a852cf083fb82 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 09:18:57 +0200 Subject: [PATCH 25/84] Change relative paths --- pipeline/run_collect_results.py | 3 +-- pipeline/run_inference.py | 2 +- pipeline/run_spike_sorting.py | 2 +- pipeline/run_super_training.py | 2 +- pipeline/run_training.py | 4 ++-- 5 files changed, 6 insertions(+), 7 deletions(-) diff --git a/pipeline/run_collect_results.py b/pipeline/run_collect_results.py index 7843829..9b47e66 100644 --- a/pipeline/run_collect_results.py +++ b/pipeline/run_collect_results.py @@ -12,8 +12,7 @@ import pandas as pd -base_path = Path("../../..").resolve() - +base_path = Path("..").resolve() data_folder = base_path / "data" scratch_folder = base_path / "scratch" diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 5310363..2b0143a 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -31,7 +31,7 @@ os.environ["OPENBLAS_NUM_THREADS"] = "1" -base_path = Path("../../..").resolve() +base_path = Path("..").resolve() ##### DEFINE DATASETS AND FOLDERS ####### from sessions import all_sessions diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 9da2db8..61685fe 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -25,7 +25,7 @@ import spikeinterface.qualitymetrics as sqm -base_path = Path("../../..").resolve() +base_path = Path("..").resolve() ##### DEFINE DATASETS AND FOLDERS ####### from sessions import all_sessions diff --git a/pipeline/run_super_training.py b/pipeline/run_super_training.py index 45aa2ed..7e36047 100644 --- a/pipeline/run_super_training.py +++ b/pipeline/run_super_training.py @@ -29,7 +29,7 @@ import tensorflow as tf -base_path = Path("../../..").resolve() +base_path = Path("..").resolve() ##### DEFINE DATASETS AND FOLDERS ####### from sessions import all_sessions diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 0ffc40a..e570ece 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -27,8 +27,8 @@ # Tensorflow import tensorflow as tf - -base_path = Path("../../..").resolve() +# runs from "codes" +base_path = Path("..").resolve() ##### DEFINE DATASETS AND FOLDERS ####### from sessions import all_sessions From 4abf9847dc862c3cf810e67accd0b9a43bbd678b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 10:26:13 +0200 Subject: [PATCH 26/84] Fix model_folder names --- pipeline/run_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 2b0143a..c1a7869 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -126,7 +126,7 @@ print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") #### START #### - probe_models_folders = [p for p in data_folder.iterdir() if "models_" in p.name and p.is_dir()] + probe_models_folders = [p for p in data_folder.iterdir() if "model_" in p.name and p.is_dir()] if len(probe_models_folders) > 0: data_model_folder = data_folder From ff149d57b02478f9f3a9b0dec4bb4850efc21dc9 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 10:43:03 +0200 Subject: [PATCH 27/84] Fix inference sub folders --- pipeline/run_inference.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index c1a7869..add6b52 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -131,8 +131,11 @@ if len(probe_models_folders) > 0: data_model_folder = data_folder else: - data_subfolders = [p for p in data_folder.iterdir() if p.is_dir()] - data_model_folder = data_subfolders[0] + data_model_subfolders = [] + for p in data_folder.iterdir(): + if p.is_dir() and len([pp for pp in p.iterdir() if "model_" in pp.name and pp.is_dir()]) > 0: + data_model_subfolders.append(p) + data_model_folder = data_model_subfolders[0] for probe, sessions in session_dict.items(): print(f"Dataset {probe}") From 4c28a8f6119e9466e2df65eceaaf090f97597960 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 10:49:52 +0200 Subject: [PATCH 28/84] Propagate models to results --- pipeline/run_collect_results.py | 16 +++++++++++++++- pipeline/run_inference.py | 12 +++--------- pipeline/run_spike_sorting.py | 2 +- pipeline/run_super_training.py | 8 ++------ pipeline/run_training.py | 8 ++------ 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/pipeline/run_collect_results.py b/pipeline/run_collect_results.py index 9b47e66..e124673 100644 --- a/pipeline/run_collect_results.py +++ b/pipeline/run_collect_results.py @@ -53,7 +53,7 @@ df_units.to_csv(results_folder / "units.csv", index=False) # copy sortings to results folder - sortings_folders = [p for p in data_base_folder.iterdir() if "sortings_" in p.name and p.is_dir()] + sortings_folders = [p for p in data_base_folder.iterdir() if "sorting_" in p.name and p.is_dir()] sortings_output_base_folder = results_folder / "sortings" sortings_output_base_folder.mkdir(exist_ok=True) @@ -66,3 +66,17 @@ sorting_output_folder.mkdir(exist_ok=True, parents=True) for sorting_subfolder in sorting_folder.iterdir(): shutil.copytree(sorting_subfolder, sorting_output_folder / sorting_subfolder.name) + + # copy models to results folder + models_folders = [p for p in data_base_folder.iterdir() if "model_" in p.name and p.is_dir()] + models_output_base_folder = results_folder / "models" + models_output_base_folder.mkdir(exist_ok=True) + + for model_folder in models_folders: + model_folder_split = model_folder.name.split("_") + dataset_name = model_folder_split[1] + session_name = "_".join(model_folder_split[2:-1]) + filter_option = model_folder_split[-1] + model_output_folder = models_output_base_folder / dataset_name / session_name / filter_option + model_output_folder.parent.mkdir(exist_ok=True, parents=True) + shutil.copytree(model_folder, model_output_folder) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index add6b52..7d319d3 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -73,10 +73,7 @@ inference_memory_gpu = 2000 # MB di_kwargs = dict( - pre_frame=pre_frame, - post_frame=post_frame, - pre_post_omission=pre_post_omission, - desired_shape=desired_shape, + pre_frame=pre_frame, post_frame=post_frame, pre_post_omission=pre_post_omission, desired_shape=desired_shape, ) if __name__ == "__main__": @@ -164,8 +161,7 @@ recording = si.load_extractor(recording_folder) if DEBUG: recording = recording.frame_slice( - start_frame=0, - end_frame=int(DEBUG_DURATION * recording.sampling_frequency), + start_frame=0, end_frame=int(DEBUG_DURATION * recording.sampling_frequency), ) for filter_option in FILTER_OPTIONS: @@ -207,9 +203,7 @@ use_gpu=USE_GPU, ) recording_di = recording_di.save( - folder=output_folder, - n_jobs=inference_n_jobs, - chunk_duration=inference_chunk_duration, + folder=output_folder, n_jobs=inference_n_jobs, chunk_duration=inference_chunk_duration, ) t_stop_inference = time.perf_counter() elapsed_time_inference = np.round(t_stop_inference - t_start_inference, 2) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 61685fe..a6637d8 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -142,7 +142,7 @@ ) # run spike sorting - sorting_output_folder = results_folder / f"sortings_{dataset_name}_{session_name}_{filter_option}" + sorting_output_folder = results_folder / f"sorting_{dataset_name}_{session_name}_{filter_option}" sorting_output_folder.mkdir(parents=True, exist_ok=True) if (sorting_output_folder / "sorting").is_dir() and not OVERWRITE: diff --git a/pipeline/run_super_training.py b/pipeline/run_super_training.py index 7e36047..d6c3755 100644 --- a/pipeline/run_super_training.py +++ b/pipeline/run_super_training.py @@ -67,10 +67,7 @@ desired_shape = (192, 2) di_kwargs = dict( - pre_frame=pre_frame, - post_frame=post_frame, - pre_post_omission=pre_post_omission, - desired_shape=desired_shape, + pre_frame=pre_frame, post_frame=post_frame, pre_post_omission=pre_post_omission, desired_shape=desired_shape, ) @@ -135,8 +132,7 @@ recording = si.load_extractor(recording_folder) if DEBUG: recording = recording.frame_slice( - start_frame=0, - end_frame=int(DEBUG_DURATION * recording.sampling_frequency), + start_frame=0, end_frame=int(DEBUG_DURATION * recording.sampling_frequency), ) # train DI models diff --git a/pipeline/run_training.py b/pipeline/run_training.py index e570ece..3e5508c 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -66,10 +66,7 @@ desired_shape = (192, 2) di_kwargs = dict( - pre_frame=pre_frame, - post_frame=post_frame, - pre_post_omission=pre_post_omission, - desired_shape=desired_shape, + pre_frame=pre_frame, post_frame=post_frame, pre_post_omission=pre_post_omission, desired_shape=desired_shape, ) @@ -145,8 +142,7 @@ recording = si.load_extractor(recording_folder) if DEBUG: recording = recording.frame_slice( - start_frame=0, - end_frame=int(DEBUG_DURATION * recording.sampling_frequency), + start_frame=0, end_frame=int(DEBUG_DURATION * recording.sampling_frequency), ) print(recording) From 64f42717006573e3531df4f0b2e8c7704da358f2 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 11:25:48 +0200 Subject: [PATCH 29/84] Set training verbose to false --- pipeline/run_training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 3e5508c..5bf60d5 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -176,6 +176,7 @@ train_end_s=TRAINING_END_S, test_start_s=TESTING_START_S, test_end_s=TESTING_END_S, + verbose=False, **di_kwargs, ) t_stop_training = time.perf_counter() From e3e2232628581398d9bb697a9cc4dcd88f733de8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 11:31:43 +0200 Subject: [PATCH 30/84] Scale number of GPUs --- pipeline/run_training.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 5bf60d5..a9f94f9 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -113,7 +113,9 @@ si.set_global_job_kwargs(**job_kwargs) - print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") + available_gpus = tf.config.list_physical_devices("GPU") + print(f"Tensorflow GPU status: {available_gpus}") + nb_gpus = len(available_gpus) for probe, sessions in session_dict.items(): print(f"Dataset {probe}") @@ -177,6 +179,7 @@ test_start_s=TESTING_START_S, test_end_s=TESTING_END_S, verbose=False, + nb_gpus=nb_gpus, **di_kwargs, ) t_stop_training = time.perf_counter() From 4ee7afa4f3ab41c4c3701196bcb6c21e19bb59dc Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 12:39:41 +0200 Subject: [PATCH 31/84] Adjust paths --- pipeline/run_inference.py | 2 +- pipeline/run_spike_sorting.py | 2 +- pipeline/run_super_training.py | 2 +- pipeline/run_training.py | 3 +++ 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 7d319d3..fbbe543 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -185,7 +185,7 @@ model_folder = data_model_folder / f"model_{dataset_name}_{session_name}_{filter_option}" model_path = [p for p in model_folder.iterdir() if p.name.endswith("model.h5")][0] # full inference - output_folder = results_folder / f"deepinterpolatedf_{dataset_name}_{session_name}_{filter_option}" + output_folder = results_folder / f"deepinterpolated_{dataset_name}_{session_name}_{filter_option}" if OVERWRITE and output_folder.is_dir(): shutil.rmtree(output_folder) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index a6637d8..4629d79 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -138,7 +138,7 @@ processed_json_folder = processed_folder / f"processed_{dataset_name}_{session_name}_{filter_option}" recording = si.load_extractor(processed_json_folder / "processed.json", base_folder=data_folder) recording_di = si.load_extractor( - processed_json_folder / "deepinterpolated.json", base_folder=data_folder + processed_json_folder / "deepinterpolated.json", base_folder=processed_folder ) # run spike sorting diff --git a/pipeline/run_super_training.py b/pipeline/run_super_training.py index d6c3755..9f3e10b 100644 --- a/pipeline/run_super_training.py +++ b/pipeline/run_super_training.py @@ -196,7 +196,7 @@ # plot losses fig, ax = plt.subplots() ax.plot(loss_accuracies, color="C0", label="loss") - ax.plot(val_accuracies, color="C1", label="loss") + ax.plot(val_accuracies, color="C1", label="val_loss") ax.set_xlabel("number of epochs") ax.set_ylabel("training loss") ax.legend() diff --git a/pipeline/run_training.py b/pipeline/run_training.py index a9f94f9..7de0435 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -116,6 +116,9 @@ available_gpus = tf.config.list_physical_devices("GPU") print(f"Tensorflow GPU status: {available_gpus}") nb_gpus = len(available_gpus) + if len(nb_gpus) > 1: + print("Use 1 GPU only") + nb_gpus = 1 for probe, sessions in session_dict.items(): print(f"Dataset {probe}") From 203c3f78ec4b7ae94a8905a93959517b4298a0ea Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 13:05:53 +0200 Subject: [PATCH 32/84] Fix collect capsule --- pipeline/run_collect_results.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pipeline/run_collect_results.py b/pipeline/run_collect_results.py index e124673..c6bb0c2 100644 --- a/pipeline/run_collect_results.py +++ b/pipeline/run_collect_results.py @@ -20,12 +20,16 @@ if __name__ == "__main__": + # list all data entries + print("Data folder content:") + for p in data_folder.iterdir(): + print(f"\t{p.name}") # concatenate dataframes df_session = None df_units = None - probe_sortings_folders = [p for p in data_folder.iterdir() if "sortings_" in p.name and p.is_dir()] + probe_sortings_folders = [p for p in data_folder.iterdir() if "sorting_" in p.name and p.is_dir()] if len(probe_sortings_folders) > 0: data_base_folder = data_folder @@ -33,8 +37,8 @@ data_subfolders = [p for p in data_folder.iterdir() if p.is_dir()] data_base_folder = data_subfolders[0] - session_csvs = [p for p in data_base_folder.iterdir() if "session" in p.name and p.suffix == ".csv"] - unit_csvs = [p for p in data_base_folder.iterdir() if "unit" in p.name and p.suffix == ".csv"] + session_csvs = [p for p in data_base_folder.iterdir() if p.name.endswith("sessions.csv")] + unit_csvs = [p for p in data_base_folder.iterdir() if p.name.endswith("units.csv")] for session_csv in session_csvs: if df_session is None: From 649a237e9a70248fe6c8a3bbe5f1038c55b69e9f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 13:15:14 +0200 Subject: [PATCH 33/84] Fix collect capsule 1 --- pipeline/run_collect_results.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/pipeline/run_collect_results.py b/pipeline/run_collect_results.py index c6bb0c2..e9b2b6d 100644 --- a/pipeline/run_collect_results.py +++ b/pipeline/run_collect_results.py @@ -32,13 +32,23 @@ probe_sortings_folders = [p for p in data_folder.iterdir() if "sorting_" in p.name and p.is_dir()] if len(probe_sortings_folders) > 0: - data_base_folder = data_folder + data_models_folder = data_folder + data_sortings_folder = data_folder else: - data_subfolders = [p for p in data_folder.iterdir() if p.is_dir()] - data_base_folder = data_subfolders[0] + data_model_subfolders = [] + for p in data_folder.iterdir(): + if p.is_dir() and len([pp for pp in p.iterdir() if "model_" in pp.name and pp.is_dir()]) > 0: + data_model_subfolders.append(p) + data_models_folder = data_model_subfolders[0] - session_csvs = [p for p in data_base_folder.iterdir() if p.name.endswith("sessions.csv")] - unit_csvs = [p for p in data_base_folder.iterdir() if p.name.endswith("units.csv")] + data_sorting_subfolders = [] + for p in data_folder.iterdir(): + if p.is_dir() and len([pp for pp in p.iterdir() if "sorting_" in pp.name and pp.is_dir()]) > 0: + data_sorting_subfolders.append(p) + data_sortings_folder = data_sorting_subfolders[0] + + session_csvs = [p for p in data_sortings_folder.iterdir() if p.name.endswith("sessions.csv")] + unit_csvs = [p for p in data_sortings_folder.iterdir() if p.name.endswith("units.csv")] for session_csv in session_csvs: if df_session is None: @@ -57,7 +67,7 @@ df_units.to_csv(results_folder / "units.csv", index=False) # copy sortings to results folder - sortings_folders = [p for p in data_base_folder.iterdir() if "sorting_" in p.name and p.is_dir()] + sortings_folders = [p for p in data_sortings_folder.iterdir() if "sorting_" in p.name and p.is_dir()] sortings_output_base_folder = results_folder / "sortings" sortings_output_base_folder.mkdir(exist_ok=True) @@ -72,7 +82,7 @@ shutil.copytree(sorting_subfolder, sorting_output_folder / sorting_subfolder.name) # copy models to results folder - models_folders = [p for p in data_base_folder.iterdir() if "model_" in p.name and p.is_dir()] + models_folders = [p for p in data_models_folder.iterdir() if "model_" in p.name and p.is_dir()] models_output_base_folder = results_folder / "models" models_output_base_folder.mkdir(exist_ok=True) From f71c1dc81a40393c34f7b6c0ec2c671e97e3ead6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 13:18:55 +0200 Subject: [PATCH 34/84] Fix collect capsule 2 --- pipeline/run_collect_results.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipeline/run_collect_results.py b/pipeline/run_collect_results.py index e9b2b6d..4195e5f 100644 --- a/pipeline/run_collect_results.py +++ b/pipeline/run_collect_results.py @@ -29,7 +29,7 @@ df_session = None df_units = None - probe_sortings_folders = [p for p in data_folder.iterdir() if "sorting_" in p.name and p.is_dir()] + probe_sortings_folders = [p for p in data_folder.iterdir() if p.name.startswith("sorting_") and p.is_dir()] if len(probe_sortings_folders) > 0: data_models_folder = data_folder From 2ca3796c89cbce84b4d65ff999de8294af465960 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 13:37:06 +0200 Subject: [PATCH 35/84] Fix nb_gpus --- pipeline/run_training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 7de0435..b901702 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -116,8 +116,8 @@ available_gpus = tf.config.list_physical_devices("GPU") print(f"Tensorflow GPU status: {available_gpus}") nb_gpus = len(available_gpus) - if len(nb_gpus) > 1: - print("Use 1 GPU only") + if nb_gpus > 1: + print("Use 1 GPU only!") nb_gpus = 1 for probe, sessions in session_dict.items(): From c88957ada8fd3227d0526c1024e434dd9a874619 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 19:37:40 +0200 Subject: [PATCH 36/84] Specify steps per epoch --- pipeline/run_training.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pipeline/run_training.py b/pipeline/run_training.py index b901702..c313c2d 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -52,7 +52,7 @@ ##### DEFINE PARAMS ##### OVERWRITE = False USE_GPU = True -FULL_INFERENCE = True +STEPS_PER_EPOCH = 100 # Define training and testing constants (@Jad you can gradually increase this) @@ -183,6 +183,7 @@ test_end_s=TESTING_END_S, verbose=False, nb_gpus=nb_gpus, + steps_per_epoch=STEPS_PER_EPOCH, **di_kwargs, ) t_stop_training = time.perf_counter() From 606ab188401b036abac7a28d2831e390a94418df Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jul 2023 10:10:14 +0200 Subject: [PATCH 37/84] Add remove_excess_spikes curation --- pipeline/run_spike_sorting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 4629d79..f4c5a49 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -17,10 +17,8 @@ # SpikeInterface import spikeinterface as si -import spikeinterface.extractors as se -import spikeinterface.preprocessing as spre import spikeinterface.sorters as ss -import spikeinterface.postprocessing as spost +import spikeinterface.curation as scur import spikeinterface.comparison as sc import spikeinterface.qualitymetrics as sqm @@ -158,6 +156,7 @@ verbose=True, singularity_image=singularity_image, ) + sorting = scur.remove_excess_spikes(sorting, recording) sorting = sorting.save(folder=sorting_output_folder / "sorting") if (sorting_output_folder / "sorting_di").is_dir() and not OVERWRITE: @@ -173,6 +172,7 @@ verbose=True, singularity_image=singularity_image, ) + sorting_di = scur.remove_excess_spikes(sorting_di, recording_di) sorting_di = sorting_di.save(folder=sorting_output_folder / "sorting_di") # compare outputs From c78168be9ebe509efc67afef6795712a940752ae Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jul 2023 17:41:42 +0200 Subject: [PATCH 38/84] Extend to simulated data --- pipeline/run_inference.py | 54 ++++----- pipeline/run_spike_sorting.py | 14 --- pipeline/run_spike_sorting_GT.py | 202 +++++++++++++++++++++++++++++++ pipeline/run_super_training.py | 18 +-- pipeline/run_training.py | 64 +++++----- pipeline/sessions.py | 26 +++- 6 files changed, 292 insertions(+), 86 deletions(-) create mode 100644 pipeline/run_spike_sorting_GT.py diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index fbbe543..9982087 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -34,9 +34,10 @@ base_path = Path("..").resolve() ##### DEFINE DATASETS AND FOLDERS ####### -from sessions import all_sessions +from sessions import all_sessions_exp, all_sessions_sim n_jobs = 16 + job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") data_folder = base_path / "data" @@ -44,8 +45,19 @@ results_folder = base_path / "results" -# DATASET_BUCKET = "s3://aind-benchmark-data/ephys-compression/aind-np2/" -DATASET_BUCKET = data_folder / "ephys-compression-benchmark" +if (data_folder / "ephys-compression-benchmark").is_dir(): + DATASET_FOLDER = data_folder / "ephys-compression-benchmark" + all_sessions = all_sessions_exp + data_type = "exp" +elif (data_folder / "MEArec-NP-recordings").is_dir(): + DATASET_FOLDER = data_folder / "MEArec-NP-recordings" + all_sessions = all_sessions_sim + data_type = "sim" +else: + raise Exception("Could not find dataset folder") + + +DATASET_FOLDER = data_folder / "ephys-compression-benchmark" DEBUG = False NUM_DEBUG_SESSIONS = 2 @@ -73,7 +85,10 @@ inference_memory_gpu = 2000 # MB di_kwargs = dict( - pre_frame=pre_frame, post_frame=post_frame, pre_post_omission=pre_post_omission, desired_shape=desired_shape, + pre_frame=pre_frame, + post_frame=post_frame, + pre_post_omission=pre_post_omission, + desired_shape=desired_shape, ) if __name__ == "__main__": @@ -140,29 +155,12 @@ print(f"\nAnalyzing session {session}\n") dataset_name, session_name = session.split("/") - if str(DATASET_BUCKET).startswith("s3"): - raw_data_folder = scratch_folder / "raw" - raw_data_folder.mkdir(exist_ok=True) - dst_folder = raw_data_folder / session - - # download dataset - dst_folder.mkdir(exist_ok=True) - - src_folder = f"{DATASET_BUCKET}{session}" - - cmd = f"aws s3 sync --no-sign-request {src_folder} {dst_folder}" - # aws command to download - os.system(cmd) + if data_type == "exp": + recording = si.load_extractor(DATASET_FOLDER / session) else: - raw_data_folder = DATASET_BUCKET - dst_folder = raw_data_folder / session - - recording_folder = dst_folder - recording = si.load_extractor(recording_folder) - if DEBUG: - recording = recording.frame_slice( - start_frame=0, end_frame=int(DEBUG_DURATION * recording.sampling_frequency), - ) + recording, _ = se.read_mearec(DATASET_FOLDER / session) + session_name = session_name.split(".")[0] + recording = spre.depth_order(recording) for filter_option in FILTER_OPTIONS: print(f"\tFilter option: {filter_option}") @@ -203,7 +201,9 @@ use_gpu=USE_GPU, ) recording_di = recording_di.save( - folder=output_folder, n_jobs=inference_n_jobs, chunk_duration=inference_chunk_duration, + folder=output_folder, + n_jobs=inference_n_jobs, + chunk_duration=inference_chunk_duration, ) t_stop_inference = time.perf_counter() elapsed_time_inference = np.round(t_stop_inference - t_start_inference, 2) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index f4c5a49..23c7c1e 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -5,15 +5,10 @@ #### IMPORTS ####### -import os import sys import json -import numpy as np from pathlib import Path -from numba import cuda import pandas as pd -import time - # SpikeInterface import spikeinterface as si @@ -37,15 +32,6 @@ results_folder = base_path / "results" -# DATASET_BUCKET = "s3://aind-benchmark-data/ephys-compression/aind-np2/" -DATASET_BUCKET = data_folder / "ephys-compression-benchmark" / "aind-np2" - -DEBUG = False -NUM_DEBUG_SESSIONS = 2 -DEBUG_DURATION = 20 -OVERWRITE = False - - # Define training and testing constants FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py new file mode 100644 index 0000000..90ee284 --- /dev/null +++ b/pipeline/run_spike_sorting_GT.py @@ -0,0 +1,202 @@ +import warnings + +warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +#### IMPORTS ####### +import sys +import json +from pathlib import Path +import pandas as pd + + +# SpikeInterface +import spikeinterface as si +import spikeinterface.extractors as se +import spikeinterface.sorters as ss +import spikeinterface.curation as scur +import spikeinterface.comparison as sc +import spikeinterface.qualitymetrics as sqm + + +base_path = Path("..").resolve() + +##### DEFINE DATASETS AND FOLDERS ####### +from sessions import all_sessions_sim as all_sessions + +n_jobs = 16 + +job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") + +data_folder = base_path / "data" +scratch_folder = base_path / "scratch" +results_folder = base_path / "results" + +DATASET_FOLDER = data_folder / "MEArec-NP-recordings" + +OVERWRITE = False + +# Define training and testing constants +FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" + + +sorter_name = "pykilosort" +singularity_image = False +match_score = 0.7 + + +if __name__ == "__main__": + if len(sys.argv) == 2: + if sys.argv[1] == "true": + OVERWRITE = True + else: + OVERWRITE = False + + json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] + + if len(json_files) > 0: + print(f"Found {len(json_files)} JSON config") + session_dict = {} + # each json file contains a session to run + for json_file in json_files: + with open(json_file, "r") as f: + d = json.load(f) + probe = d["probe"] + if probe not in session_dict: + session_dict[probe] = [] + session = d["session"] + session_dict[probe].append(session) + else: + session_dict = all_sessions + + print(session_dict) + + si.set_global_job_kwargs(**job_kwargs) + + #### START #### + probe_processed_folders = [p for p in data_folder.iterdir() if "processed_" in p.name and p.is_dir()] + + if len(probe_processed_folders) > 0: + processed_folder = data_folder + else: + data_subfolders = [p for p in data_folder.iterdir() if p.is_dir()] + processed_folder = data_subfolders[0] + + for probe, sessions in session_dict.items(): + + print(f"Dataset {probe}") + for session in sessions: + print(f"\nAnalyzing session {session}\n") + dataset_name, session_name = session.split("/") + + _, sorting_gt = se.read_mearec(DATASET_FOLDER / session) + session_name = session_name.split(".")[0] + + session_level_results = None + unit_level_results = None + + for filter_option in FILTER_OPTIONS: + print(f"\tFilter option: {filter_option}") + + # load recordings + # save processed json + processed_json_folder = processed_folder / f"processed_{dataset_name}_{session_name}_{filter_option}" + recording = si.load_extractor(processed_json_folder / "processed.json", base_folder=data_folder) + recording_di = si.load_extractor( + processed_json_folder / "deepinterpolated.json", base_folder=processed_folder + ) + + # run spike sorting + sorting_output_folder = results_folder / f"sorting_{dataset_name}_{session_name}_{filter_option}" + sorting_output_folder.mkdir(parents=True, exist_ok=True) + + if (sorting_output_folder / "sorting").is_dir() and not OVERWRITE: + print("\t\tLoading NO DI sorting") + sorting = si.load_extractor(sorting_output_folder / "sorting") + else: + print(f"\t\tSpike sorting NO DI with {sorter_name}") + sorting = ss.run_sorter( + sorter_name, + recording=recording, + output_folder=scratch_folder / session / filter_option / "no_di", + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting = scur.remove_excess_spikes(sorting, recording) + sorting = sorting.save(folder=sorting_output_folder / "sorting") + + if (sorting_output_folder / "sorting_di").is_dir() and not OVERWRITE: + print("\t\tLoading DI sorting") + sorting_di = si.load_extractor(sorting_output_folder / "sorting_di") + else: + print(f"\t\tSpike sorting DI with {sorter_name}") + sorting_di = ss.run_sorter( + sorter_name, + recording=recording_di, + output_folder=scratch_folder / session / filter_option / "di", + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting_di = scur.remove_excess_spikes(sorting_di, recording_di) + sorting_di = sorting_di.save(folder=sorting_output_folder / "sorting_di") + + # compare to GT + print("\tRunning comparison") + cmp = sc.compare_sorter_to_ground_truth(sorting_gt, sorting, exhaustive_gt=True) + cmp_di = sc.compare_sorter_to_ground_truth(sorting_gt, sorting_di, exhaustive_gt=True) + + perf_avg = cmp.get_performance(method="pooled_with_average", output="dict") + perf_avg_di = cmp_di.get_performance(method="pooled_with_average", output="dict") + counts = cmp.count_units_categories() + counts_di = cmp.count_units_categories() + + new_data = { + "probe": probe, + "session": session_name, + "num_units": len(sorting.unit_ids), + "num_units_di": len(sorting_di.unit_ids), + "filter_option": filter_option, + "deepinteprolated": False, + } + new_data_di = new_data.copy() + new_data_di["deepinteprolated"] = True + + new_data.update(perf_avg) + new_data.update(counts.to_dict()) + + new_data_di.update(perf_avg_di) + new_data_di.update(counts_di.to_dict()) + + new_df = pd.DataFrame(new_data) + new_df = pd.concat([new_df, pd.DataFrame(new_data_di)], ignore_index=True) + + if session_level_results is None: + session_level_results = new_df + else: + session_level_results = pd.concat([session_level_results, new_df], ignore_index=True) + + # by unit + perf_by_unit = cmp.get_performance(method="by_unit") + perf_by_unit.loc[:, "probe"] = [probe] * len(perf_by_unit) + perf_by_unit.loc[:, "session"] = [session_name] * len(perf_by_unit) + perf_by_unit.loc[:, "filter_option"] = [filter_option] * len(perf_by_unit) + perf_by_unit.loc[:, "deepinterpolated"] = [False] * len(perf_by_unit) + + perf_by_unit_di = cmp_di.get_performance(method="by_unit") + perf_by_unit_di.loc[:, "probe"] = [probe] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "session"] = [session_name] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "filter_option"] = [filter_option] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "deepinterpolated"] = [True] * len(perf_by_unit_di) + + new_unit_df = pd.concat([perf_by_unit, perf_by_unit_di], ignore_index=True) + + if unit_level_results is None: + unit_level_results = new_unit_df + else: + unit_level_results = pd.concat([unit_level_results, new_unit_df], ignore_index=True) + + session_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-sessions.csv", index=False) + unit_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-units.csv", index=False) diff --git a/pipeline/run_super_training.py b/pipeline/run_super_training.py index 9f3e10b..647018b 100644 --- a/pipeline/run_super_training.py +++ b/pipeline/run_super_training.py @@ -43,8 +43,8 @@ results_folder = base_path / "results" -# DATASET_BUCKET = "s3://aind-benchmark-data/ephys-compression/aind-np2/" -DATASET_BUCKET = data_folder / "ephys-compression-benchmark" +# DATASET_FOLDER = "s3://aind-benchmark-data/ephys-compression/aind-np2/" +DATASET_FOLDER = data_folder / "ephys-compression-benchmark" DEBUG = False NUM_DEBUG_SESSIONS = 4 @@ -67,7 +67,10 @@ desired_shape = (192, 2) di_kwargs = dict( - pre_frame=pre_frame, post_frame=post_frame, pre_post_omission=pre_post_omission, desired_shape=desired_shape, + pre_frame=pre_frame, + post_frame=post_frame, + pre_post_omission=pre_post_omission, + desired_shape=desired_shape, ) @@ -112,27 +115,28 @@ print(f"\tRunning super training with {len(sessions_to_use)} sessions") for i, session in enumerate(sessions_to_use): print(f"\t\tSession {session} - Iteration {i}\n") - if str(DATASET_BUCKET).startswith("s3"): + if str(DATASET_FOLDER).startswith("s3"): raw_data_folder = scratch_folder / "raw" raw_data_folder.mkdir(exist_ok=True) # download dataset dst_folder.mkdir(exist_ok=True) - src_folder = f"{DATASET_BUCKET}{session}" + src_folder = f"{DATASET_FOLDER}{session}" cmd = f"aws s3 sync {src_folder} {dst_folder}" # aws command to download os.system(cmd) else: - raw_data_folder = DATASET_BUCKET + raw_data_folder = DATASET_FOLDER dst_folder = raw_data_folder / session recording_folder = dst_folder recording = si.load_extractor(recording_folder) if DEBUG: recording = recording.frame_slice( - start_frame=0, end_frame=int(DEBUG_DURATION * recording.sampling_frequency), + start_frame=0, + end_frame=int(DEBUG_DURATION * recording.sampling_frequency), ) # train DI models diff --git a/pipeline/run_training.py b/pipeline/run_training.py index c313c2d..f295f1a 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -5,13 +5,11 @@ #### IMPORTS ####### -import os import sys import shutil import json import numpy as np from pathlib import Path -import pandas as pd import time @@ -31,19 +29,27 @@ base_path = Path("..").resolve() ##### DEFINE DATASETS AND FOLDERS ####### -from sessions import all_sessions - -n_jobs = 16 - -job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") +from sessions import all_sessions_exp, all_sessions_sim data_folder = base_path / "data" scratch_folder = base_path / "scratch" results_folder = base_path / "results" -# DATASET_BUCKET = "s3://aind-benchmark-data/ephys-compression/aind-np2/" -DATASET_BUCKET = data_folder / "ephys-compression-benchmark" +if (data_folder / "ephys-compression-benchmark").is_dir(): + DATASET_FOLDER = data_folder / "ephys-compression-benchmark" + all_sessions = all_sessions_exp + data_type = "exp" +elif (data_folder / "MEArec-NP-recordings").is_dir(): + DATASET_FOLDER = data_folder / "MEArec-NP-recordings" + all_sessions = all_sessions_sim + data_type = "sim" +else: + raise Exception("Could not find dataset folder") + + +n_jobs = 16 +job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") DEBUG = False NUM_DEBUG_SESSIONS = 2 @@ -54,9 +60,7 @@ USE_GPU = True STEPS_PER_EPOCH = 100 -# Define training and testing constants (@Jad you can gradually increase this) - - +# Define training and testing constants FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" # DI params @@ -66,7 +70,10 @@ desired_shape = (192, 2) di_kwargs = dict( - pre_frame=pre_frame, post_frame=post_frame, pre_post_omission=pre_post_omission, desired_shape=desired_shape, + pre_frame=pre_frame, + post_frame=post_frame, + pre_post_omission=pre_post_omission, + desired_shape=desired_shape, ) @@ -127,34 +134,23 @@ print(f"\nAnalyzing session {session}\n") dataset_name, session_name = session.split("/") - if str(DATASET_BUCKET).startswith("s3"): - raw_data_folder = scratch_folder / "raw" - raw_data_folder.mkdir(exist_ok=True) - - # download dataset - dst_folder.mkdir(exist_ok=True) - - src_folder = f"{DATASET_BUCKET}{session}" - - cmd = f"aws s3 sync {src_folder} {dst_folder}" - # aws command to download - os.system(cmd) + if data_type == "exp": + recording = si.load_extractor(DATASET_FOLDER / session) else: - raw_data_folder = DATASET_BUCKET - dst_folder = raw_data_folder / session + recording, _ = se.read_mearec(DATASET_FOLDER / session) + session_name = session_name.split(".")[0] + recording = spre.depth_order(recording) - recording_folder = dst_folder - recording = si.load_extractor(recording_folder) if DEBUG: recording = recording.frame_slice( - start_frame=0, end_frame=int(DEBUG_DURATION * recording.sampling_frequency), + start_frame=0, + end_frame=int(DEBUG_DURATION * recording.sampling_frequency), ) - print(recording) + print(f"\t{recording}") for filter_option in FILTER_OPTIONS: print(f"\tFilter option: {filter_option}") # train DI models - print(f"\t\tTraning DI") training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) testing_time = np.round(TESTING_END_S - TESTING_START_S, 3) model_name = f"{filter_option}_t{training_time}s_v{testing_time}s" @@ -193,7 +189,3 @@ for json_file in json_files: print(f"Copying JSON file: {json_file.name} to {results_folder}") shutil.copy(json_file, results_folder) - - print("Results folder content:") - for p in results_folder.iterdir(): - print(p.name) diff --git a/pipeline/sessions.py b/pipeline/sessions.py index 6797a51..6873445 100644 --- a/pipeline/sessions.py +++ b/pipeline/sessions.py @@ -1,7 +1,7 @@ from pathlib import Path import json -all_sessions = { +all_sessions_exp = { "NP1": [ "aind-np1/625749_2022-08-03_15-15-06_ProbeA", "aind-np1/634568_2022-08-05_15-59-46_ProbeA", @@ -24,11 +24,33 @@ ], } +all_sessions_sim = { + "NP1": [ + "NP1/recording-0.h5", + "NP1/recording-1.h5", + "NP1/recording-2.h5", + "NP1/recording-3.h5", + "NP1/recording-4.h5", + ], + "NP2": [ + "NP2/recording-0.h5", + "NP2/recording-1.h5", + "NP2/recording-2.h5", + "NP2/recording-3.h5", + "NP2/recording-4.h5", + ], +} + -def generate_job_config_list(output_folder, split_probes=True): +def generate_job_config_list(output_folder, split_probes=True, dataset="exp"): output_folder = Path(output_folder) output_folder.mkdir(exist_ok=True, parents=True) + if dataset == "exp": + all_sessions = all_sessions_exp + else: + all_sessions = all_sessions_sim + i = 0 for probe, sessions in all_sessions.items(): if split_probes: From a1e8f17dcf174c4c187fb0802cacb27b84763acc Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jul 2023 18:10:30 +0200 Subject: [PATCH 39/84] Steps per epoch 10 in debug mode --- pipeline/run_training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pipeline/run_training.py b/pipeline/run_training.py index f295f1a..d42f713 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -81,6 +81,7 @@ if len(sys.argv) == 2: if sys.argv[1] == "true": DEBUG = True + STEPS_PER_EPOCH = 10 else: DEBUG = False From f3827c4779e9b2ef51e1368ebc439519815f110a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jul 2023 18:22:32 +0200 Subject: [PATCH 40/84] Move depth order later --- pipeline/run_training.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pipeline/run_training.py b/pipeline/run_training.py index d42f713..4e6d0f9 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -165,6 +165,9 @@ recording_processed = recording recording_zscore = spre.zscore(recording_processed) + if data_type == "sim": + recording = spre.depth_order(recording) + # train model model_folder = results_folder / f"model_{dataset_name}_{session_name}_{filter_option}" model_folder.parent.mkdir(parents=True, exist_ok=True) From 54731b0019f7f3a2ed1241f0150274fa3292f99b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jul 2023 18:23:22 +0200 Subject: [PATCH 41/84] Move depth order later 2 --- pipeline/run_inference.py | 4 +++- pipeline/run_training.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 9982087..9817c2a 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -160,7 +160,6 @@ else: recording, _ = se.read_mearec(DATASET_FOLDER / session) session_name = session_name.split(".")[0] - recording = spre.depth_order(recording) for filter_option in FILTER_OPTIONS: print(f"\tFilter option: {filter_option}") @@ -179,6 +178,9 @@ recording_processed = recording recording_zscore = spre.zscore(recording_processed) + if data_type == "sim": + recording_zscore = spre.depth_order(recording_zscore) + # train model model_folder = data_model_folder / f"model_{dataset_name}_{session_name}_{filter_option}" model_path = [p for p in model_folder.iterdir() if p.name.endswith("model.h5")][0] diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 4e6d0f9..6443265 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -166,7 +166,7 @@ recording_zscore = spre.zscore(recording_processed) if data_type == "sim": - recording = spre.depth_order(recording) + recording_zscore = spre.depth_order(recording_zscore) # train model model_folder = results_folder / f"model_{dataset_name}_{session_name}_{filter_option}" From a5336163f6e8fb4edd7905dc770db38491c33ba6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jul 2023 18:30:12 +0200 Subject: [PATCH 42/84] Oups! --- pipeline/run_inference.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 9817c2a..6d3fcb0 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -57,8 +57,6 @@ raise Exception("Could not find dataset folder") -DATASET_FOLDER = data_folder / "ephys-compression-benchmark" - DEBUG = False NUM_DEBUG_SESSIONS = 2 DEBUG_DURATION = 20 From e065b000e4c56e699b0885317e423bc518fb66bd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jul 2023 18:55:01 +0200 Subject: [PATCH 43/84] Reintroduce debug in inference --- pipeline/run_inference.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 6d3fcb0..81e4719 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -93,8 +93,10 @@ if len(sys.argv) == 2: if sys.argv[1] == "true": DEBUG = True + OVERWRITE = True else: DEBUG = False + OVERWRITE = False json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] @@ -118,19 +120,6 @@ print(session_dict) - if DEBUG: - TRAINING_START_S = 0 - TRAINING_END_S = 0.2 - TESTING_START_S = 10 - TESTING_END_S = 10.05 - OVERWRITE = True - else: - TRAINING_START_S = 0 - TRAINING_END_S = 20 - TESTING_START_S = 70 - TESTING_END_S = 70.5 - OVERWRITE = False - si.set_global_job_kwargs(**job_kwargs) print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") @@ -159,13 +148,15 @@ recording, _ = se.read_mearec(DATASET_FOLDER / session) session_name = session_name.split(".")[0] + if DEBUG: + recording = recording.frame_slice( + start_frame=0, + end_frame=int(DEBUG_DURATION * recording.sampling_frequency), + ) + print(f"\t{recording}") + for filter_option in FILTER_OPTIONS: print(f"\tFilter option: {filter_option}") - # train DI models - print(f"\t\tTraning DI") - training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) - testing_time = np.round(TESTING_END_S - TESTING_START_S, 3) - model_name = f"{filter_option}_t{training_time}s_v{testing_time}s" # apply filter and zscore if filter_option == "hp": From 3eadac09524e9146de977e5c239cddbbf2a50b93 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jul 2023 19:04:02 +0200 Subject: [PATCH 44/84] Fix sorting import --- pipeline/run_inference.py | 1 + pipeline/run_spike_sorting.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 81e4719..d983e02 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -2,6 +2,7 @@ warnings.filterwarnings("ignore") warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) #### IMPORTS ####### diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 23c7c1e..3e0758e 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -21,7 +21,7 @@ base_path = Path("..").resolve() ##### DEFINE DATASETS AND FOLDERS ####### -from sessions import all_sessions +from sessions import all_sessions_exp as all_sessions n_jobs = 16 From 01b2b0bc88f64d1f975e276f807142f094255395 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 22 Jul 2023 18:37:01 +0200 Subject: [PATCH 45/84] Optimize inference --- pipeline/run_inference.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index d983e02..5964bcc 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -78,9 +78,9 @@ pre_post_omission = 1 desired_shape = (192, 2) # play around with these -inference_n_jobs = 16 -inference_chunk_duration = "500ms" -inference_predict_workers = 8 +inference_n_jobs = -1 +inference_chunk_duration = "1s" +inference_predict_workers = 1 inference_memory_gpu = 2000 # MB di_kwargs = dict( From 1e496c1d727eaff1b47d0aeb01dd244f1f893c9c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 22 Jul 2023 18:47:18 +0200 Subject: [PATCH 46/84] Save sim to binary --- pipeline/run_inference.py | 3 ++- pipeline/run_training.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 5964bcc..560e504 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -37,7 +37,7 @@ ##### DEFINE DATASETS AND FOLDERS ####### from sessions import all_sessions_exp, all_sessions_sim -n_jobs = 16 +n_jobs = -1 job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") @@ -170,6 +170,7 @@ if data_type == "sim": recording_zscore = spre.depth_order(recording_zscore) + recording_zscore = recording_zscore.save(folder=scratch_folder / "recording_zscored") # train model model_folder = data_model_folder / f"model_{dataset_name}_{session_name}_{filter_option}" diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 6443265..0f900eb 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -167,6 +167,7 @@ if data_type == "sim": recording_zscore = spre.depth_order(recording_zscore) + recording_zscore = recording_zscore.save(folder=scratch_folder / "recording_zscored") # train model model_folder = results_folder / f"model_{dataset_name}_{session_name}_{filter_option}" From 1de9909bcfeaab96ecfcd6b71f24a8d78e61ddd5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 22 Jul 2023 18:53:04 +0200 Subject: [PATCH 47/84] Fix zscore binary --- pipeline/run_inference.py | 5 +++-- pipeline/run_training.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 560e504..de868e2 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -170,7 +170,8 @@ if data_type == "sim": recording_zscore = spre.depth_order(recording_zscore) - recording_zscore = recording_zscore.save(folder=scratch_folder / "recording_zscored") + # This speeds things up a lot + recording_zscore_bin = recording_zscore.save(folder=scratch_folder / "recording_zscored") # train model model_folder = data_model_folder / f"model_{dataset_name}_{session_name}_{filter_option}" @@ -184,7 +185,7 @@ t_start_inference = time.perf_counter() output_folder.parent.mkdir(exist_ok=True, parents=True) recording_di = spre.deepinterpolate( - recording_zscore, + recording_zscore_bin, model_path=model_path, pre_frame=pre_frame, post_frame=post_frame, diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 0f900eb..54554dd 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -167,7 +167,8 @@ if data_type == "sim": recording_zscore = spre.depth_order(recording_zscore) - recording_zscore = recording_zscore.save(folder=scratch_folder / "recording_zscored") + # This speeds things up a lot + recording_zscore_bin = recording_zscore.save(folder=scratch_folder / "recording_zscored") # train model model_folder = results_folder / f"model_{dataset_name}_{session_name}_{filter_option}" @@ -175,7 +176,7 @@ # Use SI function t_start_training = time.perf_counter() model_path = spre.train_deepinterpolation( - recording_zscore, + recording_zscore_bin, model_folder=model_folder, model_name=model_name, train_start_s=TRAINING_START_S, From 67f43d939e0a7ea00498acb704a3f0209c125025 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 22 Jul 2023 19:05:29 +0200 Subject: [PATCH 48/84] Fix zscore binary 1 --- pipeline/run_inference.py | 5 +++-- pipeline/run_training.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index de868e2..b710655 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -166,10 +166,11 @@ recording_processed = spre.bandpass_filter(recording) else: recording_processed = recording - recording_zscore = spre.zscore(recording_processed) if data_type == "sim": - recording_zscore = spre.depth_order(recording_zscore) + recording_processed = spre.depth_order(recording_processed) + + recording_zscore = spre.zscore(recording_processed) # This speeds things up a lot recording_zscore_bin = recording_zscore.save(folder=scratch_folder / "recording_zscored") diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 54554dd..a1dd905 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -163,10 +163,11 @@ recording_processed = spre.bandpass_filter(recording) else: recording_processed = recording - recording_zscore = spre.zscore(recording_processed) if data_type == "sim": - recording_zscore = spre.depth_order(recording_zscore) + recording_processed = spre.depth_order(recording_processed) + + recording_zscore = spre.zscore(recording_processed) # This speeds things up a lot recording_zscore_bin = recording_zscore.save(folder=scratch_folder / "recording_zscored") From 5d4481760ee60a8312d2e817667406aae0435c60 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 22 Jul 2023 19:14:06 +0200 Subject: [PATCH 49/84] Fix zscore binary 2 --- pipeline/run_inference.py | 9 +++++---- pipeline/run_training.py | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index b710655..acdc826 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -158,6 +158,7 @@ for filter_option in FILTER_OPTIONS: print(f"\tFilter option: {filter_option}") + recording_name = f"{dataset_name}_{session_name}_{filter_option}" # apply filter and zscore if filter_option == "hp": @@ -172,13 +173,13 @@ recording_zscore = spre.zscore(recording_processed) # This speeds things up a lot - recording_zscore_bin = recording_zscore.save(folder=scratch_folder / "recording_zscored") + recording_zscore_bin = recording_zscore.save(folder=scratch_folder / f"recording_zscored_{recording_name}") # train model - model_folder = data_model_folder / f"model_{dataset_name}_{session_name}_{filter_option}" + model_folder = data_model_folder / f"model_{recording_name}" model_path = [p for p in model_folder.iterdir() if p.name.endswith("model.h5")][0] # full inference - output_folder = results_folder / f"deepinterpolated_{dataset_name}_{session_name}_{filter_option}" + output_folder = results_folder / f"deepinterpolated_{recording_name}" if OVERWRITE and output_folder.is_dir(): shutil.rmtree(output_folder) @@ -212,7 +213,7 @@ recording_di = spre.scale(recording_di, gain=inverse_gains, offset=inverse_offset, dtype="float") # save processed json - processed_folder = results_folder / f"processed_{dataset_name}_{session_name}_{filter_option}" + processed_folder = results_folder / f"processed_{recording_name}" processed_folder.mkdir(exist_ok=True, parents=True) recording_processed.dump_to_json(processed_folder / "processed.json", relative_to=results_folder) recording_di.dump_to_json(processed_folder / f"deepinterpolated.json", relative_to=results_folder) diff --git a/pipeline/run_training.py b/pipeline/run_training.py index a1dd905..1eedbb9 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -151,6 +151,7 @@ for filter_option in FILTER_OPTIONS: print(f"\tFilter option: {filter_option}") + recording_name = f"{dataset_name}_{session_name}_{filter_option}" # train DI models training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) testing_time = np.round(TESTING_END_S - TESTING_START_S, 3) @@ -169,10 +170,10 @@ recording_zscore = spre.zscore(recording_processed) # This speeds things up a lot - recording_zscore_bin = recording_zscore.save(folder=scratch_folder / "recording_zscored") + recording_zscore_bin = recording_zscore.save(folder=scratch_folder / f"recording_zscored_{recording_name}") # train model - model_folder = results_folder / f"model_{dataset_name}_{session_name}_{filter_option}" + model_folder = results_folder / f"model_{recording_name}" model_folder.parent.mkdir(parents=True, exist_ok=True) # Use SI function t_start_training = time.perf_counter() From 6b4edb7b3607cdcdb3fc6f760f148497ed3406dd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 22 Jul 2023 21:08:23 +0200 Subject: [PATCH 50/84] Fix sorting eval sim --- pipeline/run_spike_sorting_GT.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py index 90ee284..bf63812 100644 --- a/pipeline/run_spike_sorting_GT.py +++ b/pipeline/run_spike_sorting_GT.py @@ -148,8 +148,8 @@ cmp = sc.compare_sorter_to_ground_truth(sorting_gt, sorting, exhaustive_gt=True) cmp_di = sc.compare_sorter_to_ground_truth(sorting_gt, sorting_di, exhaustive_gt=True) - perf_avg = cmp.get_performance(method="pooled_with_average", output="dict") - perf_avg_di = cmp_di.get_performance(method="pooled_with_average", output="dict") + perf_avg = cmp.get_performance(method="pooled_with_average") + perf_avg_di = cmp_di.get_performance(method="pooled_with_average") counts = cmp.count_units_categories() counts_di = cmp.count_units_categories() @@ -164,10 +164,10 @@ new_data_di = new_data.copy() new_data_di["deepinteprolated"] = True - new_data.update(perf_avg) + new_data.update(perf_avg.to_dict()) new_data.update(counts.to_dict()) - new_data_di.update(perf_avg_di) + new_data_di.update(perf_avg_di.to_dict()) new_data_di.update(counts_di.to_dict()) new_df = pd.DataFrame(new_data) From 3a41f681238e5d9f97f36bada01b86cd2996d04f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 10:02:10 +0200 Subject: [PATCH 51/84] Fix df concatenation --- pipeline/run_spike_sorting_GT.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py index bf63812..cdae314 100644 --- a/pipeline/run_spike_sorting_GT.py +++ b/pipeline/run_spike_sorting_GT.py @@ -170,13 +170,14 @@ new_data_di.update(perf_avg_di.to_dict()) new_data_di.update(counts_di.to_dict()) - new_df = pd.DataFrame(new_data) - new_df = pd.concat([new_df, pd.DataFrame(new_data_di)], ignore_index=True) + new_df = pd.DataFrame([new_data]) + new_df_di = pd.DataFrame([new_data_di]) + new_df_session = pd.concat([new_df, new_df_di], ignore_index=True) if session_level_results is None: - session_level_results = new_df + session_level_results = new_df_session else: - session_level_results = pd.concat([session_level_results, new_df], ignore_index=True) + session_level_results = pd.concat([session_level_results, new_df_session], ignore_index=True) # by unit perf_by_unit = cmp.get_performance(method="by_unit") From 60a08f1b2a420810b91aedfd6588be1a0e3fa565 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 10:50:42 +0200 Subject: [PATCH 52/84] Fix debug mode for sorting GT and add inference parallel params --- pipeline/run_inference.py | 6 ++++-- pipeline/run_spike_sorting_GT.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index acdc826..075fb36 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -78,7 +78,7 @@ pre_post_omission = 1 desired_shape = (192, 2) # play around with these -inference_n_jobs = -1 +inference_n_jobs = os.cpu_count() - 4 inference_chunk_duration = "1s" inference_predict_workers = 1 inference_memory_gpu = 2000 # MB @@ -91,13 +91,15 @@ ) if __name__ == "__main__": - if len(sys.argv) == 2: + if len(sys.argv) == 4: if sys.argv[1] == "true": DEBUG = True OVERWRITE = True else: DEBUG = False OVERWRITE = False + inference_n_jobs = int(sys.argv[2]) + inference_predict_workers = int(sys.argv[3]) json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py index cdae314..07c6ec4 100644 --- a/pipeline/run_spike_sorting_GT.py +++ b/pipeline/run_spike_sorting_GT.py @@ -90,7 +90,7 @@ print(f"\nAnalyzing session {session}\n") dataset_name, session_name = session.split("/") - _, sorting_gt = se.read_mearec(DATASET_FOLDER / session) + recording_gt, sorting_gt = se.read_mearec(DATASET_FOLDER / session) session_name = session_name.split(".")[0] session_level_results = None @@ -107,6 +107,10 @@ processed_json_folder / "deepinterpolated.json", base_folder=processed_folder ) + # DEBUG mode + if recording.get_num_samples() < recording_gt.get_num_samples(): + sorting_gt = sorting_gt.frame_slice(start_frame=0, end_frame=recording.get_num_samples()) + # run spike sorting sorting_output_folder = results_folder / f"sorting_{dataset_name}_{session_name}_{filter_option}" sorting_output_folder.mkdir(parents=True, exist_ok=True) @@ -157,12 +161,12 @@ "probe": probe, "session": session_name, "num_units": len(sorting.unit_ids), - "num_units_di": len(sorting_di.unit_ids), "filter_option": filter_option, - "deepinteprolated": False, + "deepinterpolated": False, } new_data_di = new_data.copy() - new_data_di["deepinteprolated"] = True + new_data_di["deepinterpolated"] = True + new_data_di["num_units"] = len(sorting_di.unit_ids), new_data.update(perf_avg.to_dict()) new_data.update(counts.to_dict()) From c89e7ab43d8d2d667d5f7a91375b82e0580d0af8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 10:52:39 +0200 Subject: [PATCH 53/84] Add debug print --- pipeline/run_spike_sorting_GT.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py index 07c6ec4..b38c35e 100644 --- a/pipeline/run_spike_sorting_GT.py +++ b/pipeline/run_spike_sorting_GT.py @@ -109,6 +109,7 @@ # DEBUG mode if recording.get_num_samples() < recording_gt.get_num_samples(): + print("DEBUG MODE: slicing GT") sorting_gt = sorting_gt.frame_slice(start_frame=0, end_frame=recording.get_num_samples()) # run spike sorting From c649d279a6886d0a20266d95eec525b761cf3d0c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 10:53:45 +0200 Subject: [PATCH 54/84] Final cmp fix --- pipeline/run_spike_sorting_GT.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py index b38c35e..a0c72a3 100644 --- a/pipeline/run_spike_sorting_GT.py +++ b/pipeline/run_spike_sorting_GT.py @@ -156,7 +156,7 @@ perf_avg = cmp.get_performance(method="pooled_with_average") perf_avg_di = cmp_di.get_performance(method="pooled_with_average") counts = cmp.count_units_categories() - counts_di = cmp.count_units_categories() + counts_di = cmp_di.count_units_categories() new_data = { "probe": probe, @@ -167,7 +167,7 @@ } new_data_di = new_data.copy() new_data_di["deepinterpolated"] = True - new_data_di["num_units"] = len(sorting_di.unit_ids), + new_data_di["num_units"] = len(sorting_di.unit_ids) new_data.update(perf_avg.to_dict()) new_data.update(counts.to_dict()) From 218a57a05f123719489203560bfbf21166c6fce5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 11:18:20 +0200 Subject: [PATCH 55/84] Add unit id column and sort columns GT --- pipeline/run_spike_sorting_GT.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py index a0c72a3..5681011 100644 --- a/pipeline/run_spike_sorting_GT.py +++ b/pipeline/run_spike_sorting_GT.py @@ -190,12 +190,14 @@ perf_by_unit.loc[:, "session"] = [session_name] * len(perf_by_unit) perf_by_unit.loc[:, "filter_option"] = [filter_option] * len(perf_by_unit) perf_by_unit.loc[:, "deepinterpolated"] = [False] * len(perf_by_unit) + perf_by_unit.loc[:, "unit_id"] = sorting_gt.unit_ids perf_by_unit_di = cmp_di.get_performance(method="by_unit") perf_by_unit_di.loc[:, "probe"] = [probe] * len(perf_by_unit_di) perf_by_unit_di.loc[:, "session"] = [session_name] * len(perf_by_unit_di) perf_by_unit_di.loc[:, "filter_option"] = [filter_option] * len(perf_by_unit_di) perf_by_unit_di.loc[:, "deepinterpolated"] = [True] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "unit_id"] = sorting_gt.unit_ids new_unit_df = pd.concat([perf_by_unit, perf_by_unit_di], ignore_index=True) @@ -204,5 +206,10 @@ else: unit_level_results = pd.concat([unit_level_results, new_unit_df], ignore_index=True) + sorted_columns = ["probe", "session", "filter_option", "deepinterpolated", "unit_id"] + for col in perf_by_unit.columns: + sorted_columns.append(col) + unit_level_results = unit_level_results[sorted_columns] + session_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-sessions.csv", index=False) unit_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-units.csv", index=False) From 12d80a8a27f194377f7b104bf0070a60f463628a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 11:19:39 +0200 Subject: [PATCH 56/84] Add unit id column and sort columns GT 1 --- pipeline/run_spike_sorting_GT.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py index 5681011..72254ea 100644 --- a/pipeline/run_spike_sorting_GT.py +++ b/pipeline/run_spike_sorting_GT.py @@ -206,10 +206,10 @@ else: unit_level_results = pd.concat([unit_level_results, new_unit_df], ignore_index=True) - sorted_columns = ["probe", "session", "filter_option", "deepinterpolated", "unit_id"] - for col in perf_by_unit.columns: - sorted_columns.append(col) - unit_level_results = unit_level_results[sorted_columns] + sorted_columns = ["probe", "session", "filter_option", "deepinterpolated", "unit_id"] + for col in perf_by_unit.columns: + sorted_columns.append(col) + unit_level_results = unit_level_results[sorted_columns] session_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-sessions.csv", index=False) unit_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-units.csv", index=False) From fab66b71428c3bd61916e2de40ced6de53072955 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 11:21:01 +0200 Subject: [PATCH 57/84] Add unit id column and sort columns GT 2 --- pipeline/run_spike_sorting_GT.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py index 72254ea..34af3f5 100644 --- a/pipeline/run_spike_sorting_GT.py +++ b/pipeline/run_spike_sorting_GT.py @@ -186,6 +186,7 @@ # by unit perf_by_unit = cmp.get_performance(method="by_unit") + perf_columns = perf_by_unit.columns perf_by_unit.loc[:, "probe"] = [probe] * len(perf_by_unit) perf_by_unit.loc[:, "session"] = [session_name] * len(perf_by_unit) perf_by_unit.loc[:, "filter_option"] = [filter_option] * len(perf_by_unit) @@ -207,7 +208,7 @@ unit_level_results = pd.concat([unit_level_results, new_unit_df], ignore_index=True) sorted_columns = ["probe", "session", "filter_option", "deepinterpolated", "unit_id"] - for col in perf_by_unit.columns: + for col in perf_columns: sorted_columns.append(col) unit_level_results = unit_level_results[sorted_columns] From e049bb322fce3ad85ee47221334acebcbfd29f6c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 11:50:32 +0200 Subject: [PATCH 58/84] Debug paths --- pipeline/run_spike_sorting_GT.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py index 34af3f5..6f788fd 100644 --- a/pipeline/run_spike_sorting_GT.py +++ b/pipeline/run_spike_sorting_GT.py @@ -5,6 +5,7 @@ #### IMPORTS ####### +import os import sys import json from pathlib import Path @@ -102,6 +103,11 @@ # load recordings # save processed json processed_json_folder = processed_folder / f"processed_{dataset_name}_{session_name}_{filter_option}" + + print(os.getcwd()) + print("Processed JSON file: ", processed_json_folder / "processed.json") + print("DeepInterpolated JSON file: ", processed_json_folder / "deepinterpolated.json") + recording = si.load_extractor(processed_json_folder / "processed.json", base_folder=data_folder) recording_di = si.load_extractor( processed_json_folder / "deepinterpolated.json", base_folder=processed_folder From f4c8a0a7dde53d82b63b18e4e5ebdc282cc3012b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 11:53:15 +0200 Subject: [PATCH 59/84] Remove resolve --- pipeline/run_spike_sorting_GT.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py index 6f788fd..7622d6f 100644 --- a/pipeline/run_spike_sorting_GT.py +++ b/pipeline/run_spike_sorting_GT.py @@ -21,7 +21,7 @@ import spikeinterface.qualitymetrics as sqm -base_path = Path("..").resolve() +base_path = Path("..") ##### DEFINE DATASETS AND FOLDERS ####### from sessions import all_sessions_sim as all_sessions From b3e2257bd304464d16fe6886d720a7427f868f36 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 11:58:49 +0200 Subject: [PATCH 60/84] Remove resolved and update super training --- pipeline/run_collect_results.py | 2 +- pipeline/run_inference.py | 2 +- pipeline/run_spike_sorting.py | 2 +- pipeline/run_spike_sorting_GT.py | 4 ---- pipeline/run_super_training.py | 30 +++++++++++------------------- pipeline/run_training.py | 2 +- 6 files changed, 15 insertions(+), 27 deletions(-) diff --git a/pipeline/run_collect_results.py b/pipeline/run_collect_results.py index 4195e5f..1f77d40 100644 --- a/pipeline/run_collect_results.py +++ b/pipeline/run_collect_results.py @@ -12,7 +12,7 @@ import pandas as pd -base_path = Path("..").resolve() +base_path = Path("..") data_folder = base_path / "data" scratch_folder = base_path / "scratch" diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 075fb36..e60d16a 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -32,7 +32,7 @@ os.environ["OPENBLAS_NUM_THREADS"] = "1" -base_path = Path("..").resolve() +base_path = Path("..") ##### DEFINE DATASETS AND FOLDERS ####### from sessions import all_sessions_exp, all_sessions_sim diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 3e0758e..9ec18b6 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -18,7 +18,7 @@ import spikeinterface.qualitymetrics as sqm -base_path = Path("..").resolve() +base_path = Path("..") ##### DEFINE DATASETS AND FOLDERS ####### from sessions import all_sessions_exp as all_sessions diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py index 7622d6f..b99e9b8 100644 --- a/pipeline/run_spike_sorting_GT.py +++ b/pipeline/run_spike_sorting_GT.py @@ -104,10 +104,6 @@ # save processed json processed_json_folder = processed_folder / f"processed_{dataset_name}_{session_name}_{filter_option}" - print(os.getcwd()) - print("Processed JSON file: ", processed_json_folder / "processed.json") - print("DeepInterpolated JSON file: ", processed_json_folder / "deepinterpolated.json") - recording = si.load_extractor(processed_json_folder / "processed.json", base_folder=data_folder) recording_di = si.load_extractor( processed_json_folder / "deepinterpolated.json", base_folder=processed_folder diff --git a/pipeline/run_super_training.py b/pipeline/run_super_training.py index 647018b..4c1f816 100644 --- a/pipeline/run_super_training.py +++ b/pipeline/run_super_training.py @@ -29,10 +29,10 @@ import tensorflow as tf -base_path = Path("..").resolve() +base_path = Path("..") ##### DEFINE DATASETS AND FOLDERS ####### -from sessions import all_sessions +from sessions import all_sessions_exp as all_sessions n_jobs = 16 @@ -115,24 +115,11 @@ print(f"\tRunning super training with {len(sessions_to_use)} sessions") for i, session in enumerate(sessions_to_use): print(f"\t\tSession {session} - Iteration {i}\n") - if str(DATASET_FOLDER).startswith("s3"): - raw_data_folder = scratch_folder / "raw" - raw_data_folder.mkdir(exist_ok=True) + dataset_name, session_name = session.split("/") + recording_name = f"{dataset_name}_{session_name}_{filter_option}" - # download dataset - dst_folder.mkdir(exist_ok=True) + recording = si.load_extractor(DATASET_FOLDER / session) - src_folder = f"{DATASET_FOLDER}{session}" - - cmd = f"aws s3 sync {src_folder} {dst_folder}" - # aws command to download - os.system(cmd) - else: - raw_data_folder = DATASET_FOLDER - dst_folder = raw_data_folder / session - - recording_folder = dst_folder - recording = si.load_extractor(recording_folder) if DEBUG: recording = recording.frame_slice( start_frame=0, @@ -153,14 +140,19 @@ recording_processed = recording recording_zscore = spre.zscore(recording_processed) + # This speeds things up a lot + recording_zscore_bin = recording_zscore.save(folder=scratch_folder / f"recording_zscored_{recording_name}") + # train model model_folder = results_folder / f"models_{filter_option}" / f"iter{i}" model_folder.parent.mkdir(parents=True, exist_ok=True) # Use SI function t_start_training = time.perf_counter() + if pretrained_model_path is not None: + print(f"\t\tUsing pretrained model: {pretrained_model_path}") model_path = spre.train_deepinterpolation( - recording_zscore, + recording_zscore_bin, model_folder=model_folder, model_name=model_name, existing_model_path=pretrained_model_path, diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 1eedbb9..b5e5a38 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -26,7 +26,7 @@ import tensorflow as tf # runs from "codes" -base_path = Path("..").resolve() +base_path = Path("..") ##### DEFINE DATASETS AND FOLDERS ####### from sessions import all_sessions_exp, all_sessions_sim From 15a7489d7b1343c13e10d665aa9f7f623df225f5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 12:07:03 +0200 Subject: [PATCH 61/84] Super-training: add probe option --- pipeline/run_super_training.py | 75 +++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/pipeline/run_super_training.py b/pipeline/run_super_training.py index 4c1f816..09765ac 100644 --- a/pipeline/run_super_training.py +++ b/pipeline/run_super_training.py @@ -75,11 +75,12 @@ if __name__ == "__main__": - if len(sys.argv) == 2: + if len(sys.argv) == 3: if sys.argv[1] == "true": DEBUG = True else: DEBUG = False + PROBESET = sys.argv[2] session_dict = all_sessions @@ -100,19 +101,25 @@ si.set_global_job_kwargs(**job_kwargs) + assert PROBESET in ["NP1", "NP2", "NP1-NP2"] + + probes = PROBESET.split("-") + print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") - pretrained_model_path = None for filter_option in FILTER_OPTIONS: print(f"Filter option: {filter_option}") - for probe, sessions in session_dict.items(): + for probe in probes: + sessions = all_sessions[probe] print(f"\tDataset {probe}") if DEBUG: sessions_to_use = sessions[:NUM_DEBUG_SESSIONS] else: sessions_to_use = sessions print(f"\tRunning super training with {len(sessions_to_use)} sessions") + + pretrained_model_path = None for i, session in enumerate(sessions_to_use): print(f"\t\tSession {session} - Iteration {i}\n") dataset_name, session_name = session.split("/") @@ -144,7 +151,7 @@ recording_zscore_bin = recording_zscore.save(folder=scratch_folder / f"recording_zscored_{recording_name}") # train model - model_folder = results_folder / f"models_{filter_option}" / f"iter{i}" + model_folder = results_folder / f"models_{probe}_{filter_option}" / f"iter{i}" model_folder.parent.mkdir(parents=True, exist_ok=True) # Use SI function @@ -167,33 +174,33 @@ elapsed_time_training = np.round(t_stop_training - t_start_training, 2) print(f"\t\tElapsed time TRAINING {session}-{filter_option}: {elapsed_time_training}s") - # aggregate results - print(f"Aggregating results for {filter_option}") - final_model_folder = results_folder / f"model_{filter_option}" - shutil.copytree(model_folder, final_model_folder) - final_model_name = [p.name for p in final_model_folder.iterdir() if "_model" in p.name][0] - final_model_stem = final_model_name.split("_model")[0] - - # concatenate loss and val loss - loss_accuracies = np.array([]) - val_accuracies = np.array([]) - - for i in range(len(sessions_to_use)): - model_folder = results_folder / f"models_{filter_option}" / f"iter{i}" - loss_file = [p for p in model_folder.iterdir() if "_loss.npy" in p.name and "val" not in p.name][0] - val_loss_file = [p for p in model_folder.iterdir() if "val_loss.npy" in p.name][0] - loss = np.load(loss_file) - val_loss = np.load(val_loss_file) - loss_accuracies = np.concatenate((loss_accuracies, loss)) - val_accuracies = np.concatenate((val_accuracies, val_loss)) - np.save(final_model_folder / f"{final_model_stem}_loss.npy", loss_accuracies) - np.save(final_model_folder / f"{final_model_stem}_val_loss.npy", val_accuracies) - - # plot losses - fig, ax = plt.subplots() - ax.plot(loss_accuracies, color="C0", label="loss") - ax.plot(val_accuracies, color="C1", label="val_loss") - ax.set_xlabel("number of epochs") - ax.set_ylabel("training loss") - ax.legend() - fig.savefig(final_model_folder / f"{final_model_stem}_losses.png", dpi=300) + # aggregate results + print(f"Aggregating results for {probe}-{filter_option}") + final_model_folder = results_folder / f"model_{probe}_{filter_option}" + shutil.copytree(model_folder, final_model_folder) + final_model_name = [p.name for p in final_model_folder.iterdir() if "_model" in p.name][0] + final_model_stem = final_model_name.split("_model")[0] + + # concatenate loss and val loss + loss_accuracies = np.array([]) + val_accuracies = np.array([]) + + for i in range(len(sessions_to_use)): + model_folder = results_folder / f"models_{probe}_{filter_option}" / f"iter{i}" + loss_file = [p for p in model_folder.iterdir() if "_loss.npy" in p.name and "val" not in p.name][0] + val_loss_file = [p for p in model_folder.iterdir() if "val_loss.npy" in p.name][0] + loss = np.load(loss_file) + val_loss = np.load(val_loss_file) + loss_accuracies = np.concatenate((loss_accuracies, loss)) + val_accuracies = np.concatenate((val_accuracies, val_loss)) + np.save(final_model_folder / f"{final_model_stem}_loss.npy", loss_accuracies) + np.save(final_model_folder / f"{final_model_stem}_val_loss.npy", val_accuracies) + + # plot losses + fig, ax = plt.subplots() + ax.plot(loss_accuracies, color="C0", label="loss") + ax.plot(val_accuracies, color="C1", label="val_loss") + ax.set_xlabel("number of epochs") + ax.set_ylabel("training loss") + ax.legend() + fig.savefig(final_model_folder / f"{final_model_stem}_losses.png", dpi=300) From 3acf580c6539a601b4c4dcddc43a7a2e76ee144d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 12:20:14 +0200 Subject: [PATCH 62/84] Don't max out CPU --- pipeline/run_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index e60d16a..b15cf4a 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -37,7 +37,7 @@ ##### DEFINE DATASETS AND FOLDERS ####### from sessions import all_sessions_exp, all_sessions_sim -n_jobs = -1 +n_jobs = os.cpu_count() - 4 job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") From 9ce301bd885315ec8b72293f4d2ef85528e1cf20 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 13:25:10 +0200 Subject: [PATCH 63/84] Steps per epoch in super-training --- pipeline/run_inference.py | 2 +- pipeline/run_super_training.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index b15cf4a..06b2d1d 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -65,7 +65,7 @@ ##### DEFINE PARAMS ##### OVERWRITE = False USE_GPU = True -FULL_INFERENCE = True + # Define training and testing constants (@Jad you can gradually increase this) diff --git a/pipeline/run_super_training.py b/pipeline/run_super_training.py index 09765ac..49a0fa3 100644 --- a/pipeline/run_super_training.py +++ b/pipeline/run_super_training.py @@ -53,10 +53,7 @@ ##### DEFINE PARAMS ##### OVERWRITE = False USE_GPU = True -FULL_INFERENCE = True - -# Define training and testing constants (@Jad you can gradually increase this) - +STEPS_PER_EPOCH = 100 FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" @@ -91,6 +88,7 @@ TRAINING_END_S = 0.2 TESTING_START_S = 10 TESTING_END_S = 10.05 + STEPS_PER_EPOCH = 10 OVERWRITE = True else: TRAINING_START_S = 0 @@ -167,6 +165,9 @@ train_end_s=TRAINING_END_S, test_start_s=TESTING_START_S, test_end_s=TESTING_END_S, + verbose=False, + nb_gpus=1, + steps_per_epoch=STEPS_PER_EPOCH, **di_kwargs, ) pretrained_model_path = model_path From 6693240ed085b3f0e6d054e7864f0fdfc6577f99 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 13:25:47 +0200 Subject: [PATCH 64/84] Add default probeset --- pipeline/run_super_training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pipeline/run_super_training.py b/pipeline/run_super_training.py index 49a0fa3..abbacdb 100644 --- a/pipeline/run_super_training.py +++ b/pipeline/run_super_training.py @@ -47,6 +47,7 @@ DATASET_FOLDER = data_folder / "ephys-compression-benchmark" DEBUG = False +PROBESET = "NP2" NUM_DEBUG_SESSIONS = 4 DEBUG_DURATION = 20 From 1d57918dbde2ec7ee2532722dcbb8db79e608339 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 17:41:19 +0200 Subject: [PATCH 65/84] Limit n_njobs --- pipeline/run_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 06b2d1d..1ed45d0 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -37,7 +37,7 @@ ##### DEFINE DATASETS AND FOLDERS ####### from sessions import all_sessions_exp, all_sessions_sim -n_jobs = os.cpu_count() - 4 +n_jobs = int(0.7 * (os.cpu_count())) job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") @@ -78,7 +78,7 @@ pre_post_omission = 1 desired_shape = (192, 2) # play around with these -inference_n_jobs = os.cpu_count() - 4 +inference_n_jobs = int(0.7 * (os.cpu_count())) inference_chunk_duration = "1s" inference_predict_workers = 1 inference_memory_gpu = 2000 # MB From f38240ded99af3db860de793bdb79f2b5ac9a745 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 25 Jul 2023 09:09:48 +0200 Subject: [PATCH 66/84] Set n_jobs with params --- pipeline/run_inference.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 1ed45d0..105e559 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -37,7 +37,7 @@ ##### DEFINE DATASETS AND FOLDERS ####### from sessions import all_sessions_exp, all_sessions_sim -n_jobs = int(0.7 * (os.cpu_count())) +n_jobs = 24 job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") @@ -78,7 +78,7 @@ pre_post_omission = 1 desired_shape = (192, 2) # play around with these -inference_n_jobs = int(0.7 * (os.cpu_count())) +inference_n_jobs = 24 inference_chunk_duration = "1s" inference_predict_workers = 1 inference_memory_gpu = 2000 # MB @@ -91,15 +91,16 @@ ) if __name__ == "__main__": - if len(sys.argv) == 4: + if len(sys.argv) == 5: if sys.argv[1] == "true": DEBUG = True OVERWRITE = True else: DEBUG = False OVERWRITE = False - inference_n_jobs = int(sys.argv[2]) - inference_predict_workers = int(sys.argv[3]) + n_jobs = int(sys.argv[2]) + inference_n_jobs = int(sys.argv[3]) + inference_predict_workers = int(sys.argv[4]) json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] From bf26914bd2386b6d1e1f715494161232a8efecb2 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 31 Jul 2023 10:01:08 +0200 Subject: [PATCH 67/84] Handle sorting errors in run_sorting function --- pipeline/run_spike_sorting.py | 246 ++++++++++++++++--------------- pipeline/run_spike_sorting_GT.py | 177 ++++++++++++++-------- 2 files changed, 246 insertions(+), 177 deletions(-) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 9ec18b6..bdc2a42 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -134,138 +134,152 @@ sorting = si.load_extractor(sorting_output_folder / "sorting") else: print(f"\t\tSpike sorting NO DI with {sorter_name}") - sorting = ss.run_sorter( - sorter_name, - recording=recording, - output_folder=scratch_folder / session / filter_option / "no_di", - n_jobs=n_jobs, - verbose=True, - singularity_image=singularity_image, - ) - sorting = scur.remove_excess_spikes(sorting, recording) - sorting = sorting.save(folder=sorting_output_folder / "sorting") + try: + sorting = ss.run_sorter( + sorter_name, + recording=recording, + output_folder=scratch_folder / session / filter_option / "no_di", + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting = scur.remove_excess_spikes(sorting, recording) + sorting = sorting.save(folder=sorting_output_folder / "sorting") + except: + print(f"Error sorting {session} with {sorter_name} and {filter_option}") + sorting = None + if (sorting_output_folder / "sorting_di").is_dir() and not OVERWRITE: print("\t\tLoading DI sorting") sorting_di = si.load_extractor(sorting_output_folder / "sorting_di") else: print(f"\t\tSpike sorting DI with {sorter_name}") - sorting_di = ss.run_sorter( - sorter_name, - recording=recording_di, - output_folder=scratch_folder / session / filter_option / "di", - n_jobs=n_jobs, - verbose=True, - singularity_image=singularity_image, + try: + sorting_di = ss.run_sorter( + sorter_name, + recording=recording_di, + output_folder=scratch_folder / session / filter_option / "di", + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting_di = scur.remove_excess_spikes(sorting_di, recording_di) + sorting_di = sorting_di.save(folder=sorting_output_folder / "sorting_di") + except: + print(f"Error sorting DI {session} with {sorter_name} and {filter_option}") + sorting_di = None + + if sorting is not None and sorting_di is not None: + # compare outputs + print("\t\tComparing sortings") + comp = sc.compare_two_sorters( + sorting1=sorting, + sorting2=sorting_di, + sorting1_name="no_di", + sorting2_name="di", + match_score=match_score, ) - sorting_di = scur.remove_excess_spikes(sorting_di, recording_di) - sorting_di = sorting_di.save(folder=sorting_output_folder / "sorting_di") - - # compare outputs - print("\t\tComparing sortings") - comp = sc.compare_two_sorters( - sorting1=sorting, - sorting2=sorting_di, - sorting1_name="no_di", - sorting2_name="di", - match_score=match_score, - ) - matched_units = comp.get_matching()[0] - matched_unit_ids = matched_units.index.values.astype(int) - matched_unit_ids_di = matched_units.values.astype(int) - matched_units_valid = matched_unit_ids_di != -1 - matched_unit_ids = matched_unit_ids[matched_units_valid] - matched_unit_ids_di = matched_unit_ids_di[matched_units_valid] - sorting_matched = sorting.select_units(unit_ids=matched_unit_ids) - sorting_di_matched = sorting_di.select_units(unit_ids=matched_unit_ids_di) - - ## add entries to session-level results + matched_units = comp.get_matching()[0] + matched_unit_ids = matched_units.index.values.astype(int) + matched_unit_ids_di = matched_units.values.astype(int) + matched_units_valid = matched_unit_ids_di != -1 + matched_unit_ids = matched_unit_ids[matched_units_valid] + matched_unit_ids_di = matched_unit_ids_di[matched_units_valid] + sorting_matched = sorting.select_units(unit_ids=matched_unit_ids) + sorting_di_matched = sorting_di.select_units(unit_ids=matched_unit_ids_di) + else: + sorting_matched = None + sorting_di_matched = None + new_row = { "dataset": dataset_name, "session": session_name, "filter_option": filter_option, "probe": probe, - "num_units": len(sorting.unit_ids), - "num_units_di": len(sorting_di.unit_ids), - "num_match": len(sorting_matched.unit_ids), - "sorting_path": str((sorting_output_folder / "sorting").relative_to(results_folder)), - "sorting_path_di": str((sorting_output_folder / "sorting_di_").relative_to(results_folder)), + "num_units": len(sorting.unit_ids) if sorting is not None else 0, + "num_units_di": len(sorting_di.unit_ids) if sorting_di is not None else 0, + "num_match": len(sorting_matched.unit_ids) if sorting_matched is not None else 0, + "sorting_path": str((sorting_output_folder / "sorting").relative_to(results_folder)) if sorting is not None else None, + "sorting_path_di": str((sorting_output_folder / "sorting_di_").relative_to(results_folder)) if sorting_di is not None else None, } - session_level_results = pd.concat([session_level_results, pd.DataFrame([new_row])], ignore_index=True) print( f"\n\t\tNum units: {new_row['num_units']} - Num units DI: {new_row['num_units_di']} - Num match: {new_row['num_match']}" ) - # waveforms - waveforms_folder = results_folder / f"waveforms_{dataset_name}_{session_name}_{filter_option}" - waveforms_folder.mkdir(exist_ok=True, parents=True) - - if (waveforms_folder / "waveforms").is_dir() and not OVERWRITE: - print("\t\tLoad NO DI waveforms") - we = si.load_waveforms(waveforms_folder / "waveforms") - else: - print("\t\tCompute NO DI waveforms") - we = si.extract_waveforms( - recording, - sorting_matched, - folder=waveforms_folder / "waveforms", - n_jobs=n_jobs, - overwrite=True, - ) - - if (waveforms_folder / "waveforms_di").is_dir() and not OVERWRITE: - print("\t\tLoad DI waveforms") - we_di = si.load_waveforms(waveforms_folder / "waveforms_di") - else: - print("\t\tCompute DI waveforms") - we_di = si.extract_waveforms( - recording_di, - sorting_di_matched, - folder=waveforms_folder / "waveforms_di", - n_jobs=n_jobs, - overwrite=True, - ) - - # compute metrics - if we.is_extension("quality_metrics") and not OVERWRITE: - print("\t\tLoad NO DI metrics") - qm = we.load_extension("quality_metrics").get_data() - else: - print("\t\tCompute NO DI metrics") - qm = sqm.compute_quality_metrics(we) - - if we_di.is_extension("quality_metrics") and not OVERWRITE: - print("\t\tLoad DI metrics") - qm_di = we_di.load_extension("quality_metrics").get_data() - else: - print("\t\tCompute DI metrics") - qm_di = sqm.compute_quality_metrics(we_di) - - ## add entries to unit-level results - if unit_level_results is None: + if sorting_matched is not None: + # waveforms + waveforms_folder = results_folder / f"waveforms_{dataset_name}_{session_name}_{filter_option}" + waveforms_folder.mkdir(exist_ok=True, parents=True) + + if (waveforms_folder / "waveforms").is_dir() and not OVERWRITE: + print("\t\tLoad NO DI waveforms") + we = si.load_waveforms(waveforms_folder / "waveforms") + else: + print("\t\tCompute NO DI waveforms") + we = si.extract_waveforms( + recording, + sorting_matched, + folder=waveforms_folder / "waveforms", + n_jobs=n_jobs, + overwrite=True, + ) + + if (waveforms_folder / "waveforms_di").is_dir() and not OVERWRITE: + print("\t\tLoad DI waveforms") + we_di = si.load_waveforms(waveforms_folder / "waveforms_di") + else: + print("\t\tCompute DI waveforms") + we_di = si.extract_waveforms( + recording_di, + sorting_di_matched, + folder=waveforms_folder / "waveforms_di", + n_jobs=n_jobs, + overwrite=True, + ) + + # compute metrics + if we.is_extension("quality_metrics") and not OVERWRITE: + print("\t\tLoad NO DI metrics") + qm = we.load_extension("quality_metrics").get_data() + else: + print("\t\tCompute NO DI metrics") + qm = sqm.compute_quality_metrics(we) + + if we_di.is_extension("quality_metrics") and not OVERWRITE: + print("\t\tLoad DI metrics") + qm_di = we_di.load_extension("quality_metrics").get_data() + else: + print("\t\tCompute DI metrics") + qm_di = sqm.compute_quality_metrics(we_di) + + ## add entries to unit-level results + if unit_level_results is None: + for metric in qm.columns: + unit_level_results_columns.append(metric) + unit_level_results_columns.append(f"{metric}_di") + unit_level_results = pd.DataFrame(columns=unit_level_results_columns) + + new_rows = { + "dataset": [dataset_name] * len(qm), + "session": [session_name] * len(qm), + "probe": [probe] * len(qm), + "filter_option": [filter_option] * len(qm), + "unit_id": we.unit_ids, + "unit_id_di": we_di.unit_ids, + } + agreement_scores = [] + for i in range(len(we.unit_ids)): + agreement_scores.append(comp.agreement_scores.at[we.unit_ids[i], we_di.unit_ids[i]]) + new_rows["agreement_score"] = agreement_scores for metric in qm.columns: - unit_level_results_columns.append(metric) - unit_level_results_columns.append(f"{metric}_di") - unit_level_results = pd.DataFrame(columns=unit_level_results_columns) - - new_rows = { - "dataset": [dataset_name] * len(qm), - "session": [session_name] * len(qm), - "probe": [probe] * len(qm), - "filter_option": [filter_option] * len(qm), - "unit_id": we.unit_ids, - "unit_id_di": we_di.unit_ids, - } - agreement_scores = [] - for i in range(len(we.unit_ids)): - agreement_scores.append(comp.agreement_scores.at[we.unit_ids[i], we_di.unit_ids[i]]) - new_rows["agreement_score"] = agreement_scores - for metric in qm.columns: - new_rows[metric] = qm[metric].values - new_rows[f"{metric}_di"] = qm_di[metric].values - # append new entries - unit_level_results = pd.concat([unit_level_results, pd.DataFrame(new_rows)], ignore_index=True) - - session_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-sessions.csv", index=False) - unit_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-units.csv", index=False) + new_rows[metric] = qm[metric].values + new_rows[f"{metric}_di"] = qm_di[metric].values + # append new entries + unit_level_results = pd.concat([unit_level_results, pd.DataFrame(new_rows)], ignore_index=True) + + if session_level_results is None: + session_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-sessions.csv", index=False) + if unit_level_results is None: + unit_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-units.csv", index=False) diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py index b99e9b8..a9fdb8a 100644 --- a/pipeline/run_spike_sorting_GT.py +++ b/pipeline/run_spike_sorting_GT.py @@ -9,6 +9,7 @@ import sys import json from pathlib import Path +import numpy as np import pandas as pd @@ -18,7 +19,6 @@ import spikeinterface.sorters as ss import spikeinterface.curation as scur import spikeinterface.comparison as sc -import spikeinterface.qualitymetrics as sqm base_path = Path("..") @@ -123,59 +123,93 @@ sorting = si.load_extractor(sorting_output_folder / "sorting") else: print(f"\t\tSpike sorting NO DI with {sorter_name}") - sorting = ss.run_sorter( - sorter_name, - recording=recording, - output_folder=scratch_folder / session / filter_option / "no_di", - n_jobs=n_jobs, - verbose=True, - singularity_image=singularity_image, - ) - sorting = scur.remove_excess_spikes(sorting, recording) - sorting = sorting.save(folder=sorting_output_folder / "sorting") - + try: + sorting = ss.run_sorter( + sorter_name, + recording=recording, + output_folder=scratch_folder / session / filter_option / "no_di", + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting = scur.remove_excess_spikes(sorting, recording) + sorting = sorting.save(folder=sorting_output_folder / "sorting") + except: + print(f"\t\t\t{sorter_name} failed on original") + sorting = None if (sorting_output_folder / "sorting_di").is_dir() and not OVERWRITE: print("\t\tLoading DI sorting") sorting_di = si.load_extractor(sorting_output_folder / "sorting_di") else: print(f"\t\tSpike sorting DI with {sorter_name}") - sorting_di = ss.run_sorter( - sorter_name, - recording=recording_di, - output_folder=scratch_folder / session / filter_option / "di", - n_jobs=n_jobs, - verbose=True, - singularity_image=singularity_image, - ) - sorting_di = scur.remove_excess_spikes(sorting_di, recording_di) - sorting_di = sorting_di.save(folder=sorting_output_folder / "sorting_di") - + try: + sorting_di = ss.run_sorter( + sorter_name, + recording=recording_di, + output_folder=scratch_folder / session / filter_option / "di", + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting_di = scur.remove_excess_spikes(sorting_di, recording_di) + sorting_di = sorting_di.save(folder=sorting_output_folder / "sorting_di") + except: + print(f"\t\t\t{sorter_name} failed on DI") + sorting_di = None + # compare to GT + perf_keys = ["precision", "false_discovery_rate", "miss_rate", + "num_gt", "num_sorter", "num_well_detected", "num_overmerged", + "num_redundant", "num_false_positivenum_bad"] print("\tRunning comparison") - cmp = sc.compare_sorter_to_ground_truth(sorting_gt, sorting, exhaustive_gt=True) - cmp_di = sc.compare_sorter_to_ground_truth(sorting_gt, sorting_di, exhaustive_gt=True) - - perf_avg = cmp.get_performance(method="pooled_with_average") - perf_avg_di = cmp_di.get_performance(method="pooled_with_average") - counts = cmp.count_units_categories() - counts_di = cmp_di.count_units_categories() - - new_data = { - "probe": probe, - "session": session_name, - "num_units": len(sorting.unit_ids), - "filter_option": filter_option, - "deepinterpolated": False, - } - new_data_di = new_data.copy() - new_data_di["deepinterpolated"] = True - new_data_di["num_units"] = len(sorting_di.unit_ids) - - new_data.update(perf_avg.to_dict()) - new_data.update(counts.to_dict()) - - new_data_di.update(perf_avg_di.to_dict()) - new_data_di.update(counts_di.to_dict()) + if sorting is not None: + cmp = sc.compare_sorter_to_ground_truth(sorting_gt, sorting, exhaustive_gt=True) + perf_avg = cmp.get_performance(method="pooled_with_average") + counts = cmp.count_units_categories() + new_data = { + "probe": probe, + "session": session_name, + "num_units": len(sorting.unit_ids), + "filter_option": filter_option, + "deepinterpolated": False, + } + new_data.update(perf_avg.to_dict()) + new_data.update(counts.to_dict()) + else: + new_data = { + "probe": probe, + "session": session_name, + "num_units": np.nan, + "filter_option": filter_option, + "deepinterpolated": False, + } + for perf_key in perf_keys: + new_data[perf_key] = np.nan + + if sorting_di is not None: + cmp_di = sc.compare_sorter_to_ground_truth(sorting_gt, sorting_di, exhaustive_gt=True) + perf_avg_di = cmp_di.get_performance(method="pooled_with_average") + counts_di = cmp_di.count_units_categories() + + new_data_di = { + "probe": probe, + "session": session_name, + "num_units": len(sorting_di.unit_ids), + "filter_option": filter_option, + "deepinterpolated": True, + } + new_data_di.update(perf_avg_di.to_dict()) + new_data_di.update(counts_di.to_dict()) + else: + new_data_di = { + "probe": probe, + "session": session_name, + "num_units": np.nan, + "filter_option": filter_option, + "deepinterpolated": True, + } + for perf_key in perf_keys: + new_data[perf_key] = np.nan new_df = pd.DataFrame([new_data]) new_df_di = pd.DataFrame([new_data_di]) @@ -187,20 +221,41 @@ session_level_results = pd.concat([session_level_results, new_df_session], ignore_index=True) # by unit - perf_by_unit = cmp.get_performance(method="by_unit") - perf_columns = perf_by_unit.columns - perf_by_unit.loc[:, "probe"] = [probe] * len(perf_by_unit) - perf_by_unit.loc[:, "session"] = [session_name] * len(perf_by_unit) - perf_by_unit.loc[:, "filter_option"] = [filter_option] * len(perf_by_unit) - perf_by_unit.loc[:, "deepinterpolated"] = [False] * len(perf_by_unit) - perf_by_unit.loc[:, "unit_id"] = sorting_gt.unit_ids - - perf_by_unit_di = cmp_di.get_performance(method="by_unit") - perf_by_unit_di.loc[:, "probe"] = [probe] * len(perf_by_unit_di) - perf_by_unit_di.loc[:, "session"] = [session_name] * len(perf_by_unit_di) - perf_by_unit_di.loc[:, "filter_option"] = [filter_option] * len(perf_by_unit_di) - perf_by_unit_di.loc[:, "deepinterpolated"] = [True] * len(perf_by_unit_di) - perf_by_unit_di.loc[:, "unit_id"] = sorting_gt.unit_ids + unit_perf_keys = ["accuracy", "recall", "precision", "false_discovery_rate", "miss_rate"] + if sorting is not None: + perf_by_unit = cmp.get_performance(method="by_unit") + perf_columns = perf_by_unit.columns + perf_by_unit.loc[:, "probe"] = [probe] * len(perf_by_unit) + perf_by_unit.loc[:, "session"] = [session_name] * len(perf_by_unit) + perf_by_unit.loc[:, "filter_option"] = [filter_option] * len(perf_by_unit) + perf_by_unit.loc[:, "deepinterpolated"] = [False] * len(perf_by_unit) + perf_by_unit.loc[:, "unit_id"] = sorting_gt.unit_ids + else: + perf_by_unit = pd.DataFrame({"unit_id": sorting_gt.unit_ids}) + perf_by_unit.loc[:, "probe"] = [probe] * len(perf_by_unit) + perf_by_unit.loc[:, "session"] = [session_name] * len(perf_by_unit) + perf_by_unit.loc[:, "filter_option"] = [filter_option] * len(perf_by_unit) + perf_by_unit.loc[:, "deepinterpolated"] = [False] * len(perf_by_unit) + perf_by_unit.loc[:, "unit_id"] = sorting_gt.unit_ids + for perf_key in unit_perf_keys: + perf_by_unit.loc[:, perf_key] = np.nan + + if sorting_di is not None: + perf_by_unit_di = cmp_di.get_performance(method="by_unit") + perf_by_unit_di.loc[:, "probe"] = [probe] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "session"] = [session_name] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "filter_option"] = [filter_option] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "deepinterpolated"] = [True] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "unit_id"] = sorting_gt.unit_ids + else: + perf_by_unit_di = pd.DataFrame({"unit_id": sorting_gt.unit_ids}) + perf_by_unit_di.loc[:, "probe"] = [probe] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "session"] = [session_name] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "filter_option"] = [filter_option] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "deepinterpolated"] = [True] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "unit_id"] = sorting_gt.unit_ids + for perf_key in unit_perf_keys: + perf_by_unit_di.loc[:, perf_key] = np.nan new_unit_df = pd.concat([perf_by_unit, perf_by_unit_di], ignore_index=True) From c2672ae903c00f04c06bef5888783b8144c2c79b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 31 Jul 2023 12:15:01 +0200 Subject: [PATCH 68/84] Oups --- pipeline/run_spike_sorting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index bdc2a42..5c0fb70 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -279,7 +279,7 @@ # append new entries unit_level_results = pd.concat([unit_level_results, pd.DataFrame(new_rows)], ignore_index=True) - if session_level_results is None: + if session_level_results is not None: session_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-sessions.csv", index=False) - if unit_level_results is None: + if unit_level_results is not None: unit_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-units.csv", index=False) From 122707b03a066b215001669a8c80e0f80e1aca3a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 2 Aug 2023 09:46:10 +0200 Subject: [PATCH 69/84] Protect against small mismatches in sampling frequency --- pipeline/run_spike_sorting.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 5c0fb70..89aae19 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -218,6 +218,9 @@ we = si.load_waveforms(waveforms_folder / "waveforms") else: print("\t\tCompute NO DI waveforms") + if sorting_matched.sampling_frequency != recording.sampling_frequency: + print("\t\tSetting sorting sampling frequency to match recording") + sorting_matched._sampling_frequency = recording.sampling_frequency we = si.extract_waveforms( recording, sorting_matched, @@ -231,6 +234,9 @@ we_di = si.load_waveforms(waveforms_folder / "waveforms_di") else: print("\t\tCompute DI waveforms") + if sorting_di_matched.sampling_frequency != recording.sampling_frequency: + print("\t\tSetting sorting DI sampling frequency to match recording") + sorting_di_matched._sampling_frequency = recording.sampling_frequency we_di = si.extract_waveforms( recording_di, sorting_di_matched, From 96abe6a60bb8a5ee373a18b7a5ed5aea8cf42760 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 10:14:55 +0200 Subject: [PATCH 70/84] Correcly save session level results --- pipeline/run_spike_sorting.py | 30 +++++++++++++++++------------- pipeline/run_training.py | 2 +- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 5c0fb70..0ef4996 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -89,19 +89,18 @@ print(f"\nAnalyzing session {session}\n") dataset_name, session_name = session.split("/") - session_level_results = pd.DataFrame( - columns=[ - "dataset", - "session", - "probe", - "filter_option", - "num_units", - "num_units_di", - "sorting_path", - "sorting_path_di", - "num_match", - ] - ) + session_level_results_columns = [ + "dataset", + "session", + "probe", + "filter_option", + "num_units", + "num_units_di", + "sorting_path", + "sorting_path_di", + "num_match", + ] + session_level_results = None unit_level_results_columns = [ "dataset", @@ -207,6 +206,11 @@ print( f"\n\t\tNum units: {new_row['num_units']} - Num units DI: {new_row['num_units_di']} - Num match: {new_row['num_match']}" ) + + if session_level_results is None: + session_level_results = pd.DataFrame(columns=session_level_results_columns) + session_level_results = pd.concat([session_level_results, pd.DataFrame([new_row])], ignore_index=True) + if sorting_matched is not None: # waveforms diff --git a/pipeline/run_training.py b/pipeline/run_training.py index b5e5a38..afd67e5 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -86,7 +86,7 @@ DEBUG = False json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] - print(f"Found {len(json_files)} JSON config") + print(f"Found {len(json_files)} JSON config: {json_files}") if len(json_files) > 0: session_dict = {} # each json file contains a session to run From 8067f750c8138711fa9e0752f9f2a44f3e05fb60 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 10:16:08 +0200 Subject: [PATCH 71/84] A round of black --- pipeline/run_inference.py | 4 +++- pipeline/run_spike_sorting.py | 17 ++++++++++------- pipeline/run_spike_sorting_GT.py | 17 ++++++++++++----- pipeline/run_super_training.py | 4 +++- pipeline/run_training.py | 6 ++++-- 5 files changed, 32 insertions(+), 16 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 105e559..5327028 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -176,7 +176,9 @@ recording_zscore = spre.zscore(recording_processed) # This speeds things up a lot - recording_zscore_bin = recording_zscore.save(folder=scratch_folder / f"recording_zscored_{recording_name}") + recording_zscore_bin = recording_zscore.save( + folder=scratch_folder / f"recording_zscored_{recording_name}" + ) # train model model_folder = data_model_folder / f"model_{recording_name}" diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 1fb43a3..a4d365f 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -83,7 +83,6 @@ processed_folder = data_subfolders[0] for probe, sessions in session_dict.items(): - print(f"Dataset {probe}") for session in sessions: print(f"\nAnalyzing session {session}\n") @@ -147,7 +146,6 @@ except: print(f"Error sorting {session} with {sorter_name} and {filter_option}") sorting = None - if (sorting_output_folder / "sorting_di").is_dir() and not OVERWRITE: print("\t\tLoading DI sorting") @@ -199,19 +197,22 @@ "num_units": len(sorting.unit_ids) if sorting is not None else 0, "num_units_di": len(sorting_di.unit_ids) if sorting_di is not None else 0, "num_match": len(sorting_matched.unit_ids) if sorting_matched is not None else 0, - "sorting_path": str((sorting_output_folder / "sorting").relative_to(results_folder)) if sorting is not None else None, - "sorting_path_di": str((sorting_output_folder / "sorting_di_").relative_to(results_folder)) if sorting_di is not None else None, + "sorting_path": str((sorting_output_folder / "sorting").relative_to(results_folder)) + if sorting is not None + else None, + "sorting_path_di": str((sorting_output_folder / "sorting_di_").relative_to(results_folder)) + if sorting_di is not None + else None, } print( f"\n\t\tNum units: {new_row['num_units']} - Num units DI: {new_row['num_units_di']} - Num match: {new_row['num_match']}" ) - + if session_level_results is None: session_level_results = pd.DataFrame(columns=session_level_results_columns) session_level_results = pd.concat([session_level_results, pd.DataFrame([new_row])], ignore_index=True) - if sorting_matched is not None: # waveforms waveforms_folder = results_folder / f"waveforms_{dataset_name}_{session_name}_{filter_option}" @@ -290,6 +291,8 @@ unit_level_results = pd.concat([unit_level_results, pd.DataFrame(new_rows)], ignore_index=True) if session_level_results is not None: - session_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-sessions.csv", index=False) + session_level_results.to_csv( + results_folder / f"{dataset_name}-{session_name}-sessions.csv", index=False + ) if unit_level_results is not None: unit_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-units.csv", index=False) diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py index a9fdb8a..040905f 100644 --- a/pipeline/run_spike_sorting_GT.py +++ b/pipeline/run_spike_sorting_GT.py @@ -85,7 +85,6 @@ processed_folder = data_subfolders[0] for probe, sessions in session_dict.items(): - print(f"Dataset {probe}") for session in sessions: print(f"\nAnalyzing session {session}\n") @@ -156,11 +155,19 @@ except: print(f"\t\t\t{sorter_name} failed on DI") sorting_di = None - + # compare to GT - perf_keys = ["precision", "false_discovery_rate", "miss_rate", - "num_gt", "num_sorter", "num_well_detected", "num_overmerged", - "num_redundant", "num_false_positivenum_bad"] + perf_keys = [ + "precision", + "false_discovery_rate", + "miss_rate", + "num_gt", + "num_sorter", + "num_well_detected", + "num_overmerged", + "num_redundant", + "num_false_positivenum_bad", + ] print("\tRunning comparison") if sorting is not None: cmp = sc.compare_sorter_to_ground_truth(sorting_gt, sorting, exhaustive_gt=True) diff --git a/pipeline/run_super_training.py b/pipeline/run_super_training.py index abbacdb..f126c15 100644 --- a/pipeline/run_super_training.py +++ b/pipeline/run_super_training.py @@ -147,7 +147,9 @@ recording_zscore = spre.zscore(recording_processed) # This speeds things up a lot - recording_zscore_bin = recording_zscore.save(folder=scratch_folder / f"recording_zscored_{recording_name}") + recording_zscore_bin = recording_zscore.save( + folder=scratch_folder / f"recording_zscored_{recording_name}" + ) # train model model_folder = results_folder / f"models_{probe}_{filter_option}" / f"iter{i}" diff --git a/pipeline/run_training.py b/pipeline/run_training.py index afd67e5..3a8834d 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -169,8 +169,10 @@ recording_processed = spre.depth_order(recording_processed) recording_zscore = spre.zscore(recording_processed) - # This speeds things up a lot - recording_zscore_bin = recording_zscore.save(folder=scratch_folder / f"recording_zscored_{recording_name}") + # This speeds things up a lot + recording_zscore_bin = recording_zscore.save( + folder=scratch_folder / f"recording_zscored_{recording_name}" + ) # train model model_folder = results_folder / f"model_{recording_name}" From e4d473dc74c8fb1b0bc4ce55ea2fe7ff03ee05cd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 5 Sep 2023 19:24:47 +0200 Subject: [PATCH 72/84] Extend training --- pipeline/run_super_training.py | 26 +++++++++++++++++--------- pipeline/run_training.py | 24 ++++++++++++++++-------- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/pipeline/run_super_training.py b/pipeline/run_super_training.py index f126c15..c4e07d7 100644 --- a/pipeline/run_super_training.py +++ b/pipeline/run_super_training.py @@ -85,17 +85,20 @@ print(session_dict) if DEBUG: - TRAINING_START_S = 0 - TRAINING_END_S = 0.2 - TESTING_START_S = 10 - TESTING_END_S = 10.05 - STEPS_PER_EPOCH = 10 + TRAINING_START_S = 10 + TRAINING_END_S = None + TRAINING_DURATION_S = 1 + TESTING_START_S = 0 + TESTING_END_S = 10 + TESTING_DURATION_S = 0.05 OVERWRITE = True else: - TRAINING_START_S = 0 - TRAINING_END_S = 10 - TESTING_START_S = 70 - TESTING_END_S = 70.1 + TRAINING_START_S = 10 + TRAINING_END_S = None + TRAINING_DURATION_S = 600 + TESTING_START_S = 0 + TESTING_END_S = 10 + TESTING_DURATION_S = 0.1 OVERWRITE = False si.set_global_job_kwargs(**job_kwargs) @@ -155,6 +158,9 @@ model_folder = results_folder / f"models_{probe}_{filter_option}" / f"iter{i}" model_folder.parent.mkdir(parents=True, exist_ok=True) + if TRAINING_END_S is None: + TRAINING_END_S = recording.get_total_duration() + # Use SI function t_start_training = time.perf_counter() if pretrained_model_path is not None: @@ -166,8 +172,10 @@ existing_model_path=pretrained_model_path, train_start_s=TRAINING_START_S, train_end_s=TRAINING_END_S, + train_duration_s=TRAINING_DURATION_S, test_start_s=TESTING_START_S, test_end_s=TESTING_END_S, + test_duration_s=TESTING_DURATION_S, verbose=False, nb_gpus=1, steps_per_epoch=STEPS_PER_EPOCH, diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 3a8834d..74adbeb 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -107,16 +107,20 @@ print(session_dict) if DEBUG: - TRAINING_START_S = 0 - TRAINING_END_S = 0.2 - TESTING_START_S = 10 - TESTING_END_S = 10.05 + TRAINING_START_S = 10 + TRAINING_END_S = None + TRAINING_DURATION_S = 1 + TESTING_START_S = 0 + TESTING_END_S = 10 + TESTING_DURATION_S = 0.05 OVERWRITE = True else: - TRAINING_START_S = 0 - TRAINING_END_S = 20 - TESTING_START_S = 70 - TESTING_END_S = 70.1 + TRAINING_START_S = 10 + TRAINING_END_S = None + TRAINING_DURATION_S = 600 + TESTING_START_S = 0 + TESTING_END_S = 10 + TESTING_DURATION_S = 0.1 OVERWRITE = False si.set_global_job_kwargs(**job_kwargs) @@ -148,6 +152,8 @@ end_frame=int(DEBUG_DURATION * recording.sampling_frequency), ) print(f"\t{recording}") + if TRAINING_END_S is None: + TRAINING_END_S = recording.get_total_duration() for filter_option in FILTER_OPTIONS: print(f"\tFilter option: {filter_option}") @@ -185,8 +191,10 @@ model_name=model_name, train_start_s=TRAINING_START_S, train_end_s=TRAINING_END_S, + train_duration_s=TRAINING_DURATION_S, test_start_s=TESTING_START_S, test_end_s=TESTING_END_S, + test_duration_s=TESTING_DURATION_S, verbose=False, nb_gpus=nb_gpus, steps_per_epoch=STEPS_PER_EPOCH, From 83008fec16ddaf5df519f91472cbd3a513fecef4 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 15 Sep 2023 12:32:07 +0200 Subject: [PATCH 73/84] Extend unit level and matched unit level results --- pipeline/run_collect_results.py | 11 +- pipeline/run_spike_sorting.py | 187 ++++++++++++++++++++++++-------- 2 files changed, 152 insertions(+), 46 deletions(-) diff --git a/pipeline/run_collect_results.py b/pipeline/run_collect_results.py index 1f77d40..b5f87aa 100644 --- a/pipeline/run_collect_results.py +++ b/pipeline/run_collect_results.py @@ -28,6 +28,7 @@ # concatenate dataframes df_session = None df_units = None + df_matched_units = None probe_sortings_folders = [p for p in data_folder.iterdir() if p.name.startswith("sorting_") and p.is_dir()] @@ -48,7 +49,8 @@ data_sortings_folder = data_sorting_subfolders[0] session_csvs = [p for p in data_sortings_folder.iterdir() if p.name.endswith("sessions.csv")] - unit_csvs = [p for p in data_sortings_folder.iterdir() if p.name.endswith("units.csv")] + unit_csvs = [p for p in data_sortings_folder.iterdir() if p.name.endswith("units.csv") and "matched" not in p.name] + matched_unit_csvs = [p for p in data_sortings_folder.iterdir() if p.name.endswith("matched-units.csv")] for session_csv in session_csvs: if df_session is None: @@ -62,9 +64,16 @@ else: df_units = pd.concat([df_units, pd.read_csv(unit_csv)]) + for matched_unit_csv in matched_unit_csvs: + if df_matched_units is None: + df_matched_units = pd.read_csv(matched_unit_csv) + else: + df_matched_units = pd.concat([df_matched_units, pd.read_csv(matched_unit_csv)]) + # save concatenated dataframes df_session.to_csv(results_folder / "sessions.csv", index=False) df_units.to_csv(results_folder / "units.csv", index=False) + df_matched_units.to_csv(results_folder / "matched-units.csv", index=False) # copy sortings to results folder sortings_folders = [p for p in data_sortings_folder.iterdir() if "sorting_" in p.name and p.is_dir()] diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index a4d365f..8829294 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -15,6 +15,7 @@ import spikeinterface.sorters as ss import spikeinterface.curation as scur import spikeinterface.comparison as sc +import spikeinterface.postprocessing as spost import spikeinterface.qualitymetrics as sqm @@ -40,6 +41,11 @@ singularity_image = False match_score = 0.7 +sparsity_kwargs = dict( + method="radius", + radius_um=200, +) + if __name__ == "__main__": if len(sys.argv) == 2: @@ -102,6 +108,16 @@ session_level_results = None unit_level_results_columns = [ + "dataset", + "session", + "probe", + "filter_option", + "unit_id", + "deepinterpolated", + ] + unit_level_results = None + + matched_unit_level_results_columns = [ "dataset", "session", "probe", @@ -110,7 +126,7 @@ "unit_id_di", "agreement_score", ] - unit_level_results = None + matched_unit_level_results = None for filter_option in FILTER_OPTIONS: print(f"\tFilter option: {filter_option}") @@ -214,81 +230,158 @@ session_level_results = pd.concat([session_level_results, pd.DataFrame([new_row])], ignore_index=True) if sorting_matched is not None: - # waveforms - waveforms_folder = results_folder / f"waveforms_{dataset_name}_{session_name}_{filter_option}" - waveforms_folder.mkdir(exist_ok=True, parents=True) + # waveforms for all units + waveforms_all_folder = ( + results_folder / f"waveforms_all_{dataset_name}_{session_name}_{filter_option}" + ) + waveforms_all_folder.mkdir(exist_ok=True, parents=True) - if (waveforms_folder / "waveforms").is_dir() and not OVERWRITE: - print("\t\tLoad NO DI waveforms") - we = si.load_waveforms(waveforms_folder / "waveforms") + if (waveforms_all_folder / "waveforms").is_dir() and not OVERWRITE: + print("\t\tLoad NO DI waveforms all") + we_all = si.load_waveforms(waveforms_all_folder / "waveforms") else: - print("\t\tCompute NO DI waveforms") - if sorting_matched.sampling_frequency != recording.sampling_frequency: + print("\t\tCompute NO DI waveforms all") + if sorting.sampling_frequency != recording.sampling_frequency: print("\t\tSetting sorting sampling frequency to match recording") - sorting_matched._sampling_frequency = recording.sampling_frequency - we = si.extract_waveforms( + sorting._sampling_frequency = recording.sampling_frequency + we_all = si.extract_waveforms( recording, - sorting_matched, - folder=waveforms_folder / "waveforms", + sorting, + folder=waveforms_all_folder / "waveforms", n_jobs=n_jobs, overwrite=True, + sparse=True, + **sparsity_kwargs, ) + print("\t\tCompute NO DI spike amplitudes") + _ = spost.compute_spike_amplitudes(we_all) + print("\t\tCompute NO DI spike locations") + _ = spost.compute_spike_locations(we_all) + print("\t\tCompute NO DI PCA scores") + _ = spost.compute_principal_components(we_all) + print("\t\tCompute NO DI template metrics") + _ = spost.compute_template_metrics(we_all) + + # finally, quality metrics + print("\t\tCompute DI metrics") + qm_all = sqm.compute_quality_metrics(we_all) - if (waveforms_folder / "waveforms_di").is_dir() and not OVERWRITE: - print("\t\tLoad DI waveforms") - we_di = si.load_waveforms(waveforms_folder / "waveforms_di") + if (waveforms_all_folder / "waveforms_di").is_dir() and not OVERWRITE: + print("\t\tLoad DI waveforms all") + we_all_di = si.load_waveforms(waveforms_all_folder / "waveforms_di") else: - print("\t\tCompute DI waveforms") + print("\t\tCompute DI waveforms all") if sorting_di_matched.sampling_frequency != recording.sampling_frequency: print("\t\tSetting sorting DI sampling frequency to match recording") sorting_di_matched._sampling_frequency = recording.sampling_frequency - we_di = si.extract_waveforms( + we_all_di = si.extract_waveforms( recording_di, sorting_di_matched, - folder=waveforms_folder / "waveforms_di", + folder=waveforms_all_folder / "waveforms_di", n_jobs=n_jobs, overwrite=True, + sparse=True, + **sparsity_kwargs, ) + print("\t\tCompute DI spike amplitudes") + _ = spost.compute_spike_amplitudes(we_all_di) + print("\t\tCompute DI spike locations") + _ = spost.compute_spike_locations(we_all_di) + print("\t\tCompute DI PCA scores") + _ = spost.compute_principal_components(we_all_di) + print("\t\tCompute DI template metrics") + _ = spost.compute_template_metrics(we_all_di) + + # finally, quality metrics + print("\t\tCompute DI metrics") + qm_all_di = sqm.compute_quality_metrics(we_all_di) - # compute metrics - if we.is_extension("quality_metrics") and not OVERWRITE: - print("\t\tLoad NO DI metrics") - qm = we.load_extension("quality_metrics").get_data() + waveforms_matched_folder = ( + results_folder / f"waveforms_matched_{dataset_name}_{session_name}_{filter_option}" + ) + waveforms_matched_folder.mkdir(exist_ok=True, parents=True) + + if (waveforms_matched_folder / "waveforms").is_dir() and not OVERWRITE: + print("\t\tLoad NO DI waveforms matched") + we_matched = si.load_waveforms(waveforms_matched_folder / "waveforms") + qm_matched = we_matched.load_extension("quality_metrics").get_data() else: - print("\t\tCompute NO DI metrics") - qm = sqm.compute_quality_metrics(we) + print("\t\tSelect NO DI waveforms matched") + we_matched = we_all.select_units( + unit_ids=matched_unit_ids, folder=waveforms_matched_folder / "waveforms" + ) + qm_matched = we_matched.load_extension("quality_metrics").get_data() - if we_di.is_extension("quality_metrics") and not OVERWRITE: - print("\t\tLoad DI metrics") - qm_di = we_di.load_extension("quality_metrics").get_data() + if (waveforms_matched_folder / "waveforms_di").is_dir() and not OVERWRITE: + print("\t\tLoad DI waveforms matched") + we_matched_di = si.load_waveforms(waveforms_matched_folder / "waveforms_di") + qm_matched_di = we_matched_di.load_extension("quality_metrics").get_data() else: - print("\t\tCompute DI metrics") - qm_di = sqm.compute_quality_metrics(we_di) + print("\t\tSelect DI waveforms matched") + we_matched_di = we_all_di.select_units( + unit_ids=matched_unit_ids_di, folder=waveforms_matched_folder / "waveforms_di" + ) + qm_matched_di = we_matched_di.load_extension("quality_metrics").get_data() ## add entries to unit-level results if unit_level_results is None: - for metric in qm.columns: + for metric in qm_all.columns: unit_level_results_columns.append(metric) - unit_level_results_columns.append(f"{metric}_di") unit_level_results = pd.DataFrame(columns=unit_level_results_columns) new_rows = { - "dataset": [dataset_name] * len(qm), - "session": [session_name] * len(qm), - "probe": [probe] * len(qm), - "filter_option": [filter_option] * len(qm), - "unit_id": we.unit_ids, - "unit_id_di": we_di.unit_ids, + "dataset": [dataset_name] * len(qm_all), + "session": [session_name] * len(qm_all), + "probe": [probe] * len(qm_all), + "filter_option": [filter_option] * len(qm_all), + "unit_id": we_all.unit_ids, + "deepinterpolated": [False] * len(qm_all), } + new_rows_di = { + "dataset": [dataset_name] * len(qm_all_di), + "session": [session_name] * len(qm_all_di), + "probe": [probe] * len(qm_all_di), + "filter_option": [filter_option] * len(qm_all_di), + "unit_id": we_all_di.unit_ids, + "deepinterpolated": [True] * len(qm_all_di), + } + for metric in qm_all.columns: + new_rows[metric] = qm_all[metric].values + new_rows_di[metric] = qm_all_di[metric].values + # append new entries + unit_level_results = pd.concat( + [unit_level_results, pd.DataFrame(new_rows), pd.DataFrame(new_rows_di)], ignore_index=True + ) + + ## add entries to matched unit-level results + if matched_unit_level_results is None: + for metric in qm_matched.columns: + matched_unit_level_results_columns.append(metric) + matched_unit_level_results_columns.append(f"{metric}_di") + matched_unit_level_results = pd.DataFrame(columns=matched_unit_level_results) + + new_matched_rows = { + "dataset": [dataset_name] * len(qm_matched), + "session": [session_name] * len(qm_matched), + "probe": [probe] * len(qm_matched), + "filter_option": [filter_option] * len(qm_matched), + "unit_id": we_matched.unit_ids, + "unit_id_di": we_matched_di.unit_ids, + } + agreement_scores = [] - for i in range(len(we.unit_ids)): - agreement_scores.append(comp.agreement_scores.at[we.unit_ids[i], we_di.unit_ids[i]]) - new_rows["agreement_score"] = agreement_scores - for metric in qm.columns: - new_rows[metric] = qm[metric].values - new_rows[f"{metric}_di"] = qm_di[metric].values + for i in range(len(we_matched.unit_ids)): + agreement_scores.append( + comp.agreement_scores.at[we_matched.unit_ids[i], we_matched_di.unit_ids[i]] + ) + new_matched_rows["agreement_score"] = agreement_scores + for metric in qm_matched.columns: + new_rows[metric] = qm_matched[metric].values + new_rows[f"{metric}_di"] = qm_matched_di[metric].values # append new entries - unit_level_results = pd.concat([unit_level_results, pd.DataFrame(new_rows)], ignore_index=True) + matched_unit_level_results = pd.concat( + [unit_level_results, pd.DataFrame(new_matched_rows)], ignore_index=True + ) if session_level_results is not None: session_level_results.to_csv( @@ -296,3 +389,7 @@ ) if unit_level_results is not None: unit_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-units.csv", index=False) + if matched_unit_level_results is not None: + matched_unit_level_results.to_csv( + results_folder / f"{dataset_name}-{session_name}-matched-units.csv", index=False + ) From 05b74f7e3b0b20d5911915d56afd22a16c2cd4fe Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 15 Sep 2023 12:43:56 +0200 Subject: [PATCH 74/84] wrong argument --- pipeline/run_spike_sorting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 8829294..82f77b6 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -308,7 +308,7 @@ else: print("\t\tSelect NO DI waveforms matched") we_matched = we_all.select_units( - unit_ids=matched_unit_ids, folder=waveforms_matched_folder / "waveforms" + unit_ids=matched_unit_ids, new_folder=waveforms_matched_folder / "waveforms" ) qm_matched = we_matched.load_extension("quality_metrics").get_data() @@ -319,7 +319,7 @@ else: print("\t\tSelect DI waveforms matched") we_matched_di = we_all_di.select_units( - unit_ids=matched_unit_ids_di, folder=waveforms_matched_folder / "waveforms_di" + unit_ids=matched_unit_ids_di, new_folder=waveforms_matched_folder / "waveforms_di" ) qm_matched_di = we_matched_di.load_extension("quality_metrics").get_data() From 339cc4a6bbb72549f4acc90475ab9023a31e6056 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 15 Sep 2023 13:00:29 +0200 Subject: [PATCH 75/84] Fixes --- pipeline/run_spike_sorting.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 82f77b6..d909f67 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -271,12 +271,12 @@ we_all_di = si.load_waveforms(waveforms_all_folder / "waveforms_di") else: print("\t\tCompute DI waveforms all") - if sorting_di_matched.sampling_frequency != recording.sampling_frequency: + if sorting_di.sampling_frequency != recording.sampling_frequency: print("\t\tSetting sorting DI sampling frequency to match recording") - sorting_di_matched._sampling_frequency = recording.sampling_frequency + sorting_di._sampling_frequency = recording.sampling_frequency we_all_di = si.extract_waveforms( recording_di, - sorting_di_matched, + sorting_di, folder=waveforms_all_folder / "waveforms_di", n_jobs=n_jobs, overwrite=True, @@ -380,7 +380,7 @@ new_rows[f"{metric}_di"] = qm_matched_di[metric].values # append new entries matched_unit_level_results = pd.concat( - [unit_level_results, pd.DataFrame(new_matched_rows)], ignore_index=True + [matched_unit_level_results, pd.DataFrame(new_matched_rows)], ignore_index=True ) if session_level_results is not None: From db18cd99bfabb5fb338dea7b749aad943f514250 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 15 Sep 2023 15:39:19 +0200 Subject: [PATCH 76/84] Use scratch folder for waveforms and n_jobs=1 for QM --- pipeline/run_spike_sorting.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index d909f67..14e5ed8 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -232,7 +232,7 @@ if sorting_matched is not None: # waveforms for all units waveforms_all_folder = ( - results_folder / f"waveforms_all_{dataset_name}_{session_name}_{filter_option}" + scratch_folder / f"waveforms_all_{dataset_name}_{session_name}_{filter_option}" ) waveforms_all_folder.mkdir(exist_ok=True, parents=True) @@ -264,7 +264,7 @@ # finally, quality metrics print("\t\tCompute DI metrics") - qm_all = sqm.compute_quality_metrics(we_all) + qm_all = sqm.compute_quality_metrics(we_all, n_jobs=1) if (waveforms_all_folder / "waveforms_di").is_dir() and not OVERWRITE: print("\t\tLoad DI waveforms all") @@ -294,10 +294,10 @@ # finally, quality metrics print("\t\tCompute DI metrics") - qm_all_di = sqm.compute_quality_metrics(we_all_di) + qm_all_di = sqm.compute_quality_metrics(we_all_di, n_jobs=1) waveforms_matched_folder = ( - results_folder / f"waveforms_matched_{dataset_name}_{session_name}_{filter_option}" + scratch_folder / f"waveforms_matched_{dataset_name}_{session_name}_{filter_option}" ) waveforms_matched_folder.mkdir(exist_ok=True, parents=True) From 8950ad7f1c3e1c78867b9cf52e22ad6d4aceb9c9 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 15 Sep 2023 15:50:24 +0200 Subject: [PATCH 77/84] Fix data loading in spike sorting pipeline --- pipeline/run_spike_sorting.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 14e5ed8..f882971 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -85,8 +85,11 @@ if len(probe_processed_folders) > 0: processed_folder = data_folder else: - data_subfolders = [p for p in data_folder.iterdir() if p.is_dir()] - processed_folder = data_subfolders[0] + data_processed_subfolders = [] + for p in data_folder.iterdir(): + if p.is_dir() and len([pp for pp in p.iterdir() if "processed_" in pp.name and pp.is_dir()]) > 0: + data_processed_subfolders.append(p) + processed_folder = data_processed_subfolders[0] for probe, sessions in session_dict.items(): print(f"Dataset {probe}") From 9fbb1d90a44954de3d1266ed7432a9ebe23a4e76 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 18 Sep 2023 11:41:05 +0200 Subject: [PATCH 78/84] pre-compute sparsity --- pipeline/run_spike_sorting.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index f882971..4c95962 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -247,14 +247,23 @@ if sorting.sampling_frequency != recording.sampling_frequency: print("\t\tSetting sorting sampling frequency to match recording") sorting._sampling_frequency = recording.sampling_frequency + # first full, then sparse + we_dense = si.extract_waveforms( + recording, + sorting, + folder=waveforms_all_folder / "waveforms_dense", + n_jobs=n_jobs, + overwrite=True, + max_spikes_per_unit=100 + ) + sparsity = si.compute_sparsity(we_dense, **sparsity_kwargs) we_all = si.extract_waveforms( recording, sorting, - folder=waveforms_all_folder / "waveforms", + folder=waveforms_all_folder / "waveforms_all", n_jobs=n_jobs, overwrite=True, - sparse=True, - **sparsity_kwargs, + sparsity=sparsity ) print("\t\tCompute NO DI spike amplitudes") _ = spost.compute_spike_amplitudes(we_all) @@ -277,15 +286,25 @@ if sorting_di.sampling_frequency != recording.sampling_frequency: print("\t\tSetting sorting DI sampling frequency to match recording") sorting_di._sampling_frequency = recording.sampling_frequency + # first full, then sparse + we_dense_di = si.extract_waveforms( + recording_di, + sorting_di, + folder=waveforms_all_folder / "waveforms_dense_di", + n_jobs=n_jobs, + overwrite=True, + max_spikes_per_unit=100 + ) + sparsity_di = si.compute_sparsity(we_dense_di, **sparsity_kwargs) we_all_di = si.extract_waveforms( recording_di, sorting_di, - folder=waveforms_all_folder / "waveforms_di", + folder=waveforms_all_folder / "waveforms_all_di", n_jobs=n_jobs, overwrite=True, - sparse=True, - **sparsity_kwargs, + sparsity=sparsity ) + print("\t\tCompute DI spike amplitudes") _ = spost.compute_spike_amplitudes(we_all_di) print("\t\tCompute DI spike locations") From 26c50b8f9d610c0116d2b6508457475f7780d0a3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 18 Sep 2023 16:42:36 +0200 Subject: [PATCH 79/84] Add filter_option in JSOn and skip NN metrics --- pipeline/run_inference.py | 12 +++-- pipeline/run_spike_sorting.py | 76 +++++++++++++++++++++++++------- pipeline/run_spike_sorting_GT.py | 12 +++-- pipeline/run_training.py | 12 +++-- pipeline/sessions.py | 4 +- 5 files changed, 86 insertions(+), 30 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 5327028..3524747 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -110,15 +110,19 @@ # each json file contains a session to run for json_file in json_files: with open(json_file, "r") as f: - d = json.load(f) - probe = d["probe"] + config = json.load(f) + probe = config["probe"] if probe not in session_dict: session_dict[probe] = [] - session = d["session"] + session = config["session"] assert ( session in all_sessions[probe] ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" session_dict[probe].append(session) + if "filter_options" in config: + filter_options = [config["filter_options"]] + else: + filter_options = FILTER_OPTIONS else: session_dict = all_sessions @@ -159,7 +163,7 @@ ) print(f"\t{recording}") - for filter_option in FILTER_OPTIONS: + for filter_option in filter_options: print(f"\tFilter option: {filter_option}") recording_name = f"{dataset_name}_{session_name}_{filter_option}" diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 4c95962..ca25f9e 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -7,6 +7,7 @@ #### IMPORTS ####### import sys import json +import shutil from pathlib import Path import pandas as pd @@ -26,7 +27,7 @@ n_jobs = 16 -job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=n_jobs, progress_bar=False, chunk_duration="1s") data_folder = base_path / "data" scratch_folder = base_path / "scratch" @@ -46,6 +47,22 @@ radius_um=200, ) +# skip NN because extremely slow +qm_metric_names = [ + "num_spikes", + "firing_rate", + "presence_ratio", + "snr", + "isi_violation", + "rp_violation", + "sliding_rp_violation", + "amplitude_cutoff", + "drift", + "isolation_distance", + "l_ratio", + "d_prime", +] + if __name__ == "__main__": if len(sys.argv) == 2: @@ -63,15 +80,19 @@ # each json file contains a session to run for json_file in json_files: with open(json_file, "r") as f: - d = json.load(f) - probe = d["probe"] + config = json.load(f) + probe = config["probe"] if probe not in session_dict: session_dict[probe] = [] - session = d["session"] + session = config["session"] assert ( session in all_sessions[probe] ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" session_dict[probe].append(session) + if "filter_options" in config: + filter_options = [config["filter_options"]] + else: + filter_options = FILTER_OPTIONS else: session_dict = all_sessions @@ -131,7 +152,7 @@ ] matched_unit_level_results = None - for filter_option in FILTER_OPTIONS: + for filter_option in filter_options: print(f"\tFilter option: {filter_option}") # load recordings @@ -254,7 +275,7 @@ folder=waveforms_all_folder / "waveforms_dense", n_jobs=n_jobs, overwrite=True, - max_spikes_per_unit=100 + max_spikes_per_unit=100, ) sparsity = si.compute_sparsity(we_dense, **sparsity_kwargs) we_all = si.extract_waveforms( @@ -263,20 +284,23 @@ folder=waveforms_all_folder / "waveforms_all", n_jobs=n_jobs, overwrite=True, - sparsity=sparsity + sparsity=sparsity, ) + # remove dense folder + shutil.rmtree(waveforms_all_folder / "waveforms_dense") + print("\t\tCompute NO DI spike amplitudes") _ = spost.compute_spike_amplitudes(we_all) print("\t\tCompute NO DI spike locations") _ = spost.compute_spike_locations(we_all) print("\t\tCompute NO DI PCA scores") _ = spost.compute_principal_components(we_all) - print("\t\tCompute NO DI template metrics") - _ = spost.compute_template_metrics(we_all) - # finally, quality metrics - print("\t\tCompute DI metrics") - qm_all = sqm.compute_quality_metrics(we_all, n_jobs=1) + # finally, template and quality metrics + print("\t\tCompute NO DI template metrics") + tm_all = spost.compute_template_metrics(we_all) + print("\t\tCompute NO DI metrics") + qm_all = sqm.compute_quality_metrics(we_all, n_jobs=1, metric_names=qm_metric_names) if (waveforms_all_folder / "waveforms_di").is_dir() and not OVERWRITE: print("\t\tLoad DI waveforms all") @@ -293,7 +317,7 @@ folder=waveforms_all_folder / "waveforms_dense_di", n_jobs=n_jobs, overwrite=True, - max_spikes_per_unit=100 + max_spikes_per_unit=100, ) sparsity_di = si.compute_sparsity(we_dense_di, **sparsity_kwargs) we_all_di = si.extract_waveforms( @@ -302,8 +326,10 @@ folder=waveforms_all_folder / "waveforms_all_di", n_jobs=n_jobs, overwrite=True, - sparsity=sparsity + sparsity=sparsity_di, ) + # remove dense folder + shutil.rmtree(waveforms_all_folder / "waveforms_dense_di") print("\t\tCompute DI spike amplitudes") _ = spost.compute_spike_amplitudes(we_all_di) @@ -311,12 +337,13 @@ _ = spost.compute_spike_locations(we_all_di) print("\t\tCompute DI PCA scores") _ = spost.compute_principal_components(we_all_di) + + # finally, template and quality metrics print("\t\tCompute DI template metrics") - _ = spost.compute_template_metrics(we_all_di) + tm_all_di = spost.compute_template_metrics(we_all_di) - # finally, quality metrics print("\t\tCompute DI metrics") - qm_all_di = sqm.compute_quality_metrics(we_all_di, n_jobs=1) + qm_all_di = sqm.compute_quality_metrics(we_all_di, n_jobs=1, metric_names=qm_metric_names) waveforms_matched_folder = ( scratch_folder / f"waveforms_matched_{dataset_name}_{session_name}_{filter_option}" @@ -327,28 +354,34 @@ print("\t\tLoad NO DI waveforms matched") we_matched = si.load_waveforms(waveforms_matched_folder / "waveforms") qm_matched = we_matched.load_extension("quality_metrics").get_data() + tm_matched = we_matched.load_extension("template_metrics").get_data() else: print("\t\tSelect NO DI waveforms matched") we_matched = we_all.select_units( unit_ids=matched_unit_ids, new_folder=waveforms_matched_folder / "waveforms" ) qm_matched = we_matched.load_extension("quality_metrics").get_data() + tm_matched = we_matched.load_extension("template_metrics").get_data() if (waveforms_matched_folder / "waveforms_di").is_dir() and not OVERWRITE: print("\t\tLoad DI waveforms matched") we_matched_di = si.load_waveforms(waveforms_matched_folder / "waveforms_di") qm_matched_di = we_matched_di.load_extension("quality_metrics").get_data() + tm_matched_di = we_matched_di.load_extension("template_metrics").get_data() else: print("\t\tSelect DI waveforms matched") we_matched_di = we_all_di.select_units( unit_ids=matched_unit_ids_di, new_folder=waveforms_matched_folder / "waveforms_di" ) qm_matched_di = we_matched_di.load_extension("quality_metrics").get_data() + tm_matched_di = we_matched_di.load_extension("template_metrics").get_data() ## add entries to unit-level results if unit_level_results is None: for metric in qm_all.columns: unit_level_results_columns.append(metric) + for metric in tm_all.columns: + unit_level_results_columns.append(metric) unit_level_results = pd.DataFrame(columns=unit_level_results_columns) new_rows = { @@ -370,6 +403,9 @@ for metric in qm_all.columns: new_rows[metric] = qm_all[metric].values new_rows_di[metric] = qm_all_di[metric].values + for metric in tm_all.columns: + new_rows[metric] = tm_all[metric].values + new_rows_di[metric] = tm_all_di[metric].values # append new entries unit_level_results = pd.concat( [unit_level_results, pd.DataFrame(new_rows), pd.DataFrame(new_rows_di)], ignore_index=True @@ -380,6 +416,9 @@ for metric in qm_matched.columns: matched_unit_level_results_columns.append(metric) matched_unit_level_results_columns.append(f"{metric}_di") + for metric in tm_matched.columns: + matched_unit_level_results_columns.append(metric) + matched_unit_level_results_columns.append(f"{metric}_di") matched_unit_level_results = pd.DataFrame(columns=matched_unit_level_results) new_matched_rows = { @@ -400,6 +439,9 @@ for metric in qm_matched.columns: new_rows[metric] = qm_matched[metric].values new_rows[f"{metric}_di"] = qm_matched_di[metric].values + for metric in tm_matched.columns: + new_rows[metric] = tm_matched[metric].values + new_rows[f"{metric}_di"] = tm_matched_di[metric].values # append new entries matched_unit_level_results = pd.concat( [matched_unit_level_results, pd.DataFrame(new_matched_rows)], ignore_index=True diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py index 040905f..ca550c7 100644 --- a/pipeline/run_spike_sorting_GT.py +++ b/pipeline/run_spike_sorting_GT.py @@ -62,12 +62,16 @@ # each json file contains a session to run for json_file in json_files: with open(json_file, "r") as f: - d = json.load(f) - probe = d["probe"] + config = json.load(f) + probe = config["probe"] if probe not in session_dict: session_dict[probe] = [] - session = d["session"] + session = config["session"] session_dict[probe].append(session) + if "filter_options" in config: + filter_options = [config["filter_options"]] + else: + filter_options = FILTER_OPTIONS else: session_dict = all_sessions @@ -96,7 +100,7 @@ session_level_results = None unit_level_results = None - for filter_option in FILTER_OPTIONS: + for filter_option in filter_options: print(f"\tFilter option: {filter_option}") # load recordings diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 74adbeb..0c27080 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -92,15 +92,19 @@ # each json file contains a session to run for json_file in json_files: with open(json_file, "r") as f: - d = json.load(f) - probe = d["probe"] + config = json.load(f) + probe = config["probe"] if probe not in session_dict: session_dict[probe] = [] - session = d["session"] + session = config["session"] assert ( session in all_sessions[probe] ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" session_dict[probe].append(session) + if "filter_options" in config: + filter_options = [config["filter_options"]] + else: + filter_options = FILTER_OPTIONS else: session_dict = all_sessions @@ -155,7 +159,7 @@ if TRAINING_END_S is None: TRAINING_END_S = recording.get_total_duration() - for filter_option in FILTER_OPTIONS: + for filter_option in filter_options: print(f"\tFilter option: {filter_option}") recording_name = f"{dataset_name}_{session_name}_{filter_option}" # train DI models diff --git a/pipeline/sessions.py b/pipeline/sessions.py index 6873445..d9a48e1 100644 --- a/pipeline/sessions.py +++ b/pipeline/sessions.py @@ -41,8 +41,10 @@ ], } +FILTER_OPTIONS = ["hp", "bp"] -def generate_job_config_list(output_folder, split_probes=True, dataset="exp"): + +def generate_job_config_list(output_folder, split_probes=True, split_filters=True, dataset="exp"): output_folder = Path(output_folder) output_folder.mkdir(exist_ok=True, parents=True) From 313211a627f24b6326b214f9609422cfb35627f2 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 18 Sep 2023 16:47:46 +0200 Subject: [PATCH 80/84] Cleanup --- pipeline/run_inference.py | 44 ++++++++++++++++--------------- pipeline/run_spike_sorting.py | 45 +++++++++++++++++--------------- pipeline/run_spike_sorting_GT.py | 42 ++++++++++++++++------------- pipeline/run_training.py | 44 +++++++++++++++++-------------- pipeline/sessions.py | 13 ++++++--- 5 files changed, 105 insertions(+), 83 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index 3524747..b598fdb 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -102,32 +102,34 @@ inference_n_jobs = int(sys.argv[3]) inference_predict_workers = int(sys.argv[4]) - json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] + session_dict = all_sessions + filter_options = FILTER_OPTIONS - if len(json_files) > 0: + json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] + if len(json_files) == 1: print(f"Found {len(json_files)} JSON config") session_dict = {} # each json file contains a session to run - for json_file in json_files: - with open(json_file, "r") as f: - config = json.load(f) - probe = config["probe"] - if probe not in session_dict: - session_dict[probe] = [] - session = config["session"] - assert ( - session in all_sessions[probe] - ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" - session_dict[probe].append(session) - if "filter_options" in config: - filter_options = [config["filter_options"]] - else: - filter_options = FILTER_OPTIONS - else: - session_dict = all_sessions - - print(session_dict) + json_file = json_files[0] + with open(json_file, "r") as f: + config = json.load(f) + probe = config["probe"] + if probe not in session_dict: + session_dict[probe] = [] + session = config["session"] + assert ( + session in all_sessions[probe] + ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" + session_dict[probe].append(session) + if "filter_options" in config: + filter_options = [config["filter_options"]] + else: + filter_options = FILTER_OPTIONS + elif len(json_files) > 1: + print("Only 1 JSON config file allowed, using default sessions") + print(f"Sessions:\n{session_dict}") + print(f"Filter options:\n{filter_options}") si.set_global_job_kwargs(**job_kwargs) print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index ca25f9e..4044564 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -72,31 +72,34 @@ else: DEBUG = False - json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] + session_dict = all_sessions + filter_options = FILTER_OPTIONS - if len(json_files) > 0: + json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] + if len(json_files) == 1: print(f"Found {len(json_files)} JSON config") session_dict = {} # each json file contains a session to run - for json_file in json_files: - with open(json_file, "r") as f: - config = json.load(f) - probe = config["probe"] - if probe not in session_dict: - session_dict[probe] = [] - session = config["session"] - assert ( - session in all_sessions[probe] - ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" - session_dict[probe].append(session) - if "filter_options" in config: - filter_options = [config["filter_options"]] - else: - filter_options = FILTER_OPTIONS - else: - session_dict = all_sessions - - print(session_dict) + json_file = json_files[0] + with open(json_file, "r") as f: + config = json.load(f) + probe = config["probe"] + if probe not in session_dict: + session_dict[probe] = [] + session = config["session"] + assert ( + session in all_sessions[probe] + ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" + session_dict[probe].append(session) + if "filter_options" in config: + filter_options = [config["filter_options"]] + else: + filter_options = FILTER_OPTIONS + elif len(json_files) > 1: + print("Only 1 JSON config file allowed, using default sessions") + + print(f"Sessions:\n{session_dict}") + print(f"Filter options:\n{filter_options}") si.set_global_job_kwargs(**job_kwargs) diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py index ca550c7..bbbc8e5 100644 --- a/pipeline/run_spike_sorting_GT.py +++ b/pipeline/run_spike_sorting_GT.py @@ -54,28 +54,34 @@ else: OVERWRITE = False - json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] + session_dict = all_sessions + filter_options = FILTER_OPTIONS - if len(json_files) > 0: + json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] + if len(json_files) == 1: print(f"Found {len(json_files)} JSON config") session_dict = {} # each json file contains a session to run - for json_file in json_files: - with open(json_file, "r") as f: - config = json.load(f) - probe = config["probe"] - if probe not in session_dict: - session_dict[probe] = [] - session = config["session"] - session_dict[probe].append(session) - if "filter_options" in config: - filter_options = [config["filter_options"]] - else: - filter_options = FILTER_OPTIONS - else: - session_dict = all_sessions - - print(session_dict) + json_file = json_files[0] + with open(json_file, "r") as f: + config = json.load(f) + probe = config["probe"] + if probe not in session_dict: + session_dict[probe] = [] + session = config["session"] + assert ( + session in all_sessions[probe] + ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" + session_dict[probe].append(session) + if "filter_options" in config: + filter_options = [config["filter_options"]] + else: + filter_options = FILTER_OPTIONS + elif len(json_files) > 1: + print("Only 1 JSON config file allowed, using default sessions") + + print(f"Sessions:\n{session_dict}") + print(f"Filter options:\n{filter_options}") si.set_global_job_kwargs(**job_kwargs) diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 0c27080..45e52f8 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -85,30 +85,34 @@ else: DEBUG = False + session_dict = all_sessions + filter_options = FILTER_OPTIONS + json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] - print(f"Found {len(json_files)} JSON config: {json_files}") - if len(json_files) > 0: + if len(json_files) == 1: + print(f"Found {len(json_files)} JSON config") session_dict = {} # each json file contains a session to run - for json_file in json_files: - with open(json_file, "r") as f: - config = json.load(f) - probe = config["probe"] - if probe not in session_dict: - session_dict[probe] = [] - session = config["session"] - assert ( - session in all_sessions[probe] - ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" - session_dict[probe].append(session) - if "filter_options" in config: - filter_options = [config["filter_options"]] - else: - filter_options = FILTER_OPTIONS - else: - session_dict = all_sessions + json_file = json_files[0] + with open(json_file, "r") as f: + config = json.load(f) + probe = config["probe"] + if probe not in session_dict: + session_dict[probe] = [] + session = config["session"] + assert ( + session in all_sessions[probe] + ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" + session_dict[probe].append(session) + if "filter_options" in config: + filter_options = [config["filter_options"]] + else: + filter_options = FILTER_OPTIONS + elif len(json_files) > 1: + print("Only 1 JSON config file allowed, using default sessions") - print(session_dict) + print(f"Sessions:\n{session_dict}") + print(f"Filter options:\n{filter_options}") if DEBUG: TRAINING_START_S = 10 diff --git a/pipeline/sessions.py b/pipeline/sessions.py index d9a48e1..d0294f1 100644 --- a/pipeline/sessions.py +++ b/pipeline/sessions.py @@ -65,7 +65,14 @@ def generate_job_config_list(output_folder, split_probes=True, split_filters=Tru for session in sessions: d = dict(session=session, probe=probe) - with open(probe_folder / f"job{i}.json", "w") as f: - json.dump(d, f) + if split_filters: + for filter_option in FILTER_OPTIONS: + d["filter_options"] = filter_option + with open(probe_folder / f"job{i}.json", "w") as f: + json.dump(d, f) + i += 1 + else: + with open(probe_folder / f"job{i}.json", "w") as f: + json.dump(d, f) + i += 1 - i += 1 From 84dc31e16ee026297804779008a65d2f3d897692 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 18 Sep 2023 16:51:07 +0200 Subject: [PATCH 81/84] Save sparse waveforms in results --- pipeline/run_spike_sorting.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 4044564..01d3609 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -258,9 +258,12 @@ if sorting_matched is not None: # waveforms for all units - waveforms_all_folder = ( + waveforms_scratch_folder = ( scratch_folder / f"waveforms_all_{dataset_name}_{session_name}_{filter_option}" ) + waveforms_all_folder = ( + results_folder / f"waveforms_all_{dataset_name}_{session_name}_{filter_option}" + ) waveforms_all_folder.mkdir(exist_ok=True, parents=True) if (waveforms_all_folder / "waveforms").is_dir() and not OVERWRITE: @@ -275,7 +278,7 @@ we_dense = si.extract_waveforms( recording, sorting, - folder=waveforms_all_folder / "waveforms_dense", + folder=waveforms_scratch_folder / "waveforms_dense", n_jobs=n_jobs, overwrite=True, max_spikes_per_unit=100, @@ -317,7 +320,7 @@ we_dense_di = si.extract_waveforms( recording_di, sorting_di, - folder=waveforms_all_folder / "waveforms_dense_di", + folder=waveforms_scratch_folder / "waveforms_dense_di", n_jobs=n_jobs, overwrite=True, max_spikes_per_unit=100, From 81964272fdc679019f27503e72a2ccb586c967e5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 18 Sep 2023 17:18:44 +0200 Subject: [PATCH 82/84] filter_options -> filter_option --- pipeline/run_inference.py | 4 ++-- pipeline/run_spike_sorting.py | 4 ++-- pipeline/run_spike_sorting_GT.py | 4 ++-- pipeline/run_training.py | 4 ++-- pipeline/sessions.py | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py index b598fdb..e412619 100644 --- a/pipeline/run_inference.py +++ b/pipeline/run_inference.py @@ -121,8 +121,8 @@ session in all_sessions[probe] ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" session_dict[probe].append(session) - if "filter_options" in config: - filter_options = [config["filter_options"]] + if "filter_option" in config: + filter_options = [config["filter_option"]] else: filter_options = FILTER_OPTIONS elif len(json_files) > 1: diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index 01d3609..c55daa6 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -91,8 +91,8 @@ session in all_sessions[probe] ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" session_dict[probe].append(session) - if "filter_options" in config: - filter_options = [config["filter_options"]] + if "filter_option" in config: + filter_options = [config["filter_option"]] else: filter_options = FILTER_OPTIONS elif len(json_files) > 1: diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py index bbbc8e5..6173a79 100644 --- a/pipeline/run_spike_sorting_GT.py +++ b/pipeline/run_spike_sorting_GT.py @@ -73,8 +73,8 @@ session in all_sessions[probe] ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" session_dict[probe].append(session) - if "filter_options" in config: - filter_options = [config["filter_options"]] + if "filter_option" in config: + filter_options = [config["filter_option"]] else: filter_options = FILTER_OPTIONS elif len(json_files) > 1: diff --git a/pipeline/run_training.py b/pipeline/run_training.py index 45e52f8..451d8a7 100644 --- a/pipeline/run_training.py +++ b/pipeline/run_training.py @@ -104,8 +104,8 @@ session in all_sessions[probe] ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" session_dict[probe].append(session) - if "filter_options" in config: - filter_options = [config["filter_options"]] + if "filter_option" in config: + filter_options = [config["filter_option"]] else: filter_options = FILTER_OPTIONS elif len(json_files) > 1: diff --git a/pipeline/sessions.py b/pipeline/sessions.py index d0294f1..1627a5a 100644 --- a/pipeline/sessions.py +++ b/pipeline/sessions.py @@ -67,7 +67,7 @@ def generate_job_config_list(output_folder, split_probes=True, split_filters=Tru if split_filters: for filter_option in FILTER_OPTIONS: - d["filter_options"] = filter_option + d["filter_option"] = filter_option with open(probe_folder / f"job{i}.json", "w") as f: json.dump(d, f) i += 1 From 53754d2a91d395b96377f714678cd99914ab0909 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 19 Sep 2023 10:33:28 +0200 Subject: [PATCH 83/84] Remove correct folder --- pipeline/run_spike_sorting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index c55daa6..e6163fd 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -293,7 +293,7 @@ sparsity=sparsity, ) # remove dense folder - shutil.rmtree(waveforms_all_folder / "waveforms_dense") + shutil.rmtree(waveforms_scratch_folder / "waveforms_dense") print("\t\tCompute NO DI spike amplitudes") _ = spost.compute_spike_amplitudes(we_all) @@ -335,7 +335,7 @@ sparsity=sparsity_di, ) # remove dense folder - shutil.rmtree(waveforms_all_folder / "waveforms_dense_di") + shutil.rmtree(waveforms_scratch_folder / "waveforms_dense_di") print("\t\tCompute DI spike amplitudes") _ = spost.compute_spike_amplitudes(we_all_di) From 2a4f63b437cb8965391f8e9f4fc6913d89086725 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 19 Sep 2023 13:36:47 +0200 Subject: [PATCH 84/84] Add metrics to matched-units csv --- pipeline/run_spike_sorting.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py index e6163fd..6da0a80 100644 --- a/pipeline/run_spike_sorting.py +++ b/pipeline/run_spike_sorting.py @@ -443,11 +443,11 @@ ) new_matched_rows["agreement_score"] = agreement_scores for metric in qm_matched.columns: - new_rows[metric] = qm_matched[metric].values - new_rows[f"{metric}_di"] = qm_matched_di[metric].values + new_matched_rows[metric] = qm_matched[metric].values + new_matched_rows[f"{metric}_di"] = qm_matched_di[metric].values for metric in tm_matched.columns: - new_rows[metric] = tm_matched[metric].values - new_rows[f"{metric}_di"] = tm_matched_di[metric].values + new_matched_rows[metric] = tm_matched[metric].values + new_matched_rows[f"{metric}_di"] = tm_matched_di[metric].values # append new entries matched_unit_level_results = pd.concat( [matched_unit_level_results, pd.DataFrame(new_matched_rows)], ignore_index=True