diff --git a/generate_image_inference.py b/generate_image_inference.py deleted file mode 100644 index 2ba171a..0000000 --- a/generate_image_inference.py +++ /dev/null @@ -1,28 +0,0 @@ -import os -import matplotlib -import matplotlib.pyplot as plt -from scipy import ndimage -import scipy -import h5py -import numpy as np -os.listdir('.') -fn= 'ephys_tiny_continuous_deep_interpolation.h5' -f= h5py.File(fn) -tsurf = f['data'] -tsurf -di_traces = tsurf[:,:,0] -di_traces.shape -plot= plt.imshow(di_traces.T, - origin='lower', - vmin=-50, - vmax=50, - cmap='RdGy', - aspect='auto') -plt.xlabel('Sample Index') -plt.ylabel('Acquisition Channels') -#plt.ylim(0,386) -#plt.xlim(0,400) -#matplotlib.pyplot.annotate('alpha', xy= (1,1), xytext=(3,3), xycoords='data', textcoords=None, arrowprops=None, annotation_clip=None) -matplotlib.pyplot.colorbar(label= 'μV', shrink=0.25) -plt.title("1000samples_28Epoch") -matplotlib.pyplot.savefig('1000samples_28Epoch') diff --git a/pipeline/run_collect_results.py b/pipeline/run_collect_results.py new file mode 100644 index 0000000..b5f87aa --- /dev/null +++ b/pipeline/run_collect_results.py @@ -0,0 +1,105 @@ +import warnings + +warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +#### IMPORTS ####### +import shutil + + +from pathlib import Path +import pandas as pd + + +base_path = Path("..") + +data_folder = base_path / "data" +scratch_folder = base_path / "scratch" +results_folder = base_path / "results" + + +if __name__ == "__main__": + # list all data entries + print("Data folder content:") + for p in data_folder.iterdir(): + print(f"\t{p.name}") + + # concatenate dataframes + df_session = None + df_units = None + df_matched_units = None + + probe_sortings_folders = [p for p in data_folder.iterdir() if p.name.startswith("sorting_") and p.is_dir()] + + if len(probe_sortings_folders) > 0: + data_models_folder = data_folder + data_sortings_folder = data_folder + else: + data_model_subfolders = [] + for p in data_folder.iterdir(): + if p.is_dir() and len([pp for pp in p.iterdir() if "model_" in pp.name and pp.is_dir()]) > 0: + data_model_subfolders.append(p) + data_models_folder = data_model_subfolders[0] + + data_sorting_subfolders = [] + for p in data_folder.iterdir(): + if p.is_dir() and len([pp for pp in p.iterdir() if "sorting_" in pp.name and pp.is_dir()]) > 0: + data_sorting_subfolders.append(p) + data_sortings_folder = data_sorting_subfolders[0] + + session_csvs = [p for p in data_sortings_folder.iterdir() if p.name.endswith("sessions.csv")] + unit_csvs = [p for p in data_sortings_folder.iterdir() if p.name.endswith("units.csv") and "matched" not in p.name] + matched_unit_csvs = [p for p in data_sortings_folder.iterdir() if p.name.endswith("matched-units.csv")] + + for session_csv in session_csvs: + if df_session is None: + df_session = pd.read_csv(session_csv) + else: + df_session = pd.concat([df_session, pd.read_csv(session_csv)]) + + for unit_csv in unit_csvs: + if df_units is None: + df_units = pd.read_csv(unit_csv) + else: + df_units = pd.concat([df_units, pd.read_csv(unit_csv)]) + + for matched_unit_csv in matched_unit_csvs: + if df_matched_units is None: + df_matched_units = pd.read_csv(matched_unit_csv) + else: + df_matched_units = pd.concat([df_matched_units, pd.read_csv(matched_unit_csv)]) + + # save concatenated dataframes + df_session.to_csv(results_folder / "sessions.csv", index=False) + df_units.to_csv(results_folder / "units.csv", index=False) + df_matched_units.to_csv(results_folder / "matched-units.csv", index=False) + + # copy sortings to results folder + sortings_folders = [p for p in data_sortings_folder.iterdir() if "sorting_" in p.name and p.is_dir()] + sortings_output_base_folder = results_folder / "sortings" + sortings_output_base_folder.mkdir(exist_ok=True) + + for sorting_folder in sortings_folders: + sorting_folder_split = sorting_folder.name.split("_") + dataset_name = sorting_folder_split[1] + session_name = "_".join(sorting_folder_split[2:-1]) + filter_option = sorting_folder_split[-1] + sorting_output_folder = sortings_output_base_folder / dataset_name / session_name / filter_option + sorting_output_folder.mkdir(exist_ok=True, parents=True) + for sorting_subfolder in sorting_folder.iterdir(): + shutil.copytree(sorting_subfolder, sorting_output_folder / sorting_subfolder.name) + + # copy models to results folder + models_folders = [p for p in data_models_folder.iterdir() if "model_" in p.name and p.is_dir()] + models_output_base_folder = results_folder / "models" + models_output_base_folder.mkdir(exist_ok=True) + + for model_folder in models_folders: + model_folder_split = model_folder.name.split("_") + dataset_name = model_folder_split[1] + session_name = "_".join(model_folder_split[2:-1]) + filter_option = model_folder_split[-1] + model_output_folder = models_output_base_folder / dataset_name / session_name / filter_option + model_output_folder.parent.mkdir(exist_ok=True, parents=True) + shutil.copytree(model_folder, model_output_folder) diff --git a/pipeline/run_inference.py b/pipeline/run_inference.py new file mode 100644 index 0000000..e412619 --- /dev/null +++ b/pipeline/run_inference.py @@ -0,0 +1,233 @@ +import warnings + +warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) + + +#### IMPORTS ####### +import os +import sys +import shutil +import json +import numpy as np +from pathlib import Path +import pandas as pd +import time + + +# SpikeInterface +import spikeinterface as si +import spikeinterface.extractors as se +import spikeinterface.preprocessing as spre +import spikeinterface.sorters as ss +import spikeinterface.postprocessing as spost +import spikeinterface.comparison as sc +import spikeinterface.qualitymetrics as sqm + +# Tensorflow +import tensorflow as tf + + +os.environ["OPENBLAS_NUM_THREADS"] = "1" + + +base_path = Path("..") + +##### DEFINE DATASETS AND FOLDERS ####### +from sessions import all_sessions_exp, all_sessions_sim + +n_jobs = 24 + +job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") + +data_folder = base_path / "data" +scratch_folder = base_path / "scratch" +results_folder = base_path / "results" + + +if (data_folder / "ephys-compression-benchmark").is_dir(): + DATASET_FOLDER = data_folder / "ephys-compression-benchmark" + all_sessions = all_sessions_exp + data_type = "exp" +elif (data_folder / "MEArec-NP-recordings").is_dir(): + DATASET_FOLDER = data_folder / "MEArec-NP-recordings" + all_sessions = all_sessions_sim + data_type = "sim" +else: + raise Exception("Could not find dataset folder") + + +DEBUG = False +NUM_DEBUG_SESSIONS = 2 +DEBUG_DURATION = 20 + +##### DEFINE PARAMS ##### +OVERWRITE = False +USE_GPU = True + + +# Define training and testing constants (@Jad you can gradually increase this) + + +FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" + +# DI params +pre_frame = 30 +post_frame = 30 +pre_post_omission = 1 +desired_shape = (192, 2) +# play around with these +inference_n_jobs = 24 +inference_chunk_duration = "1s" +inference_predict_workers = 1 +inference_memory_gpu = 2000 # MB + +di_kwargs = dict( + pre_frame=pre_frame, + post_frame=post_frame, + pre_post_omission=pre_post_omission, + desired_shape=desired_shape, +) + +if __name__ == "__main__": + if len(sys.argv) == 5: + if sys.argv[1] == "true": + DEBUG = True + OVERWRITE = True + else: + DEBUG = False + OVERWRITE = False + n_jobs = int(sys.argv[2]) + inference_n_jobs = int(sys.argv[3]) + inference_predict_workers = int(sys.argv[4]) + + session_dict = all_sessions + filter_options = FILTER_OPTIONS + + json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] + if len(json_files) == 1: + print(f"Found {len(json_files)} JSON config") + session_dict = {} + # each json file contains a session to run + json_file = json_files[0] + with open(json_file, "r") as f: + config = json.load(f) + probe = config["probe"] + if probe not in session_dict: + session_dict[probe] = [] + session = config["session"] + assert ( + session in all_sessions[probe] + ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" + session_dict[probe].append(session) + if "filter_option" in config: + filter_options = [config["filter_option"]] + else: + filter_options = FILTER_OPTIONS + elif len(json_files) > 1: + print("Only 1 JSON config file allowed, using default sessions") + + print(f"Sessions:\n{session_dict}") + print(f"Filter options:\n{filter_options}") + si.set_global_job_kwargs(**job_kwargs) + + print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") + + #### START #### + probe_models_folders = [p for p in data_folder.iterdir() if "model_" in p.name and p.is_dir()] + + if len(probe_models_folders) > 0: + data_model_folder = data_folder + else: + data_model_subfolders = [] + for p in data_folder.iterdir(): + if p.is_dir() and len([pp for pp in p.iterdir() if "model_" in pp.name and pp.is_dir()]) > 0: + data_model_subfolders.append(p) + data_model_folder = data_model_subfolders[0] + + for probe, sessions in session_dict.items(): + print(f"Dataset {probe}") + for session in sessions: + print(f"\nAnalyzing session {session}\n") + dataset_name, session_name = session.split("/") + + if data_type == "exp": + recording = si.load_extractor(DATASET_FOLDER / session) + else: + recording, _ = se.read_mearec(DATASET_FOLDER / session) + session_name = session_name.split(".")[0] + + if DEBUG: + recording = recording.frame_slice( + start_frame=0, + end_frame=int(DEBUG_DURATION * recording.sampling_frequency), + ) + print(f"\t{recording}") + + for filter_option in filter_options: + print(f"\tFilter option: {filter_option}") + recording_name = f"{dataset_name}_{session_name}_{filter_option}" + + # apply filter and zscore + if filter_option == "hp": + recording_processed = spre.highpass_filter(recording) + elif filter_option == "bp": + recording_processed = spre.bandpass_filter(recording) + else: + recording_processed = recording + + if data_type == "sim": + recording_processed = spre.depth_order(recording_processed) + + recording_zscore = spre.zscore(recording_processed) + # This speeds things up a lot + recording_zscore_bin = recording_zscore.save( + folder=scratch_folder / f"recording_zscored_{recording_name}" + ) + + # train model + model_folder = data_model_folder / f"model_{recording_name}" + model_path = [p for p in model_folder.iterdir() if p.name.endswith("model.h5")][0] + # full inference + output_folder = results_folder / f"deepinterpolated_{recording_name}" + if OVERWRITE and output_folder.is_dir(): + shutil.rmtree(output_folder) + + if not output_folder.is_dir(): + t_start_inference = time.perf_counter() + output_folder.parent.mkdir(exist_ok=True, parents=True) + recording_di = spre.deepinterpolate( + recording_zscore_bin, + model_path=model_path, + pre_frame=pre_frame, + post_frame=post_frame, + pre_post_omission=pre_post_omission, + memory_gpu=inference_memory_gpu, + predict_workers=inference_predict_workers, + use_gpu=USE_GPU, + ) + recording_di = recording_di.save( + folder=output_folder, + n_jobs=inference_n_jobs, + chunk_duration=inference_chunk_duration, + ) + t_stop_inference = time.perf_counter() + elapsed_time_inference = np.round(t_stop_inference - t_start_inference, 2) + print(f"\t\tElapsed time INFERENCE: {elapsed_time_inference}s") + else: + print("\t\tLoading existing folder") + recording_di = si.load_extractor(output_folder) + # apply inverse z-scoring + inverse_gains = 1 / recording_zscore.gain + inverse_offset = -recording_zscore.offset * inverse_gains + recording_di = spre.scale(recording_di, gain=inverse_gains, offset=inverse_offset, dtype="float") + + # save processed json + processed_folder = results_folder / f"processed_{recording_name}" + processed_folder.mkdir(exist_ok=True, parents=True) + recording_processed.dump_to_json(processed_folder / "processed.json", relative_to=results_folder) + recording_di.dump_to_json(processed_folder / f"deepinterpolated.json", relative_to=results_folder) + + for json_file in json_files: + shutil.copy(json_file, results_folder) diff --git a/pipeline/run_spike_sorting.py b/pipeline/run_spike_sorting.py new file mode 100644 index 0000000..6da0a80 --- /dev/null +++ b/pipeline/run_spike_sorting.py @@ -0,0 +1,465 @@ +import warnings + +warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +#### IMPORTS ####### +import sys +import json +import shutil +from pathlib import Path +import pandas as pd + +# SpikeInterface +import spikeinterface as si +import spikeinterface.sorters as ss +import spikeinterface.curation as scur +import spikeinterface.comparison as sc +import spikeinterface.postprocessing as spost +import spikeinterface.qualitymetrics as sqm + + +base_path = Path("..") + +##### DEFINE DATASETS AND FOLDERS ####### +from sessions import all_sessions_exp as all_sessions + +n_jobs = 16 + +job_kwargs = dict(n_jobs=n_jobs, progress_bar=False, chunk_duration="1s") + +data_folder = base_path / "data" +scratch_folder = base_path / "scratch" +results_folder = base_path / "results" + + +# Define training and testing constants +FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" + + +sorter_name = "pykilosort" +singularity_image = False +match_score = 0.7 + +sparsity_kwargs = dict( + method="radius", + radius_um=200, +) + +# skip NN because extremely slow +qm_metric_names = [ + "num_spikes", + "firing_rate", + "presence_ratio", + "snr", + "isi_violation", + "rp_violation", + "sliding_rp_violation", + "amplitude_cutoff", + "drift", + "isolation_distance", + "l_ratio", + "d_prime", +] + + +if __name__ == "__main__": + if len(sys.argv) == 2: + if sys.argv[1] == "true": + DEBUG = True + OVERWRITE = True + else: + DEBUG = False + + session_dict = all_sessions + filter_options = FILTER_OPTIONS + + json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] + if len(json_files) == 1: + print(f"Found {len(json_files)} JSON config") + session_dict = {} + # each json file contains a session to run + json_file = json_files[0] + with open(json_file, "r") as f: + config = json.load(f) + probe = config["probe"] + if probe not in session_dict: + session_dict[probe] = [] + session = config["session"] + assert ( + session in all_sessions[probe] + ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" + session_dict[probe].append(session) + if "filter_option" in config: + filter_options = [config["filter_option"]] + else: + filter_options = FILTER_OPTIONS + elif len(json_files) > 1: + print("Only 1 JSON config file allowed, using default sessions") + + print(f"Sessions:\n{session_dict}") + print(f"Filter options:\n{filter_options}") + + si.set_global_job_kwargs(**job_kwargs) + + #### START #### + probe_processed_folders = [p for p in data_folder.iterdir() if "processed_" in p.name and p.is_dir()] + + if len(probe_processed_folders) > 0: + processed_folder = data_folder + else: + data_processed_subfolders = [] + for p in data_folder.iterdir(): + if p.is_dir() and len([pp for pp in p.iterdir() if "processed_" in pp.name and pp.is_dir()]) > 0: + data_processed_subfolders.append(p) + processed_folder = data_processed_subfolders[0] + + for probe, sessions in session_dict.items(): + print(f"Dataset {probe}") + for session in sessions: + print(f"\nAnalyzing session {session}\n") + dataset_name, session_name = session.split("/") + + session_level_results_columns = [ + "dataset", + "session", + "probe", + "filter_option", + "num_units", + "num_units_di", + "sorting_path", + "sorting_path_di", + "num_match", + ] + session_level_results = None + + unit_level_results_columns = [ + "dataset", + "session", + "probe", + "filter_option", + "unit_id", + "deepinterpolated", + ] + unit_level_results = None + + matched_unit_level_results_columns = [ + "dataset", + "session", + "probe", + "filter_option", + "unit_id", + "unit_id_di", + "agreement_score", + ] + matched_unit_level_results = None + + for filter_option in filter_options: + print(f"\tFilter option: {filter_option}") + + # load recordings + # save processed json + processed_json_folder = processed_folder / f"processed_{dataset_name}_{session_name}_{filter_option}" + recording = si.load_extractor(processed_json_folder / "processed.json", base_folder=data_folder) + recording_di = si.load_extractor( + processed_json_folder / "deepinterpolated.json", base_folder=processed_folder + ) + + # run spike sorting + sorting_output_folder = results_folder / f"sorting_{dataset_name}_{session_name}_{filter_option}" + sorting_output_folder.mkdir(parents=True, exist_ok=True) + + if (sorting_output_folder / "sorting").is_dir() and not OVERWRITE: + print("\t\tLoading NO DI sorting") + sorting = si.load_extractor(sorting_output_folder / "sorting") + else: + print(f"\t\tSpike sorting NO DI with {sorter_name}") + try: + sorting = ss.run_sorter( + sorter_name, + recording=recording, + output_folder=scratch_folder / session / filter_option / "no_di", + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting = scur.remove_excess_spikes(sorting, recording) + sorting = sorting.save(folder=sorting_output_folder / "sorting") + except: + print(f"Error sorting {session} with {sorter_name} and {filter_option}") + sorting = None + + if (sorting_output_folder / "sorting_di").is_dir() and not OVERWRITE: + print("\t\tLoading DI sorting") + sorting_di = si.load_extractor(sorting_output_folder / "sorting_di") + else: + print(f"\t\tSpike sorting DI with {sorter_name}") + try: + sorting_di = ss.run_sorter( + sorter_name, + recording=recording_di, + output_folder=scratch_folder / session / filter_option / "di", + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting_di = scur.remove_excess_spikes(sorting_di, recording_di) + sorting_di = sorting_di.save(folder=sorting_output_folder / "sorting_di") + except: + print(f"Error sorting DI {session} with {sorter_name} and {filter_option}") + sorting_di = None + + if sorting is not None and sorting_di is not None: + # compare outputs + print("\t\tComparing sortings") + comp = sc.compare_two_sorters( + sorting1=sorting, + sorting2=sorting_di, + sorting1_name="no_di", + sorting2_name="di", + match_score=match_score, + ) + matched_units = comp.get_matching()[0] + matched_unit_ids = matched_units.index.values.astype(int) + matched_unit_ids_di = matched_units.values.astype(int) + matched_units_valid = matched_unit_ids_di != -1 + matched_unit_ids = matched_unit_ids[matched_units_valid] + matched_unit_ids_di = matched_unit_ids_di[matched_units_valid] + sorting_matched = sorting.select_units(unit_ids=matched_unit_ids) + sorting_di_matched = sorting_di.select_units(unit_ids=matched_unit_ids_di) + else: + sorting_matched = None + sorting_di_matched = None + + new_row = { + "dataset": dataset_name, + "session": session_name, + "filter_option": filter_option, + "probe": probe, + "num_units": len(sorting.unit_ids) if sorting is not None else 0, + "num_units_di": len(sorting_di.unit_ids) if sorting_di is not None else 0, + "num_match": len(sorting_matched.unit_ids) if sorting_matched is not None else 0, + "sorting_path": str((sorting_output_folder / "sorting").relative_to(results_folder)) + if sorting is not None + else None, + "sorting_path_di": str((sorting_output_folder / "sorting_di_").relative_to(results_folder)) + if sorting_di is not None + else None, + } + + print( + f"\n\t\tNum units: {new_row['num_units']} - Num units DI: {new_row['num_units_di']} - Num match: {new_row['num_match']}" + ) + + if session_level_results is None: + session_level_results = pd.DataFrame(columns=session_level_results_columns) + session_level_results = pd.concat([session_level_results, pd.DataFrame([new_row])], ignore_index=True) + + if sorting_matched is not None: + # waveforms for all units + waveforms_scratch_folder = ( + scratch_folder / f"waveforms_all_{dataset_name}_{session_name}_{filter_option}" + ) + waveforms_all_folder = ( + results_folder / f"waveforms_all_{dataset_name}_{session_name}_{filter_option}" + ) + waveforms_all_folder.mkdir(exist_ok=True, parents=True) + + if (waveforms_all_folder / "waveforms").is_dir() and not OVERWRITE: + print("\t\tLoad NO DI waveforms all") + we_all = si.load_waveforms(waveforms_all_folder / "waveforms") + else: + print("\t\tCompute NO DI waveforms all") + if sorting.sampling_frequency != recording.sampling_frequency: + print("\t\tSetting sorting sampling frequency to match recording") + sorting._sampling_frequency = recording.sampling_frequency + # first full, then sparse + we_dense = si.extract_waveforms( + recording, + sorting, + folder=waveforms_scratch_folder / "waveforms_dense", + n_jobs=n_jobs, + overwrite=True, + max_spikes_per_unit=100, + ) + sparsity = si.compute_sparsity(we_dense, **sparsity_kwargs) + we_all = si.extract_waveforms( + recording, + sorting, + folder=waveforms_all_folder / "waveforms_all", + n_jobs=n_jobs, + overwrite=True, + sparsity=sparsity, + ) + # remove dense folder + shutil.rmtree(waveforms_scratch_folder / "waveforms_dense") + + print("\t\tCompute NO DI spike amplitudes") + _ = spost.compute_spike_amplitudes(we_all) + print("\t\tCompute NO DI spike locations") + _ = spost.compute_spike_locations(we_all) + print("\t\tCompute NO DI PCA scores") + _ = spost.compute_principal_components(we_all) + + # finally, template and quality metrics + print("\t\tCompute NO DI template metrics") + tm_all = spost.compute_template_metrics(we_all) + print("\t\tCompute NO DI metrics") + qm_all = sqm.compute_quality_metrics(we_all, n_jobs=1, metric_names=qm_metric_names) + + if (waveforms_all_folder / "waveforms_di").is_dir() and not OVERWRITE: + print("\t\tLoad DI waveforms all") + we_all_di = si.load_waveforms(waveforms_all_folder / "waveforms_di") + else: + print("\t\tCompute DI waveforms all") + if sorting_di.sampling_frequency != recording.sampling_frequency: + print("\t\tSetting sorting DI sampling frequency to match recording") + sorting_di._sampling_frequency = recording.sampling_frequency + # first full, then sparse + we_dense_di = si.extract_waveforms( + recording_di, + sorting_di, + folder=waveforms_scratch_folder / "waveforms_dense_di", + n_jobs=n_jobs, + overwrite=True, + max_spikes_per_unit=100, + ) + sparsity_di = si.compute_sparsity(we_dense_di, **sparsity_kwargs) + we_all_di = si.extract_waveforms( + recording_di, + sorting_di, + folder=waveforms_all_folder / "waveforms_all_di", + n_jobs=n_jobs, + overwrite=True, + sparsity=sparsity_di, + ) + # remove dense folder + shutil.rmtree(waveforms_scratch_folder / "waveforms_dense_di") + + print("\t\tCompute DI spike amplitudes") + _ = spost.compute_spike_amplitudes(we_all_di) + print("\t\tCompute DI spike locations") + _ = spost.compute_spike_locations(we_all_di) + print("\t\tCompute DI PCA scores") + _ = spost.compute_principal_components(we_all_di) + + # finally, template and quality metrics + print("\t\tCompute DI template metrics") + tm_all_di = spost.compute_template_metrics(we_all_di) + + print("\t\tCompute DI metrics") + qm_all_di = sqm.compute_quality_metrics(we_all_di, n_jobs=1, metric_names=qm_metric_names) + + waveforms_matched_folder = ( + scratch_folder / f"waveforms_matched_{dataset_name}_{session_name}_{filter_option}" + ) + waveforms_matched_folder.mkdir(exist_ok=True, parents=True) + + if (waveforms_matched_folder / "waveforms").is_dir() and not OVERWRITE: + print("\t\tLoad NO DI waveforms matched") + we_matched = si.load_waveforms(waveforms_matched_folder / "waveforms") + qm_matched = we_matched.load_extension("quality_metrics").get_data() + tm_matched = we_matched.load_extension("template_metrics").get_data() + else: + print("\t\tSelect NO DI waveforms matched") + we_matched = we_all.select_units( + unit_ids=matched_unit_ids, new_folder=waveforms_matched_folder / "waveforms" + ) + qm_matched = we_matched.load_extension("quality_metrics").get_data() + tm_matched = we_matched.load_extension("template_metrics").get_data() + + if (waveforms_matched_folder / "waveforms_di").is_dir() and not OVERWRITE: + print("\t\tLoad DI waveforms matched") + we_matched_di = si.load_waveforms(waveforms_matched_folder / "waveforms_di") + qm_matched_di = we_matched_di.load_extension("quality_metrics").get_data() + tm_matched_di = we_matched_di.load_extension("template_metrics").get_data() + else: + print("\t\tSelect DI waveforms matched") + we_matched_di = we_all_di.select_units( + unit_ids=matched_unit_ids_di, new_folder=waveforms_matched_folder / "waveforms_di" + ) + qm_matched_di = we_matched_di.load_extension("quality_metrics").get_data() + tm_matched_di = we_matched_di.load_extension("template_metrics").get_data() + + ## add entries to unit-level results + if unit_level_results is None: + for metric in qm_all.columns: + unit_level_results_columns.append(metric) + for metric in tm_all.columns: + unit_level_results_columns.append(metric) + unit_level_results = pd.DataFrame(columns=unit_level_results_columns) + + new_rows = { + "dataset": [dataset_name] * len(qm_all), + "session": [session_name] * len(qm_all), + "probe": [probe] * len(qm_all), + "filter_option": [filter_option] * len(qm_all), + "unit_id": we_all.unit_ids, + "deepinterpolated": [False] * len(qm_all), + } + new_rows_di = { + "dataset": [dataset_name] * len(qm_all_di), + "session": [session_name] * len(qm_all_di), + "probe": [probe] * len(qm_all_di), + "filter_option": [filter_option] * len(qm_all_di), + "unit_id": we_all_di.unit_ids, + "deepinterpolated": [True] * len(qm_all_di), + } + for metric in qm_all.columns: + new_rows[metric] = qm_all[metric].values + new_rows_di[metric] = qm_all_di[metric].values + for metric in tm_all.columns: + new_rows[metric] = tm_all[metric].values + new_rows_di[metric] = tm_all_di[metric].values + # append new entries + unit_level_results = pd.concat( + [unit_level_results, pd.DataFrame(new_rows), pd.DataFrame(new_rows_di)], ignore_index=True + ) + + ## add entries to matched unit-level results + if matched_unit_level_results is None: + for metric in qm_matched.columns: + matched_unit_level_results_columns.append(metric) + matched_unit_level_results_columns.append(f"{metric}_di") + for metric in tm_matched.columns: + matched_unit_level_results_columns.append(metric) + matched_unit_level_results_columns.append(f"{metric}_di") + matched_unit_level_results = pd.DataFrame(columns=matched_unit_level_results) + + new_matched_rows = { + "dataset": [dataset_name] * len(qm_matched), + "session": [session_name] * len(qm_matched), + "probe": [probe] * len(qm_matched), + "filter_option": [filter_option] * len(qm_matched), + "unit_id": we_matched.unit_ids, + "unit_id_di": we_matched_di.unit_ids, + } + + agreement_scores = [] + for i in range(len(we_matched.unit_ids)): + agreement_scores.append( + comp.agreement_scores.at[we_matched.unit_ids[i], we_matched_di.unit_ids[i]] + ) + new_matched_rows["agreement_score"] = agreement_scores + for metric in qm_matched.columns: + new_matched_rows[metric] = qm_matched[metric].values + new_matched_rows[f"{metric}_di"] = qm_matched_di[metric].values + for metric in tm_matched.columns: + new_matched_rows[metric] = tm_matched[metric].values + new_matched_rows[f"{metric}_di"] = tm_matched_di[metric].values + # append new entries + matched_unit_level_results = pd.concat( + [matched_unit_level_results, pd.DataFrame(new_matched_rows)], ignore_index=True + ) + + if session_level_results is not None: + session_level_results.to_csv( + results_folder / f"{dataset_name}-{session_name}-sessions.csv", index=False + ) + if unit_level_results is not None: + unit_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-units.csv", index=False) + if matched_unit_level_results is not None: + matched_unit_level_results.to_csv( + results_folder / f"{dataset_name}-{session_name}-matched-units.csv", index=False + ) diff --git a/pipeline/run_spike_sorting_GT.py b/pipeline/run_spike_sorting_GT.py new file mode 100644 index 0000000..6173a79 --- /dev/null +++ b/pipeline/run_spike_sorting_GT.py @@ -0,0 +1,290 @@ +import warnings + +warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +#### IMPORTS ####### +import os +import sys +import json +from pathlib import Path +import numpy as np +import pandas as pd + + +# SpikeInterface +import spikeinterface as si +import spikeinterface.extractors as se +import spikeinterface.sorters as ss +import spikeinterface.curation as scur +import spikeinterface.comparison as sc + + +base_path = Path("..") + +##### DEFINE DATASETS AND FOLDERS ####### +from sessions import all_sessions_sim as all_sessions + +n_jobs = 16 + +job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") + +data_folder = base_path / "data" +scratch_folder = base_path / "scratch" +results_folder = base_path / "results" + +DATASET_FOLDER = data_folder / "MEArec-NP-recordings" + +OVERWRITE = False + +# Define training and testing constants +FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" + + +sorter_name = "pykilosort" +singularity_image = False +match_score = 0.7 + + +if __name__ == "__main__": + if len(sys.argv) == 2: + if sys.argv[1] == "true": + OVERWRITE = True + else: + OVERWRITE = False + + session_dict = all_sessions + filter_options = FILTER_OPTIONS + + json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] + if len(json_files) == 1: + print(f"Found {len(json_files)} JSON config") + session_dict = {} + # each json file contains a session to run + json_file = json_files[0] + with open(json_file, "r") as f: + config = json.load(f) + probe = config["probe"] + if probe not in session_dict: + session_dict[probe] = [] + session = config["session"] + assert ( + session in all_sessions[probe] + ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" + session_dict[probe].append(session) + if "filter_option" in config: + filter_options = [config["filter_option"]] + else: + filter_options = FILTER_OPTIONS + elif len(json_files) > 1: + print("Only 1 JSON config file allowed, using default sessions") + + print(f"Sessions:\n{session_dict}") + print(f"Filter options:\n{filter_options}") + + si.set_global_job_kwargs(**job_kwargs) + + #### START #### + probe_processed_folders = [p for p in data_folder.iterdir() if "processed_" in p.name and p.is_dir()] + + if len(probe_processed_folders) > 0: + processed_folder = data_folder + else: + data_subfolders = [p for p in data_folder.iterdir() if p.is_dir()] + processed_folder = data_subfolders[0] + + for probe, sessions in session_dict.items(): + print(f"Dataset {probe}") + for session in sessions: + print(f"\nAnalyzing session {session}\n") + dataset_name, session_name = session.split("/") + + recording_gt, sorting_gt = se.read_mearec(DATASET_FOLDER / session) + session_name = session_name.split(".")[0] + + session_level_results = None + unit_level_results = None + + for filter_option in filter_options: + print(f"\tFilter option: {filter_option}") + + # load recordings + # save processed json + processed_json_folder = processed_folder / f"processed_{dataset_name}_{session_name}_{filter_option}" + + recording = si.load_extractor(processed_json_folder / "processed.json", base_folder=data_folder) + recording_di = si.load_extractor( + processed_json_folder / "deepinterpolated.json", base_folder=processed_folder + ) + + # DEBUG mode + if recording.get_num_samples() < recording_gt.get_num_samples(): + print("DEBUG MODE: slicing GT") + sorting_gt = sorting_gt.frame_slice(start_frame=0, end_frame=recording.get_num_samples()) + + # run spike sorting + sorting_output_folder = results_folder / f"sorting_{dataset_name}_{session_name}_{filter_option}" + sorting_output_folder.mkdir(parents=True, exist_ok=True) + + if (sorting_output_folder / "sorting").is_dir() and not OVERWRITE: + print("\t\tLoading NO DI sorting") + sorting = si.load_extractor(sorting_output_folder / "sorting") + else: + print(f"\t\tSpike sorting NO DI with {sorter_name}") + try: + sorting = ss.run_sorter( + sorter_name, + recording=recording, + output_folder=scratch_folder / session / filter_option / "no_di", + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting = scur.remove_excess_spikes(sorting, recording) + sorting = sorting.save(folder=sorting_output_folder / "sorting") + except: + print(f"\t\t\t{sorter_name} failed on original") + sorting = None + if (sorting_output_folder / "sorting_di").is_dir() and not OVERWRITE: + print("\t\tLoading DI sorting") + sorting_di = si.load_extractor(sorting_output_folder / "sorting_di") + else: + print(f"\t\tSpike sorting DI with {sorter_name}") + try: + sorting_di = ss.run_sorter( + sorter_name, + recording=recording_di, + output_folder=scratch_folder / session / filter_option / "di", + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting_di = scur.remove_excess_spikes(sorting_di, recording_di) + sorting_di = sorting_di.save(folder=sorting_output_folder / "sorting_di") + except: + print(f"\t\t\t{sorter_name} failed on DI") + sorting_di = None + + # compare to GT + perf_keys = [ + "precision", + "false_discovery_rate", + "miss_rate", + "num_gt", + "num_sorter", + "num_well_detected", + "num_overmerged", + "num_redundant", + "num_false_positivenum_bad", + ] + print("\tRunning comparison") + if sorting is not None: + cmp = sc.compare_sorter_to_ground_truth(sorting_gt, sorting, exhaustive_gt=True) + perf_avg = cmp.get_performance(method="pooled_with_average") + counts = cmp.count_units_categories() + new_data = { + "probe": probe, + "session": session_name, + "num_units": len(sorting.unit_ids), + "filter_option": filter_option, + "deepinterpolated": False, + } + new_data.update(perf_avg.to_dict()) + new_data.update(counts.to_dict()) + else: + new_data = { + "probe": probe, + "session": session_name, + "num_units": np.nan, + "filter_option": filter_option, + "deepinterpolated": False, + } + for perf_key in perf_keys: + new_data[perf_key] = np.nan + + if sorting_di is not None: + cmp_di = sc.compare_sorter_to_ground_truth(sorting_gt, sorting_di, exhaustive_gt=True) + perf_avg_di = cmp_di.get_performance(method="pooled_with_average") + counts_di = cmp_di.count_units_categories() + + new_data_di = { + "probe": probe, + "session": session_name, + "num_units": len(sorting_di.unit_ids), + "filter_option": filter_option, + "deepinterpolated": True, + } + new_data_di.update(perf_avg_di.to_dict()) + new_data_di.update(counts_di.to_dict()) + else: + new_data_di = { + "probe": probe, + "session": session_name, + "num_units": np.nan, + "filter_option": filter_option, + "deepinterpolated": True, + } + for perf_key in perf_keys: + new_data[perf_key] = np.nan + + new_df = pd.DataFrame([new_data]) + new_df_di = pd.DataFrame([new_data_di]) + new_df_session = pd.concat([new_df, new_df_di], ignore_index=True) + + if session_level_results is None: + session_level_results = new_df_session + else: + session_level_results = pd.concat([session_level_results, new_df_session], ignore_index=True) + + # by unit + unit_perf_keys = ["accuracy", "recall", "precision", "false_discovery_rate", "miss_rate"] + if sorting is not None: + perf_by_unit = cmp.get_performance(method="by_unit") + perf_columns = perf_by_unit.columns + perf_by_unit.loc[:, "probe"] = [probe] * len(perf_by_unit) + perf_by_unit.loc[:, "session"] = [session_name] * len(perf_by_unit) + perf_by_unit.loc[:, "filter_option"] = [filter_option] * len(perf_by_unit) + perf_by_unit.loc[:, "deepinterpolated"] = [False] * len(perf_by_unit) + perf_by_unit.loc[:, "unit_id"] = sorting_gt.unit_ids + else: + perf_by_unit = pd.DataFrame({"unit_id": sorting_gt.unit_ids}) + perf_by_unit.loc[:, "probe"] = [probe] * len(perf_by_unit) + perf_by_unit.loc[:, "session"] = [session_name] * len(perf_by_unit) + perf_by_unit.loc[:, "filter_option"] = [filter_option] * len(perf_by_unit) + perf_by_unit.loc[:, "deepinterpolated"] = [False] * len(perf_by_unit) + perf_by_unit.loc[:, "unit_id"] = sorting_gt.unit_ids + for perf_key in unit_perf_keys: + perf_by_unit.loc[:, perf_key] = np.nan + + if sorting_di is not None: + perf_by_unit_di = cmp_di.get_performance(method="by_unit") + perf_by_unit_di.loc[:, "probe"] = [probe] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "session"] = [session_name] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "filter_option"] = [filter_option] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "deepinterpolated"] = [True] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "unit_id"] = sorting_gt.unit_ids + else: + perf_by_unit_di = pd.DataFrame({"unit_id": sorting_gt.unit_ids}) + perf_by_unit_di.loc[:, "probe"] = [probe] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "session"] = [session_name] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "filter_option"] = [filter_option] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "deepinterpolated"] = [True] * len(perf_by_unit_di) + perf_by_unit_di.loc[:, "unit_id"] = sorting_gt.unit_ids + for perf_key in unit_perf_keys: + perf_by_unit_di.loc[:, perf_key] = np.nan + + new_unit_df = pd.concat([perf_by_unit, perf_by_unit_di], ignore_index=True) + + if unit_level_results is None: + unit_level_results = new_unit_df + else: + unit_level_results = pd.concat([unit_level_results, new_unit_df], ignore_index=True) + + sorted_columns = ["probe", "session", "filter_option", "deepinterpolated", "unit_id"] + for col in perf_columns: + sorted_columns.append(col) + unit_level_results = unit_level_results[sorted_columns] + + session_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-sessions.csv", index=False) + unit_level_results.to_csv(results_folder / f"{dataset_name}-{session_name}-units.csv", index=False) diff --git a/pipeline/run_super_training.py b/pipeline/run_super_training.py new file mode 100644 index 0000000..c4e07d7 --- /dev/null +++ b/pipeline/run_super_training.py @@ -0,0 +1,218 @@ +import warnings + +warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +#### IMPORTS ####### +import os +import sys +import shutil +import json +import numpy as np +from pathlib import Path +import pandas as pd +import time + +import matplotlib.pyplot as plt + +# SpikeInterface +import spikeinterface as si +import spikeinterface.extractors as se +import spikeinterface.preprocessing as spre +import spikeinterface.sorters as ss +import spikeinterface.postprocessing as spost +import spikeinterface.comparison as sc +import spikeinterface.qualitymetrics as sqm + +# Tensorflow +import tensorflow as tf + + +base_path = Path("..") + +##### DEFINE DATASETS AND FOLDERS ####### +from sessions import all_sessions_exp as all_sessions + +n_jobs = 16 + +job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") + +data_folder = base_path / "data" +scratch_folder = base_path / "scratch" +results_folder = base_path / "results" + + +# DATASET_FOLDER = "s3://aind-benchmark-data/ephys-compression/aind-np2/" +DATASET_FOLDER = data_folder / "ephys-compression-benchmark" + +DEBUG = False +PROBESET = "NP2" +NUM_DEBUG_SESSIONS = 4 +DEBUG_DURATION = 20 + +##### DEFINE PARAMS ##### +OVERWRITE = False +USE_GPU = True +STEPS_PER_EPOCH = 100 + +FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" + +# DI params +pre_frame = 30 +post_frame = 30 +pre_post_omission = 1 +desired_shape = (192, 2) + +di_kwargs = dict( + pre_frame=pre_frame, + post_frame=post_frame, + pre_post_omission=pre_post_omission, + desired_shape=desired_shape, +) + + +if __name__ == "__main__": + if len(sys.argv) == 3: + if sys.argv[1] == "true": + DEBUG = True + else: + DEBUG = False + PROBESET = sys.argv[2] + + session_dict = all_sessions + + print(session_dict) + + if DEBUG: + TRAINING_START_S = 10 + TRAINING_END_S = None + TRAINING_DURATION_S = 1 + TESTING_START_S = 0 + TESTING_END_S = 10 + TESTING_DURATION_S = 0.05 + OVERWRITE = True + else: + TRAINING_START_S = 10 + TRAINING_END_S = None + TRAINING_DURATION_S = 600 + TESTING_START_S = 0 + TESTING_END_S = 10 + TESTING_DURATION_S = 0.1 + OVERWRITE = False + + si.set_global_job_kwargs(**job_kwargs) + + assert PROBESET in ["NP1", "NP2", "NP1-NP2"] + + probes = PROBESET.split("-") + + print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") + + for filter_option in FILTER_OPTIONS: + print(f"Filter option: {filter_option}") + + for probe in probes: + sessions = all_sessions[probe] + print(f"\tDataset {probe}") + if DEBUG: + sessions_to_use = sessions[:NUM_DEBUG_SESSIONS] + else: + sessions_to_use = sessions + print(f"\tRunning super training with {len(sessions_to_use)} sessions") + + pretrained_model_path = None + for i, session in enumerate(sessions_to_use): + print(f"\t\tSession {session} - Iteration {i}\n") + dataset_name, session_name = session.split("/") + recording_name = f"{dataset_name}_{session_name}_{filter_option}" + + recording = si.load_extractor(DATASET_FOLDER / session) + + if DEBUG: + recording = recording.frame_slice( + start_frame=0, + end_frame=int(DEBUG_DURATION * recording.sampling_frequency), + ) + + # train DI models + training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) + testing_time = np.round(TESTING_END_S - TESTING_START_S, 3) + model_name = f"{filter_option}_t{training_time}s_v{testing_time}s" + + # apply filter and zscore + if filter_option == "hp": + recording_processed = spre.highpass_filter(recording) + elif filter_option == "bp": + recording_processed = spre.bandpass_filter(recording) + else: + recording_processed = recording + recording_zscore = spre.zscore(recording_processed) + + # This speeds things up a lot + recording_zscore_bin = recording_zscore.save( + folder=scratch_folder / f"recording_zscored_{recording_name}" + ) + + # train model + model_folder = results_folder / f"models_{probe}_{filter_option}" / f"iter{i}" + model_folder.parent.mkdir(parents=True, exist_ok=True) + + if TRAINING_END_S is None: + TRAINING_END_S = recording.get_total_duration() + + # Use SI function + t_start_training = time.perf_counter() + if pretrained_model_path is not None: + print(f"\t\tUsing pretrained model: {pretrained_model_path}") + model_path = spre.train_deepinterpolation( + recording_zscore_bin, + model_folder=model_folder, + model_name=model_name, + existing_model_path=pretrained_model_path, + train_start_s=TRAINING_START_S, + train_end_s=TRAINING_END_S, + train_duration_s=TRAINING_DURATION_S, + test_start_s=TESTING_START_S, + test_end_s=TESTING_END_S, + test_duration_s=TESTING_DURATION_S, + verbose=False, + nb_gpus=1, + steps_per_epoch=STEPS_PER_EPOCH, + **di_kwargs, + ) + pretrained_model_path = model_path + t_stop_training = time.perf_counter() + elapsed_time_training = np.round(t_stop_training - t_start_training, 2) + print(f"\t\tElapsed time TRAINING {session}-{filter_option}: {elapsed_time_training}s") + + # aggregate results + print(f"Aggregating results for {probe}-{filter_option}") + final_model_folder = results_folder / f"model_{probe}_{filter_option}" + shutil.copytree(model_folder, final_model_folder) + final_model_name = [p.name for p in final_model_folder.iterdir() if "_model" in p.name][0] + final_model_stem = final_model_name.split("_model")[0] + + # concatenate loss and val loss + loss_accuracies = np.array([]) + val_accuracies = np.array([]) + + for i in range(len(sessions_to_use)): + model_folder = results_folder / f"models_{probe}_{filter_option}" / f"iter{i}" + loss_file = [p for p in model_folder.iterdir() if "_loss.npy" in p.name and "val" not in p.name][0] + val_loss_file = [p for p in model_folder.iterdir() if "val_loss.npy" in p.name][0] + loss = np.load(loss_file) + val_loss = np.load(val_loss_file) + loss_accuracies = np.concatenate((loss_accuracies, loss)) + val_accuracies = np.concatenate((val_accuracies, val_loss)) + np.save(final_model_folder / f"{final_model_stem}_loss.npy", loss_accuracies) + np.save(final_model_folder / f"{final_model_stem}_val_loss.npy", val_accuracies) + + # plot losses + fig, ax = plt.subplots() + ax.plot(loss_accuracies, color="C0", label="loss") + ax.plot(val_accuracies, color="C1", label="val_loss") + ax.set_xlabel("number of epochs") + ax.set_ylabel("training loss") + ax.legend() + fig.savefig(final_model_folder / f"{final_model_stem}_losses.png", dpi=300) diff --git a/pipeline/run_training.py b/pipeline/run_training.py new file mode 100644 index 0000000..451d8a7 --- /dev/null +++ b/pipeline/run_training.py @@ -0,0 +1,217 @@ +import warnings + +warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +#### IMPORTS ####### +import sys +import shutil +import json +import numpy as np +from pathlib import Path +import time + + +# SpikeInterface +import spikeinterface as si +import spikeinterface.extractors as se +import spikeinterface.preprocessing as spre +import spikeinterface.sorters as ss +import spikeinterface.postprocessing as spost +import spikeinterface.comparison as sc +import spikeinterface.qualitymetrics as sqm + +# Tensorflow +import tensorflow as tf + +# runs from "codes" +base_path = Path("..") + +##### DEFINE DATASETS AND FOLDERS ####### +from sessions import all_sessions_exp, all_sessions_sim + +data_folder = base_path / "data" +scratch_folder = base_path / "scratch" +results_folder = base_path / "results" + + +if (data_folder / "ephys-compression-benchmark").is_dir(): + DATASET_FOLDER = data_folder / "ephys-compression-benchmark" + all_sessions = all_sessions_exp + data_type = "exp" +elif (data_folder / "MEArec-NP-recordings").is_dir(): + DATASET_FOLDER = data_folder / "MEArec-NP-recordings" + all_sessions = all_sessions_sim + data_type = "sim" +else: + raise Exception("Could not find dataset folder") + + +n_jobs = 16 +job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") + +DEBUG = False +NUM_DEBUG_SESSIONS = 2 +DEBUG_DURATION = 20 + +##### DEFINE PARAMS ##### +OVERWRITE = False +USE_GPU = True +STEPS_PER_EPOCH = 100 + +# Define training and testing constants +FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" + +# DI params +pre_frame = 30 +post_frame = 30 +pre_post_omission = 1 +desired_shape = (192, 2) + +di_kwargs = dict( + pre_frame=pre_frame, + post_frame=post_frame, + pre_post_omission=pre_post_omission, + desired_shape=desired_shape, +) + + +if __name__ == "__main__": + if len(sys.argv) == 2: + if sys.argv[1] == "true": + DEBUG = True + STEPS_PER_EPOCH = 10 + else: + DEBUG = False + + session_dict = all_sessions + filter_options = FILTER_OPTIONS + + json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")] + if len(json_files) == 1: + print(f"Found {len(json_files)} JSON config") + session_dict = {} + # each json file contains a session to run + json_file = json_files[0] + with open(json_file, "r") as f: + config = json.load(f) + probe = config["probe"] + if probe not in session_dict: + session_dict[probe] = [] + session = config["session"] + assert ( + session in all_sessions[probe] + ), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}" + session_dict[probe].append(session) + if "filter_option" in config: + filter_options = [config["filter_option"]] + else: + filter_options = FILTER_OPTIONS + elif len(json_files) > 1: + print("Only 1 JSON config file allowed, using default sessions") + + print(f"Sessions:\n{session_dict}") + print(f"Filter options:\n{filter_options}") + + if DEBUG: + TRAINING_START_S = 10 + TRAINING_END_S = None + TRAINING_DURATION_S = 1 + TESTING_START_S = 0 + TESTING_END_S = 10 + TESTING_DURATION_S = 0.05 + OVERWRITE = True + else: + TRAINING_START_S = 10 + TRAINING_END_S = None + TRAINING_DURATION_S = 600 + TESTING_START_S = 0 + TESTING_END_S = 10 + TESTING_DURATION_S = 0.1 + OVERWRITE = False + + si.set_global_job_kwargs(**job_kwargs) + + available_gpus = tf.config.list_physical_devices("GPU") + print(f"Tensorflow GPU status: {available_gpus}") + nb_gpus = len(available_gpus) + if nb_gpus > 1: + print("Use 1 GPU only!") + nb_gpus = 1 + + for probe, sessions in session_dict.items(): + print(f"Dataset {probe}") + + for session in sessions: + print(f"\nAnalyzing session {session}\n") + dataset_name, session_name = session.split("/") + + if data_type == "exp": + recording = si.load_extractor(DATASET_FOLDER / session) + else: + recording, _ = se.read_mearec(DATASET_FOLDER / session) + session_name = session_name.split(".")[0] + recording = spre.depth_order(recording) + + if DEBUG: + recording = recording.frame_slice( + start_frame=0, + end_frame=int(DEBUG_DURATION * recording.sampling_frequency), + ) + print(f"\t{recording}") + if TRAINING_END_S is None: + TRAINING_END_S = recording.get_total_duration() + + for filter_option in filter_options: + print(f"\tFilter option: {filter_option}") + recording_name = f"{dataset_name}_{session_name}_{filter_option}" + # train DI models + training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) + testing_time = np.round(TESTING_END_S - TESTING_START_S, 3) + model_name = f"{filter_option}_t{training_time}s_v{testing_time}s" + + # apply filter and zscore + if filter_option == "hp": + recording_processed = spre.highpass_filter(recording) + elif filter_option == "bp": + recording_processed = spre.bandpass_filter(recording) + else: + recording_processed = recording + + if data_type == "sim": + recording_processed = spre.depth_order(recording_processed) + + recording_zscore = spre.zscore(recording_processed) + # This speeds things up a lot + recording_zscore_bin = recording_zscore.save( + folder=scratch_folder / f"recording_zscored_{recording_name}" + ) + + # train model + model_folder = results_folder / f"model_{recording_name}" + model_folder.parent.mkdir(parents=True, exist_ok=True) + # Use SI function + t_start_training = time.perf_counter() + model_path = spre.train_deepinterpolation( + recording_zscore_bin, + model_folder=model_folder, + model_name=model_name, + train_start_s=TRAINING_START_S, + train_end_s=TRAINING_END_S, + train_duration_s=TRAINING_DURATION_S, + test_start_s=TESTING_START_S, + test_end_s=TESTING_END_S, + test_duration_s=TESTING_DURATION_S, + verbose=False, + nb_gpus=nb_gpus, + steps_per_epoch=STEPS_PER_EPOCH, + **di_kwargs, + ) + t_stop_training = time.perf_counter() + elapsed_time_training = np.round(t_stop_training - t_start_training, 2) + print(f"\t\tElapsed time TRAINING {session}-{filter_option}: {elapsed_time_training}s") + + for json_file in json_files: + print(f"Copying JSON file: {json_file.name} to {results_folder}") + shutil.copy(json_file, results_folder) diff --git a/pipeline/sessions.py b/pipeline/sessions.py new file mode 100644 index 0000000..1627a5a --- /dev/null +++ b/pipeline/sessions.py @@ -0,0 +1,78 @@ +from pathlib import Path +import json + +all_sessions_exp = { + "NP1": [ + "aind-np1/625749_2022-08-03_15-15-06_ProbeA", + "aind-np1/634568_2022-08-05_15-59-46_ProbeA", + "aind-np1/634569_2022-08-09_16-14-38_ProbeA", + "aind-np1/634571_2022-08-04_14-27-05_ProbeA", + "ibl-np1/CSHZAD026_2020-09-04_probe00", + "ibl-np1/CSHZAD029_2020-09-09_probe00", + "ibl-np1/SWC054_2020-10-05_probe00", + "ibl-np1/SWC054_2020-10-05_probe01", + ], + "NP2": [ + "aind-np2/595262_2022-02-21_15-18-07_ProbeA", + "aind-np2/602454_2022-03-22_16-30-03_ProbeB", + "aind-np2/612962_2022-04-13_19-18-04_ProbeB", + "aind-np2/612962_2022-04-14_17-17-10_ProbeC", + "aind-np2/618197_2022-06-21_14-08-06_ProbeC", + "aind-np2/618318_2022-04-13_14-59-07_ProbeB", + "aind-np2/618384_2022-04-14_15-11-00_ProbeB", + "aind-np2/621362_2022-07-14_11-19-36_ProbeA", + ], +} + +all_sessions_sim = { + "NP1": [ + "NP1/recording-0.h5", + "NP1/recording-1.h5", + "NP1/recording-2.h5", + "NP1/recording-3.h5", + "NP1/recording-4.h5", + ], + "NP2": [ + "NP2/recording-0.h5", + "NP2/recording-1.h5", + "NP2/recording-2.h5", + "NP2/recording-3.h5", + "NP2/recording-4.h5", + ], +} + +FILTER_OPTIONS = ["hp", "bp"] + + +def generate_job_config_list(output_folder, split_probes=True, split_filters=True, dataset="exp"): + output_folder = Path(output_folder) + output_folder.mkdir(exist_ok=True, parents=True) + + if dataset == "exp": + all_sessions = all_sessions_exp + else: + all_sessions = all_sessions_sim + + i = 0 + for probe, sessions in all_sessions.items(): + if split_probes: + i = 0 + probe_folder = output_folder / probe + probe_folder.mkdir(exist_ok=True) + else: + probe_folder = output_folder + + for session in sessions: + d = dict(session=session, probe=probe) + + if split_filters: + for filter_option in FILTER_OPTIONS: + d["filter_option"] = filter_option + with open(probe_folder / f"job{i}.json", "w") as f: + json.dump(d, f) + i += 1 + else: + with open(probe_folder / f"job{i}.json", "w") as f: + json.dump(d, f) + i += 1 + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4a63580 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +numpy==1.24 +protobuf==3.20.* +tensorflow==2.7.0 +pip install git+https://github.com/AllenInstitute/deepinterpolation.git +# we need a specific branch for SpikeInterface +pip install git+https://github.com/alejoe91/spikeinterface.git@deepinterp#egg=spikeinterface \ No newline at end of file diff --git a/scripts/run_full_analysis.py b/scripts/run_full_analysis.py index ceaab85..37d7ab2 100644 --- a/scripts/run_full_analysis.py +++ b/scripts/run_full_analysis.py @@ -1,9 +1,17 @@ +import warnings + +warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", category=DeprecationWarning) + + #### IMPORTS ####### import os +import sys import numpy as np from pathlib import Path -from numba import cuda +from numba import cuda import pandas as pd +import time # SpikeInterface @@ -15,12 +23,13 @@ import spikeinterface.comparison as sc import spikeinterface.qualitymetrics as sqm -from utils import train_di_model +# Tensorflow +import tensorflow as tf -##### DEFINE DATASETS AND FOLDERS ####### +base_path = Path("../../..") -DATASET_BUCKET = "s3://aind-benchmark-data/ephys-compression/aind-np2/" +##### DEFINE DATASETS AND FOLDERS ####### sessions = [ "595262_2022-02-21_15-18-07_ProbeA", @@ -32,197 +41,378 @@ "618384_2022-04-14_15-11-00_ProbeB", "621362_2022-07-14_11-19-36_ProbeA", ] -sessions = sessions[:1] n_jobs = 16 -data_folder = Path("../data") -results_folder = Path("../results") +job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") + +data_folder = base_path / "data" +scratch_folder = base_path / "scratch" +results_folder = base_path / "results" -DEBUG = True + +# DATASET_BUCKET = "s3://aind-benchmark-data/ephys-compression/aind-np2/" +DATASET_BUCKET = data_folder / "ephys-compression-benchmark" / "aind-np2" + +DEBUG = False +NUM_DEBUG_SESSIONS = 2 DEBUG_DURATION = 20 ##### DEFINE PARAMS ##### OVERWRITE = False +USE_GPU = True FULL_INFERENCE = True # Define training and testing constants (@Jad you can gradually increase this) -if DEBUG: - TRAINING_START_S = 0 - TRAINING_END_S = 0.5 - TESTING_START_S = 10 - TESTING_END_S = 10.05 -else: - TRAINING_START_S = 0 - TRAINING_END_S = 20 - TESTING_START_S = 70 - TESTING_END_S = 70.5 - -FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" + + +FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no" # DI params pre_frame = 30 post_frame = 30 pre_post_omission = 1 desired_shape = (192, 2) -inference_n_jobs = 1 # TODO: Jad - try more jobs -inference_chunk_duration = "50ms" +# play around with these +inference_n_jobs = 4 +inference_chunk_duration = "100ms" +inference_memory_gpu = 2000 #MB di_kwargs = dict( pre_frame=pre_frame, post_frame=post_frame, pre_post_omission=pre_post_omission, desired_shape=desired_shape, - inference_n_jobs=inference_n_jobs, - inference_chunk_duration=inference_chunk_duration ) -sorter_name = "kilosort2_5" +sorter_name = "pykilosort" singularity_image = False match_score = 0.7 -#### START #### -raw_data_folder = data_folder / "raw" -raw_data_folder.mkdir(exist_ok=True) - -session_level_results = pd.DataFrame(columns=['session', 'filter_option', 'di', "num_units", "sorting_path"]) -unit_level_results = pd.DataFrame(columns=['session', 'filter_option', 'di', "unit_index", - "unit_id_no_di", "unit_id_di"]) - -for session in sessions: - print(f"Analyzing session {session}") - - # download dataset - dst_folder = (raw_data_folder / session) - dst_folder.mkdir(exist_ok=True) - - src_folder = f"{DATASET_BUCKET}{session}" - - cmd = f"aws s3 sync {src_folder} {dst_folder}" - # aws command to download - os.system(cmd) +if __name__ == "__main__": + if len(sys.argv) == 3: + if sys.argv[1] == "true": + DEBUG = True + else: + DEBUG = False + if sys.argv[2] != "all": + sessions = [sys.argv[2]] - recording_folder = dst_folder - recording = si.load_extractor(recording_folder) if DEBUG: - recording = recording.frame_slice(start_frame=0, end_frame=int(DEBUG_DURATION * recording.sampling_frequency)) - - results_dict = {} - for filter_option in FILTER_OPTIONS: - print(f"\tFilter option: {filter_option}") - results_dict[filter_option] = {} - # train DI models - print(f"\t\tTraning DI") - training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) - testing_time = np.round(TESTING_END_S - TESTING_START_S, 3) - model_name = f"{filter_option}_t{training_time}s_v{testing_time}s" - recording_no_di, recording_di = train_di_model(recording, session, filter_option, - TRAINING_START_S, TRAINING_END_S, - TESTING_START_S, TESTING_END_S, - data_folder, FULL_INFERENCE, model_name, - di_kwargs, overwrite=OVERWRITE) - results_dict[filter_option]["recording_no_di"] = recording_no_di - results_dict[filter_option]["recording_di"] = recording_di - - # release GPU memory - device = cuda.get_current_device() - device.reset() - - # run spike sorting - sorting_output_folder = data_folder / "sortings" / session - sorting_output_folder.mkdir(parents=True, exist_ok=True) - - recording_no_di = results_dict[filter_option]["recording_no_di"] - if (sorting_output_folder / f"no_di_{model_name}").is_dir() and not OVERWRITE: - print("\t\tLoading NO DI sorting") - sorting_no_di = si.load_extractor(sorting_output_folder / f"no_di_{model_name}") - else: - print(f"\t\tSpike sorting NO DI with {sorter_name}") - sorting_no_di = ss.run_sorter(sorter_name, recording=recording_no_di, - n_jobs=n_jobs, verbose=True, singularity_image=singularity_image) - sorting_no_di = sorting_no_di.save(folder=sorting_output_folder / f"no_di_{model_name}") - results_dict[filter_option]["sorting_no_di"] = sorting_no_di - - recording_di = results_dict[filter_option]["recording_di"] - if (sorting_output_folder / f"di_{model_name}").is_dir() and not OVERWRITE: - print("\t\tLoading DI sorting") - sorting_di = si.load_extractor(sorting_output_folder / f"di_{model_name}") + TRAINING_START_S = 0 + TRAINING_END_S = 0.2 + TESTING_START_S = 10 + TESTING_END_S = 10.05 + if len(sessions) > NUM_DEBUG_SESSIONS: + sessions = sessions[:NUM_DEBUG_SESSIONS] + OVERWRITE = True + else: + TRAINING_START_S = 0 + TRAINING_END_S = 20 + TESTING_START_S = 70 + TESTING_END_S = 70.5 + OVERWRITE = False + + si.set_global_job_kwargs(**job_kwargs) + + print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}") + + #### START #### + session_level_results = pd.DataFrame( + columns=[ + "session", + "probe", + "filter_option", + "num_units", + "num_units_di", + "sorting_path", + "sorting_path_di", + "num_match", + ] + ) + + unit_level_results_columns = [ + "session", + "probe", + "filter_option", + "unit_id", + "unit_id_di", + ] + unit_level_results = None + + sessions = sessions[:2] + + for session in sessions: + print(f"\nAnalyzing session {session}\n") + if str(DATASET_BUCKET).startswith("s3"): + raw_data_folder = scratch_folder / "raw" + raw_data_folder.mkdir(exist_ok=True) + + # download dataset + dst_folder.mkdir(exist_ok=True) + + src_folder = f"{DATASET_BUCKET}{session}" + + cmd = f"aws s3 sync {src_folder} {dst_folder}" + # aws command to download + os.system(cmd) else: - print(f"\t\tSpike sorting DI with {sorter_name}") - sorting_di = ss.run_sorter(sorter_name, recording=recording_di, - n_jobs=n_jobs, verbose=True, singularity_image=singularity_image) - sorting_di = sorting_di.save(folder=sorting_output_folder / f"di_{model_name}") - results_dict[filter_option]["sorting_di"] = sorting_di - - # TODO: Jad - compute waveforms and quality metrics (https://spikeinterface.readthedocs.io/en/latest/how_to/get_started.html) - - ## add entries to session-level results - session_level_results.append({"session": session, "filter_option": filter_option, - "di": False, "num_units": len(sorting_no_di.unit_ids), - "sorting_path": str((sorting_output_folder / f"no_di_{model_name}").absolute())}, - ignore_index=True) - session_level_results.append({"session": session, "filter_option": filter_option, - "di": True, "num_units": len(sorting_di.unit_ids), - "sorting_path": str((sorting_output_folder / f"di_{model_name}").absolute())}, - ignore_index=True) - - # compare outputs - print("\t\tComparing sortings") - cmp = sc.compare_two_sorters(sorting1=sorting_no_di, sorting2=sorting_di, - sorting1_name="no_di", sorting2_name="di", - match_score=match_score) - matched_units = cmp.get_matching()[0] - matched_unit_ids_no_di = matched_units.index.values.astype(int) - matched_unit_ids_di = matched_units.values.astype(int) - matched_units_valid = matched_unit_ids_di != -1 - matched_unit_ids_no_di = matched_unit_ids_no_di[matched_units_valid] - matched_unit_ids_di = matched_unit_ids_di[matched_units_valid] - sorting_no_di_matched = sorting_no_di.select_units(unit_ids=matched_unit_ids_no_di) - sorting_di_matched = sorting_di.select_units(unit_ids=matched_unit_ids_di) - - waveforms_folder = data_folder / "waveforms" / session - waveforms_folder.mkdir(exist_ok=True, parents=True) - - if (waveforms_folder / f"no_di_{model_name}").is_dir() and not OVERWRITE: - print("\t\tLoad NO DI waveforms") - we_no_di = si.load_waveforms(waveforms_folder / f"no_di_{model_name}") - else: - print("\t\tCompute NO DI waveforms") - we_no_di = si.extract_waveforms(recording_no_di, sorting_no_di_matched, - folder=waveforms_folder / f"no_di_{model_name}", - n_jobs=n_jobs, overwrite=True) - results_dict[filter_option]["we_no_di"] = we_no_di - - if (waveforms_folder / f"di_{model_name}").is_dir() and not OVERWRITE: - print("\t\tLoad DI waveforms") - we_di = si.load_waveforms(waveforms_folder / f"di_{model_name}") - else: - print("\t\tCompute DI waveforms") - we_di = si.extract_waveforms(recording_di, sorting_di_matched, - folder=waveforms_folder / f"di_{model_name}", - n_jobs=n_jobs, overwrite=True) - results_dict[filter_option]["we_di"] = we_di - - # compute metrics - if we_no_di.is_extension("quality_metrics") and not OVERWRITE: - print("\t\tLoad NO DI metrics") - qm_no_di = we_no_di.load_extension("quality_metrics").get_data() - else: - print("\t\tCompute NO DI metrics") - qm_no_di = sqm.compute_quality_metrics(we_no_di) - results_dict[filter_option]["qm_no_di"] = qm_no_di + raw_data_folder = DATASET_BUCKET + dst_folder = raw_data_folder / session - if we_di.is_extension("quality_metrics") and not OVERWRITE: - print("\t\tLoad DI metrics") - qm_di = we_di.load_extension("quality_metrics").get_data() + if "np1" in dst_folder.name: + probe = "NP1" else: - print("\t\tCompute DI metrics") - qm_di = sqm.compute_quality_metrics(we_di) - results_dict[filter_option]["qm_di"] = qm_di - - ## add entries to unit-level results - -results_folder.mkdir(exist_ok=True) -session_level_results.to_csv(results_folder / "session-results.csv") -unit_level_results.to_csv(results_folder / "unit-results.csv") + probe = "NP2" + + recording_folder = dst_folder + recording = si.load_extractor(recording_folder) + if DEBUG: + recording = recording.frame_slice( + start_frame=0, + end_frame=int(DEBUG_DURATION * recording.sampling_frequency), + ) + + results_dict = {} + for filter_option in FILTER_OPTIONS: + print(f"\tFilter option: {filter_option}") + results_dict[filter_option] = {} + # train DI models + print(f"\t\tTraning DI") + training_time = np.round(TRAINING_END_S - TRAINING_START_S, 3) + testing_time = np.round(TESTING_END_S - TESTING_START_S, 3) + model_name = f"{filter_option}_t{training_time}s_v{testing_time}s" + + # apply filter and zscore + if filter_option == "hp": + recording_processed = spre.highpass_filter(recording) + elif filter_option == "bp": + recording_processed = spre.bandpass_filter(recording) + else: + recording_processed = recording + recording_zscore = spre.zscore(recording_processed) + + # train model + model_folder = results_folder / "models" / session / filter_option + model_folder.parent.mkdir(parents=True, exist_ok=True) + # Use SI function + t_start_training = time.perf_counter() + model_path = spre.train_deepinterpolation( + recording_zscore, + model_folder=model_folder, + model_name=model_name, + train_start_s=TRAINING_START_S, + train_end_s=TRAINING_END_S, + test_start_s=TESTING_START_S, + test_end_s=TESTING_END_S, + **di_kwargs, + ) + t_stop_training = time.perf_counter() + elapsed_time_training = np.round(t_stop_training - t_start_training, 2) + print(f"\t\tElapsed time TRAINING: {elapsed_time_training}s") + # full inference + output_folder = ( + results_folder / "deepinterpolated" / session / filter_option + ) + if OVERWRITE and output_folder.is_dir(): + shutil.rmtree(output_folder) + + if not output_folder.is_dir(): + t_start_inference = time.perf_counter() + output_folder.parent.mkdir(exist_ok=True, parents=True) + recording_di = spre.deepinterpolate( + recording_zscore, + model_path=model_path, + pre_frame=pre_frame, + post_frame=post_frame, + pre_post_omission=pre_post_omission, + memory_gpu=inference_memory_gpu, + use_gpu=USE_GPU + ) + recording_di = recording_di.save( + folder=output_folder, + n_jobs=inference_n_jobs, + chunk_duration=inference_chunk_duration, + ) + t_stop_inference = time.perf_counter() + elapsed_time_inference = np.round( + t_stop_inference - t_start_inference, 2 + ) + print(f"\t\tElapsed time INFERENCE: {elapsed_time_inference}s") + else: + print("\t\tLoading existing folder") + recording_di = si.load_extractor(output_folder) + # apply inverse z-scoring + inverse_gains = 1 / recording_zscore.gain + inverse_offset = -recording_zscore.offset * inverse_gains + recording_di_inverse_zscore = spre.scale( + recording_di, gain=inverse_gains, offset=inverse_offset, dtype="float" + ) + + results_dict[filter_option]["recording_no_di"] = recording_processed + results_dict[filter_option]["recording_di"] = recording_di_inverse_zscore + + # run spike sorting + sorting_output_folder = ( + results_folder / "sortings" / session / filter_option + ) + sorting_output_folder.mkdir(parents=True, exist_ok=True) + + recording_no_di = results_dict[filter_option]["recording_no_di"] + if ( + sorting_output_folder / f"no_di_{model_name}" + ).is_dir() and not OVERWRITE: + print("\t\tLoading NO DI sorting") + sorting_no_di = si.load_extractor(sorting_output_folder / "sorting") + else: + print(f"\t\tSpike sorting NO DI with {sorter_name}") + sorting_no_di = ss.run_sorter( + sorter_name, + recording=recording_no_di, + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting_no_di = sorting_no_di.save( + folder=sorting_output_folder / "sorting" + ) + results_dict[filter_option]["sorting_no_di"] = sorting_no_di + + recording_di = results_dict[filter_option]["recording_di"] + if (sorting_output_folder / f"di_{model_name}").is_dir() and not OVERWRITE: + print("\t\tLoading DI sorting") + sorting_di = si.load_extractor(sorting_output_folder / "sorting_di") + else: + print(f"\t\tSpike sorting DI with {sorter_name}") + sorting_di = ss.run_sorter( + sorter_name, + recording=recording_di, + n_jobs=n_jobs, + verbose=True, + singularity_image=singularity_image, + ) + sorting_di = sorting_di.save( + folder=sorting_output_folder / "sorting_di" + ) + results_dict[filter_option]["sorting_di"] = sorting_di + + # compare outputs + print("\t\tComparing sortings") + comp = sc.compare_two_sorters( + sorting1=sorting_no_di, + sorting2=sorting_di, + sorting1_name="no_di", + sorting2_name="di", + match_score=match_score, + ) + matched_units = comp.get_matching()[0] + matched_unit_ids_no_di = matched_units.index.values.astype(int) + matched_unit_ids_di = matched_units.values.astype(int) + matched_units_valid = matched_unit_ids_di != -1 + matched_unit_ids_no_di = matched_unit_ids_no_di[matched_units_valid] + matched_unit_ids_di = matched_unit_ids_di[matched_units_valid] + sorting_no_di_matched = sorting_no_di.select_units( + unit_ids=matched_unit_ids_no_di + ) + sorting_di_matched = sorting_di.select_units(unit_ids=matched_unit_ids_di) + + ## add entries to session-level results + new_row = { + "session": session, + "filter_option": filter_option, + "probe": probe, + "num_units": len(sorting_no_di.unit_ids), + "num_units_di": len(sorting_di.unit_ids), + "num_match": len(sorting_no_di_matched.unit_ids), + "sorting_path": str( + (sorting_output_folder / "sorting").relative_to(results_folder) + ), + "sorting_path_di": str( + (sorting_output_folder / "sorting_di_{model_name}").relative_to( + results_folder + ) + ), + } + session_level_results = pd.concat( + [session_level_results, pd.DataFrame([new_row])], ignore_index=True + ) + + print( + f"\n\t\tNum units: {new_row['num_units']} - Num units DI: {new_row['num_units_di']} - Num match: {new_row['num_match']}" + ) + + # waveforms + waveforms_folder = results_folder / "waveforms" / session / filter_option + waveforms_folder.mkdir(exist_ok=True, parents=True) + + if (waveforms_folder / f"no_di_{model_name}").is_dir() and not OVERWRITE: + print("\t\tLoad NO DI waveforms") + we_no_di = si.load_waveforms(waveforms_folder / f"no_di_{model_name}") + else: + print("\t\tCompute NO DI waveforms") + we_no_di = si.extract_waveforms( + recording_no_di, + sorting_no_di_matched, + folder=waveforms_folder / f"no_di_{model_name}", + n_jobs=n_jobs, + overwrite=True, + ) + results_dict[filter_option]["we_no_di"] = we_no_di + + if (waveforms_folder / f"di_{model_name}").is_dir() and not OVERWRITE: + print("\t\tLoad DI waveforms") + we_di = si.load_waveforms(waveforms_folder / f"di_{model_name}") + else: + print("\t\tCompute DI waveforms") + we_di = si.extract_waveforms( + recording_di, + sorting_di_matched, + folder=waveforms_folder / f"di_{model_name}", + n_jobs=n_jobs, + overwrite=True, + ) + results_dict[filter_option]["we_di"] = we_di + + # compute metrics + if we_no_di.is_extension("quality_metrics") and not OVERWRITE: + print("\t\tLoad NO DI metrics") + qm_no_di = we_no_di.load_extension("quality_metrics").get_data() + else: + print("\t\tCompute NO DI metrics") + qm_no_di = sqm.compute_quality_metrics(we_no_di) + results_dict[filter_option]["qm_no_di"] = qm_no_di + + if we_di.is_extension("quality_metrics") and not OVERWRITE: + print("\t\tLoad DI metrics") + qm_di = we_di.load_extension("quality_metrics").get_data() + else: + print("\t\tCompute DI metrics") + qm_di = sqm.compute_quality_metrics(we_di) + results_dict[filter_option]["qm_di"] = qm_di + + ## add entries to unit-level results + if unit_level_results is None: + for metric in qm_no_di.columns: + unit_level_results_columns.append(metric) + unit_level_results_columns.append(f"{metric}_di") + unit_level_results = pd.DataFrame(columns=unit_level_results_columns) + + new_rows = { + "session": [session] * len(qm_no_di), + "probe": [probe] * len(qm_no_di), + "filter_option": [filter_option] * len(qm_no_di), + "unit_id": we_no_di.unit_ids, + "unit_id_di": we_di.unit_ids, + } + for metric in qm_no_di.columns: + new_rows[metric] = qm_no_di[metric].values + new_rows[f"{metric}_di"] = qm_di[metric].values + # append new entries + unit_level_results = pd.concat( + [unit_level_results, pd.DataFrame(new_rows)], ignore_index=True + ) + + results_folder.mkdir(exist_ok=True) + session_level_results.to_csv(results_folder / "session-results.csv") + unit_level_results.to_csv(results_folder / "unit-results.csv") diff --git a/scripts/utils.py b/scripts/utils.py deleted file mode 100644 index 4e97a04..0000000 --- a/scripts/utils.py +++ /dev/null @@ -1,181 +0,0 @@ - -import json -import sys -import shutil - -import spikeinterface as si -import spikeinterface.preprocessing as spre - -# DeepInterpolation -from deepinterpolation.trainor_collection import core_trainer -from deepinterpolation.network_collection import unet_single_ephys_1024 -from deepinterpolation.generic import ClassLoader - -# Import local classes for DI+SI -sys.path.append("../src") - -# the generator is in the "spikeinterface_generator.py" -from spikeinterface_generator import SpikeInterfaceGenerator -from deepinterpolation_recording import DeepInterpolatedRecording - - - - -def train_di_model(recording, session, filter_option, train_start_s, train_end_s, - test_start_s, test_end_s, data_folder, full_inference, model_name, di_kwargs, - overwrite=False, use_gpu=True): - """_summary_ - - Parameters - ---------- - recording : _type_ - _description_ - session : _type_ - _description_ - filter_option : _type_ - _description_ - train_start_s : _type_ - _description_ - train_end_s : _type_ - _description_ - test_start_s : _type_ - _description_ - test_end_s : _type_ - _description_ - data_folder : _type_ - _description_ - full_inference : _type_ - _description_ - di_kwargs : _type_ - _description_ - """ - model_folder = data_folder / "models" / session - model_folder.mkdir(exist_ok=True, parents=True) - trained_model_folder = model_folder / model_name - - pre_frame = di_kwargs["pre_frame"] - post_frame = di_kwargs["post_frame"] - pre_post_omission = di_kwargs["pre_post_omission"] - desired_shape = di_kwargs["desired_shape"] - inference_n_jobs = di_kwargs["inference_n_jobs"] - inference_chunk_duration = di_kwargs["inference_chunk_duration"] - - # pre-process - assert filter_option in ("no", "hp", "bp"), "Wrong filter option!" - if filter_option == "hp": - rec_f = spre.highpass_filter(recording) - elif filter_option == "bp": - rec_f = spre.bandpass_filter(recording) - else: - rec_f = recording - - rec_processed = rec_f - rec_norm = spre.zscore(rec_processed) - - ### Define params - start_frame_training = int(train_start_s * rec_norm.sampling_frequency) - end_frame_training = int(train_end_s * rec_norm.sampling_frequency) - start_frame_test = int(test_start_s * rec_norm.sampling_frequency) - end_frame_test = int(test_end_s * rec_norm.sampling_frequency) - - # Those are parameters used for the network topology - network_params = dict() - network_params["type"] = "network" - # Name of network topology in the collection - network_params["name"] = "unet_single_ephys_1024" - training_params = dict() - training_params["output_dir"] = str(trained_model_folder) - # We pass on the uid - training_params["run_uid"] = "first_test" - - # We convert to old schema - training_params["nb_gpus"] = 1 - training_params["type"] = "trainer" - training_params["steps_per_epoch"] = 10 - training_params["period_save"] = 100 - training_params["apply_learning_decay"] = 0 - training_params["nb_times_through_data"] = 1 - training_params["learning_rate"] = 0.0001 - training_params["pre_post_frame"] = 1 - training_params["loss"] = "mean_absolute_error" - training_params["nb_workers"] = 2 - training_params["caching_validation"] = False - training_params["model_string"] = f"{network_params['name']}_{training_params['loss']}" - - - if not trained_model_folder.is_dir() or overwrite: - trained_model_folder.mkdir(exist_ok=True) - - # Training (from core_trainor class) - training_data_generator = SpikeInterfaceGenerator(rec_norm, zscore=False, - pre_frame=pre_frame, post_frame=post_frame, - pre_post_omission=pre_post_omission, - start_frame=start_frame_training, - end_frame=end_frame_training, - desired_shape=desired_shape) - test_data_generator = SpikeInterfaceGenerator(rec_norm, zscore=False, - pre_frame=pre_frame, post_frame=post_frame, - pre_post_omission=pre_post_omission, - start_frame=start_frame_test, - end_frame=end_frame_test, - steps_per_epoch=-1, - desired_shape=desired_shape) - - - network_json_path = trained_model_folder / "network_params.json" - with open(network_json_path, "w") as f: - json.dump(network_params, f) - - network_obj = ClassLoader(network_json_path) - data_network = network_obj.find_and_build()(network_json_path) - - training_json_path = trained_model_folder / "training_params.json" - with open(training_json_path, "w") as f: - json.dump(training_params, f) - - - training_class = core_trainer( - training_data_generator, test_data_generator, data_network, - training_json_path - ) - - print("created objects for training") - training_class.run() - - print("training job finished - finalizing output model") - training_class.finalize() - else: - print("Loading pre-trained model") - - ### Test inference - - # Re-load model from output folder - model_path = trained_model_folder / f"{training_params['run_uid']}_{training_params['model_string']}_model.h5" - - rec_di = DeepInterpolatedRecording(rec_norm, model_path=model_path, pre_frames=pre_frame, - post_frames=post_frame, pre_post_omission=pre_post_omission, - disable_tf_logger=True, - use_gpu=use_gpu) - - if full_inference: - deepinterpolated_folder = data_folder / "deepinterpolated" / session - deepinterpolated_folder.mkdir(exist_ok=True, parents=True) - output_folder = deepinterpolated_folder / model_name - if output_folder.is_dir() and overwrite: - shutil.rmtree(output_folder) - rec_di = DeepInterpolatedRecording(rec_norm, model_path=model_path, pre_frames=pre_frame, - post_frames=post_frame, pre_post_omission=pre_post_omission, - disable_tf_logger=True, - use_gpu=use_gpu) - rec_di_saved = rec_di.save(folder=output_folder, n_jobs=inference_n_jobs, - chunk_duration=inference_chunk_duration) - else: - rec_di_saved = si.load_extractor(output_folder) - else: - rec_di_saved = rec_di - - # apply inverse z-scoring - inverse_gains = 1 / rec_norm.gain - inverse_offset = - rec_norm.offset * inverse_gains - rec_di_inverse_zscore = spre.scale(rec_di_saved, gain=inverse_gains, offset=inverse_offset, dtype='float') - return rec_processed, rec_di_inverse_zscore diff --git a/src/deepinterpolation_recording.py b/src/deepinterpolation_recording.py index 3d85158..8657bea 100644 --- a/src/deepinterpolation_recording.py +++ b/src/deepinterpolation_recording.py @@ -21,19 +21,20 @@ def has_tf(use_gpu=True, disable_tf_logger=True, memory_gpu=None): return True except ImportError: return False - + + def import_tf(use_gpu=True, disable_tf_logger=True, memory_gpu=None): import tensorflow as tf if not use_gpu: - os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" if disable_tf_logger: - os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' - tf.get_logger().setLevel('ERROR') + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + tf.get_logger().setLevel("ERROR") tf.compat.v1.disable_eager_execution() - gpus = tf.config.list_physical_devices('GPU') + gpus = tf.config.list_physical_devices("GPU") if gpus: if memory_gpu is None: try: @@ -47,24 +48,32 @@ def import_tf(use_gpu=True, disable_tf_logger=True, memory_gpu=None): else: for gpu in gpus: tf.config.set_logical_device_configuration( - gpus[0], - [tf.config.LogicalDeviceConfiguration(memory_limit=memory_gpu)]) + gpus[0], [tf.config.LogicalDeviceConfiguration(memory_limit=memory_gpu)] + ) return tf - class DeepInterpolatedRecording(BasePreprocessor): - name = 'deepinterpolate' - - def __init__(self, recording, model_path: str, - pre_frames: int = 30, post_frames: int = 30, pre_post_omission: int = 1, - batch_size=128, use_gpu: bool = True, disable_tf_logger: bool = True, - memory_gpu=None): + name = "deepinterpolate" + + def __init__( + self, + recording, + model_path: str, + pre_frames: int = 30, + post_frames: int = 30, + pre_post_omission: int = 1, + batch_size=128, + use_gpu: bool = True, + disable_tf_logger: bool = True, + memory_gpu=None, + ): assert has_tf( - use_gpu, disable_tf_logger, memory_gpu), "To use DeepInterpolation, you first need to install `tensorflow`." - + use_gpu, disable_tf_logger, memory_gpu + ), "To use DeepInterpolation, you first need to install `tensorflow`." + self.tf = import_tf(use_gpu, disable_tf_logger, memory_gpu=memory_gpu) - + # try move model load here with spawn BasePreprocessor.__init__(self, recording) @@ -75,46 +84,71 @@ def __init__(self, recording, model_path: str, # check shape (this will need to be done at inference) network_input_shape = model.get_config()["layers"][0]["config"]["batch_input_shape"] desired_shape = network_input_shape[1:3] - assert desired_shape[0]*desired_shape[1] == recording.get_num_channels(), "text" + assert desired_shape[0] * desired_shape[1] == recording.get_num_channels(), "text" assert network_input_shape[-1] == pre_frames + post_frames - self.model = model # add segment for segment in recording._recording_segments: - recording_segment = DeepInterpolatedRecordingSegment(segment, self.model, - pre_frames, post_frames, pre_post_omission, - desired_shape, batch_size, use_gpu, - disable_tf_logger, memory_gpu) + recording_segment = DeepInterpolatedRecordingSegment( + segment, + self.model, + pre_frames, + post_frames, + pre_post_omission, + desired_shape, + batch_size, + use_gpu, + disable_tf_logger, + memory_gpu, + ) self.add_recording_segment(recording_segment) self._preferred_mp_context = "spawn" - self._kwargs = dict(recording=recording.to_dict(), model_path=str(model_path), - pre_frames=pre_frames, post_frames=post_frames, pre_post_omission=pre_post_omission, - batch_size=batch_size, use_gpu=use_gpu, disable_tf_logger=disable_tf_logger, - memory_gpu=memory_gpu) - self.extra_requirements.extend(['tensorflow']) + self._kwargs = dict( + recording=recording.to_dict(), + model_path=str(model_path), + pre_frames=pre_frames, + post_frames=post_frames, + pre_post_omission=pre_post_omission, + batch_size=batch_size, + use_gpu=use_gpu, + disable_tf_logger=disable_tf_logger, + memory_gpu=memory_gpu, + ) + self.extra_requirements.extend(["tensorflow"]) class DeepInterpolatedRecordingSegment(BasePreprocessorSegment): - - def __init__(self, recording_segment, model, - pre_frames, post_frames, pre_post_omission, - desired_shape, batch_size, use_gpu, disable_tf_logger, memory_gpu): + def __init__( + self, + recording_segment, + model, + pre_frames, + post_frames, + pre_post_omission, + desired_shape, + batch_size, + use_gpu, + disable_tf_logger, + memory_gpu, + ): from spikeinterface_generator import SpikeInterfaceRecordingSegmentGenerator BasePreprocessorSegment.__init__(self, recording_segment) - + self.model = model self.pre_frames = pre_frames self.post_frames = post_frames self.pre_post_omission = pre_post_omission self.batch_size = batch_size self.use_gpu = use_gpu - self.desired_shape=desired_shape + self.desired_shape = desired_shape # creating class dynamically to use the imported TF with GPU enabled/disabled based on the use_gpu flag - self.SpikeInterfaceGenerator = SpikeInterfaceRecordingSegmentGenerator #define_input_generator_class( use_gpu, disable_tf_logger) + self.SpikeInterfaceGenerator = ( + SpikeInterfaceRecordingSegmentGenerator # define_input_generator_class( use_gpu, disable_tf_logger) + ) def get_traces(self, start_frame, end_frame, channel_indices): n_frames = self.parent_recording_segment.get_num_samples() @@ -127,42 +161,44 @@ def get_traces(self, start_frame, end_frame, channel_indices): # for frames that lack full training data (i.e. pre and post frames including omissinos), # just return uninterpolated - if start_frame < self.pre_frames+self.pre_post_omission: - true_start_frame = self.pre_frames+self.pre_post_omission - array_to_append_front = self.parent_recording_segment.get_traces(start_frame=0, - end_frame=true_start_frame, - channel_indices=channel_indices) + if start_frame < self.pre_frames + self.pre_post_omission: + true_start_frame = self.pre_frames + self.pre_post_omission + array_to_append_front = self.parent_recording_segment.get_traces( + start_frame=0, end_frame=true_start_frame, channel_indices=channel_indices + ) else: true_start_frame = start_frame - if end_frame > n_frames-self.post_frames-self.pre_post_omission: - true_end_frame = n_frames-self.post_frames-self.pre_post_omission - array_to_append_back = self.parent_recording_segment.get_traces(start_frame=true_end_frame, - end_frame=n_frames, - channel_indices=channel_indices) + if end_frame > n_frames - self.post_frames - self.pre_post_omission: + true_end_frame = n_frames - self.post_frames - self.pre_post_omission + array_to_append_back = self.parent_recording_segment.get_traces( + start_frame=true_end_frame, end_frame=n_frames, channel_indices=channel_indices + ) else: true_end_frame = end_frame # instantiate an input generator that can be passed directly to model.predict - input_generator = self.SpikeInterfaceGenerator(recording_segment=self.parent_recording_segment, - start_frame=true_start_frame, - end_frame=true_end_frame, - pre_frame=self.pre_frames, - post_frame=self.post_frames, - pre_post_omission=self.pre_post_omission, - batch_size=self.batch_size) + input_generator = self.SpikeInterfaceGenerator( + recording_segment=self.parent_recording_segment, + start_frame=true_start_frame, + end_frame=true_end_frame, + pre_frame=self.pre_frames, + post_frame=self.post_frames, + pre_post_omission=self.pre_post_omission, + batch_size=self.batch_size, + ) input_generator.randomize = False input_generator._calculate_list_samples(input_generator.total_samples) di_output = self.model.predict(input_generator, verbose=2) out_traces = input_generator.reshape_output(di_output) - if true_start_frame != start_frame: # related to the restriction to be applied from the start and end frames around 0 and end - out_traces = np.concatenate( - (array_to_append_front, out_traces), axis=0) + if ( + true_start_frame != start_frame + ): # related to the restriction to be applied from the start and end frames around 0 and end + out_traces = np.concatenate((array_to_append_front, out_traces), axis=0) if true_end_frame != end_frame: - out_traces = np.concatenate( - (out_traces, array_to_append_back), axis=0) + out_traces = np.concatenate((out_traces, array_to_append_back), axis=0) return out_traces[:, channel_indices] diff --git a/src/spikeinterface_generator.py b/src/spikeinterface_generator.py index 48be09b..126e2e8 100644 --- a/src/spikeinterface_generator.py +++ b/src/spikeinterface_generator.py @@ -7,16 +7,28 @@ from deepinterpolation.generator_collection import SequentialGenerator + # TODO: rename to SpikeInterfaceRecordingGenerator class SpikeInterfaceGenerator(SequentialGenerator): """This generator is used when dealing with a SpikeInterface recording. The desired shape controls the reshaping of the input data before convolutions.""" - def __init__(self, recording, pre_frame=30, post_frame=30, pre_post_omission=1, desired_shape=(192, 2), - batch_size=100, steps_per_epoch=10, zscore=True, start_frame=None, end_frame=None): + def __init__( + self, + recording, + pre_frame=30, + post_frame=30, + pre_post_omission=1, + desired_shape=(192, 2), + batch_size=100, + steps_per_epoch=10, + zscore=True, + start_frame=None, + end_frame=None, + ): "Initialization" assert recording.get_num_segments() == 1, "Only supported for mon-segment recordings" - + if zscore: recording_z = spre.zscore(recording) else: @@ -25,15 +37,16 @@ def __init__(self, recording, pre_frame=30, post_frame=30, pre_post_omission=1, self.recording = recording_z self.total_samples = recording.get_num_samples() assert len(desired_shape) == 2, "desired_shape should be 2D" - assert desired_shape[0] * desired_shape [1] == recording.get_num_channels(), \ - f"The product of desired_shape dimensions should be the number of channels: {recording.get_num_channels()}" + assert ( + desired_shape[0] * desired_shape[1] == recording.get_num_channels() + ), f"The product of desired_shape dimensions should be the number of channels: {recording.get_num_channels()}" self.desired_shape = desired_shape - + start_frame = start_frame if start_frame is not None else 0 end_frame = end_frame if end_frame is not None else self.total_samples - + assert end_frame > start_frame, "end_frame must be greater than start_frame" - + sequential_generator_params = dict() sequential_generator_params["steps_per_epoch"] = steps_per_epoch sequential_generator_params["pre_frame"] = pre_frame @@ -53,44 +66,39 @@ def __init__(self, recording, pre_frame=30, post_frame=30, pre_post_omission=1, self._calculate_list_samples(self.total_samples) self.last_batch_size = np.mod(self.end_frame - self.start_frame, self.batch_size) - def __len__(self): "Denotes the total number of batches" - return int(len(self.list_samples) / self.batch_size) + 1 + return int(np.floor(len(self.list_samples) / self.batch_size) + 1) def generate_batch_indexes(self, index): # This is to ensure we are going through # the entire data when steps_per_epoch 0: index = index + self.steps_per_epoch * self.epoch_index # Generate indexes of the batch - indexes = np.arange(index * self.batch_size, - (index + 1) * self.batch_size) - + indexes = np.arange(index * self.batch_size, (index + 1) * self.batch_size) + if max(indexes) > len(self.list_samples): + print(index, min(indexes), max(indexes), len(self.list_samples)) shuffle_indexes = self.list_samples[indexes] elif index == len(self) - 1: - shuffle_indexes = self.list_samples[-self.last_batch_size:] + shuffle_indexes = self.list_samples[-self.last_batch_size :] else: raise Exception(f"Exceeding number of available batches: {len(self)}") return shuffle_indexes - def __getitem__(self, index): # This is to ensure we are going through # the entire data when steps_per_epoch start_frame, "end_frame must be greater than start_frame" - + sequential_generator_params = dict() sequential_generator_params["steps_per_epoch"] = steps_per_epoch sequential_generator_params["pre_frame"] = pre_frame @@ -177,7 +192,6 @@ def __init__(self, recording_segment, start_frame, end_frame, pre_frame=30, post self._calculate_list_samples(self.total_samples) self.last_batch_size = np.mod(self.end_frame - self.start_frame, self.batch_size) - def __len__(self): "Denotes the total number of batches" return int(len(self.list_samples) / self.batch_size) + 1 @@ -185,36 +199,31 @@ def __len__(self): def generate_batch_indexes(self, index): # This is to ensure we are going through # the entire data when steps_per_epoch 0: index = index + self.steps_per_epoch * self.epoch_index # Generate indexes of the batch - indexes = np.arange(index * self.batch_size, - (index + 1) * self.batch_size) + indexes = np.arange(index * self.batch_size, (index + 1) * self.batch_size) shuffle_indexes = self.list_samples[indexes] elif index == len(self) - 1: - shuffle_indexes = self.list_samples[-self.last_batch_size:] + shuffle_indexes = self.list_samples[-self.last_batch_size :] else: raise Exception(f"Exceeding number of available batches: {len(self)}") return shuffle_indexes - def __getitem__(self, index): # This is to ensure we are going through # the entire data when steps_per_epoch